diff --git a/index.js b/index.js index 3fc14dd5..57d0bb2a 100644 --- a/index.js +++ b/index.js @@ -145,6 +145,11 @@ class Replicate { async run(ref, options, progress) { const { wait = { mode: "block" }, signal, ...data } = options; + // Honour an already-aborted signal before any network work happens. + if (signal && signal.aborted) { + signal.throwIfAborted(); + } + const identifier = ModelVersionIdentifier.parse(ref); let prediction; @@ -153,12 +158,14 @@ class Replicate { ...data, version: identifier.version, wait: wait.mode === "block" ? wait.timeout ?? true : false, + signal, }); } else if (identifier.owner && identifier.name) { prediction = await this.predictions.create({ ...data, model: `${identifier.owner}/${identifier.name}`, wait: wait.mode === "block" ? wait.timeout ?? true : false, + signal, }); } else { throw new Error("Invalid model version identifier"); @@ -191,7 +198,16 @@ class Replicate { } if (signal && signal.aborted) { - prediction = await this.predictions.cancel(prediction.id); + // Best-effort cancel on Replicate's side so we don't keep billing the + // user for compute they no longer want, then surface the abort to the + // caller. Without the throw, the awaited promise would resolve with a + // half-cancelled prediction, which is silent on the consumer side. + try { + prediction = await this.predictions.cancel(prediction.id); + } catch { + // Ignore cancel failures — the abort is the higher-priority signal. + } + signal.throwIfAborted(); } // Call progress callback with the completed prediction object diff --git a/index.test.ts b/index.test.ts index 96f50db7..b5ce274a 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1598,20 +1598,28 @@ describe("Replicate client", () => { }); const onProgress = jest.fn(); - const output = await client.run( - "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - { - wait: { mode: "poll" }, - input: { text: "Hello, world!" }, - signal, - }, - onProgress - ); + let caught: unknown; + try { + await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + wait: { mode: "poll" }, + input: { text: "Hello, world!" }, + signal, + }, + onProgress + ); + } catch (err) { + caught = err; + } expect(body).toBeDefined(); expect(body?.["signal"]).toBeUndefined(); expect(signal.aborted).toBe(true); - expect(output).toBeUndefined(); + // Regression for replicate-javascript#370: an aborted run() must throw + // an AbortError so the caller can detect cancellation, not silently + // resolve to `undefined` from the canceled prediction's empty output. + expect((caught as Error | undefined)?.name).toBe("AbortError"); expect(onProgress).toHaveBeenNthCalledWith( 1, @@ -1625,16 +1633,36 @@ describe("Replicate client", () => { status: "processing", }) ); - expect(onProgress).toHaveBeenNthCalledWith( - 3, - expect.objectContaining({ - status: "canceled", - }) - ); scope.done(); }); + test("throws AbortError immediately when signal is already aborted", async () => { + // Regression for replicate-javascript#370: a pre-aborted signal must + // short-circuit before any HTTP request — previously run() created the + // prediction and waited for it anyway. + const controller = new AbortController(); + controller.abort(); + + let caught: unknown; + try { + await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + wait: { mode: "poll" }, + input: { text: "Hello, world!" }, + signal: controller.signal, + } + ); + } catch (err) { + caught = err; + } + + expect((caught as Error | undefined)?.name).toBe("AbortError"); + // No nock scope was registered — if any HTTP request fired, nock would + // throw a "Nock: No match for request" error instead of an AbortError. + }); + test("returns FileOutput for URLs when useFileOutput is true", async () => { client = new Replicate({ auth: "foo", useFileOutput: true });