@ai-sdk/replicate 2.0.7 → 2.0.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +14 -0
- package/dist/index.js +1 -1
- package/dist/index.mjs +1 -1
- package/package.json +5 -4
- package/src/index.ts +7 -0
- package/src/replicate-error.ts +13 -0
- package/src/replicate-image-model.test.ts +752 -0
- package/src/replicate-image-model.ts +268 -0
- package/src/replicate-image-settings.ts +36 -0
- package/src/replicate-provider.test.ts +24 -0
- package/src/replicate-provider.ts +99 -0
- package/src/version.ts +6 -0
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,19 @@
|
|
|
1
1
|
# @ai-sdk/replicate
|
|
2
2
|
|
|
3
|
+
## 2.0.9
|
|
4
|
+
|
|
5
|
+
### Patch Changes
|
|
6
|
+
|
|
7
|
+
- 8dc54db: chore: add src folders to package bundle
|
|
8
|
+
|
|
9
|
+
## 2.0.8
|
|
10
|
+
|
|
11
|
+
### Patch Changes
|
|
12
|
+
|
|
13
|
+
- Updated dependencies [5c090e7]
|
|
14
|
+
- @ai-sdk/provider@3.0.4
|
|
15
|
+
- @ai-sdk/provider-utils@4.0.8
|
|
16
|
+
|
|
3
17
|
## 2.0.7
|
|
4
18
|
|
|
5
19
|
### Patch Changes
|
package/dist/index.js
CHANGED
|
@@ -227,7 +227,7 @@ var replicateImageProviderOptionsSchema = (0, import_provider_utils2.lazySchema)
|
|
|
227
227
|
);
|
|
228
228
|
|
|
229
229
|
// src/version.ts
|
|
230
|
-
var VERSION = true ? "2.0.
|
|
230
|
+
var VERSION = true ? "2.0.9" : "0.0.0-test";
|
|
231
231
|
|
|
232
232
|
// src/replicate-provider.ts
|
|
233
233
|
function createReplicate(options = {}) {
|
package/dist/index.mjs
CHANGED
|
@@ -210,7 +210,7 @@ var replicateImageProviderOptionsSchema = lazySchema(
|
|
|
210
210
|
);
|
|
211
211
|
|
|
212
212
|
// src/version.ts
|
|
213
|
-
var VERSION = true ? "2.0.
|
|
213
|
+
var VERSION = true ? "2.0.9" : "0.0.0-test";
|
|
214
214
|
|
|
215
215
|
// src/replicate-provider.ts
|
|
216
216
|
function createReplicate(options = {}) {
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@ai-sdk/replicate",
|
|
3
|
-
"version": "2.0.
|
|
3
|
+
"version": "2.0.9",
|
|
4
4
|
"license": "Apache-2.0",
|
|
5
5
|
"sideEffects": false,
|
|
6
6
|
"main": "./dist/index.js",
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
"types": "./dist/index.d.ts",
|
|
9
9
|
"files": [
|
|
10
10
|
"dist/**/*",
|
|
11
|
+
"src",
|
|
11
12
|
"CHANGELOG.md",
|
|
12
13
|
"README.md"
|
|
13
14
|
],
|
|
@@ -20,15 +21,15 @@
|
|
|
20
21
|
}
|
|
21
22
|
},
|
|
22
23
|
"dependencies": {
|
|
23
|
-
"@ai-sdk/provider": "3.0.
|
|
24
|
-
"@ai-sdk/provider-utils": "4.0.
|
|
24
|
+
"@ai-sdk/provider": "3.0.4",
|
|
25
|
+
"@ai-sdk/provider-utils": "4.0.8"
|
|
25
26
|
},
|
|
26
27
|
"devDependencies": {
|
|
27
28
|
"@types/node": "20.17.24",
|
|
28
29
|
"tsup": "^8",
|
|
29
30
|
"typescript": "5.8.3",
|
|
30
31
|
"zod": "3.25.76",
|
|
31
|
-
"@ai-sdk/test-server": "1.0.
|
|
32
|
+
"@ai-sdk/test-server": "1.0.2",
|
|
32
33
|
"@vercel/ai-tsconfig": "0.0.0"
|
|
33
34
|
},
|
|
34
35
|
"peerDependencies": {
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
export { createReplicate, replicate } from './replicate-provider';
|
|
2
|
+
export type {
|
|
3
|
+
ReplicateProvider,
|
|
4
|
+
ReplicateProviderSettings,
|
|
5
|
+
} from './replicate-provider';
|
|
6
|
+
export type { ReplicateImageProviderOptions } from './replicate-image-model';
|
|
7
|
+
export { VERSION } from './version';
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { createJsonErrorResponseHandler } from '@ai-sdk/provider-utils';
|
|
2
|
+
import { z } from 'zod/v4';
|
|
3
|
+
|
|
4
|
+
const replicateErrorSchema = z.object({
|
|
5
|
+
detail: z.string().optional(),
|
|
6
|
+
error: z.string().optional(),
|
|
7
|
+
});
|
|
8
|
+
|
|
9
|
+
export const replicateFailedResponseHandler = createJsonErrorResponseHandler({
|
|
10
|
+
errorSchema: replicateErrorSchema,
|
|
11
|
+
errorToMessage: error =>
|
|
12
|
+
error.detail ?? error.error ?? 'Unknown Replicate error',
|
|
13
|
+
});
|
|
@@ -0,0 +1,752 @@
|
|
|
1
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
2
|
+
import { createReplicate } from './replicate-provider';
|
|
3
|
+
import { ReplicateImageModel } from './replicate-image-model';
|
|
4
|
+
import { describe, it, expect, vi } from 'vitest';
|
|
5
|
+
|
|
6
|
+
vi.mock('./version', () => ({
|
|
7
|
+
VERSION: '0.0.0-test',
|
|
8
|
+
}));
|
|
9
|
+
|
|
10
|
+
const prompt = 'The Loch Ness monster getting a manicure';
|
|
11
|
+
|
|
12
|
+
const provider = createReplicate({ apiToken: 'test-api-token' });
|
|
13
|
+
const model = provider.image('black-forest-labs/flux-schnell');
|
|
14
|
+
|
|
15
|
+
describe('doGenerate', () => {
|
|
16
|
+
const testDate = new Date(2024, 0, 1);
|
|
17
|
+
const server = createTestServer({
|
|
18
|
+
'https://api.replicate.com/*': {},
|
|
19
|
+
'https://replicate.delivery/*': {
|
|
20
|
+
response: {
|
|
21
|
+
type: 'binary',
|
|
22
|
+
body: Buffer.from('test-binary-content'),
|
|
23
|
+
},
|
|
24
|
+
},
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
function prepareResponse({
|
|
28
|
+
output = ['https://replicate.delivery/xezq/abc/out-0.webp'],
|
|
29
|
+
}: { output?: string | Array<string> } = {}) {
|
|
30
|
+
server.urls['https://api.replicate.com/*'].response = {
|
|
31
|
+
type: 'json-value',
|
|
32
|
+
body: {
|
|
33
|
+
id: 's7x1e3dcmhrmc0cm8rbatcneec',
|
|
34
|
+
model: 'black-forest-labs/flux-schnell',
|
|
35
|
+
version: 'dp-4d0bcc010b3049749a251855f12800be',
|
|
36
|
+
input: {
|
|
37
|
+
num_outputs: 1,
|
|
38
|
+
prompt: 'The Loch Ness Monster getting a manicure',
|
|
39
|
+
},
|
|
40
|
+
logs: '',
|
|
41
|
+
output,
|
|
42
|
+
data_removed: false,
|
|
43
|
+
error: null,
|
|
44
|
+
status: 'processing',
|
|
45
|
+
created_at: '2025-01-08T13:24:38.692Z',
|
|
46
|
+
urls: {
|
|
47
|
+
cancel:
|
|
48
|
+
'https://api.replicate.com/v1/predictions/s7x1e3dcmhrmc0cm8rbatcneec/cancel',
|
|
49
|
+
get: 'https://api.replicate.com/v1/predictions/s7x1e3dcmhrmc0cm8rbatcneec',
|
|
50
|
+
stream:
|
|
51
|
+
'https://stream.replicate.com/v1/files/bcwr-3okdfv3o2wehstv5f2okyftwxy57hhypqsi6osiim5iaq5k7u24a',
|
|
52
|
+
},
|
|
53
|
+
},
|
|
54
|
+
};
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
it('should pass the model and the settings', async () => {
|
|
58
|
+
prepareResponse();
|
|
59
|
+
|
|
60
|
+
await model.doGenerate({
|
|
61
|
+
prompt,
|
|
62
|
+
files: undefined,
|
|
63
|
+
mask: undefined,
|
|
64
|
+
n: 1,
|
|
65
|
+
size: '1024x768',
|
|
66
|
+
aspectRatio: '3:4',
|
|
67
|
+
seed: 123,
|
|
68
|
+
providerOptions: {
|
|
69
|
+
replicate: {
|
|
70
|
+
style: 'realistic_image',
|
|
71
|
+
},
|
|
72
|
+
other: {
|
|
73
|
+
something: 'else',
|
|
74
|
+
},
|
|
75
|
+
},
|
|
76
|
+
});
|
|
77
|
+
|
|
78
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
79
|
+
input: {
|
|
80
|
+
prompt,
|
|
81
|
+
num_outputs: 1,
|
|
82
|
+
aspect_ratio: '3:4',
|
|
83
|
+
size: '1024x768',
|
|
84
|
+
seed: 123,
|
|
85
|
+
style: 'realistic_image',
|
|
86
|
+
},
|
|
87
|
+
});
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
it('should call the correct url', async () => {
|
|
91
|
+
prepareResponse();
|
|
92
|
+
|
|
93
|
+
await model.doGenerate({
|
|
94
|
+
prompt,
|
|
95
|
+
files: undefined,
|
|
96
|
+
mask: undefined,
|
|
97
|
+
n: 1,
|
|
98
|
+
size: undefined,
|
|
99
|
+
aspectRatio: undefined,
|
|
100
|
+
seed: undefined,
|
|
101
|
+
providerOptions: {},
|
|
102
|
+
});
|
|
103
|
+
|
|
104
|
+
expect(server.calls[0].requestMethod).toStrictEqual('POST');
|
|
105
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
106
|
+
'https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions',
|
|
107
|
+
);
|
|
108
|
+
});
|
|
109
|
+
|
|
110
|
+
it('should pass headers and set the prefer header', async () => {
|
|
111
|
+
prepareResponse();
|
|
112
|
+
|
|
113
|
+
const provider = createReplicate({
|
|
114
|
+
apiToken: 'test-api-token',
|
|
115
|
+
headers: {
|
|
116
|
+
'Custom-Provider-Header': 'provider-header-value',
|
|
117
|
+
},
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
await provider.image('black-forest-labs/flux-schnell').doGenerate({
|
|
121
|
+
prompt,
|
|
122
|
+
files: undefined,
|
|
123
|
+
mask: undefined,
|
|
124
|
+
n: 1,
|
|
125
|
+
size: undefined,
|
|
126
|
+
aspectRatio: undefined,
|
|
127
|
+
seed: undefined,
|
|
128
|
+
providerOptions: {},
|
|
129
|
+
headers: {
|
|
130
|
+
'Custom-Request-Header': 'request-header-value',
|
|
131
|
+
},
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
expect(server.calls[0].requestHeaders).toStrictEqual({
|
|
135
|
+
authorization: 'Bearer test-api-token',
|
|
136
|
+
'content-type': 'application/json',
|
|
137
|
+
'custom-provider-header': 'provider-header-value',
|
|
138
|
+
'custom-request-header': 'request-header-value',
|
|
139
|
+
prefer: 'wait',
|
|
140
|
+
});
|
|
141
|
+
|
|
142
|
+
expect(server.calls[0].requestUserAgent).toContain(
|
|
143
|
+
`ai-sdk/replicate/0.0.0-test`,
|
|
144
|
+
);
|
|
145
|
+
});
|
|
146
|
+
|
|
147
|
+
it('should set custom wait time in prefer header when maxWaitTimeInSeconds is specified', async () => {
|
|
148
|
+
prepareResponse();
|
|
149
|
+
|
|
150
|
+
await model.doGenerate({
|
|
151
|
+
prompt,
|
|
152
|
+
files: undefined,
|
|
153
|
+
mask: undefined,
|
|
154
|
+
n: 1,
|
|
155
|
+
size: undefined,
|
|
156
|
+
aspectRatio: undefined,
|
|
157
|
+
seed: undefined,
|
|
158
|
+
providerOptions: {
|
|
159
|
+
replicate: {
|
|
160
|
+
maxWaitTimeInSeconds: 120,
|
|
161
|
+
},
|
|
162
|
+
},
|
|
163
|
+
});
|
|
164
|
+
|
|
165
|
+
expect(server.calls[0].requestHeaders.prefer).toBe('wait=120');
|
|
166
|
+
});
|
|
167
|
+
|
|
168
|
+
it('should not include maxWaitTimeInSeconds in request body', async () => {
|
|
169
|
+
prepareResponse();
|
|
170
|
+
|
|
171
|
+
await model.doGenerate({
|
|
172
|
+
prompt,
|
|
173
|
+
files: undefined,
|
|
174
|
+
mask: undefined,
|
|
175
|
+
n: 1,
|
|
176
|
+
size: undefined,
|
|
177
|
+
aspectRatio: undefined,
|
|
178
|
+
seed: undefined,
|
|
179
|
+
providerOptions: {
|
|
180
|
+
replicate: {
|
|
181
|
+
maxWaitTimeInSeconds: 120,
|
|
182
|
+
guidance_scale: 7.5,
|
|
183
|
+
},
|
|
184
|
+
},
|
|
185
|
+
});
|
|
186
|
+
|
|
187
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
188
|
+
expect(requestBody.input.maxWaitTimeInSeconds).toBeUndefined();
|
|
189
|
+
expect(requestBody.input.guidance_scale).toBe(7.5);
|
|
190
|
+
});
|
|
191
|
+
|
|
192
|
+
it('should extract the generated image from array response', async () => {
|
|
193
|
+
prepareResponse({
|
|
194
|
+
output: ['https://replicate.delivery/xezq/abc/out-0.webp'],
|
|
195
|
+
});
|
|
196
|
+
|
|
197
|
+
const result = await model.doGenerate({
|
|
198
|
+
prompt,
|
|
199
|
+
files: undefined,
|
|
200
|
+
mask: undefined,
|
|
201
|
+
n: 1,
|
|
202
|
+
size: undefined,
|
|
203
|
+
aspectRatio: undefined,
|
|
204
|
+
seed: undefined,
|
|
205
|
+
providerOptions: {},
|
|
206
|
+
});
|
|
207
|
+
|
|
208
|
+
expect(result.images).toStrictEqual([
|
|
209
|
+
new Uint8Array(Buffer.from('test-binary-content')),
|
|
210
|
+
]);
|
|
211
|
+
|
|
212
|
+
expect(server.calls[1].requestMethod).toStrictEqual('GET');
|
|
213
|
+
expect(server.calls[1].requestUrl).toStrictEqual(
|
|
214
|
+
'https://replicate.delivery/xezq/abc/out-0.webp',
|
|
215
|
+
);
|
|
216
|
+
});
|
|
217
|
+
|
|
218
|
+
it('should extract the generated image from string response', async () => {
|
|
219
|
+
prepareResponse({
|
|
220
|
+
output: 'https://replicate.delivery/xezq/abc/out-0.webp',
|
|
221
|
+
});
|
|
222
|
+
|
|
223
|
+
const result = await model.doGenerate({
|
|
224
|
+
prompt,
|
|
225
|
+
files: undefined,
|
|
226
|
+
mask: undefined,
|
|
227
|
+
n: 1,
|
|
228
|
+
size: undefined,
|
|
229
|
+
aspectRatio: undefined,
|
|
230
|
+
seed: undefined,
|
|
231
|
+
providerOptions: {},
|
|
232
|
+
});
|
|
233
|
+
|
|
234
|
+
expect(result.images).toStrictEqual([
|
|
235
|
+
new Uint8Array(Buffer.from('test-binary-content')),
|
|
236
|
+
]);
|
|
237
|
+
|
|
238
|
+
expect(server.calls[1].requestMethod).toStrictEqual('GET');
|
|
239
|
+
expect(server.calls[1].requestUrl).toStrictEqual(
|
|
240
|
+
'https://replicate.delivery/xezq/abc/out-0.webp',
|
|
241
|
+
);
|
|
242
|
+
});
|
|
243
|
+
|
|
244
|
+
it('should return response metadata', async () => {
|
|
245
|
+
const modelWithTimestamp = new ReplicateImageModel(
|
|
246
|
+
'black-forest-labs/flux-schnell',
|
|
247
|
+
{
|
|
248
|
+
provider: 'replicate',
|
|
249
|
+
baseURL: 'https://api.replicate.com',
|
|
250
|
+
_internal: { currentDate: () => testDate },
|
|
251
|
+
},
|
|
252
|
+
);
|
|
253
|
+
prepareResponse({
|
|
254
|
+
output: ['https://replicate.delivery/xezq/abc/out-0.webp'],
|
|
255
|
+
});
|
|
256
|
+
|
|
257
|
+
const result = await modelWithTimestamp.doGenerate({
|
|
258
|
+
prompt,
|
|
259
|
+
files: undefined,
|
|
260
|
+
mask: undefined,
|
|
261
|
+
n: 1,
|
|
262
|
+
size: undefined,
|
|
263
|
+
aspectRatio: undefined,
|
|
264
|
+
seed: undefined,
|
|
265
|
+
providerOptions: {},
|
|
266
|
+
});
|
|
267
|
+
|
|
268
|
+
expect(result.response).toStrictEqual({
|
|
269
|
+
timestamp: testDate,
|
|
270
|
+
modelId: 'black-forest-labs/flux-schnell',
|
|
271
|
+
headers: expect.any(Object),
|
|
272
|
+
});
|
|
273
|
+
});
|
|
274
|
+
|
|
275
|
+
it('should include response headers in metadata', async () => {
|
|
276
|
+
const modelWithTimestamp = new ReplicateImageModel(
|
|
277
|
+
'black-forest-labs/flux-schnell',
|
|
278
|
+
{
|
|
279
|
+
provider: 'replicate',
|
|
280
|
+
baseURL: 'https://api.replicate.com',
|
|
281
|
+
_internal: {
|
|
282
|
+
currentDate: () => testDate,
|
|
283
|
+
},
|
|
284
|
+
},
|
|
285
|
+
);
|
|
286
|
+
server.urls['https://api.replicate.com/*'].response = {
|
|
287
|
+
type: 'json-value',
|
|
288
|
+
headers: {
|
|
289
|
+
'custom-response-header': 'response-header-value',
|
|
290
|
+
},
|
|
291
|
+
body: {
|
|
292
|
+
id: 's7x1e3dcmhrmc0cm8rbatcneec',
|
|
293
|
+
model: 'black-forest-labs/flux-schnell',
|
|
294
|
+
version: 'dp-4d0bcc010b3049749a251855f12800be',
|
|
295
|
+
input: {
|
|
296
|
+
num_outputs: 1,
|
|
297
|
+
prompt: 'The Loch Ness Monster getting a manicure',
|
|
298
|
+
},
|
|
299
|
+
logs: '',
|
|
300
|
+
output: ['https://replicate.delivery/xezq/abc/out-0.webp'],
|
|
301
|
+
data_removed: false,
|
|
302
|
+
error: null,
|
|
303
|
+
status: 'processing',
|
|
304
|
+
created_at: '2025-01-08T13:24:38.692Z',
|
|
305
|
+
urls: {
|
|
306
|
+
cancel:
|
|
307
|
+
'https://api.replicate.com/v1/predictions/s7x1e3dcmhrmc0cm8rbatcneec/cancel',
|
|
308
|
+
get: 'https://api.replicate.com/v1/predictions/s7x1e3dcmhrmc0cm8rbatcneec',
|
|
309
|
+
stream:
|
|
310
|
+
'https://stream.replicate.com/v1/files/bcwr-3okdfv3o2wehstv5f2okyftwxy57hhypqsi6osiim5iaq5k7u24a',
|
|
311
|
+
},
|
|
312
|
+
},
|
|
313
|
+
};
|
|
314
|
+
|
|
315
|
+
const result = await modelWithTimestamp.doGenerate({
|
|
316
|
+
prompt,
|
|
317
|
+
files: undefined,
|
|
318
|
+
mask: undefined,
|
|
319
|
+
n: 1,
|
|
320
|
+
size: undefined,
|
|
321
|
+
aspectRatio: undefined,
|
|
322
|
+
seed: undefined,
|
|
323
|
+
providerOptions: {},
|
|
324
|
+
});
|
|
325
|
+
|
|
326
|
+
expect(result.response).toStrictEqual({
|
|
327
|
+
timestamp: testDate,
|
|
328
|
+
modelId: 'black-forest-labs/flux-schnell',
|
|
329
|
+
headers: {
|
|
330
|
+
'content-length': '646',
|
|
331
|
+
'content-type': 'application/json',
|
|
332
|
+
'custom-response-header': 'response-header-value',
|
|
333
|
+
},
|
|
334
|
+
});
|
|
335
|
+
});
|
|
336
|
+
|
|
337
|
+
it('should set version in request body for versioned models', async () => {
|
|
338
|
+
prepareResponse();
|
|
339
|
+
|
|
340
|
+
const versionedModel = provider.image(
|
|
341
|
+
'bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637',
|
|
342
|
+
);
|
|
343
|
+
|
|
344
|
+
await versionedModel.doGenerate({
|
|
345
|
+
prompt,
|
|
346
|
+
files: undefined,
|
|
347
|
+
mask: undefined,
|
|
348
|
+
n: 1,
|
|
349
|
+
size: undefined,
|
|
350
|
+
aspectRatio: undefined,
|
|
351
|
+
seed: undefined,
|
|
352
|
+
providerOptions: {},
|
|
353
|
+
});
|
|
354
|
+
|
|
355
|
+
expect(server.calls[0].requestMethod).toStrictEqual('POST');
|
|
356
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
357
|
+
'https://api.replicate.com/v1/predictions',
|
|
358
|
+
);
|
|
359
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
360
|
+
input: {
|
|
361
|
+
prompt,
|
|
362
|
+
num_outputs: 1,
|
|
363
|
+
},
|
|
364
|
+
version:
|
|
365
|
+
'5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637',
|
|
366
|
+
});
|
|
367
|
+
});
|
|
368
|
+
|
|
369
|
+
describe('Image Editing', () => {
|
|
370
|
+
it('should send image when URL file is provided', async () => {
|
|
371
|
+
prepareResponse();
|
|
372
|
+
|
|
373
|
+
await model.doGenerate({
|
|
374
|
+
prompt: 'Add a hat to the person',
|
|
375
|
+
files: [
|
|
376
|
+
{
|
|
377
|
+
type: 'url',
|
|
378
|
+
url: 'https://example.com/input.jpg',
|
|
379
|
+
},
|
|
380
|
+
],
|
|
381
|
+
mask: undefined,
|
|
382
|
+
n: 1,
|
|
383
|
+
size: undefined,
|
|
384
|
+
aspectRatio: undefined,
|
|
385
|
+
seed: undefined,
|
|
386
|
+
providerOptions: {},
|
|
387
|
+
});
|
|
388
|
+
|
|
389
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
390
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
391
|
+
{
|
|
392
|
+
"input": {
|
|
393
|
+
"image": "https://example.com/input.jpg",
|
|
394
|
+
"num_outputs": 1,
|
|
395
|
+
"prompt": "Add a hat to the person",
|
|
396
|
+
},
|
|
397
|
+
}
|
|
398
|
+
`);
|
|
399
|
+
});
|
|
400
|
+
|
|
401
|
+
it('should convert Uint8Array file to data URI', async () => {
|
|
402
|
+
prepareResponse();
|
|
403
|
+
|
|
404
|
+
const testImageData = new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10]);
|
|
405
|
+
|
|
406
|
+
await model.doGenerate({
|
|
407
|
+
prompt: 'Transform this image',
|
|
408
|
+
files: [
|
|
409
|
+
{
|
|
410
|
+
type: 'file',
|
|
411
|
+
data: testImageData,
|
|
412
|
+
mediaType: 'image/png',
|
|
413
|
+
},
|
|
414
|
+
],
|
|
415
|
+
mask: undefined,
|
|
416
|
+
n: 1,
|
|
417
|
+
size: undefined,
|
|
418
|
+
aspectRatio: undefined,
|
|
419
|
+
seed: undefined,
|
|
420
|
+
providerOptions: {},
|
|
421
|
+
});
|
|
422
|
+
|
|
423
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
424
|
+
expect(requestBody.input.image).toMatch(/^data:image\/png;base64,/);
|
|
425
|
+
expect(requestBody.input.prompt).toBe('Transform this image');
|
|
426
|
+
});
|
|
427
|
+
|
|
428
|
+
it('should convert file with base64 string data to data URI', async () => {
|
|
429
|
+
prepareResponse();
|
|
430
|
+
|
|
431
|
+
await model.doGenerate({
|
|
432
|
+
prompt: 'Edit this',
|
|
433
|
+
files: [
|
|
434
|
+
{
|
|
435
|
+
type: 'file',
|
|
436
|
+
data: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
|
437
|
+
mediaType: 'image/png',
|
|
438
|
+
},
|
|
439
|
+
],
|
|
440
|
+
mask: undefined,
|
|
441
|
+
n: 1,
|
|
442
|
+
size: undefined,
|
|
443
|
+
aspectRatio: undefined,
|
|
444
|
+
seed: undefined,
|
|
445
|
+
providerOptions: {},
|
|
446
|
+
});
|
|
447
|
+
|
|
448
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
449
|
+
expect(requestBody.input.image).toBe(
|
|
450
|
+
'',
|
|
451
|
+
);
|
|
452
|
+
});
|
|
453
|
+
|
|
454
|
+
it('should send mask for inpainting', async () => {
|
|
455
|
+
prepareResponse();
|
|
456
|
+
|
|
457
|
+
await model.doGenerate({
|
|
458
|
+
prompt: 'Replace the masked area with a tree',
|
|
459
|
+
files: [
|
|
460
|
+
{
|
|
461
|
+
type: 'url',
|
|
462
|
+
url: 'https://example.com/input.jpg',
|
|
463
|
+
},
|
|
464
|
+
],
|
|
465
|
+
mask: {
|
|
466
|
+
type: 'url',
|
|
467
|
+
url: 'https://example.com/mask.png',
|
|
468
|
+
},
|
|
469
|
+
n: 1,
|
|
470
|
+
size: undefined,
|
|
471
|
+
aspectRatio: undefined,
|
|
472
|
+
seed: undefined,
|
|
473
|
+
providerOptions: {},
|
|
474
|
+
});
|
|
475
|
+
|
|
476
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
477
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
478
|
+
{
|
|
479
|
+
"input": {
|
|
480
|
+
"image": "https://example.com/input.jpg",
|
|
481
|
+
"mask": "https://example.com/mask.png",
|
|
482
|
+
"num_outputs": 1,
|
|
483
|
+
"prompt": "Replace the masked area with a tree",
|
|
484
|
+
},
|
|
485
|
+
}
|
|
486
|
+
`);
|
|
487
|
+
});
|
|
488
|
+
|
|
489
|
+
it('should warn when multiple files are provided', async () => {
|
|
490
|
+
prepareResponse();
|
|
491
|
+
|
|
492
|
+
const result = await model.doGenerate({
|
|
493
|
+
prompt: 'Edit multiple images',
|
|
494
|
+
files: [
|
|
495
|
+
{
|
|
496
|
+
type: 'url',
|
|
497
|
+
url: 'https://example.com/input1.jpg',
|
|
498
|
+
},
|
|
499
|
+
{
|
|
500
|
+
type: 'url',
|
|
501
|
+
url: 'https://example.com/input2.jpg',
|
|
502
|
+
},
|
|
503
|
+
],
|
|
504
|
+
mask: undefined,
|
|
505
|
+
n: 1,
|
|
506
|
+
size: undefined,
|
|
507
|
+
aspectRatio: undefined,
|
|
508
|
+
seed: undefined,
|
|
509
|
+
providerOptions: {},
|
|
510
|
+
});
|
|
511
|
+
|
|
512
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
513
|
+
[
|
|
514
|
+
{
|
|
515
|
+
"message": "This Replicate model only supports a single input image. Additional images are ignored.",
|
|
516
|
+
"type": "other",
|
|
517
|
+
},
|
|
518
|
+
]
|
|
519
|
+
`);
|
|
520
|
+
|
|
521
|
+
// Should only use the first image
|
|
522
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
523
|
+
expect(requestBody.input.image).toBe('https://example.com/input1.jpg');
|
|
524
|
+
});
|
|
525
|
+
|
|
526
|
+
it('should pass provider options with image editing', async () => {
|
|
527
|
+
prepareResponse();
|
|
528
|
+
|
|
529
|
+
await model.doGenerate({
|
|
530
|
+
prompt: 'Inpaint this area',
|
|
531
|
+
files: [
|
|
532
|
+
{
|
|
533
|
+
type: 'url',
|
|
534
|
+
url: 'https://example.com/input.jpg',
|
|
535
|
+
},
|
|
536
|
+
],
|
|
537
|
+
mask: {
|
|
538
|
+
type: 'url',
|
|
539
|
+
url: 'https://example.com/mask.png',
|
|
540
|
+
},
|
|
541
|
+
n: 1,
|
|
542
|
+
size: undefined,
|
|
543
|
+
aspectRatio: undefined,
|
|
544
|
+
seed: undefined,
|
|
545
|
+
providerOptions: {
|
|
546
|
+
replicate: {
|
|
547
|
+
guidance_scale: 7.5,
|
|
548
|
+
num_inference_steps: 30,
|
|
549
|
+
negative_prompt: 'blurry, low quality',
|
|
550
|
+
},
|
|
551
|
+
},
|
|
552
|
+
});
|
|
553
|
+
|
|
554
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
555
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
556
|
+
{
|
|
557
|
+
"input": {
|
|
558
|
+
"guidance_scale": 7.5,
|
|
559
|
+
"image": "https://example.com/input.jpg",
|
|
560
|
+
"mask": "https://example.com/mask.png",
|
|
561
|
+
"negative_prompt": "blurry, low quality",
|
|
562
|
+
"num_inference_steps": 30,
|
|
563
|
+
"num_outputs": 1,
|
|
564
|
+
"prompt": "Inpaint this area",
|
|
565
|
+
},
|
|
566
|
+
}
|
|
567
|
+
`);
|
|
568
|
+
});
|
|
569
|
+
});
|
|
570
|
+
|
|
571
|
+
describe('Flux-2 Models', () => {
|
|
572
|
+
const flux2Model = provider.image('black-forest-labs/flux-2-pro');
|
|
573
|
+
|
|
574
|
+
it('should report maxImagesPerCall as 8 for Flux-2 models', () => {
|
|
575
|
+
expect(flux2Model.maxImagesPerCall).toBe(8);
|
|
576
|
+
});
|
|
577
|
+
|
|
578
|
+
it('should report maxImagesPerCall as 1 for non-Flux-2 models', () => {
|
|
579
|
+
expect(model.maxImagesPerCall).toBe(1);
|
|
580
|
+
});
|
|
581
|
+
|
|
582
|
+
it('should send single image as input_image for Flux-2 models', async () => {
|
|
583
|
+
prepareResponse();
|
|
584
|
+
|
|
585
|
+
await flux2Model.doGenerate({
|
|
586
|
+
prompt: 'Generate image in similar style',
|
|
587
|
+
files: [
|
|
588
|
+
{
|
|
589
|
+
type: 'url',
|
|
590
|
+
url: 'https://example.com/reference.jpg',
|
|
591
|
+
},
|
|
592
|
+
],
|
|
593
|
+
mask: undefined,
|
|
594
|
+
n: 1,
|
|
595
|
+
size: undefined,
|
|
596
|
+
aspectRatio: undefined,
|
|
597
|
+
seed: undefined,
|
|
598
|
+
providerOptions: {},
|
|
599
|
+
});
|
|
600
|
+
|
|
601
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
602
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
603
|
+
{
|
|
604
|
+
"input": {
|
|
605
|
+
"input_image": "https://example.com/reference.jpg",
|
|
606
|
+
"num_outputs": 1,
|
|
607
|
+
"prompt": "Generate image in similar style",
|
|
608
|
+
},
|
|
609
|
+
}
|
|
610
|
+
`);
|
|
611
|
+
});
|
|
612
|
+
|
|
613
|
+
it('should send multiple images as input_image, input_image_2, etc. for Flux-2 models', async () => {
|
|
614
|
+
prepareResponse();
|
|
615
|
+
|
|
616
|
+
await flux2Model.doGenerate({
|
|
617
|
+
prompt: 'Combine styles from reference images',
|
|
618
|
+
files: [
|
|
619
|
+
{
|
|
620
|
+
type: 'url',
|
|
621
|
+
url: 'https://example.com/reference1.jpg',
|
|
622
|
+
},
|
|
623
|
+
{
|
|
624
|
+
type: 'url',
|
|
625
|
+
url: 'https://example.com/reference2.jpg',
|
|
626
|
+
},
|
|
627
|
+
{
|
|
628
|
+
type: 'url',
|
|
629
|
+
url: 'https://example.com/reference3.jpg',
|
|
630
|
+
},
|
|
631
|
+
],
|
|
632
|
+
mask: undefined,
|
|
633
|
+
n: 1,
|
|
634
|
+
size: undefined,
|
|
635
|
+
aspectRatio: undefined,
|
|
636
|
+
seed: undefined,
|
|
637
|
+
providerOptions: {},
|
|
638
|
+
});
|
|
639
|
+
|
|
640
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
641
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
642
|
+
{
|
|
643
|
+
"input": {
|
|
644
|
+
"input_image": "https://example.com/reference1.jpg",
|
|
645
|
+
"input_image_2": "https://example.com/reference2.jpg",
|
|
646
|
+
"input_image_3": "https://example.com/reference3.jpg",
|
|
647
|
+
"num_outputs": 1,
|
|
648
|
+
"prompt": "Combine styles from reference images",
|
|
649
|
+
},
|
|
650
|
+
}
|
|
651
|
+
`);
|
|
652
|
+
});
|
|
653
|
+
|
|
654
|
+
it('should warn when more than 8 images are provided for Flux-2 models', async () => {
|
|
655
|
+
prepareResponse();
|
|
656
|
+
|
|
657
|
+
const result = await flux2Model.doGenerate({
|
|
658
|
+
prompt: 'Too many images',
|
|
659
|
+
files: [
|
|
660
|
+
{ type: 'url', url: 'https://example.com/img1.jpg' },
|
|
661
|
+
{ type: 'url', url: 'https://example.com/img2.jpg' },
|
|
662
|
+
{ type: 'url', url: 'https://example.com/img3.jpg' },
|
|
663
|
+
{ type: 'url', url: 'https://example.com/img4.jpg' },
|
|
664
|
+
{ type: 'url', url: 'https://example.com/img5.jpg' },
|
|
665
|
+
{ type: 'url', url: 'https://example.com/img6.jpg' },
|
|
666
|
+
{ type: 'url', url: 'https://example.com/img7.jpg' },
|
|
667
|
+
{ type: 'url', url: 'https://example.com/img8.jpg' },
|
|
668
|
+
{ type: 'url', url: 'https://example.com/img9.jpg' },
|
|
669
|
+
],
|
|
670
|
+
mask: undefined,
|
|
671
|
+
n: 1,
|
|
672
|
+
size: undefined,
|
|
673
|
+
aspectRatio: undefined,
|
|
674
|
+
seed: undefined,
|
|
675
|
+
providerOptions: {},
|
|
676
|
+
});
|
|
677
|
+
|
|
678
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
679
|
+
[
|
|
680
|
+
{
|
|
681
|
+
"message": "Flux-2 models support up to 8 input images. Additional images are ignored.",
|
|
682
|
+
"type": "other",
|
|
683
|
+
},
|
|
684
|
+
]
|
|
685
|
+
`);
|
|
686
|
+
|
|
687
|
+
// Should only include 8 images
|
|
688
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
689
|
+
expect(requestBody.input.input_image).toBe(
|
|
690
|
+
'https://example.com/img1.jpg',
|
|
691
|
+
);
|
|
692
|
+
expect(requestBody.input.input_image_8).toBe(
|
|
693
|
+
'https://example.com/img8.jpg',
|
|
694
|
+
);
|
|
695
|
+
expect(requestBody.input.input_image_9).toBeUndefined();
|
|
696
|
+
});
|
|
697
|
+
|
|
698
|
+
it('should warn and ignore mask for Flux-2 models', async () => {
|
|
699
|
+
prepareResponse();
|
|
700
|
+
|
|
701
|
+
const result = await flux2Model.doGenerate({
|
|
702
|
+
prompt: 'Edit with mask',
|
|
703
|
+
files: [
|
|
704
|
+
{
|
|
705
|
+
type: 'url',
|
|
706
|
+
url: 'https://example.com/input.jpg',
|
|
707
|
+
},
|
|
708
|
+
],
|
|
709
|
+
mask: {
|
|
710
|
+
type: 'url',
|
|
711
|
+
url: 'https://example.com/mask.png',
|
|
712
|
+
},
|
|
713
|
+
n: 1,
|
|
714
|
+
size: undefined,
|
|
715
|
+
aspectRatio: undefined,
|
|
716
|
+
seed: undefined,
|
|
717
|
+
providerOptions: {},
|
|
718
|
+
});
|
|
719
|
+
|
|
720
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
721
|
+
[
|
|
722
|
+
{
|
|
723
|
+
"message": "Flux-2 models do not support mask input. The mask will be ignored.",
|
|
724
|
+
"type": "other",
|
|
725
|
+
},
|
|
726
|
+
]
|
|
727
|
+
`);
|
|
728
|
+
|
|
729
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
730
|
+
expect(requestBody.input.mask).toBeUndefined();
|
|
731
|
+
});
|
|
732
|
+
|
|
733
|
+
it('should call correct URL for Flux-2 models', async () => {
|
|
734
|
+
prepareResponse();
|
|
735
|
+
|
|
736
|
+
await flux2Model.doGenerate({
|
|
737
|
+
prompt: 'Generate something',
|
|
738
|
+
files: undefined,
|
|
739
|
+
mask: undefined,
|
|
740
|
+
n: 1,
|
|
741
|
+
size: undefined,
|
|
742
|
+
aspectRatio: undefined,
|
|
743
|
+
seed: undefined,
|
|
744
|
+
providerOptions: {},
|
|
745
|
+
});
|
|
746
|
+
|
|
747
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
748
|
+
'https://api.replicate.com/v1/models/black-forest-labs/flux-2-pro/predictions',
|
|
749
|
+
);
|
|
750
|
+
});
|
|
751
|
+
});
|
|
752
|
+
});
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import type { ImageModelV3, SharedV3Warning } from '@ai-sdk/provider';
|
|
2
|
+
import type { Resolvable } from '@ai-sdk/provider-utils';
|
|
3
|
+
import {
|
|
4
|
+
combineHeaders,
|
|
5
|
+
convertImageModelFileToDataUri,
|
|
6
|
+
createBinaryResponseHandler,
|
|
7
|
+
createJsonResponseHandler,
|
|
8
|
+
FetchFunction,
|
|
9
|
+
getFromApi,
|
|
10
|
+
InferSchema,
|
|
11
|
+
lazySchema,
|
|
12
|
+
parseProviderOptions,
|
|
13
|
+
postJsonToApi,
|
|
14
|
+
resolve,
|
|
15
|
+
zodSchema,
|
|
16
|
+
} from '@ai-sdk/provider-utils';
|
|
17
|
+
import { z } from 'zod/v4';
|
|
18
|
+
import { replicateFailedResponseHandler } from './replicate-error';
|
|
19
|
+
import { ReplicateImageModelId } from './replicate-image-settings';
|
|
20
|
+
|
|
21
|
+
interface ReplicateImageModelConfig {
|
|
22
|
+
provider: string;
|
|
23
|
+
baseURL: string;
|
|
24
|
+
headers?: Resolvable<Record<string, string | undefined>>;
|
|
25
|
+
fetch?: FetchFunction;
|
|
26
|
+
_internal?: {
|
|
27
|
+
currentDate?: () => Date;
|
|
28
|
+
};
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
// Flux-2 models support up to 8 input images with input_image, input_image_2, etc.
|
|
32
|
+
const FLUX_2_MODEL_PATTERN = /^black-forest-labs\/flux-2-/;
|
|
33
|
+
const MAX_FLUX_2_INPUT_IMAGES = 8;
|
|
34
|
+
|
|
35
|
+
export class ReplicateImageModel implements ImageModelV3 {
|
|
36
|
+
readonly specificationVersion = 'v3';
|
|
37
|
+
|
|
38
|
+
get maxImagesPerCall(): number {
|
|
39
|
+
// Flux-2 models support up to 8 input images
|
|
40
|
+
return this.isFlux2Model ? MAX_FLUX_2_INPUT_IMAGES : 1;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
get provider(): string {
|
|
44
|
+
return this.config.provider;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
private get isFlux2Model(): boolean {
|
|
48
|
+
return FLUX_2_MODEL_PATTERN.test(this.modelId);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
constructor(
|
|
52
|
+
readonly modelId: ReplicateImageModelId,
|
|
53
|
+
private readonly config: ReplicateImageModelConfig,
|
|
54
|
+
) {}
|
|
55
|
+
|
|
56
|
+
async doGenerate({
|
|
57
|
+
prompt,
|
|
58
|
+
n,
|
|
59
|
+
aspectRatio,
|
|
60
|
+
size,
|
|
61
|
+
seed,
|
|
62
|
+
providerOptions,
|
|
63
|
+
headers,
|
|
64
|
+
abortSignal,
|
|
65
|
+
files,
|
|
66
|
+
mask,
|
|
67
|
+
}: Parameters<ImageModelV3['doGenerate']>[0]): Promise<
|
|
68
|
+
Awaited<ReturnType<ImageModelV3['doGenerate']>>
|
|
69
|
+
> {
|
|
70
|
+
const warnings: Array<SharedV3Warning> = [];
|
|
71
|
+
|
|
72
|
+
const [modelId, version] = this.modelId.split(':');
|
|
73
|
+
|
|
74
|
+
const currentDate = this.config._internal?.currentDate?.() ?? new Date();
|
|
75
|
+
|
|
76
|
+
// Parse provider options
|
|
77
|
+
const replicateOptions = await parseProviderOptions({
|
|
78
|
+
provider: 'replicate',
|
|
79
|
+
providerOptions,
|
|
80
|
+
schema: replicateImageProviderOptionsSchema,
|
|
81
|
+
});
|
|
82
|
+
|
|
83
|
+
// Handle image input from files
|
|
84
|
+
let imageInputs: Record<string, string> = {};
|
|
85
|
+
if (files != null && files.length > 0) {
|
|
86
|
+
if (this.isFlux2Model) {
|
|
87
|
+
// Flux-2 models use input_image, input_image_2, input_image_3, etc.
|
|
88
|
+
for (
|
|
89
|
+
let i = 0;
|
|
90
|
+
i < Math.min(files.length, MAX_FLUX_2_INPUT_IMAGES);
|
|
91
|
+
i++
|
|
92
|
+
) {
|
|
93
|
+
const key = i === 0 ? 'input_image' : `input_image_${i + 1}`;
|
|
94
|
+
imageInputs[key] = convertImageModelFileToDataUri(files[i]);
|
|
95
|
+
}
|
|
96
|
+
if (files.length > MAX_FLUX_2_INPUT_IMAGES) {
|
|
97
|
+
warnings.push({
|
|
98
|
+
type: 'other',
|
|
99
|
+
message: `Flux-2 models support up to ${MAX_FLUX_2_INPUT_IMAGES} input images. Additional images are ignored.`,
|
|
100
|
+
});
|
|
101
|
+
}
|
|
102
|
+
} else {
|
|
103
|
+
// Other models use single 'image' parameter
|
|
104
|
+
imageInputs = { image: convertImageModelFileToDataUri(files[0]) };
|
|
105
|
+
if (files.length > 1) {
|
|
106
|
+
warnings.push({
|
|
107
|
+
type: 'other',
|
|
108
|
+
message:
|
|
109
|
+
'This Replicate model only supports a single input image. Additional images are ignored.',
|
|
110
|
+
});
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
// Handle mask input (not supported by Flux-2 models)
|
|
116
|
+
let maskInput: string | undefined;
|
|
117
|
+
if (mask != null) {
|
|
118
|
+
if (this.isFlux2Model) {
|
|
119
|
+
warnings.push({
|
|
120
|
+
type: 'other',
|
|
121
|
+
message:
|
|
122
|
+
'Flux-2 models do not support mask input. The mask will be ignored.',
|
|
123
|
+
});
|
|
124
|
+
} else {
|
|
125
|
+
maskInput = convertImageModelFileToDataUri(mask);
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
// Extract maxWaitTimeInSeconds from provider options and prepare the rest for the request body
|
|
130
|
+
const { maxWaitTimeInSeconds, ...inputOptions } = replicateOptions ?? {};
|
|
131
|
+
|
|
132
|
+
// Build the prefer header based on maxWaitTimeInSeconds:
|
|
133
|
+
// - undefined/null: use default sync wait (prefer: wait)
|
|
134
|
+
// - positive number: use custom wait duration (prefer: wait=N)
|
|
135
|
+
const preferHeader: Record<string, string> =
|
|
136
|
+
maxWaitTimeInSeconds != null
|
|
137
|
+
? { prefer: `wait=${maxWaitTimeInSeconds}` }
|
|
138
|
+
: { prefer: 'wait' };
|
|
139
|
+
|
|
140
|
+
const {
|
|
141
|
+
value: { output },
|
|
142
|
+
responseHeaders,
|
|
143
|
+
} = await postJsonToApi({
|
|
144
|
+
url:
|
|
145
|
+
// different endpoints for versioned vs unversioned models:
|
|
146
|
+
version != null
|
|
147
|
+
? `${this.config.baseURL}/predictions`
|
|
148
|
+
: `${this.config.baseURL}/models/${modelId}/predictions`,
|
|
149
|
+
|
|
150
|
+
headers: combineHeaders(
|
|
151
|
+
await resolve(this.config.headers),
|
|
152
|
+
headers,
|
|
153
|
+
preferHeader,
|
|
154
|
+
),
|
|
155
|
+
|
|
156
|
+
body: {
|
|
157
|
+
input: {
|
|
158
|
+
prompt,
|
|
159
|
+
aspect_ratio: aspectRatio,
|
|
160
|
+
size,
|
|
161
|
+
seed,
|
|
162
|
+
num_outputs: n,
|
|
163
|
+
...imageInputs,
|
|
164
|
+
...(maskInput != null ? { mask: maskInput } : {}),
|
|
165
|
+
...inputOptions,
|
|
166
|
+
},
|
|
167
|
+
// for versioned models, include the version in the body:
|
|
168
|
+
...(version != null ? { version } : {}),
|
|
169
|
+
},
|
|
170
|
+
|
|
171
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
172
|
+
replicateImageResponseSchema,
|
|
173
|
+
),
|
|
174
|
+
failedResponseHandler: replicateFailedResponseHandler,
|
|
175
|
+
abortSignal,
|
|
176
|
+
fetch: this.config.fetch,
|
|
177
|
+
});
|
|
178
|
+
|
|
179
|
+
// download the images:
|
|
180
|
+
const outputArray = Array.isArray(output) ? output : [output];
|
|
181
|
+
const images = await Promise.all(
|
|
182
|
+
outputArray.map(async url => {
|
|
183
|
+
const { value: image } = await getFromApi({
|
|
184
|
+
url,
|
|
185
|
+
successfulResponseHandler: createBinaryResponseHandler(),
|
|
186
|
+
failedResponseHandler: replicateFailedResponseHandler,
|
|
187
|
+
abortSignal,
|
|
188
|
+
fetch: this.config.fetch,
|
|
189
|
+
});
|
|
190
|
+
return image;
|
|
191
|
+
}),
|
|
192
|
+
);
|
|
193
|
+
|
|
194
|
+
return {
|
|
195
|
+
images,
|
|
196
|
+
warnings,
|
|
197
|
+
response: {
|
|
198
|
+
timestamp: currentDate,
|
|
199
|
+
modelId: this.modelId,
|
|
200
|
+
headers: responseHeaders,
|
|
201
|
+
},
|
|
202
|
+
};
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
const replicateImageResponseSchema = z.object({
|
|
207
|
+
output: z.union([z.array(z.string()), z.string()]),
|
|
208
|
+
});
|
|
209
|
+
|
|
210
|
+
/**
|
|
211
|
+
* Provider options schema for Replicate image generation.
|
|
212
|
+
*
|
|
213
|
+
* Note: Different Replicate models support different parameters.
|
|
214
|
+
* This schema includes common parameters, but you can pass any
|
|
215
|
+
* model-specific parameters through the passthrough.
|
|
216
|
+
*/
|
|
217
|
+
export const replicateImageProviderOptionsSchema = lazySchema(() =>
|
|
218
|
+
zodSchema(
|
|
219
|
+
z
|
|
220
|
+
.object({
|
|
221
|
+
/**
|
|
222
|
+
* Maximum time in seconds to wait for the prediction to complete in sync mode.
|
|
223
|
+
* By default, Replicate uses sync mode with a 60-second timeout.
|
|
224
|
+
*
|
|
225
|
+
* - When not specified: Uses default 60-second sync wait (`prefer: wait`)
|
|
226
|
+
* - When set to a positive number: Uses that duration (`prefer: wait=N`)
|
|
227
|
+
*/
|
|
228
|
+
maxWaitTimeInSeconds: z.number().positive().nullish(),
|
|
229
|
+
|
|
230
|
+
/**
|
|
231
|
+
* Guidance scale for classifier-free guidance.
|
|
232
|
+
* Higher values make the output more closely match the prompt.
|
|
233
|
+
*/
|
|
234
|
+
guidance_scale: z.number().nullish(),
|
|
235
|
+
|
|
236
|
+
/**
|
|
237
|
+
* Number of denoising steps. More steps = higher quality but slower.
|
|
238
|
+
*/
|
|
239
|
+
num_inference_steps: z.number().nullish(),
|
|
240
|
+
|
|
241
|
+
/**
|
|
242
|
+
* Negative prompt to guide what to avoid in the generation.
|
|
243
|
+
*/
|
|
244
|
+
negative_prompt: z.string().nullish(),
|
|
245
|
+
|
|
246
|
+
/**
|
|
247
|
+
* Output image format.
|
|
248
|
+
*/
|
|
249
|
+
output_format: z.enum(['png', 'jpg', 'webp']).nullish(),
|
|
250
|
+
|
|
251
|
+
/**
|
|
252
|
+
* Output image quality (1-100). Only applies to jpg and webp.
|
|
253
|
+
*/
|
|
254
|
+
output_quality: z.number().min(1).max(100).nullish(),
|
|
255
|
+
|
|
256
|
+
/**
|
|
257
|
+
* Strength of the transformation for img2img (0-1).
|
|
258
|
+
* Lower values keep more of the original image.
|
|
259
|
+
*/
|
|
260
|
+
strength: z.number().min(0).max(1).nullish(),
|
|
261
|
+
})
|
|
262
|
+
.passthrough(),
|
|
263
|
+
),
|
|
264
|
+
);
|
|
265
|
+
|
|
266
|
+
export type ReplicateImageProviderOptions = InferSchema<
|
|
267
|
+
typeof replicateImageProviderOptionsSchema
|
|
268
|
+
>;
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
export type ReplicateImageModelId =
|
|
2
|
+
// Text-to-image models
|
|
3
|
+
| 'black-forest-labs/flux-1.1-pro'
|
|
4
|
+
| 'black-forest-labs/flux-1.1-pro-ultra'
|
|
5
|
+
| 'black-forest-labs/flux-dev'
|
|
6
|
+
| 'black-forest-labs/flux-pro'
|
|
7
|
+
| 'black-forest-labs/flux-schnell'
|
|
8
|
+
| 'bytedance/sdxl-lightning-4step'
|
|
9
|
+
| 'fofr/aura-flow'
|
|
10
|
+
| 'fofr/latent-consistency-model'
|
|
11
|
+
| 'fofr/realvisxl-v3-multi-controlnet-lora'
|
|
12
|
+
| 'fofr/sdxl-emoji'
|
|
13
|
+
| 'fofr/sdxl-multi-controlnet-lora'
|
|
14
|
+
| 'ideogram-ai/ideogram-v2'
|
|
15
|
+
| 'ideogram-ai/ideogram-v2-turbo'
|
|
16
|
+
| 'lucataco/dreamshaper-xl-turbo'
|
|
17
|
+
| 'lucataco/open-dalle-v1.1'
|
|
18
|
+
| 'lucataco/realvisxl-v2.0'
|
|
19
|
+
| 'lucataco/realvisxl2-lcm'
|
|
20
|
+
| 'luma/photon'
|
|
21
|
+
| 'luma/photon-flash'
|
|
22
|
+
| 'nvidia/sana'
|
|
23
|
+
| 'playgroundai/playground-v2.5-1024px-aesthetic'
|
|
24
|
+
| 'recraft-ai/recraft-v3'
|
|
25
|
+
| 'recraft-ai/recraft-v3-svg'
|
|
26
|
+
| 'stability-ai/stable-diffusion-3.5-large'
|
|
27
|
+
| 'stability-ai/stable-diffusion-3.5-large-turbo'
|
|
28
|
+
| 'stability-ai/stable-diffusion-3.5-medium'
|
|
29
|
+
| 'tstramer/material-diffusion'
|
|
30
|
+
// Inpainting and image editing models
|
|
31
|
+
| 'black-forest-labs/flux-fill-pro'
|
|
32
|
+
| 'black-forest-labs/flux-fill-dev'
|
|
33
|
+
// Flux-2 models (support up to 8 reference images)
|
|
34
|
+
| 'black-forest-labs/flux-2-pro'
|
|
35
|
+
| 'black-forest-labs/flux-2-dev'
|
|
36
|
+
| (string & {});
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import { describe, it, expect } from 'vitest';
|
|
2
|
+
import { createReplicate } from './replicate-provider';
|
|
3
|
+
import { ReplicateImageModel } from './replicate-image-model';
|
|
4
|
+
|
|
5
|
+
describe('createReplicate', () => {
|
|
6
|
+
it('creates a provider with required settings', () => {
|
|
7
|
+
const provider = createReplicate({ apiToken: 'test-token' });
|
|
8
|
+
expect(provider.image).toBeDefined();
|
|
9
|
+
});
|
|
10
|
+
|
|
11
|
+
it('creates a provider with custom settings', () => {
|
|
12
|
+
const provider = createReplicate({
|
|
13
|
+
apiToken: 'test-token',
|
|
14
|
+
baseURL: 'https://custom.replicate.com',
|
|
15
|
+
});
|
|
16
|
+
expect(provider.image).toBeDefined();
|
|
17
|
+
});
|
|
18
|
+
|
|
19
|
+
it('creates an image model instance', () => {
|
|
20
|
+
const provider = createReplicate({ apiToken: 'test-token' });
|
|
21
|
+
const model = provider.image('black-forest-labs/flux-schnell');
|
|
22
|
+
expect(model).toBeInstanceOf(ReplicateImageModel);
|
|
23
|
+
});
|
|
24
|
+
});
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import { NoSuchModelError, ProviderV3 } from '@ai-sdk/provider';
|
|
2
|
+
import type { FetchFunction } from '@ai-sdk/provider-utils';
|
|
3
|
+
import { loadApiKey, withUserAgentSuffix } from '@ai-sdk/provider-utils';
|
|
4
|
+
import { ReplicateImageModel } from './replicate-image-model';
|
|
5
|
+
import { ReplicateImageModelId } from './replicate-image-settings';
|
|
6
|
+
import { VERSION } from './version';
|
|
7
|
+
|
|
8
|
+
export interface ReplicateProviderSettings {
|
|
9
|
+
/**
|
|
10
|
+
API token that is being send using the `Authorization` header.
|
|
11
|
+
It defaults to the `REPLICATE_API_TOKEN` environment variable.
|
|
12
|
+
*/
|
|
13
|
+
apiToken?: string;
|
|
14
|
+
|
|
15
|
+
/**
|
|
16
|
+
Use a different URL prefix for API calls, e.g. to use proxy servers.
|
|
17
|
+
The default prefix is `https://api.replicate.com/v1`.
|
|
18
|
+
*/
|
|
19
|
+
baseURL?: string;
|
|
20
|
+
|
|
21
|
+
/**
|
|
22
|
+
Custom headers to include in the requests.
|
|
23
|
+
*/
|
|
24
|
+
headers?: Record<string, string>;
|
|
25
|
+
|
|
26
|
+
/**
|
|
27
|
+
Custom fetch implementation. You can use it as a middleware to intercept requests,
|
|
28
|
+
or to provide a custom fetch implementation for e.g. testing.
|
|
29
|
+
*/
|
|
30
|
+
fetch?: FetchFunction;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
export interface ReplicateProvider extends ProviderV3 {
|
|
34
|
+
/**
|
|
35
|
+
* Creates a Replicate image generation model.
|
|
36
|
+
*/
|
|
37
|
+
image(modelId: ReplicateImageModelId): ReplicateImageModel;
|
|
38
|
+
|
|
39
|
+
/**
|
|
40
|
+
* Creates a Replicate image generation model.
|
|
41
|
+
*/
|
|
42
|
+
imageModel(modelId: ReplicateImageModelId): ReplicateImageModel;
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* @deprecated Use `embeddingModel` instead.
|
|
46
|
+
*/
|
|
47
|
+
textEmbeddingModel(modelId: string): never;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Create a Replicate provider instance.
|
|
52
|
+
*/
|
|
53
|
+
export function createReplicate(
|
|
54
|
+
options: ReplicateProviderSettings = {},
|
|
55
|
+
): ReplicateProvider {
|
|
56
|
+
const createImageModel = (modelId: ReplicateImageModelId) =>
|
|
57
|
+
new ReplicateImageModel(modelId, {
|
|
58
|
+
provider: 'replicate',
|
|
59
|
+
baseURL: options.baseURL ?? 'https://api.replicate.com/v1',
|
|
60
|
+
headers: withUserAgentSuffix(
|
|
61
|
+
{
|
|
62
|
+
Authorization: `Bearer ${loadApiKey({
|
|
63
|
+
apiKey: options.apiToken,
|
|
64
|
+
environmentVariableName: 'REPLICATE_API_TOKEN',
|
|
65
|
+
description: 'Replicate',
|
|
66
|
+
})}`,
|
|
67
|
+
...options.headers,
|
|
68
|
+
},
|
|
69
|
+
`ai-sdk/replicate/${VERSION}`,
|
|
70
|
+
),
|
|
71
|
+
fetch: options.fetch,
|
|
72
|
+
});
|
|
73
|
+
|
|
74
|
+
const embeddingModel = (modelId: string) => {
|
|
75
|
+
throw new NoSuchModelError({
|
|
76
|
+
modelId,
|
|
77
|
+
modelType: 'embeddingModel',
|
|
78
|
+
});
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
return {
|
|
82
|
+
specificationVersion: 'v3' as const,
|
|
83
|
+
image: createImageModel,
|
|
84
|
+
imageModel: createImageModel,
|
|
85
|
+
languageModel: (modelId: string) => {
|
|
86
|
+
throw new NoSuchModelError({
|
|
87
|
+
modelId,
|
|
88
|
+
modelType: 'languageModel',
|
|
89
|
+
});
|
|
90
|
+
},
|
|
91
|
+
embeddingModel,
|
|
92
|
+
textEmbeddingModel: embeddingModel,
|
|
93
|
+
};
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* Default Replicate provider instance.
|
|
98
|
+
*/
|
|
99
|
+
export const replicate = createReplicate();
|