@ai-sdk/replicate 0.0.0-70e0935a-20260114150030 → 0.0.0-98261322-20260122142521
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +31 -4
- package/dist/index.js +1 -1
- package/dist/index.mjs +1 -1
- package/docs/60-replicate.mdx +241 -0
- package/package.json +10 -5
- package/src/index.ts +7 -0
- package/src/replicate-error.ts +13 -0
- package/src/replicate-image-model.test.ts +752 -0
- package/src/replicate-image-model.ts +268 -0
- package/src/replicate-image-settings.ts +36 -0
- package/src/replicate-provider.test.ts +24 -0
- package/src/replicate-provider.ts +99 -0
- package/src/version.ts +6 -0
|
@@ -0,0 +1,752 @@
|
|
|
1
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
2
|
+
import { createReplicate } from './replicate-provider';
|
|
3
|
+
import { ReplicateImageModel } from './replicate-image-model';
|
|
4
|
+
import { describe, it, expect, vi } from 'vitest';
|
|
5
|
+
|
|
6
|
+
vi.mock('./version', () => ({
|
|
7
|
+
VERSION: '0.0.0-test',
|
|
8
|
+
}));
|
|
9
|
+
|
|
10
|
+
const prompt = 'The Loch Ness monster getting a manicure';
|
|
11
|
+
|
|
12
|
+
const provider = createReplicate({ apiToken: 'test-api-token' });
|
|
13
|
+
const model = provider.image('black-forest-labs/flux-schnell');
|
|
14
|
+
|
|
15
|
+
describe('doGenerate', () => {
|
|
16
|
+
const testDate = new Date(2024, 0, 1);
|
|
17
|
+
const server = createTestServer({
|
|
18
|
+
'https://api.replicate.com/*': {},
|
|
19
|
+
'https://replicate.delivery/*': {
|
|
20
|
+
response: {
|
|
21
|
+
type: 'binary',
|
|
22
|
+
body: Buffer.from('test-binary-content'),
|
|
23
|
+
},
|
|
24
|
+
},
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
function prepareResponse({
|
|
28
|
+
output = ['https://replicate.delivery/xezq/abc/out-0.webp'],
|
|
29
|
+
}: { output?: string | Array<string> } = {}) {
|
|
30
|
+
server.urls['https://api.replicate.com/*'].response = {
|
|
31
|
+
type: 'json-value',
|
|
32
|
+
body: {
|
|
33
|
+
id: 's7x1e3dcmhrmc0cm8rbatcneec',
|
|
34
|
+
model: 'black-forest-labs/flux-schnell',
|
|
35
|
+
version: 'dp-4d0bcc010b3049749a251855f12800be',
|
|
36
|
+
input: {
|
|
37
|
+
num_outputs: 1,
|
|
38
|
+
prompt: 'The Loch Ness Monster getting a manicure',
|
|
39
|
+
},
|
|
40
|
+
logs: '',
|
|
41
|
+
output,
|
|
42
|
+
data_removed: false,
|
|
43
|
+
error: null,
|
|
44
|
+
status: 'processing',
|
|
45
|
+
created_at: '2025-01-08T13:24:38.692Z',
|
|
46
|
+
urls: {
|
|
47
|
+
cancel:
|
|
48
|
+
'https://api.replicate.com/v1/predictions/s7x1e3dcmhrmc0cm8rbatcneec/cancel',
|
|
49
|
+
get: 'https://api.replicate.com/v1/predictions/s7x1e3dcmhrmc0cm8rbatcneec',
|
|
50
|
+
stream:
|
|
51
|
+
'https://stream.replicate.com/v1/files/bcwr-3okdfv3o2wehstv5f2okyftwxy57hhypqsi6osiim5iaq5k7u24a',
|
|
52
|
+
},
|
|
53
|
+
},
|
|
54
|
+
};
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
it('should pass the model and the settings', async () => {
|
|
58
|
+
prepareResponse();
|
|
59
|
+
|
|
60
|
+
await model.doGenerate({
|
|
61
|
+
prompt,
|
|
62
|
+
files: undefined,
|
|
63
|
+
mask: undefined,
|
|
64
|
+
n: 1,
|
|
65
|
+
size: '1024x768',
|
|
66
|
+
aspectRatio: '3:4',
|
|
67
|
+
seed: 123,
|
|
68
|
+
providerOptions: {
|
|
69
|
+
replicate: {
|
|
70
|
+
style: 'realistic_image',
|
|
71
|
+
},
|
|
72
|
+
other: {
|
|
73
|
+
something: 'else',
|
|
74
|
+
},
|
|
75
|
+
},
|
|
76
|
+
});
|
|
77
|
+
|
|
78
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
79
|
+
input: {
|
|
80
|
+
prompt,
|
|
81
|
+
num_outputs: 1,
|
|
82
|
+
aspect_ratio: '3:4',
|
|
83
|
+
size: '1024x768',
|
|
84
|
+
seed: 123,
|
|
85
|
+
style: 'realistic_image',
|
|
86
|
+
},
|
|
87
|
+
});
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
it('should call the correct url', async () => {
|
|
91
|
+
prepareResponse();
|
|
92
|
+
|
|
93
|
+
await model.doGenerate({
|
|
94
|
+
prompt,
|
|
95
|
+
files: undefined,
|
|
96
|
+
mask: undefined,
|
|
97
|
+
n: 1,
|
|
98
|
+
size: undefined,
|
|
99
|
+
aspectRatio: undefined,
|
|
100
|
+
seed: undefined,
|
|
101
|
+
providerOptions: {},
|
|
102
|
+
});
|
|
103
|
+
|
|
104
|
+
expect(server.calls[0].requestMethod).toStrictEqual('POST');
|
|
105
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
106
|
+
'https://api.replicate.com/v1/models/black-forest-labs/flux-schnell/predictions',
|
|
107
|
+
);
|
|
108
|
+
});
|
|
109
|
+
|
|
110
|
+
it('should pass headers and set the prefer header', async () => {
|
|
111
|
+
prepareResponse();
|
|
112
|
+
|
|
113
|
+
const provider = createReplicate({
|
|
114
|
+
apiToken: 'test-api-token',
|
|
115
|
+
headers: {
|
|
116
|
+
'Custom-Provider-Header': 'provider-header-value',
|
|
117
|
+
},
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
await provider.image('black-forest-labs/flux-schnell').doGenerate({
|
|
121
|
+
prompt,
|
|
122
|
+
files: undefined,
|
|
123
|
+
mask: undefined,
|
|
124
|
+
n: 1,
|
|
125
|
+
size: undefined,
|
|
126
|
+
aspectRatio: undefined,
|
|
127
|
+
seed: undefined,
|
|
128
|
+
providerOptions: {},
|
|
129
|
+
headers: {
|
|
130
|
+
'Custom-Request-Header': 'request-header-value',
|
|
131
|
+
},
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
expect(server.calls[0].requestHeaders).toStrictEqual({
|
|
135
|
+
authorization: 'Bearer test-api-token',
|
|
136
|
+
'content-type': 'application/json',
|
|
137
|
+
'custom-provider-header': 'provider-header-value',
|
|
138
|
+
'custom-request-header': 'request-header-value',
|
|
139
|
+
prefer: 'wait',
|
|
140
|
+
});
|
|
141
|
+
|
|
142
|
+
expect(server.calls[0].requestUserAgent).toContain(
|
|
143
|
+
`ai-sdk/replicate/0.0.0-test`,
|
|
144
|
+
);
|
|
145
|
+
});
|
|
146
|
+
|
|
147
|
+
it('should set custom wait time in prefer header when maxWaitTimeInSeconds is specified', async () => {
|
|
148
|
+
prepareResponse();
|
|
149
|
+
|
|
150
|
+
await model.doGenerate({
|
|
151
|
+
prompt,
|
|
152
|
+
files: undefined,
|
|
153
|
+
mask: undefined,
|
|
154
|
+
n: 1,
|
|
155
|
+
size: undefined,
|
|
156
|
+
aspectRatio: undefined,
|
|
157
|
+
seed: undefined,
|
|
158
|
+
providerOptions: {
|
|
159
|
+
replicate: {
|
|
160
|
+
maxWaitTimeInSeconds: 120,
|
|
161
|
+
},
|
|
162
|
+
},
|
|
163
|
+
});
|
|
164
|
+
|
|
165
|
+
expect(server.calls[0].requestHeaders.prefer).toBe('wait=120');
|
|
166
|
+
});
|
|
167
|
+
|
|
168
|
+
it('should not include maxWaitTimeInSeconds in request body', async () => {
|
|
169
|
+
prepareResponse();
|
|
170
|
+
|
|
171
|
+
await model.doGenerate({
|
|
172
|
+
prompt,
|
|
173
|
+
files: undefined,
|
|
174
|
+
mask: undefined,
|
|
175
|
+
n: 1,
|
|
176
|
+
size: undefined,
|
|
177
|
+
aspectRatio: undefined,
|
|
178
|
+
seed: undefined,
|
|
179
|
+
providerOptions: {
|
|
180
|
+
replicate: {
|
|
181
|
+
maxWaitTimeInSeconds: 120,
|
|
182
|
+
guidance_scale: 7.5,
|
|
183
|
+
},
|
|
184
|
+
},
|
|
185
|
+
});
|
|
186
|
+
|
|
187
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
188
|
+
expect(requestBody.input.maxWaitTimeInSeconds).toBeUndefined();
|
|
189
|
+
expect(requestBody.input.guidance_scale).toBe(7.5);
|
|
190
|
+
});
|
|
191
|
+
|
|
192
|
+
it('should extract the generated image from array response', async () => {
|
|
193
|
+
prepareResponse({
|
|
194
|
+
output: ['https://replicate.delivery/xezq/abc/out-0.webp'],
|
|
195
|
+
});
|
|
196
|
+
|
|
197
|
+
const result = await model.doGenerate({
|
|
198
|
+
prompt,
|
|
199
|
+
files: undefined,
|
|
200
|
+
mask: undefined,
|
|
201
|
+
n: 1,
|
|
202
|
+
size: undefined,
|
|
203
|
+
aspectRatio: undefined,
|
|
204
|
+
seed: undefined,
|
|
205
|
+
providerOptions: {},
|
|
206
|
+
});
|
|
207
|
+
|
|
208
|
+
expect(result.images).toStrictEqual([
|
|
209
|
+
new Uint8Array(Buffer.from('test-binary-content')),
|
|
210
|
+
]);
|
|
211
|
+
|
|
212
|
+
expect(server.calls[1].requestMethod).toStrictEqual('GET');
|
|
213
|
+
expect(server.calls[1].requestUrl).toStrictEqual(
|
|
214
|
+
'https://replicate.delivery/xezq/abc/out-0.webp',
|
|
215
|
+
);
|
|
216
|
+
});
|
|
217
|
+
|
|
218
|
+
it('should extract the generated image from string response', async () => {
|
|
219
|
+
prepareResponse({
|
|
220
|
+
output: 'https://replicate.delivery/xezq/abc/out-0.webp',
|
|
221
|
+
});
|
|
222
|
+
|
|
223
|
+
const result = await model.doGenerate({
|
|
224
|
+
prompt,
|
|
225
|
+
files: undefined,
|
|
226
|
+
mask: undefined,
|
|
227
|
+
n: 1,
|
|
228
|
+
size: undefined,
|
|
229
|
+
aspectRatio: undefined,
|
|
230
|
+
seed: undefined,
|
|
231
|
+
providerOptions: {},
|
|
232
|
+
});
|
|
233
|
+
|
|
234
|
+
expect(result.images).toStrictEqual([
|
|
235
|
+
new Uint8Array(Buffer.from('test-binary-content')),
|
|
236
|
+
]);
|
|
237
|
+
|
|
238
|
+
expect(server.calls[1].requestMethod).toStrictEqual('GET');
|
|
239
|
+
expect(server.calls[1].requestUrl).toStrictEqual(
|
|
240
|
+
'https://replicate.delivery/xezq/abc/out-0.webp',
|
|
241
|
+
);
|
|
242
|
+
});
|
|
243
|
+
|
|
244
|
+
it('should return response metadata', async () => {
|
|
245
|
+
const modelWithTimestamp = new ReplicateImageModel(
|
|
246
|
+
'black-forest-labs/flux-schnell',
|
|
247
|
+
{
|
|
248
|
+
provider: 'replicate',
|
|
249
|
+
baseURL: 'https://api.replicate.com',
|
|
250
|
+
_internal: { currentDate: () => testDate },
|
|
251
|
+
},
|
|
252
|
+
);
|
|
253
|
+
prepareResponse({
|
|
254
|
+
output: ['https://replicate.delivery/xezq/abc/out-0.webp'],
|
|
255
|
+
});
|
|
256
|
+
|
|
257
|
+
const result = await modelWithTimestamp.doGenerate({
|
|
258
|
+
prompt,
|
|
259
|
+
files: undefined,
|
|
260
|
+
mask: undefined,
|
|
261
|
+
n: 1,
|
|
262
|
+
size: undefined,
|
|
263
|
+
aspectRatio: undefined,
|
|
264
|
+
seed: undefined,
|
|
265
|
+
providerOptions: {},
|
|
266
|
+
});
|
|
267
|
+
|
|
268
|
+
expect(result.response).toStrictEqual({
|
|
269
|
+
timestamp: testDate,
|
|
270
|
+
modelId: 'black-forest-labs/flux-schnell',
|
|
271
|
+
headers: expect.any(Object),
|
|
272
|
+
});
|
|
273
|
+
});
|
|
274
|
+
|
|
275
|
+
it('should include response headers in metadata', async () => {
|
|
276
|
+
const modelWithTimestamp = new ReplicateImageModel(
|
|
277
|
+
'black-forest-labs/flux-schnell',
|
|
278
|
+
{
|
|
279
|
+
provider: 'replicate',
|
|
280
|
+
baseURL: 'https://api.replicate.com',
|
|
281
|
+
_internal: {
|
|
282
|
+
currentDate: () => testDate,
|
|
283
|
+
},
|
|
284
|
+
},
|
|
285
|
+
);
|
|
286
|
+
server.urls['https://api.replicate.com/*'].response = {
|
|
287
|
+
type: 'json-value',
|
|
288
|
+
headers: {
|
|
289
|
+
'custom-response-header': 'response-header-value',
|
|
290
|
+
},
|
|
291
|
+
body: {
|
|
292
|
+
id: 's7x1e3dcmhrmc0cm8rbatcneec',
|
|
293
|
+
model: 'black-forest-labs/flux-schnell',
|
|
294
|
+
version: 'dp-4d0bcc010b3049749a251855f12800be',
|
|
295
|
+
input: {
|
|
296
|
+
num_outputs: 1,
|
|
297
|
+
prompt: 'The Loch Ness Monster getting a manicure',
|
|
298
|
+
},
|
|
299
|
+
logs: '',
|
|
300
|
+
output: ['https://replicate.delivery/xezq/abc/out-0.webp'],
|
|
301
|
+
data_removed: false,
|
|
302
|
+
error: null,
|
|
303
|
+
status: 'processing',
|
|
304
|
+
created_at: '2025-01-08T13:24:38.692Z',
|
|
305
|
+
urls: {
|
|
306
|
+
cancel:
|
|
307
|
+
'https://api.replicate.com/v1/predictions/s7x1e3dcmhrmc0cm8rbatcneec/cancel',
|
|
308
|
+
get: 'https://api.replicate.com/v1/predictions/s7x1e3dcmhrmc0cm8rbatcneec',
|
|
309
|
+
stream:
|
|
310
|
+
'https://stream.replicate.com/v1/files/bcwr-3okdfv3o2wehstv5f2okyftwxy57hhypqsi6osiim5iaq5k7u24a',
|
|
311
|
+
},
|
|
312
|
+
},
|
|
313
|
+
};
|
|
314
|
+
|
|
315
|
+
const result = await modelWithTimestamp.doGenerate({
|
|
316
|
+
prompt,
|
|
317
|
+
files: undefined,
|
|
318
|
+
mask: undefined,
|
|
319
|
+
n: 1,
|
|
320
|
+
size: undefined,
|
|
321
|
+
aspectRatio: undefined,
|
|
322
|
+
seed: undefined,
|
|
323
|
+
providerOptions: {},
|
|
324
|
+
});
|
|
325
|
+
|
|
326
|
+
expect(result.response).toStrictEqual({
|
|
327
|
+
timestamp: testDate,
|
|
328
|
+
modelId: 'black-forest-labs/flux-schnell',
|
|
329
|
+
headers: {
|
|
330
|
+
'content-length': '646',
|
|
331
|
+
'content-type': 'application/json',
|
|
332
|
+
'custom-response-header': 'response-header-value',
|
|
333
|
+
},
|
|
334
|
+
});
|
|
335
|
+
});
|
|
336
|
+
|
|
337
|
+
it('should set version in request body for versioned models', async () => {
|
|
338
|
+
prepareResponse();
|
|
339
|
+
|
|
340
|
+
const versionedModel = provider.image(
|
|
341
|
+
'bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637',
|
|
342
|
+
);
|
|
343
|
+
|
|
344
|
+
await versionedModel.doGenerate({
|
|
345
|
+
prompt,
|
|
346
|
+
files: undefined,
|
|
347
|
+
mask: undefined,
|
|
348
|
+
n: 1,
|
|
349
|
+
size: undefined,
|
|
350
|
+
aspectRatio: undefined,
|
|
351
|
+
seed: undefined,
|
|
352
|
+
providerOptions: {},
|
|
353
|
+
});
|
|
354
|
+
|
|
355
|
+
expect(server.calls[0].requestMethod).toStrictEqual('POST');
|
|
356
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
357
|
+
'https://api.replicate.com/v1/predictions',
|
|
358
|
+
);
|
|
359
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
360
|
+
input: {
|
|
361
|
+
prompt,
|
|
362
|
+
num_outputs: 1,
|
|
363
|
+
},
|
|
364
|
+
version:
|
|
365
|
+
'5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637',
|
|
366
|
+
});
|
|
367
|
+
});
|
|
368
|
+
|
|
369
|
+
describe('Image Editing', () => {
|
|
370
|
+
it('should send image when URL file is provided', async () => {
|
|
371
|
+
prepareResponse();
|
|
372
|
+
|
|
373
|
+
await model.doGenerate({
|
|
374
|
+
prompt: 'Add a hat to the person',
|
|
375
|
+
files: [
|
|
376
|
+
{
|
|
377
|
+
type: 'url',
|
|
378
|
+
url: 'https://example.com/input.jpg',
|
|
379
|
+
},
|
|
380
|
+
],
|
|
381
|
+
mask: undefined,
|
|
382
|
+
n: 1,
|
|
383
|
+
size: undefined,
|
|
384
|
+
aspectRatio: undefined,
|
|
385
|
+
seed: undefined,
|
|
386
|
+
providerOptions: {},
|
|
387
|
+
});
|
|
388
|
+
|
|
389
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
390
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
391
|
+
{
|
|
392
|
+
"input": {
|
|
393
|
+
"image": "https://example.com/input.jpg",
|
|
394
|
+
"num_outputs": 1,
|
|
395
|
+
"prompt": "Add a hat to the person",
|
|
396
|
+
},
|
|
397
|
+
}
|
|
398
|
+
`);
|
|
399
|
+
});
|
|
400
|
+
|
|
401
|
+
it('should convert Uint8Array file to data URI', async () => {
|
|
402
|
+
prepareResponse();
|
|
403
|
+
|
|
404
|
+
const testImageData = new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10]);
|
|
405
|
+
|
|
406
|
+
await model.doGenerate({
|
|
407
|
+
prompt: 'Transform this image',
|
|
408
|
+
files: [
|
|
409
|
+
{
|
|
410
|
+
type: 'file',
|
|
411
|
+
data: testImageData,
|
|
412
|
+
mediaType: 'image/png',
|
|
413
|
+
},
|
|
414
|
+
],
|
|
415
|
+
mask: undefined,
|
|
416
|
+
n: 1,
|
|
417
|
+
size: undefined,
|
|
418
|
+
aspectRatio: undefined,
|
|
419
|
+
seed: undefined,
|
|
420
|
+
providerOptions: {},
|
|
421
|
+
});
|
|
422
|
+
|
|
423
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
424
|
+
expect(requestBody.input.image).toMatch(/^data:image\/png;base64,/);
|
|
425
|
+
expect(requestBody.input.prompt).toBe('Transform this image');
|
|
426
|
+
});
|
|
427
|
+
|
|
428
|
+
it('should convert file with base64 string data to data URI', async () => {
|
|
429
|
+
prepareResponse();
|
|
430
|
+
|
|
431
|
+
await model.doGenerate({
|
|
432
|
+
prompt: 'Edit this',
|
|
433
|
+
files: [
|
|
434
|
+
{
|
|
435
|
+
type: 'file',
|
|
436
|
+
data: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
|
437
|
+
mediaType: 'image/png',
|
|
438
|
+
},
|
|
439
|
+
],
|
|
440
|
+
mask: undefined,
|
|
441
|
+
n: 1,
|
|
442
|
+
size: undefined,
|
|
443
|
+
aspectRatio: undefined,
|
|
444
|
+
seed: undefined,
|
|
445
|
+
providerOptions: {},
|
|
446
|
+
});
|
|
447
|
+
|
|
448
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
449
|
+
expect(requestBody.input.image).toBe(
|
|
450
|
+
'',
|
|
451
|
+
);
|
|
452
|
+
});
|
|
453
|
+
|
|
454
|
+
it('should send mask for inpainting', async () => {
|
|
455
|
+
prepareResponse();
|
|
456
|
+
|
|
457
|
+
await model.doGenerate({
|
|
458
|
+
prompt: 'Replace the masked area with a tree',
|
|
459
|
+
files: [
|
|
460
|
+
{
|
|
461
|
+
type: 'url',
|
|
462
|
+
url: 'https://example.com/input.jpg',
|
|
463
|
+
},
|
|
464
|
+
],
|
|
465
|
+
mask: {
|
|
466
|
+
type: 'url',
|
|
467
|
+
url: 'https://example.com/mask.png',
|
|
468
|
+
},
|
|
469
|
+
n: 1,
|
|
470
|
+
size: undefined,
|
|
471
|
+
aspectRatio: undefined,
|
|
472
|
+
seed: undefined,
|
|
473
|
+
providerOptions: {},
|
|
474
|
+
});
|
|
475
|
+
|
|
476
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
477
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
478
|
+
{
|
|
479
|
+
"input": {
|
|
480
|
+
"image": "https://example.com/input.jpg",
|
|
481
|
+
"mask": "https://example.com/mask.png",
|
|
482
|
+
"num_outputs": 1,
|
|
483
|
+
"prompt": "Replace the masked area with a tree",
|
|
484
|
+
},
|
|
485
|
+
}
|
|
486
|
+
`);
|
|
487
|
+
});
|
|
488
|
+
|
|
489
|
+
it('should warn when multiple files are provided', async () => {
|
|
490
|
+
prepareResponse();
|
|
491
|
+
|
|
492
|
+
const result = await model.doGenerate({
|
|
493
|
+
prompt: 'Edit multiple images',
|
|
494
|
+
files: [
|
|
495
|
+
{
|
|
496
|
+
type: 'url',
|
|
497
|
+
url: 'https://example.com/input1.jpg',
|
|
498
|
+
},
|
|
499
|
+
{
|
|
500
|
+
type: 'url',
|
|
501
|
+
url: 'https://example.com/input2.jpg',
|
|
502
|
+
},
|
|
503
|
+
],
|
|
504
|
+
mask: undefined,
|
|
505
|
+
n: 1,
|
|
506
|
+
size: undefined,
|
|
507
|
+
aspectRatio: undefined,
|
|
508
|
+
seed: undefined,
|
|
509
|
+
providerOptions: {},
|
|
510
|
+
});
|
|
511
|
+
|
|
512
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
513
|
+
[
|
|
514
|
+
{
|
|
515
|
+
"message": "This Replicate model only supports a single input image. Additional images are ignored.",
|
|
516
|
+
"type": "other",
|
|
517
|
+
},
|
|
518
|
+
]
|
|
519
|
+
`);
|
|
520
|
+
|
|
521
|
+
// Should only use the first image
|
|
522
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
523
|
+
expect(requestBody.input.image).toBe('https://example.com/input1.jpg');
|
|
524
|
+
});
|
|
525
|
+
|
|
526
|
+
it('should pass provider options with image editing', async () => {
|
|
527
|
+
prepareResponse();
|
|
528
|
+
|
|
529
|
+
await model.doGenerate({
|
|
530
|
+
prompt: 'Inpaint this area',
|
|
531
|
+
files: [
|
|
532
|
+
{
|
|
533
|
+
type: 'url',
|
|
534
|
+
url: 'https://example.com/input.jpg',
|
|
535
|
+
},
|
|
536
|
+
],
|
|
537
|
+
mask: {
|
|
538
|
+
type: 'url',
|
|
539
|
+
url: 'https://example.com/mask.png',
|
|
540
|
+
},
|
|
541
|
+
n: 1,
|
|
542
|
+
size: undefined,
|
|
543
|
+
aspectRatio: undefined,
|
|
544
|
+
seed: undefined,
|
|
545
|
+
providerOptions: {
|
|
546
|
+
replicate: {
|
|
547
|
+
guidance_scale: 7.5,
|
|
548
|
+
num_inference_steps: 30,
|
|
549
|
+
negative_prompt: 'blurry, low quality',
|
|
550
|
+
},
|
|
551
|
+
},
|
|
552
|
+
});
|
|
553
|
+
|
|
554
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
555
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
556
|
+
{
|
|
557
|
+
"input": {
|
|
558
|
+
"guidance_scale": 7.5,
|
|
559
|
+
"image": "https://example.com/input.jpg",
|
|
560
|
+
"mask": "https://example.com/mask.png",
|
|
561
|
+
"negative_prompt": "blurry, low quality",
|
|
562
|
+
"num_inference_steps": 30,
|
|
563
|
+
"num_outputs": 1,
|
|
564
|
+
"prompt": "Inpaint this area",
|
|
565
|
+
},
|
|
566
|
+
}
|
|
567
|
+
`);
|
|
568
|
+
});
|
|
569
|
+
});
|
|
570
|
+
|
|
571
|
+
describe('Flux-2 Models', () => {
|
|
572
|
+
const flux2Model = provider.image('black-forest-labs/flux-2-pro');
|
|
573
|
+
|
|
574
|
+
it('should report maxImagesPerCall as 8 for Flux-2 models', () => {
|
|
575
|
+
expect(flux2Model.maxImagesPerCall).toBe(8);
|
|
576
|
+
});
|
|
577
|
+
|
|
578
|
+
it('should report maxImagesPerCall as 1 for non-Flux-2 models', () => {
|
|
579
|
+
expect(model.maxImagesPerCall).toBe(1);
|
|
580
|
+
});
|
|
581
|
+
|
|
582
|
+
it('should send single image as input_image for Flux-2 models', async () => {
|
|
583
|
+
prepareResponse();
|
|
584
|
+
|
|
585
|
+
await flux2Model.doGenerate({
|
|
586
|
+
prompt: 'Generate image in similar style',
|
|
587
|
+
files: [
|
|
588
|
+
{
|
|
589
|
+
type: 'url',
|
|
590
|
+
url: 'https://example.com/reference.jpg',
|
|
591
|
+
},
|
|
592
|
+
],
|
|
593
|
+
mask: undefined,
|
|
594
|
+
n: 1,
|
|
595
|
+
size: undefined,
|
|
596
|
+
aspectRatio: undefined,
|
|
597
|
+
seed: undefined,
|
|
598
|
+
providerOptions: {},
|
|
599
|
+
});
|
|
600
|
+
|
|
601
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
602
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
603
|
+
{
|
|
604
|
+
"input": {
|
|
605
|
+
"input_image": "https://example.com/reference.jpg",
|
|
606
|
+
"num_outputs": 1,
|
|
607
|
+
"prompt": "Generate image in similar style",
|
|
608
|
+
},
|
|
609
|
+
}
|
|
610
|
+
`);
|
|
611
|
+
});
|
|
612
|
+
|
|
613
|
+
it('should send multiple images as input_image, input_image_2, etc. for Flux-2 models', async () => {
|
|
614
|
+
prepareResponse();
|
|
615
|
+
|
|
616
|
+
await flux2Model.doGenerate({
|
|
617
|
+
prompt: 'Combine styles from reference images',
|
|
618
|
+
files: [
|
|
619
|
+
{
|
|
620
|
+
type: 'url',
|
|
621
|
+
url: 'https://example.com/reference1.jpg',
|
|
622
|
+
},
|
|
623
|
+
{
|
|
624
|
+
type: 'url',
|
|
625
|
+
url: 'https://example.com/reference2.jpg',
|
|
626
|
+
},
|
|
627
|
+
{
|
|
628
|
+
type: 'url',
|
|
629
|
+
url: 'https://example.com/reference3.jpg',
|
|
630
|
+
},
|
|
631
|
+
],
|
|
632
|
+
mask: undefined,
|
|
633
|
+
n: 1,
|
|
634
|
+
size: undefined,
|
|
635
|
+
aspectRatio: undefined,
|
|
636
|
+
seed: undefined,
|
|
637
|
+
providerOptions: {},
|
|
638
|
+
});
|
|
639
|
+
|
|
640
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
641
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
642
|
+
{
|
|
643
|
+
"input": {
|
|
644
|
+
"input_image": "https://example.com/reference1.jpg",
|
|
645
|
+
"input_image_2": "https://example.com/reference2.jpg",
|
|
646
|
+
"input_image_3": "https://example.com/reference3.jpg",
|
|
647
|
+
"num_outputs": 1,
|
|
648
|
+
"prompt": "Combine styles from reference images",
|
|
649
|
+
},
|
|
650
|
+
}
|
|
651
|
+
`);
|
|
652
|
+
});
|
|
653
|
+
|
|
654
|
+
it('should warn when more than 8 images are provided for Flux-2 models', async () => {
|
|
655
|
+
prepareResponse();
|
|
656
|
+
|
|
657
|
+
const result = await flux2Model.doGenerate({
|
|
658
|
+
prompt: 'Too many images',
|
|
659
|
+
files: [
|
|
660
|
+
{ type: 'url', url: 'https://example.com/img1.jpg' },
|
|
661
|
+
{ type: 'url', url: 'https://example.com/img2.jpg' },
|
|
662
|
+
{ type: 'url', url: 'https://example.com/img3.jpg' },
|
|
663
|
+
{ type: 'url', url: 'https://example.com/img4.jpg' },
|
|
664
|
+
{ type: 'url', url: 'https://example.com/img5.jpg' },
|
|
665
|
+
{ type: 'url', url: 'https://example.com/img6.jpg' },
|
|
666
|
+
{ type: 'url', url: 'https://example.com/img7.jpg' },
|
|
667
|
+
{ type: 'url', url: 'https://example.com/img8.jpg' },
|
|
668
|
+
{ type: 'url', url: 'https://example.com/img9.jpg' },
|
|
669
|
+
],
|
|
670
|
+
mask: undefined,
|
|
671
|
+
n: 1,
|
|
672
|
+
size: undefined,
|
|
673
|
+
aspectRatio: undefined,
|
|
674
|
+
seed: undefined,
|
|
675
|
+
providerOptions: {},
|
|
676
|
+
});
|
|
677
|
+
|
|
678
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
679
|
+
[
|
|
680
|
+
{
|
|
681
|
+
"message": "Flux-2 models support up to 8 input images. Additional images are ignored.",
|
|
682
|
+
"type": "other",
|
|
683
|
+
},
|
|
684
|
+
]
|
|
685
|
+
`);
|
|
686
|
+
|
|
687
|
+
// Should only include 8 images
|
|
688
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
689
|
+
expect(requestBody.input.input_image).toBe(
|
|
690
|
+
'https://example.com/img1.jpg',
|
|
691
|
+
);
|
|
692
|
+
expect(requestBody.input.input_image_8).toBe(
|
|
693
|
+
'https://example.com/img8.jpg',
|
|
694
|
+
);
|
|
695
|
+
expect(requestBody.input.input_image_9).toBeUndefined();
|
|
696
|
+
});
|
|
697
|
+
|
|
698
|
+
it('should warn and ignore mask for Flux-2 models', async () => {
|
|
699
|
+
prepareResponse();
|
|
700
|
+
|
|
701
|
+
const result = await flux2Model.doGenerate({
|
|
702
|
+
prompt: 'Edit with mask',
|
|
703
|
+
files: [
|
|
704
|
+
{
|
|
705
|
+
type: 'url',
|
|
706
|
+
url: 'https://example.com/input.jpg',
|
|
707
|
+
},
|
|
708
|
+
],
|
|
709
|
+
mask: {
|
|
710
|
+
type: 'url',
|
|
711
|
+
url: 'https://example.com/mask.png',
|
|
712
|
+
},
|
|
713
|
+
n: 1,
|
|
714
|
+
size: undefined,
|
|
715
|
+
aspectRatio: undefined,
|
|
716
|
+
seed: undefined,
|
|
717
|
+
providerOptions: {},
|
|
718
|
+
});
|
|
719
|
+
|
|
720
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
721
|
+
[
|
|
722
|
+
{
|
|
723
|
+
"message": "Flux-2 models do not support mask input. The mask will be ignored.",
|
|
724
|
+
"type": "other",
|
|
725
|
+
},
|
|
726
|
+
]
|
|
727
|
+
`);
|
|
728
|
+
|
|
729
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
730
|
+
expect(requestBody.input.mask).toBeUndefined();
|
|
731
|
+
});
|
|
732
|
+
|
|
733
|
+
it('should call correct URL for Flux-2 models', async () => {
|
|
734
|
+
prepareResponse();
|
|
735
|
+
|
|
736
|
+
await flux2Model.doGenerate({
|
|
737
|
+
prompt: 'Generate something',
|
|
738
|
+
files: undefined,
|
|
739
|
+
mask: undefined,
|
|
740
|
+
n: 1,
|
|
741
|
+
size: undefined,
|
|
742
|
+
aspectRatio: undefined,
|
|
743
|
+
seed: undefined,
|
|
744
|
+
providerOptions: {},
|
|
745
|
+
});
|
|
746
|
+
|
|
747
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
748
|
+
'https://api.replicate.com/v1/models/black-forest-labs/flux-2-pro/predictions',
|
|
749
|
+
);
|
|
750
|
+
});
|
|
751
|
+
});
|
|
752
|
+
});
|