Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion index.js
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ class Replicate {
async run(ref, options, progress) {
const { wait = { mode: "block" }, signal, ...data } = options;

// Honour an already-aborted signal before any network work happens.
if (signal && signal.aborted) {
signal.throwIfAborted();
}

const identifier = ModelVersionIdentifier.parse(ref);

let prediction;
Expand All @@ -153,12 +158,14 @@ class Replicate {
...data,
version: identifier.version,
wait: wait.mode === "block" ? wait.timeout ?? true : false,
signal,
});
} else if (identifier.owner && identifier.name) {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`,
wait: wait.mode === "block" ? wait.timeout ?? true : false,
signal,
});
} else {
throw new Error("Invalid model version identifier");
Expand Down Expand Up @@ -191,7 +198,16 @@ class Replicate {
}

if (signal && signal.aborted) {
prediction = await this.predictions.cancel(prediction.id);
// Best-effort cancel on Replicate's side so we don't keep billing the
// user for compute they no longer want, then surface the abort to the
// caller. Without the throw, the awaited promise would resolve with a
// half-cancelled prediction, which is silent on the consumer side.
try {
prediction = await this.predictions.cancel(prediction.id);
} catch {
// Ignore cancel failures — the abort is the higher-priority signal.
}
signal.throwIfAborted();
}

// Call progress callback with the completed prediction object
Expand Down
60 changes: 44 additions & 16 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1598,20 +1598,28 @@ describe("Replicate client", () => {
});

const onProgress = jest.fn();
const output = await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: { text: "Hello, world!" },
signal,
},
onProgress
);
let caught: unknown;
try {
await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: { text: "Hello, world!" },
signal,
},
onProgress
);
} catch (err) {
caught = err;
}

expect(body).toBeDefined();
expect(body?.["signal"]).toBeUndefined();
expect(signal.aborted).toBe(true);
expect(output).toBeUndefined();
// Regression for replicate-javascript#370: an aborted run() must throw
// an AbortError so the caller can detect cancellation, not silently
// resolve to `undefined` from the canceled prediction's empty output.
expect((caught as Error | undefined)?.name).toBe("AbortError");

expect(onProgress).toHaveBeenNthCalledWith(
1,
Expand All @@ -1625,16 +1633,36 @@ describe("Replicate client", () => {
status: "processing",
})
);
expect(onProgress).toHaveBeenNthCalledWith(
3,
expect.objectContaining({
status: "canceled",
})
);

scope.done();
});

test("throws AbortError immediately when signal is already aborted", async () => {
// Regression for replicate-javascript#370: a pre-aborted signal must
// short-circuit before any HTTP request — previously run() created the
// prediction and waited for it anyway.
const controller = new AbortController();
controller.abort();

let caught: unknown;
try {
await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: { text: "Hello, world!" },
signal: controller.signal,
}
);
} catch (err) {
caught = err;
}

expect((caught as Error | undefined)?.name).toBe("AbortError");
// No nock scope was registered — if any HTTP request fired, nock would
// throw a "Nock: No match for request" error instead of an AbortError.
});

test("returns FileOutput for URLs when useFileOutput is true", async () => {
client = new Replicate({ auth: "foo", useFileOutput: true });

Expand Down