@ai-sdk/togetherai 0.0.0-64aae7dd-20260114144918 → 0.0.0-98261322-20260122142521
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +70 -5
- package/dist/index.js +2 -2
- package/dist/index.js.map +1 -1
- package/dist/index.mjs +2 -2
- package/dist/index.mjs.map +1 -1
- package/docs/24-togetherai.mdx +365 -0
- package/package.json +11 -6
- package/src/index.ts +9 -0
- package/src/reranking/__fixtures__/togetherai-reranking.1.json +22 -0
- package/src/reranking/togetherai-reranking-api.ts +43 -0
- package/src/reranking/togetherai-reranking-model.test.ts +245 -0
- package/src/reranking/togetherai-reranking-model.ts +101 -0
- package/src/reranking/togetherai-reranking-options.ts +27 -0
- package/src/togetherai-chat-options.ts +36 -0
- package/src/togetherai-completion-options.ts +9 -0
- package/src/togetherai-embedding-options.ts +11 -0
- package/src/togetherai-image-model.test.ts +488 -0
- package/src/togetherai-image-model.ts +188 -0
- package/src/togetherai-image-settings.ts +18 -0
- package/src/togetherai-provider.test.ts +196 -0
- package/src/togetherai-provider.ts +180 -0
- package/src/version.ts +6 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
import { FetchFunction } from '@ai-sdk/provider-utils';
|
|
2
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
3
|
+
import { describe, expect, it } from 'vitest';
|
|
4
|
+
import { TogetherAIImageModel } from './togetherai-image-model';
|
|
5
|
+
|
|
6
|
+
const prompt = 'A cute baby sea otter';
|
|
7
|
+
|
|
8
|
+
function createBasicModel({
|
|
9
|
+
headers,
|
|
10
|
+
fetch,
|
|
11
|
+
currentDate,
|
|
12
|
+
}: {
|
|
13
|
+
headers?: () => Record<string, string>;
|
|
14
|
+
fetch?: FetchFunction;
|
|
15
|
+
currentDate?: () => Date;
|
|
16
|
+
} = {}) {
|
|
17
|
+
return new TogetherAIImageModel('stabilityai/stable-diffusion-xl', {
|
|
18
|
+
provider: 'togetherai',
|
|
19
|
+
baseURL: 'https://api.example.com',
|
|
20
|
+
headers: headers ?? (() => ({ 'api-key': 'test-key' })),
|
|
21
|
+
fetch,
|
|
22
|
+
_internal: {
|
|
23
|
+
currentDate,
|
|
24
|
+
},
|
|
25
|
+
});
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
const server = createTestServer({
|
|
29
|
+
'https://api.example.com/*': {
|
|
30
|
+
response: {
|
|
31
|
+
type: 'json-value',
|
|
32
|
+
body: {
|
|
33
|
+
id: 'test-id',
|
|
34
|
+
data: [{ index: 0, b64_json: 'test-base64-content' }],
|
|
35
|
+
model: 'stabilityai/stable-diffusion-xl',
|
|
36
|
+
object: 'list',
|
|
37
|
+
},
|
|
38
|
+
},
|
|
39
|
+
},
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
describe('doGenerate', () => {
|
|
43
|
+
it('should pass the correct parameters including size and seed', async () => {
|
|
44
|
+
const model = createBasicModel();
|
|
45
|
+
|
|
46
|
+
await model.doGenerate({
|
|
47
|
+
prompt,
|
|
48
|
+
files: undefined,
|
|
49
|
+
mask: undefined,
|
|
50
|
+
n: 1,
|
|
51
|
+
size: '1024x1024',
|
|
52
|
+
seed: 42,
|
|
53
|
+
providerOptions: { togetherai: { additional_param: 'value' } },
|
|
54
|
+
aspectRatio: undefined,
|
|
55
|
+
});
|
|
56
|
+
|
|
57
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
58
|
+
model: 'stabilityai/stable-diffusion-xl',
|
|
59
|
+
prompt,
|
|
60
|
+
seed: 42,
|
|
61
|
+
width: 1024,
|
|
62
|
+
height: 1024,
|
|
63
|
+
response_format: 'base64',
|
|
64
|
+
additional_param: 'value',
|
|
65
|
+
});
|
|
66
|
+
});
|
|
67
|
+
|
|
68
|
+
it('should include n parameter when requesting multiple images', async () => {
|
|
69
|
+
const model = createBasicModel();
|
|
70
|
+
|
|
71
|
+
await model.doGenerate({
|
|
72
|
+
prompt,
|
|
73
|
+
files: undefined,
|
|
74
|
+
mask: undefined,
|
|
75
|
+
n: 3,
|
|
76
|
+
size: '1024x1024',
|
|
77
|
+
seed: 42,
|
|
78
|
+
providerOptions: {},
|
|
79
|
+
aspectRatio: undefined,
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
83
|
+
model: 'stabilityai/stable-diffusion-xl',
|
|
84
|
+
prompt,
|
|
85
|
+
seed: 42,
|
|
86
|
+
n: 3,
|
|
87
|
+
width: 1024,
|
|
88
|
+
height: 1024,
|
|
89
|
+
response_format: 'base64',
|
|
90
|
+
});
|
|
91
|
+
});
|
|
92
|
+
|
|
93
|
+
it('should call the correct url', async () => {
|
|
94
|
+
const model = createBasicModel();
|
|
95
|
+
|
|
96
|
+
await model.doGenerate({
|
|
97
|
+
prompt,
|
|
98
|
+
files: undefined,
|
|
99
|
+
mask: undefined,
|
|
100
|
+
n: 1,
|
|
101
|
+
size: '1024x1024',
|
|
102
|
+
seed: 42,
|
|
103
|
+
providerOptions: {},
|
|
104
|
+
aspectRatio: undefined,
|
|
105
|
+
});
|
|
106
|
+
|
|
107
|
+
expect(server.calls[0].requestMethod).toStrictEqual('POST');
|
|
108
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
109
|
+
'https://api.example.com/images/generations',
|
|
110
|
+
);
|
|
111
|
+
});
|
|
112
|
+
|
|
113
|
+
it('should pass headers', async () => {
|
|
114
|
+
const modelWithHeaders = createBasicModel({
|
|
115
|
+
headers: () => ({
|
|
116
|
+
'Custom-Provider-Header': 'provider-header-value',
|
|
117
|
+
}),
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
await modelWithHeaders.doGenerate({
|
|
121
|
+
prompt,
|
|
122
|
+
files: undefined,
|
|
123
|
+
mask: undefined,
|
|
124
|
+
n: 1,
|
|
125
|
+
size: undefined,
|
|
126
|
+
seed: undefined,
|
|
127
|
+
providerOptions: {},
|
|
128
|
+
aspectRatio: undefined,
|
|
129
|
+
headers: {
|
|
130
|
+
'Custom-Request-Header': 'request-header-value',
|
|
131
|
+
},
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
expect(server.calls[0].requestHeaders).toStrictEqual({
|
|
135
|
+
'content-type': 'application/json',
|
|
136
|
+
'custom-provider-header': 'provider-header-value',
|
|
137
|
+
'custom-request-header': 'request-header-value',
|
|
138
|
+
});
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
it('should handle API errors', async () => {
|
|
142
|
+
server.urls['https://api.example.com/*'].response = {
|
|
143
|
+
type: 'error',
|
|
144
|
+
status: 400,
|
|
145
|
+
body: JSON.stringify({
|
|
146
|
+
error: {
|
|
147
|
+
message: 'Bad Request',
|
|
148
|
+
},
|
|
149
|
+
}),
|
|
150
|
+
};
|
|
151
|
+
|
|
152
|
+
const model = createBasicModel();
|
|
153
|
+
await expect(
|
|
154
|
+
model.doGenerate({
|
|
155
|
+
prompt,
|
|
156
|
+
files: undefined,
|
|
157
|
+
mask: undefined,
|
|
158
|
+
n: 1,
|
|
159
|
+
size: undefined,
|
|
160
|
+
seed: undefined,
|
|
161
|
+
providerOptions: {},
|
|
162
|
+
aspectRatio: undefined,
|
|
163
|
+
}),
|
|
164
|
+
).rejects.toMatchObject({
|
|
165
|
+
message: 'Bad Request',
|
|
166
|
+
});
|
|
167
|
+
});
|
|
168
|
+
|
|
169
|
+
describe('warnings', () => {
|
|
170
|
+
it('should return aspectRatio warning when aspectRatio is provided', async () => {
|
|
171
|
+
const model = createBasicModel();
|
|
172
|
+
|
|
173
|
+
const result = await model.doGenerate({
|
|
174
|
+
prompt,
|
|
175
|
+
files: undefined,
|
|
176
|
+
mask: undefined,
|
|
177
|
+
n: 1,
|
|
178
|
+
size: '1024x1024',
|
|
179
|
+
aspectRatio: '1:1',
|
|
180
|
+
seed: 123,
|
|
181
|
+
providerOptions: {},
|
|
182
|
+
});
|
|
183
|
+
|
|
184
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
185
|
+
[
|
|
186
|
+
{
|
|
187
|
+
"details": "This model does not support the \`aspectRatio\` option. Use \`size\` instead.",
|
|
188
|
+
"feature": "aspectRatio",
|
|
189
|
+
"type": "unsupported",
|
|
190
|
+
},
|
|
191
|
+
]
|
|
192
|
+
`);
|
|
193
|
+
});
|
|
194
|
+
});
|
|
195
|
+
|
|
196
|
+
it('should respect the abort signal', async () => {
|
|
197
|
+
const model = createBasicModel();
|
|
198
|
+
const controller = new AbortController();
|
|
199
|
+
|
|
200
|
+
const generatePromise = model.doGenerate({
|
|
201
|
+
prompt,
|
|
202
|
+
files: undefined,
|
|
203
|
+
mask: undefined,
|
|
204
|
+
n: 1,
|
|
205
|
+
size: undefined,
|
|
206
|
+
seed: undefined,
|
|
207
|
+
providerOptions: {},
|
|
208
|
+
aspectRatio: undefined,
|
|
209
|
+
abortSignal: controller.signal,
|
|
210
|
+
});
|
|
211
|
+
|
|
212
|
+
controller.abort();
|
|
213
|
+
|
|
214
|
+
await expect(generatePromise).rejects.toThrow('This operation was aborted');
|
|
215
|
+
});
|
|
216
|
+
|
|
217
|
+
describe('response metadata', () => {
|
|
218
|
+
it('should include timestamp, headers and modelId in response', async () => {
|
|
219
|
+
const testDate = new Date('2024-01-01T00:00:00Z');
|
|
220
|
+
const model = createBasicModel({
|
|
221
|
+
currentDate: () => testDate,
|
|
222
|
+
});
|
|
223
|
+
|
|
224
|
+
const result = await model.doGenerate({
|
|
225
|
+
prompt,
|
|
226
|
+
files: undefined,
|
|
227
|
+
mask: undefined,
|
|
228
|
+
n: 1,
|
|
229
|
+
size: undefined,
|
|
230
|
+
seed: undefined,
|
|
231
|
+
providerOptions: {},
|
|
232
|
+
aspectRatio: undefined,
|
|
233
|
+
});
|
|
234
|
+
|
|
235
|
+
expect(result.response).toStrictEqual({
|
|
236
|
+
timestamp: testDate,
|
|
237
|
+
modelId: 'stabilityai/stable-diffusion-xl',
|
|
238
|
+
headers: expect.any(Object),
|
|
239
|
+
});
|
|
240
|
+
});
|
|
241
|
+
|
|
242
|
+
it('should include response headers from API call', async () => {
|
|
243
|
+
server.urls['https://api.example.com/*'].response = {
|
|
244
|
+
type: 'json-value',
|
|
245
|
+
body: {
|
|
246
|
+
id: 'test-id',
|
|
247
|
+
data: [{ index: 0, b64_json: 'test-base64-content' }],
|
|
248
|
+
model: 'stabilityai/stable-diffusion-xl',
|
|
249
|
+
object: 'list',
|
|
250
|
+
},
|
|
251
|
+
headers: {
|
|
252
|
+
'x-request-id': 'test-request-id',
|
|
253
|
+
'content-length': '128',
|
|
254
|
+
},
|
|
255
|
+
};
|
|
256
|
+
|
|
257
|
+
const model = createBasicModel();
|
|
258
|
+
const result = await model.doGenerate({
|
|
259
|
+
prompt,
|
|
260
|
+
files: undefined,
|
|
261
|
+
mask: undefined,
|
|
262
|
+
n: 1,
|
|
263
|
+
size: undefined,
|
|
264
|
+
seed: undefined,
|
|
265
|
+
providerOptions: {},
|
|
266
|
+
aspectRatio: undefined,
|
|
267
|
+
});
|
|
268
|
+
|
|
269
|
+
expect(result.response.headers).toStrictEqual({
|
|
270
|
+
'x-request-id': 'test-request-id',
|
|
271
|
+
'content-type': 'application/json',
|
|
272
|
+
'content-length': '128',
|
|
273
|
+
});
|
|
274
|
+
});
|
|
275
|
+
});
|
|
276
|
+
});
|
|
277
|
+
|
|
278
|
+
describe('constructor', () => {
|
|
279
|
+
it('should expose correct provider and model information', () => {
|
|
280
|
+
const model = createBasicModel();
|
|
281
|
+
|
|
282
|
+
expect(model.provider).toBe('togetherai');
|
|
283
|
+
expect(model.modelId).toBe('stabilityai/stable-diffusion-xl');
|
|
284
|
+
expect(model.specificationVersion).toBe('v3');
|
|
285
|
+
expect(model.maxImagesPerCall).toBe(1);
|
|
286
|
+
});
|
|
287
|
+
});
|
|
288
|
+
|
|
289
|
+
describe('Image Editing', () => {
|
|
290
|
+
const server = createTestServer({
|
|
291
|
+
'https://api.example.com/*': {
|
|
292
|
+
response: {
|
|
293
|
+
type: 'json-value',
|
|
294
|
+
body: {
|
|
295
|
+
id: 'test-id',
|
|
296
|
+
data: [{ index: 0, b64_json: 'test-base64-content' }],
|
|
297
|
+
model: 'black-forest-labs/FLUX.1-kontext-pro',
|
|
298
|
+
object: 'list',
|
|
299
|
+
},
|
|
300
|
+
},
|
|
301
|
+
},
|
|
302
|
+
});
|
|
303
|
+
|
|
304
|
+
it('should send image_url when URL file is provided', async () => {
|
|
305
|
+
const model = createBasicModel();
|
|
306
|
+
|
|
307
|
+
await model.doGenerate({
|
|
308
|
+
prompt: 'Make the shirt yellow',
|
|
309
|
+
files: [
|
|
310
|
+
{
|
|
311
|
+
type: 'url',
|
|
312
|
+
url: 'https://example.com/input.jpg',
|
|
313
|
+
},
|
|
314
|
+
],
|
|
315
|
+
mask: undefined,
|
|
316
|
+
n: 1,
|
|
317
|
+
size: undefined,
|
|
318
|
+
aspectRatio: undefined,
|
|
319
|
+
seed: undefined,
|
|
320
|
+
providerOptions: {},
|
|
321
|
+
});
|
|
322
|
+
|
|
323
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
324
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
325
|
+
{
|
|
326
|
+
"image_url": "https://example.com/input.jpg",
|
|
327
|
+
"model": "stabilityai/stable-diffusion-xl",
|
|
328
|
+
"prompt": "Make the shirt yellow",
|
|
329
|
+
"response_format": "base64",
|
|
330
|
+
}
|
|
331
|
+
`);
|
|
332
|
+
});
|
|
333
|
+
|
|
334
|
+
it('should convert Uint8Array file to data URI', async () => {
|
|
335
|
+
const model = createBasicModel();
|
|
336
|
+
const testImageData = new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10]);
|
|
337
|
+
|
|
338
|
+
await model.doGenerate({
|
|
339
|
+
prompt: 'Transform this image',
|
|
340
|
+
files: [
|
|
341
|
+
{
|
|
342
|
+
type: 'file',
|
|
343
|
+
data: testImageData,
|
|
344
|
+
mediaType: 'image/png',
|
|
345
|
+
},
|
|
346
|
+
],
|
|
347
|
+
mask: undefined,
|
|
348
|
+
n: 1,
|
|
349
|
+
size: undefined,
|
|
350
|
+
aspectRatio: undefined,
|
|
351
|
+
seed: undefined,
|
|
352
|
+
providerOptions: {},
|
|
353
|
+
});
|
|
354
|
+
|
|
355
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
356
|
+
expect(requestBody.image_url).toMatch(/^data:image\/png;base64,/);
|
|
357
|
+
expect(requestBody.prompt).toBe('Transform this image');
|
|
358
|
+
});
|
|
359
|
+
|
|
360
|
+
it('should convert file with base64 string data to data URI', async () => {
|
|
361
|
+
const model = createBasicModel();
|
|
362
|
+
|
|
363
|
+
await model.doGenerate({
|
|
364
|
+
prompt: 'Edit this',
|
|
365
|
+
files: [
|
|
366
|
+
{
|
|
367
|
+
type: 'file',
|
|
368
|
+
data: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
|
369
|
+
mediaType: 'image/png',
|
|
370
|
+
},
|
|
371
|
+
],
|
|
372
|
+
mask: undefined,
|
|
373
|
+
n: 1,
|
|
374
|
+
size: undefined,
|
|
375
|
+
aspectRatio: undefined,
|
|
376
|
+
seed: undefined,
|
|
377
|
+
providerOptions: {},
|
|
378
|
+
});
|
|
379
|
+
|
|
380
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
381
|
+
expect(requestBody.image_url).toBe(
|
|
382
|
+
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
|
383
|
+
);
|
|
384
|
+
});
|
|
385
|
+
|
|
386
|
+
it('should throw error when mask is provided', async () => {
|
|
387
|
+
const model = createBasicModel();
|
|
388
|
+
|
|
389
|
+
await expect(
|
|
390
|
+
model.doGenerate({
|
|
391
|
+
prompt: 'Inpaint this area',
|
|
392
|
+
files: [
|
|
393
|
+
{
|
|
394
|
+
type: 'url',
|
|
395
|
+
url: 'https://example.com/input.jpg',
|
|
396
|
+
},
|
|
397
|
+
],
|
|
398
|
+
mask: {
|
|
399
|
+
type: 'url',
|
|
400
|
+
url: 'https://example.com/mask.png',
|
|
401
|
+
},
|
|
402
|
+
n: 1,
|
|
403
|
+
size: undefined,
|
|
404
|
+
aspectRatio: undefined,
|
|
405
|
+
seed: undefined,
|
|
406
|
+
providerOptions: {},
|
|
407
|
+
}),
|
|
408
|
+
).rejects.toThrow(
|
|
409
|
+
'Together AI does not support mask-based image editing. ' +
|
|
410
|
+
'Use FLUX Kontext models (e.g., black-forest-labs/FLUX.1-kontext-pro) ' +
|
|
411
|
+
'with a reference image and descriptive prompt instead.',
|
|
412
|
+
);
|
|
413
|
+
});
|
|
414
|
+
|
|
415
|
+
it('should warn when multiple files are provided', async () => {
|
|
416
|
+
const model = createBasicModel();
|
|
417
|
+
|
|
418
|
+
const result = await model.doGenerate({
|
|
419
|
+
prompt: 'Edit multiple images',
|
|
420
|
+
files: [
|
|
421
|
+
{
|
|
422
|
+
type: 'url',
|
|
423
|
+
url: 'https://example.com/input1.jpg',
|
|
424
|
+
},
|
|
425
|
+
{
|
|
426
|
+
type: 'url',
|
|
427
|
+
url: 'https://example.com/input2.jpg',
|
|
428
|
+
},
|
|
429
|
+
],
|
|
430
|
+
mask: undefined,
|
|
431
|
+
n: 1,
|
|
432
|
+
size: undefined,
|
|
433
|
+
aspectRatio: undefined,
|
|
434
|
+
seed: undefined,
|
|
435
|
+
providerOptions: {},
|
|
436
|
+
});
|
|
437
|
+
|
|
438
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
439
|
+
[
|
|
440
|
+
{
|
|
441
|
+
"message": "Together AI only supports a single input image. Additional images are ignored.",
|
|
442
|
+
"type": "other",
|
|
443
|
+
},
|
|
444
|
+
]
|
|
445
|
+
`);
|
|
446
|
+
|
|
447
|
+
// Should only use the first image
|
|
448
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
449
|
+
expect(requestBody.image_url).toBe('https://example.com/input1.jpg');
|
|
450
|
+
});
|
|
451
|
+
|
|
452
|
+
it('should pass provider options with image editing', async () => {
|
|
453
|
+
const model = createBasicModel();
|
|
454
|
+
|
|
455
|
+
await model.doGenerate({
|
|
456
|
+
prompt: 'Transform the style',
|
|
457
|
+
files: [
|
|
458
|
+
{
|
|
459
|
+
type: 'url',
|
|
460
|
+
url: 'https://example.com/input.jpg',
|
|
461
|
+
},
|
|
462
|
+
],
|
|
463
|
+
mask: undefined,
|
|
464
|
+
n: 1,
|
|
465
|
+
size: undefined,
|
|
466
|
+
aspectRatio: undefined,
|
|
467
|
+
seed: undefined,
|
|
468
|
+
providerOptions: {
|
|
469
|
+
togetherai: {
|
|
470
|
+
steps: 28,
|
|
471
|
+
guidance: 3.5,
|
|
472
|
+
},
|
|
473
|
+
},
|
|
474
|
+
});
|
|
475
|
+
|
|
476
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
477
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
478
|
+
{
|
|
479
|
+
"guidance": 3.5,
|
|
480
|
+
"image_url": "https://example.com/input.jpg",
|
|
481
|
+
"model": "stabilityai/stable-diffusion-xl",
|
|
482
|
+
"prompt": "Transform the style",
|
|
483
|
+
"response_format": "base64",
|
|
484
|
+
"steps": 28,
|
|
485
|
+
}
|
|
486
|
+
`);
|
|
487
|
+
});
|
|
488
|
+
});
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import { ImageModelV3, SharedV3Warning } from '@ai-sdk/provider';
|
|
2
|
+
import {
|
|
3
|
+
combineHeaders,
|
|
4
|
+
convertImageModelFileToDataUri,
|
|
5
|
+
createJsonResponseHandler,
|
|
6
|
+
createJsonErrorResponseHandler,
|
|
7
|
+
FetchFunction,
|
|
8
|
+
InferSchema,
|
|
9
|
+
lazySchema,
|
|
10
|
+
parseProviderOptions,
|
|
11
|
+
postJsonToApi,
|
|
12
|
+
zodSchema,
|
|
13
|
+
} from '@ai-sdk/provider-utils';
|
|
14
|
+
import { TogetherAIImageModelId } from './togetherai-image-settings';
|
|
15
|
+
import { z } from 'zod/v4';
|
|
16
|
+
|
|
17
|
+
interface TogetherAIImageModelConfig {
|
|
18
|
+
provider: string;
|
|
19
|
+
baseURL: string;
|
|
20
|
+
headers: () => Record<string, string>;
|
|
21
|
+
fetch?: FetchFunction;
|
|
22
|
+
_internal?: {
|
|
23
|
+
currentDate?: () => Date;
|
|
24
|
+
};
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
export class TogetherAIImageModel implements ImageModelV3 {
|
|
28
|
+
readonly specificationVersion = 'v3';
|
|
29
|
+
readonly maxImagesPerCall = 1;
|
|
30
|
+
|
|
31
|
+
get provider(): string {
|
|
32
|
+
return this.config.provider;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
constructor(
|
|
36
|
+
readonly modelId: TogetherAIImageModelId,
|
|
37
|
+
private config: TogetherAIImageModelConfig,
|
|
38
|
+
) {}
|
|
39
|
+
|
|
40
|
+
async doGenerate({
|
|
41
|
+
prompt,
|
|
42
|
+
n,
|
|
43
|
+
size,
|
|
44
|
+
seed,
|
|
45
|
+
providerOptions,
|
|
46
|
+
headers,
|
|
47
|
+
abortSignal,
|
|
48
|
+
files,
|
|
49
|
+
mask,
|
|
50
|
+
}: Parameters<ImageModelV3['doGenerate']>[0]): Promise<
|
|
51
|
+
Awaited<ReturnType<ImageModelV3['doGenerate']>>
|
|
52
|
+
> {
|
|
53
|
+
const warnings: Array<SharedV3Warning> = [];
|
|
54
|
+
|
|
55
|
+
if (mask != null) {
|
|
56
|
+
throw new Error(
|
|
57
|
+
'Together AI does not support mask-based image editing. ' +
|
|
58
|
+
'Use FLUX Kontext models (e.g., black-forest-labs/FLUX.1-kontext-pro) ' +
|
|
59
|
+
'with a reference image and descriptive prompt instead.',
|
|
60
|
+
);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
if (size != null) {
|
|
64
|
+
warnings.push({
|
|
65
|
+
type: 'unsupported',
|
|
66
|
+
feature: 'aspectRatio',
|
|
67
|
+
details:
|
|
68
|
+
'This model does not support the `aspectRatio` option. Use `size` instead.',
|
|
69
|
+
});
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
const currentDate = this.config._internal?.currentDate?.() ?? new Date();
|
|
73
|
+
|
|
74
|
+
const togetheraiOptions = await parseProviderOptions({
|
|
75
|
+
provider: 'togetherai',
|
|
76
|
+
providerOptions,
|
|
77
|
+
schema: togetheraiImageProviderOptionsSchema,
|
|
78
|
+
});
|
|
79
|
+
|
|
80
|
+
// Handle image input from files
|
|
81
|
+
let imageUrl: string | undefined;
|
|
82
|
+
if (files != null && files.length > 0) {
|
|
83
|
+
imageUrl = convertImageModelFileToDataUri(files[0]);
|
|
84
|
+
|
|
85
|
+
if (files.length > 1) {
|
|
86
|
+
warnings.push({
|
|
87
|
+
type: 'other',
|
|
88
|
+
message:
|
|
89
|
+
'Together AI only supports a single input image. Additional images are ignored.',
|
|
90
|
+
});
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
const splitSize = size?.split('x');
|
|
95
|
+
// https://docs.together.ai/reference/post_images-generations
|
|
96
|
+
const { value: response, responseHeaders } = await postJsonToApi({
|
|
97
|
+
url: `${this.config.baseURL}/images/generations`,
|
|
98
|
+
headers: combineHeaders(this.config.headers(), headers),
|
|
99
|
+
body: {
|
|
100
|
+
model: this.modelId,
|
|
101
|
+
prompt,
|
|
102
|
+
seed,
|
|
103
|
+
...(n > 1 ? { n } : {}),
|
|
104
|
+
...(splitSize && {
|
|
105
|
+
width: parseInt(splitSize[0]),
|
|
106
|
+
height: parseInt(splitSize[1]),
|
|
107
|
+
}),
|
|
108
|
+
...(imageUrl != null ? { image_url: imageUrl } : {}),
|
|
109
|
+
response_format: 'base64',
|
|
110
|
+
...(togetheraiOptions ?? {}),
|
|
111
|
+
},
|
|
112
|
+
failedResponseHandler: createJsonErrorResponseHandler({
|
|
113
|
+
errorSchema: togetheraiErrorSchema,
|
|
114
|
+
errorToMessage: data => data.error.message,
|
|
115
|
+
}),
|
|
116
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
117
|
+
togetheraiImageResponseSchema,
|
|
118
|
+
),
|
|
119
|
+
abortSignal,
|
|
120
|
+
fetch: this.config.fetch,
|
|
121
|
+
});
|
|
122
|
+
|
|
123
|
+
return {
|
|
124
|
+
images: response.data.map(item => item.b64_json),
|
|
125
|
+
warnings,
|
|
126
|
+
response: {
|
|
127
|
+
timestamp: currentDate,
|
|
128
|
+
modelId: this.modelId,
|
|
129
|
+
headers: responseHeaders,
|
|
130
|
+
},
|
|
131
|
+
};
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
// limited version of the schema, focussed on what is needed for the implementation
|
|
136
|
+
// this approach limits breakages when the API changes and increases efficiency
|
|
137
|
+
const togetheraiImageResponseSchema = z.object({
|
|
138
|
+
data: z.array(
|
|
139
|
+
z.object({
|
|
140
|
+
b64_json: z.string(),
|
|
141
|
+
}),
|
|
142
|
+
),
|
|
143
|
+
});
|
|
144
|
+
|
|
145
|
+
// limited version of the schema, focussed on what is needed for the implementation
|
|
146
|
+
// this approach limits breakages when the API changes and increases efficiency
|
|
147
|
+
const togetheraiErrorSchema = z.object({
|
|
148
|
+
error: z.object({
|
|
149
|
+
message: z.string(),
|
|
150
|
+
}),
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
/**
|
|
154
|
+
* Provider options schema for Together AI image generation.
|
|
155
|
+
*/
|
|
156
|
+
export const togetheraiImageProviderOptionsSchema = lazySchema(() =>
|
|
157
|
+
zodSchema(
|
|
158
|
+
z
|
|
159
|
+
.object({
|
|
160
|
+
/**
|
|
161
|
+
* Number of generation steps. Higher values can improve quality.
|
|
162
|
+
*/
|
|
163
|
+
steps: z.number().nullish(),
|
|
164
|
+
|
|
165
|
+
/**
|
|
166
|
+
* Guidance scale for image generation.
|
|
167
|
+
*/
|
|
168
|
+
guidance: z.number().nullish(),
|
|
169
|
+
|
|
170
|
+
/**
|
|
171
|
+
* Negative prompt to guide what to avoid.
|
|
172
|
+
*/
|
|
173
|
+
negative_prompt: z.string().nullish(),
|
|
174
|
+
|
|
175
|
+
/**
|
|
176
|
+
* Disable the safety checker for image generation.
|
|
177
|
+
* When true, the API will not reject images flagged as potentially NSFW.
|
|
178
|
+
* Not available for Flux Schnell Free and Flux Pro models.
|
|
179
|
+
*/
|
|
180
|
+
disable_safety_checker: z.boolean().nullish(),
|
|
181
|
+
})
|
|
182
|
+
.passthrough(),
|
|
183
|
+
),
|
|
184
|
+
);
|
|
185
|
+
|
|
186
|
+
export type TogetherAIImageProviderOptions = InferSchema<
|
|
187
|
+
typeof togetheraiImageProviderOptionsSchema
|
|
188
|
+
>;
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
// https://api.together.ai/models
|
|
2
|
+
export type TogetherAIImageModelId =
|
|
3
|
+
// Text-to-image models
|
|
4
|
+
| 'stabilityai/stable-diffusion-xl-base-1.0'
|
|
5
|
+
| 'black-forest-labs/FLUX.1-dev'
|
|
6
|
+
| 'black-forest-labs/FLUX.1-dev-lora'
|
|
7
|
+
| 'black-forest-labs/FLUX.1-schnell'
|
|
8
|
+
| 'black-forest-labs/FLUX.1-canny'
|
|
9
|
+
| 'black-forest-labs/FLUX.1-depth'
|
|
10
|
+
| 'black-forest-labs/FLUX.1-redux'
|
|
11
|
+
| 'black-forest-labs/FLUX.1.1-pro'
|
|
12
|
+
| 'black-forest-labs/FLUX.1-pro'
|
|
13
|
+
| 'black-forest-labs/FLUX.1-schnell-Free'
|
|
14
|
+
// FLUX Kontext models for image editing
|
|
15
|
+
| 'black-forest-labs/FLUX.1-kontext-pro'
|
|
16
|
+
| 'black-forest-labs/FLUX.1-kontext-max'
|
|
17
|
+
| 'black-forest-labs/FLUX.1-kontext-dev'
|
|
18
|
+
| (string & {});
|