@ai-sdk/fireworks 2.0.16 → 2.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +8 -0
- package/dist/index.js +1 -1
- package/dist/index.mjs +1 -1
- package/package.json +4 -3
- package/src/fireworks-chat-options.ts +20 -0
- package/src/fireworks-completion-options.ts +5 -0
- package/src/fireworks-embedding-options.ts +12 -0
- package/src/fireworks-image-model.test.ts +629 -0
- package/src/fireworks-image-model.ts +196 -0
- package/src/fireworks-image-options.ts +12 -0
- package/src/fireworks-provider.test.ts +198 -0
- package/src/fireworks-provider.ts +172 -0
- package/src/index.ts +13 -0
- package/src/version.ts +6 -0
package/CHANGELOG.md
CHANGED
package/dist/index.js
CHANGED
|
@@ -174,7 +174,7 @@ var import_provider_utils2 = require("@ai-sdk/provider-utils");
|
|
|
174
174
|
var import_v4 = require("zod/v4");
|
|
175
175
|
|
|
176
176
|
// src/version.ts
|
|
177
|
-
var VERSION = true ? "2.0.
|
|
177
|
+
var VERSION = true ? "2.0.17" : "0.0.0-test";
|
|
178
178
|
|
|
179
179
|
// src/fireworks-provider.ts
|
|
180
180
|
var fireworksErrorSchema = import_v4.z.object({
|
package/dist/index.mjs
CHANGED
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@ai-sdk/fireworks",
|
|
3
|
-
"version": "2.0.
|
|
3
|
+
"version": "2.0.17",
|
|
4
4
|
"license": "Apache-2.0",
|
|
5
5
|
"sideEffects": false,
|
|
6
6
|
"main": "./dist/index.js",
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
"types": "./dist/index.d.ts",
|
|
9
9
|
"files": [
|
|
10
10
|
"dist/**/*",
|
|
11
|
+
"src",
|
|
11
12
|
"CHANGELOG.md",
|
|
12
13
|
"README.md"
|
|
13
14
|
],
|
|
@@ -20,7 +21,7 @@
|
|
|
20
21
|
}
|
|
21
22
|
},
|
|
22
23
|
"dependencies": {
|
|
23
|
-
"@ai-sdk/openai-compatible": "2.0.
|
|
24
|
+
"@ai-sdk/openai-compatible": "2.0.17",
|
|
24
25
|
"@ai-sdk/provider": "3.0.4",
|
|
25
26
|
"@ai-sdk/provider-utils": "4.0.8"
|
|
26
27
|
},
|
|
@@ -29,7 +30,7 @@
|
|
|
29
30
|
"tsup": "^8",
|
|
30
31
|
"typescript": "5.8.3",
|
|
31
32
|
"zod": "3.25.76",
|
|
32
|
-
"@ai-sdk/test-server": "1.0.
|
|
33
|
+
"@ai-sdk/test-server": "1.0.2",
|
|
33
34
|
"@vercel/ai-tsconfig": "0.0.0"
|
|
34
35
|
},
|
|
35
36
|
"peerDependencies": {
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
// https://docs.fireworks.ai/docs/serverless-models#chat-models
|
|
2
|
+
// Below is just a subset of the available models.
|
|
3
|
+
export type FireworksChatModelId =
|
|
4
|
+
| 'accounts/fireworks/models/deepseek-v3'
|
|
5
|
+
| 'accounts/fireworks/models/llama-v3p3-70b-instruct'
|
|
6
|
+
| 'accounts/fireworks/models/llama-v3p2-3b-instruct'
|
|
7
|
+
| 'accounts/fireworks/models/llama-v3p1-405b-instruct'
|
|
8
|
+
| 'accounts/fireworks/models/llama-v3p1-8b-instruct'
|
|
9
|
+
| 'accounts/fireworks/models/mixtral-8x7b-instruct'
|
|
10
|
+
| 'accounts/fireworks/models/mixtral-8x22b-instruct'
|
|
11
|
+
| 'accounts/fireworks/models/mixtral-8x7b-instruct-hf'
|
|
12
|
+
| 'accounts/fireworks/models/qwen2p5-coder-32b-instruct'
|
|
13
|
+
| 'accounts/fireworks/models/qwen2p5-72b-instruct'
|
|
14
|
+
| 'accounts/fireworks/models/qwen-qwq-32b-preview'
|
|
15
|
+
| 'accounts/fireworks/models/qwen2-vl-72b-instruct'
|
|
16
|
+
| 'accounts/fireworks/models/llama-v3p2-11b-vision-instruct'
|
|
17
|
+
| 'accounts/fireworks/models/qwq-32b'
|
|
18
|
+
| 'accounts/fireworks/models/yi-large'
|
|
19
|
+
| 'accounts/fireworks/models/kimi-k2-instruct'
|
|
20
|
+
| (string & {});
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import { z } from 'zod/v4';
|
|
2
|
+
|
|
3
|
+
// Below is just a subset of the available models.
|
|
4
|
+
export type FireworksEmbeddingModelId =
|
|
5
|
+
| 'nomic-ai/nomic-embed-text-v1.5'
|
|
6
|
+
| (string & {});
|
|
7
|
+
|
|
8
|
+
export const fireworksEmbeddingProviderOptions = z.object({});
|
|
9
|
+
|
|
10
|
+
export type FireworksEmbeddingProviderOptions = z.infer<
|
|
11
|
+
typeof fireworksEmbeddingProviderOptions
|
|
12
|
+
>;
|
|
@@ -0,0 +1,629 @@
|
|
|
1
|
+
import { FetchFunction } from '@ai-sdk/provider-utils';
|
|
2
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
3
|
+
import { describe, expect, it, vi } from 'vitest';
|
|
4
|
+
import { FireworksImageModel } from './fireworks-image-model';
|
|
5
|
+
|
|
6
|
+
const prompt = 'A cute baby sea otter';
|
|
7
|
+
|
|
8
|
+
function createBasicModel({
|
|
9
|
+
headers,
|
|
10
|
+
fetch,
|
|
11
|
+
currentDate,
|
|
12
|
+
}: {
|
|
13
|
+
headers?: () => Record<string, string>;
|
|
14
|
+
fetch?: FetchFunction;
|
|
15
|
+
currentDate?: () => Date;
|
|
16
|
+
} = {}) {
|
|
17
|
+
return new FireworksImageModel('accounts/fireworks/models/flux-1-dev-fp8', {
|
|
18
|
+
provider: 'fireworks',
|
|
19
|
+
baseURL: 'https://api.example.com',
|
|
20
|
+
headers: headers ?? (() => ({ 'api-key': 'test-key' })),
|
|
21
|
+
fetch,
|
|
22
|
+
_internal: {
|
|
23
|
+
currentDate,
|
|
24
|
+
},
|
|
25
|
+
});
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
function createSizeModel() {
|
|
29
|
+
return new FireworksImageModel(
|
|
30
|
+
'accounts/fireworks/models/playground-v2-5-1024px-aesthetic',
|
|
31
|
+
{
|
|
32
|
+
provider: 'fireworks',
|
|
33
|
+
baseURL: 'https://api.size-example.com',
|
|
34
|
+
headers: () => ({ 'api-key': 'test-key' }),
|
|
35
|
+
},
|
|
36
|
+
);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
describe('FireworksImageModel', () => {
|
|
40
|
+
const server = createTestServer({
|
|
41
|
+
'https://api.example.com/*': {
|
|
42
|
+
response: {
|
|
43
|
+
type: 'binary',
|
|
44
|
+
body: Buffer.from('test-binary-content'),
|
|
45
|
+
},
|
|
46
|
+
},
|
|
47
|
+
'https://api.size-example.com/*': {
|
|
48
|
+
response: {
|
|
49
|
+
type: 'binary',
|
|
50
|
+
body: Buffer.from('test-binary-content'),
|
|
51
|
+
},
|
|
52
|
+
},
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
describe('doGenerate', () => {
|
|
56
|
+
it('should pass the correct parameters including aspect ratio and seed', async () => {
|
|
57
|
+
const model = createBasicModel();
|
|
58
|
+
|
|
59
|
+
await model.doGenerate({
|
|
60
|
+
prompt,
|
|
61
|
+
files: undefined,
|
|
62
|
+
mask: undefined,
|
|
63
|
+
n: 1,
|
|
64
|
+
size: undefined,
|
|
65
|
+
aspectRatio: '16:9',
|
|
66
|
+
seed: 42,
|
|
67
|
+
providerOptions: { fireworks: { additional_param: 'value' } },
|
|
68
|
+
});
|
|
69
|
+
|
|
70
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
71
|
+
prompt,
|
|
72
|
+
aspect_ratio: '16:9',
|
|
73
|
+
seed: 42,
|
|
74
|
+
samples: 1,
|
|
75
|
+
additional_param: 'value',
|
|
76
|
+
});
|
|
77
|
+
});
|
|
78
|
+
|
|
79
|
+
it('should call the correct url', async () => {
|
|
80
|
+
const model = createBasicModel();
|
|
81
|
+
|
|
82
|
+
await model.doGenerate({
|
|
83
|
+
prompt,
|
|
84
|
+
files: undefined,
|
|
85
|
+
mask: undefined,
|
|
86
|
+
n: 1,
|
|
87
|
+
size: undefined,
|
|
88
|
+
aspectRatio: '16:9',
|
|
89
|
+
seed: 42,
|
|
90
|
+
providerOptions: { fireworks: { additional_param: 'value' } },
|
|
91
|
+
});
|
|
92
|
+
|
|
93
|
+
expect(server.calls[0].requestMethod).toStrictEqual('POST');
|
|
94
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
95
|
+
'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image',
|
|
96
|
+
);
|
|
97
|
+
});
|
|
98
|
+
|
|
99
|
+
it('should pass headers', async () => {
|
|
100
|
+
const modelWithHeaders = createBasicModel({
|
|
101
|
+
headers: () => ({
|
|
102
|
+
'Custom-Provider-Header': 'provider-header-value',
|
|
103
|
+
}),
|
|
104
|
+
});
|
|
105
|
+
|
|
106
|
+
await modelWithHeaders.doGenerate({
|
|
107
|
+
prompt,
|
|
108
|
+
files: undefined,
|
|
109
|
+
mask: undefined,
|
|
110
|
+
n: 1,
|
|
111
|
+
size: undefined,
|
|
112
|
+
aspectRatio: undefined,
|
|
113
|
+
seed: undefined,
|
|
114
|
+
providerOptions: {},
|
|
115
|
+
headers: {
|
|
116
|
+
'Custom-Request-Header': 'request-header-value',
|
|
117
|
+
},
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
expect(server.calls[0].requestHeaders).toStrictEqual({
|
|
121
|
+
'content-type': 'application/json',
|
|
122
|
+
'custom-provider-header': 'provider-header-value',
|
|
123
|
+
'custom-request-header': 'request-header-value',
|
|
124
|
+
});
|
|
125
|
+
});
|
|
126
|
+
|
|
127
|
+
it('should handle empty response body', async () => {
|
|
128
|
+
server.urls['https://api.example.com/*'].response = {
|
|
129
|
+
type: 'empty',
|
|
130
|
+
};
|
|
131
|
+
|
|
132
|
+
const model = createBasicModel();
|
|
133
|
+
await expect(
|
|
134
|
+
model.doGenerate({
|
|
135
|
+
prompt,
|
|
136
|
+
files: undefined,
|
|
137
|
+
mask: undefined,
|
|
138
|
+
n: 1,
|
|
139
|
+
size: undefined,
|
|
140
|
+
aspectRatio: undefined,
|
|
141
|
+
seed: undefined,
|
|
142
|
+
providerOptions: {},
|
|
143
|
+
}),
|
|
144
|
+
).rejects.toMatchObject({
|
|
145
|
+
message: 'Response body is empty',
|
|
146
|
+
statusCode: 200,
|
|
147
|
+
url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image',
|
|
148
|
+
requestBodyValues: {
|
|
149
|
+
prompt: 'A cute baby sea otter',
|
|
150
|
+
},
|
|
151
|
+
});
|
|
152
|
+
});
|
|
153
|
+
|
|
154
|
+
it('should handle API errors', async () => {
|
|
155
|
+
server.urls['https://api.example.com/*'].response = {
|
|
156
|
+
type: 'error',
|
|
157
|
+
status: 400,
|
|
158
|
+
body: 'Bad Request',
|
|
159
|
+
};
|
|
160
|
+
|
|
161
|
+
const model = createBasicModel();
|
|
162
|
+
await expect(
|
|
163
|
+
model.doGenerate({
|
|
164
|
+
prompt,
|
|
165
|
+
files: undefined,
|
|
166
|
+
mask: undefined,
|
|
167
|
+
n: 1,
|
|
168
|
+
size: undefined,
|
|
169
|
+
aspectRatio: undefined,
|
|
170
|
+
seed: undefined,
|
|
171
|
+
providerOptions: {},
|
|
172
|
+
}),
|
|
173
|
+
).rejects.toMatchObject({
|
|
174
|
+
message: 'Bad Request',
|
|
175
|
+
statusCode: 400,
|
|
176
|
+
url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image',
|
|
177
|
+
requestBodyValues: {
|
|
178
|
+
prompt: 'A cute baby sea otter',
|
|
179
|
+
},
|
|
180
|
+
responseBody: 'Bad Request',
|
|
181
|
+
});
|
|
182
|
+
});
|
|
183
|
+
|
|
184
|
+
it('should handle size parameter for supported models', async () => {
|
|
185
|
+
const sizeModel = createSizeModel();
|
|
186
|
+
|
|
187
|
+
await sizeModel.doGenerate({
|
|
188
|
+
prompt,
|
|
189
|
+
files: undefined,
|
|
190
|
+
mask: undefined,
|
|
191
|
+
n: 1,
|
|
192
|
+
size: '1024x768',
|
|
193
|
+
aspectRatio: undefined,
|
|
194
|
+
seed: 42,
|
|
195
|
+
providerOptions: {},
|
|
196
|
+
});
|
|
197
|
+
|
|
198
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
199
|
+
prompt,
|
|
200
|
+
width: '1024',
|
|
201
|
+
height: '768',
|
|
202
|
+
seed: 42,
|
|
203
|
+
samples: 1,
|
|
204
|
+
});
|
|
205
|
+
});
|
|
206
|
+
|
|
207
|
+
describe('warnings', () => {
|
|
208
|
+
it('should return size warning on workflow model', async () => {
|
|
209
|
+
const model = createBasicModel();
|
|
210
|
+
|
|
211
|
+
const result1 = await model.doGenerate({
|
|
212
|
+
prompt,
|
|
213
|
+
files: undefined,
|
|
214
|
+
mask: undefined,
|
|
215
|
+
n: 1,
|
|
216
|
+
size: '1024x1024',
|
|
217
|
+
aspectRatio: '1:1',
|
|
218
|
+
seed: 123,
|
|
219
|
+
providerOptions: {},
|
|
220
|
+
});
|
|
221
|
+
|
|
222
|
+
expect(result1.warnings).toMatchInlineSnapshot(`
|
|
223
|
+
[
|
|
224
|
+
{
|
|
225
|
+
"details": "This model does not support the \`size\` option. Use \`aspectRatio\` instead.",
|
|
226
|
+
"feature": "size",
|
|
227
|
+
"type": "unsupported",
|
|
228
|
+
},
|
|
229
|
+
]
|
|
230
|
+
`);
|
|
231
|
+
});
|
|
232
|
+
|
|
233
|
+
it('should return aspectRatio warning on size-supporting model', async () => {
|
|
234
|
+
const sizeModel = createSizeModel();
|
|
235
|
+
|
|
236
|
+
const result2 = await sizeModel.doGenerate({
|
|
237
|
+
prompt,
|
|
238
|
+
files: undefined,
|
|
239
|
+
mask: undefined,
|
|
240
|
+
n: 1,
|
|
241
|
+
size: '1024x1024',
|
|
242
|
+
aspectRatio: '1:1',
|
|
243
|
+
seed: 123,
|
|
244
|
+
providerOptions: {},
|
|
245
|
+
});
|
|
246
|
+
|
|
247
|
+
expect(result2.warnings).toMatchInlineSnapshot(`
|
|
248
|
+
[
|
|
249
|
+
{
|
|
250
|
+
"details": "This model does not support the \`aspectRatio\` option.",
|
|
251
|
+
"feature": "aspectRatio",
|
|
252
|
+
"type": "unsupported",
|
|
253
|
+
},
|
|
254
|
+
]
|
|
255
|
+
`);
|
|
256
|
+
});
|
|
257
|
+
});
|
|
258
|
+
|
|
259
|
+
it('should respect the abort signal', async () => {
|
|
260
|
+
const model = createBasicModel();
|
|
261
|
+
const controller = new AbortController();
|
|
262
|
+
|
|
263
|
+
const generatePromise = model.doGenerate({
|
|
264
|
+
prompt,
|
|
265
|
+
files: undefined,
|
|
266
|
+
mask: undefined,
|
|
267
|
+
n: 1,
|
|
268
|
+
size: undefined,
|
|
269
|
+
aspectRatio: undefined,
|
|
270
|
+
seed: undefined,
|
|
271
|
+
providerOptions: {},
|
|
272
|
+
abortSignal: controller.signal,
|
|
273
|
+
});
|
|
274
|
+
|
|
275
|
+
controller.abort();
|
|
276
|
+
|
|
277
|
+
await expect(generatePromise).rejects.toThrow(
|
|
278
|
+
'This operation was aborted',
|
|
279
|
+
);
|
|
280
|
+
});
|
|
281
|
+
|
|
282
|
+
it('should use custom fetch function when provided', async () => {
|
|
283
|
+
const mockFetch = vi.fn().mockResolvedValue(
|
|
284
|
+
new Response(Buffer.from('mock-image-data'), {
|
|
285
|
+
status: 200,
|
|
286
|
+
}),
|
|
287
|
+
);
|
|
288
|
+
|
|
289
|
+
const model = createBasicModel({
|
|
290
|
+
fetch: mockFetch,
|
|
291
|
+
});
|
|
292
|
+
|
|
293
|
+
await model.doGenerate({
|
|
294
|
+
prompt,
|
|
295
|
+
files: undefined,
|
|
296
|
+
mask: undefined,
|
|
297
|
+
n: 1,
|
|
298
|
+
size: undefined,
|
|
299
|
+
aspectRatio: undefined,
|
|
300
|
+
seed: undefined,
|
|
301
|
+
providerOptions: {},
|
|
302
|
+
});
|
|
303
|
+
|
|
304
|
+
expect(mockFetch).toHaveBeenCalled();
|
|
305
|
+
});
|
|
306
|
+
|
|
307
|
+
it('should pass samples parameter to API', async () => {
|
|
308
|
+
const model = createBasicModel();
|
|
309
|
+
|
|
310
|
+
await model.doGenerate({
|
|
311
|
+
prompt,
|
|
312
|
+
files: undefined,
|
|
313
|
+
mask: undefined,
|
|
314
|
+
n: 42,
|
|
315
|
+
size: undefined,
|
|
316
|
+
aspectRatio: undefined,
|
|
317
|
+
seed: undefined,
|
|
318
|
+
providerOptions: {},
|
|
319
|
+
});
|
|
320
|
+
|
|
321
|
+
expect(await server.calls[0].requestBodyJson).toHaveProperty(
|
|
322
|
+
'samples',
|
|
323
|
+
42,
|
|
324
|
+
);
|
|
325
|
+
});
|
|
326
|
+
|
|
327
|
+
describe('response metadata', () => {
|
|
328
|
+
it('should include timestamp, headers and modelId in response', async () => {
|
|
329
|
+
const testDate = new Date('2024-01-01T00:00:00Z');
|
|
330
|
+
const model = createBasicModel({
|
|
331
|
+
currentDate: () => testDate,
|
|
332
|
+
});
|
|
333
|
+
|
|
334
|
+
const result = await model.doGenerate({
|
|
335
|
+
prompt,
|
|
336
|
+
files: undefined,
|
|
337
|
+
mask: undefined,
|
|
338
|
+
n: 1,
|
|
339
|
+
size: undefined,
|
|
340
|
+
aspectRatio: undefined,
|
|
341
|
+
seed: undefined,
|
|
342
|
+
providerOptions: {},
|
|
343
|
+
});
|
|
344
|
+
|
|
345
|
+
expect(result.response).toStrictEqual({
|
|
346
|
+
timestamp: testDate,
|
|
347
|
+
modelId: 'accounts/fireworks/models/flux-1-dev-fp8',
|
|
348
|
+
headers: expect.any(Object),
|
|
349
|
+
});
|
|
350
|
+
});
|
|
351
|
+
|
|
352
|
+
it('should include response headers from API call', async () => {
|
|
353
|
+
server.urls['https://api.example.com/*'].response = {
|
|
354
|
+
type: 'binary',
|
|
355
|
+
body: Buffer.from('test-binary-content'),
|
|
356
|
+
headers: {
|
|
357
|
+
'x-request-id': 'test-request-id',
|
|
358
|
+
'content-type': 'image/png',
|
|
359
|
+
},
|
|
360
|
+
};
|
|
361
|
+
|
|
362
|
+
const model = createBasicModel();
|
|
363
|
+
const result = await model.doGenerate({
|
|
364
|
+
prompt,
|
|
365
|
+
files: undefined,
|
|
366
|
+
mask: undefined,
|
|
367
|
+
n: 1,
|
|
368
|
+
size: undefined,
|
|
369
|
+
aspectRatio: undefined,
|
|
370
|
+
seed: undefined,
|
|
371
|
+
providerOptions: {},
|
|
372
|
+
});
|
|
373
|
+
|
|
374
|
+
expect(result.response.headers).toStrictEqual({
|
|
375
|
+
'content-length': '19',
|
|
376
|
+
'x-request-id': 'test-request-id',
|
|
377
|
+
'content-type': 'image/png',
|
|
378
|
+
});
|
|
379
|
+
});
|
|
380
|
+
});
|
|
381
|
+
});
|
|
382
|
+
|
|
383
|
+
describe('constructor', () => {
|
|
384
|
+
it('should expose correct provider and model information', () => {
|
|
385
|
+
const model = createBasicModel();
|
|
386
|
+
|
|
387
|
+
expect(model.provider).toBe('fireworks');
|
|
388
|
+
expect(model.modelId).toBe('accounts/fireworks/models/flux-1-dev-fp8');
|
|
389
|
+
expect(model.specificationVersion).toBe('v3');
|
|
390
|
+
expect(model.maxImagesPerCall).toBe(1);
|
|
391
|
+
});
|
|
392
|
+
});
|
|
393
|
+
|
|
394
|
+
describe('Image Editing', () => {
|
|
395
|
+
const editServer = createTestServer({
|
|
396
|
+
'https://api.edit.example.com/*': {
|
|
397
|
+
response: {
|
|
398
|
+
type: 'binary',
|
|
399
|
+
body: Buffer.from('edited-image-data'),
|
|
400
|
+
},
|
|
401
|
+
},
|
|
402
|
+
});
|
|
403
|
+
|
|
404
|
+
function createKontextModel() {
|
|
405
|
+
return new FireworksImageModel(
|
|
406
|
+
'accounts/fireworks/models/flux-kontext-pro',
|
|
407
|
+
{
|
|
408
|
+
provider: 'fireworks',
|
|
409
|
+
baseURL: 'https://api.edit.example.com',
|
|
410
|
+
headers: () => ({ 'api-key': 'test-key' }),
|
|
411
|
+
},
|
|
412
|
+
);
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
it('should send edit request with files as data URI', async () => {
|
|
416
|
+
const imageData = new Uint8Array([137, 80, 78, 71]); // PNG magic bytes
|
|
417
|
+
|
|
418
|
+
await createKontextModel().doGenerate({
|
|
419
|
+
prompt: 'Turn the cat into a dog',
|
|
420
|
+
files: [
|
|
421
|
+
{
|
|
422
|
+
type: 'file',
|
|
423
|
+
data: imageData,
|
|
424
|
+
mediaType: 'image/png',
|
|
425
|
+
},
|
|
426
|
+
],
|
|
427
|
+
mask: undefined,
|
|
428
|
+
n: 1,
|
|
429
|
+
size: undefined,
|
|
430
|
+
aspectRatio: undefined,
|
|
431
|
+
seed: undefined,
|
|
432
|
+
providerOptions: {},
|
|
433
|
+
});
|
|
434
|
+
|
|
435
|
+
const requestBody = await editServer.calls[0].requestBodyJson;
|
|
436
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
437
|
+
{
|
|
438
|
+
"input_image": "",
|
|
439
|
+
"prompt": "Turn the cat into a dog",
|
|
440
|
+
"samples": 1,
|
|
441
|
+
}
|
|
442
|
+
`);
|
|
443
|
+
});
|
|
444
|
+
|
|
445
|
+
it('should use correct URL for Kontext model (no text_to_image suffix)', async () => {
|
|
446
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
447
|
+
|
|
448
|
+
await createKontextModel().doGenerate({
|
|
449
|
+
prompt: 'Edit this image',
|
|
450
|
+
files: [
|
|
451
|
+
{
|
|
452
|
+
type: 'file',
|
|
453
|
+
data: imageData,
|
|
454
|
+
mediaType: 'image/png',
|
|
455
|
+
},
|
|
456
|
+
],
|
|
457
|
+
mask: undefined,
|
|
458
|
+
n: 1,
|
|
459
|
+
size: undefined,
|
|
460
|
+
aspectRatio: undefined,
|
|
461
|
+
seed: undefined,
|
|
462
|
+
providerOptions: {},
|
|
463
|
+
});
|
|
464
|
+
|
|
465
|
+
expect(editServer.calls[0].requestUrl).toBe(
|
|
466
|
+
'https://api.edit.example.com/workflows/accounts/fireworks/models/flux-kontext-pro',
|
|
467
|
+
);
|
|
468
|
+
});
|
|
469
|
+
|
|
470
|
+
it('should send edit request with URL-based file', async () => {
|
|
471
|
+
await createKontextModel().doGenerate({
|
|
472
|
+
prompt: 'Edit this image',
|
|
473
|
+
files: [
|
|
474
|
+
{
|
|
475
|
+
type: 'url',
|
|
476
|
+
url: 'https://example.com/input.png',
|
|
477
|
+
},
|
|
478
|
+
],
|
|
479
|
+
mask: undefined,
|
|
480
|
+
n: 1,
|
|
481
|
+
size: undefined,
|
|
482
|
+
aspectRatio: undefined,
|
|
483
|
+
seed: undefined,
|
|
484
|
+
providerOptions: {},
|
|
485
|
+
});
|
|
486
|
+
|
|
487
|
+
const requestBody = await editServer.calls[0].requestBodyJson;
|
|
488
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
489
|
+
{
|
|
490
|
+
"input_image": "https://example.com/input.png",
|
|
491
|
+
"prompt": "Edit this image",
|
|
492
|
+
"samples": 1,
|
|
493
|
+
}
|
|
494
|
+
`);
|
|
495
|
+
});
|
|
496
|
+
|
|
497
|
+
it('should send edit request with base64 string data', async () => {
|
|
498
|
+
await createKontextModel().doGenerate({
|
|
499
|
+
prompt: 'Edit this image',
|
|
500
|
+
files: [
|
|
501
|
+
{
|
|
502
|
+
type: 'file',
|
|
503
|
+
data: 'iVBORw0KGgoAAAANSUhEUgAAAAE=',
|
|
504
|
+
mediaType: 'image/png',
|
|
505
|
+
},
|
|
506
|
+
],
|
|
507
|
+
mask: undefined,
|
|
508
|
+
n: 1,
|
|
509
|
+
size: undefined,
|
|
510
|
+
aspectRatio: undefined,
|
|
511
|
+
seed: undefined,
|
|
512
|
+
providerOptions: {},
|
|
513
|
+
});
|
|
514
|
+
|
|
515
|
+
const requestBody = await editServer.calls[0].requestBodyJson;
|
|
516
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
517
|
+
{
|
|
518
|
+
"input_image": "",
|
|
519
|
+
"prompt": "Edit this image",
|
|
520
|
+
"samples": 1,
|
|
521
|
+
}
|
|
522
|
+
`);
|
|
523
|
+
});
|
|
524
|
+
|
|
525
|
+
it('should warn when multiple files are provided', async () => {
|
|
526
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
527
|
+
|
|
528
|
+
const result = await createKontextModel().doGenerate({
|
|
529
|
+
prompt: 'Edit images',
|
|
530
|
+
files: [
|
|
531
|
+
{
|
|
532
|
+
type: 'file',
|
|
533
|
+
data: imageData,
|
|
534
|
+
mediaType: 'image/png',
|
|
535
|
+
},
|
|
536
|
+
{
|
|
537
|
+
type: 'file',
|
|
538
|
+
data: imageData,
|
|
539
|
+
mediaType: 'image/png',
|
|
540
|
+
},
|
|
541
|
+
],
|
|
542
|
+
mask: undefined,
|
|
543
|
+
n: 1,
|
|
544
|
+
size: undefined,
|
|
545
|
+
aspectRatio: undefined,
|
|
546
|
+
seed: undefined,
|
|
547
|
+
providerOptions: {},
|
|
548
|
+
});
|
|
549
|
+
|
|
550
|
+
expect(result.warnings).toContainEqual({
|
|
551
|
+
type: 'other',
|
|
552
|
+
message:
|
|
553
|
+
'Fireworks only supports a single input image. Additional images are ignored.',
|
|
554
|
+
});
|
|
555
|
+
});
|
|
556
|
+
|
|
557
|
+
it('should warn when mask is provided', async () => {
|
|
558
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
559
|
+
const maskData = new Uint8Array([255, 255, 255, 0]);
|
|
560
|
+
|
|
561
|
+
const result = await createKontextModel().doGenerate({
|
|
562
|
+
prompt: 'Edit with mask',
|
|
563
|
+
files: [
|
|
564
|
+
{
|
|
565
|
+
type: 'file',
|
|
566
|
+
data: imageData,
|
|
567
|
+
mediaType: 'image/png',
|
|
568
|
+
},
|
|
569
|
+
],
|
|
570
|
+
mask: {
|
|
571
|
+
type: 'file',
|
|
572
|
+
data: maskData,
|
|
573
|
+
mediaType: 'image/png',
|
|
574
|
+
},
|
|
575
|
+
n: 1,
|
|
576
|
+
size: undefined,
|
|
577
|
+
aspectRatio: undefined,
|
|
578
|
+
seed: undefined,
|
|
579
|
+
providerOptions: {},
|
|
580
|
+
});
|
|
581
|
+
|
|
582
|
+
expect(result.warnings).toContainEqual({
|
|
583
|
+
type: 'unsupported',
|
|
584
|
+
feature: 'mask',
|
|
585
|
+
details:
|
|
586
|
+
'Fireworks Kontext models do not support explicit masks. Use the prompt to describe the areas to edit.',
|
|
587
|
+
});
|
|
588
|
+
});
|
|
589
|
+
|
|
590
|
+
it('should pass provider options with edit request', async () => {
|
|
591
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
592
|
+
|
|
593
|
+
await createKontextModel().doGenerate({
|
|
594
|
+
prompt: 'Edit with options',
|
|
595
|
+
files: [
|
|
596
|
+
{
|
|
597
|
+
type: 'file',
|
|
598
|
+
data: imageData,
|
|
599
|
+
mediaType: 'image/png',
|
|
600
|
+
},
|
|
601
|
+
],
|
|
602
|
+
mask: undefined,
|
|
603
|
+
n: 1,
|
|
604
|
+
size: undefined,
|
|
605
|
+
aspectRatio: '16:9',
|
|
606
|
+
seed: 42,
|
|
607
|
+
providerOptions: {
|
|
608
|
+
fireworks: {
|
|
609
|
+
output_format: 'jpeg',
|
|
610
|
+
safety_tolerance: 2,
|
|
611
|
+
},
|
|
612
|
+
},
|
|
613
|
+
});
|
|
614
|
+
|
|
615
|
+
const requestBody = await editServer.calls[0].requestBodyJson;
|
|
616
|
+
expect(requestBody).toMatchInlineSnapshot(`
|
|
617
|
+
{
|
|
618
|
+
"aspect_ratio": "16:9",
|
|
619
|
+
"input_image": "",
|
|
620
|
+
"output_format": "jpeg",
|
|
621
|
+
"prompt": "Edit with options",
|
|
622
|
+
"safety_tolerance": 2,
|
|
623
|
+
"samples": 1,
|
|
624
|
+
"seed": 42,
|
|
625
|
+
}
|
|
626
|
+
`);
|
|
627
|
+
});
|
|
628
|
+
});
|
|
629
|
+
});
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import { ImageModelV3, SharedV3Warning } from '@ai-sdk/provider';
|
|
2
|
+
import {
|
|
3
|
+
combineHeaders,
|
|
4
|
+
convertImageModelFileToDataUri,
|
|
5
|
+
createBinaryResponseHandler,
|
|
6
|
+
createStatusCodeErrorResponseHandler,
|
|
7
|
+
FetchFunction,
|
|
8
|
+
postJsonToApi,
|
|
9
|
+
} from '@ai-sdk/provider-utils';
|
|
10
|
+
import { FireworksImageModelId } from './fireworks-image-options';
|
|
11
|
+
|
|
12
|
+
interface FireworksImageModelBackendConfig {
|
|
13
|
+
urlFormat: 'workflows' | 'workflows_edit' | 'image_generation';
|
|
14
|
+
supportsSize?: boolean;
|
|
15
|
+
supportsEditing?: boolean;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
const modelToBackendConfig: Partial<
|
|
19
|
+
Record<FireworksImageModelId, FireworksImageModelBackendConfig>
|
|
20
|
+
> = {
|
|
21
|
+
'accounts/fireworks/models/flux-1-dev-fp8': {
|
|
22
|
+
urlFormat: 'workflows',
|
|
23
|
+
},
|
|
24
|
+
'accounts/fireworks/models/flux-1-schnell-fp8': {
|
|
25
|
+
urlFormat: 'workflows',
|
|
26
|
+
},
|
|
27
|
+
'accounts/fireworks/models/flux-kontext-pro': {
|
|
28
|
+
urlFormat: 'workflows_edit',
|
|
29
|
+
supportsEditing: true,
|
|
30
|
+
},
|
|
31
|
+
'accounts/fireworks/models/flux-kontext-max': {
|
|
32
|
+
urlFormat: 'workflows_edit',
|
|
33
|
+
supportsEditing: true,
|
|
34
|
+
},
|
|
35
|
+
'accounts/fireworks/models/playground-v2-5-1024px-aesthetic': {
|
|
36
|
+
urlFormat: 'image_generation',
|
|
37
|
+
supportsSize: true,
|
|
38
|
+
},
|
|
39
|
+
'accounts/fireworks/models/japanese-stable-diffusion-xl': {
|
|
40
|
+
urlFormat: 'image_generation',
|
|
41
|
+
supportsSize: true,
|
|
42
|
+
},
|
|
43
|
+
'accounts/fireworks/models/playground-v2-1024px-aesthetic': {
|
|
44
|
+
urlFormat: 'image_generation',
|
|
45
|
+
supportsSize: true,
|
|
46
|
+
},
|
|
47
|
+
'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0': {
|
|
48
|
+
urlFormat: 'image_generation',
|
|
49
|
+
supportsSize: true,
|
|
50
|
+
},
|
|
51
|
+
'accounts/fireworks/models/SSD-1B': {
|
|
52
|
+
urlFormat: 'image_generation',
|
|
53
|
+
supportsSize: true,
|
|
54
|
+
},
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
function getUrlForModel(
|
|
58
|
+
baseUrl: string,
|
|
59
|
+
modelId: FireworksImageModelId,
|
|
60
|
+
hasInputImage: boolean,
|
|
61
|
+
): string {
|
|
62
|
+
const config = modelToBackendConfig[modelId];
|
|
63
|
+
|
|
64
|
+
switch (config?.urlFormat) {
|
|
65
|
+
case 'image_generation':
|
|
66
|
+
return `${baseUrl}/image_generation/${modelId}`;
|
|
67
|
+
case 'workflows_edit':
|
|
68
|
+
// Kontext models: use base URL for editing (no suffix)
|
|
69
|
+
return `${baseUrl}/workflows/${modelId}`;
|
|
70
|
+
case 'workflows':
|
|
71
|
+
default:
|
|
72
|
+
// Standard FLUX models: use text_to_image for generation,
|
|
73
|
+
// but if input_image provided, some models may support editing
|
|
74
|
+
if (hasInputImage && config?.supportsEditing) {
|
|
75
|
+
return `${baseUrl}/workflows/${modelId}`;
|
|
76
|
+
}
|
|
77
|
+
return `${baseUrl}/workflows/${modelId}/text_to_image`;
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
interface FireworksImageModelConfig {
|
|
82
|
+
provider: string;
|
|
83
|
+
baseURL: string;
|
|
84
|
+
headers: () => Record<string, string>;
|
|
85
|
+
fetch?: FetchFunction;
|
|
86
|
+
_internal?: {
|
|
87
|
+
currentDate?: () => Date;
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
export class FireworksImageModel implements ImageModelV3 {
|
|
92
|
+
readonly specificationVersion = 'v3';
|
|
93
|
+
readonly maxImagesPerCall = 1;
|
|
94
|
+
|
|
95
|
+
get provider(): string {
|
|
96
|
+
return this.config.provider;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
constructor(
|
|
100
|
+
readonly modelId: FireworksImageModelId,
|
|
101
|
+
private config: FireworksImageModelConfig,
|
|
102
|
+
) {}
|
|
103
|
+
|
|
104
|
+
async doGenerate({
|
|
105
|
+
prompt,
|
|
106
|
+
n,
|
|
107
|
+
size,
|
|
108
|
+
aspectRatio,
|
|
109
|
+
seed,
|
|
110
|
+
providerOptions,
|
|
111
|
+
headers,
|
|
112
|
+
abortSignal,
|
|
113
|
+
files,
|
|
114
|
+
mask,
|
|
115
|
+
}: Parameters<ImageModelV3['doGenerate']>[0]): Promise<
|
|
116
|
+
Awaited<ReturnType<ImageModelV3['doGenerate']>>
|
|
117
|
+
> {
|
|
118
|
+
const warnings: Array<SharedV3Warning> = [];
|
|
119
|
+
|
|
120
|
+
const backendConfig = modelToBackendConfig[this.modelId];
|
|
121
|
+
if (!backendConfig?.supportsSize && size != null) {
|
|
122
|
+
warnings.push({
|
|
123
|
+
type: 'unsupported',
|
|
124
|
+
feature: 'size',
|
|
125
|
+
details:
|
|
126
|
+
'This model does not support the `size` option. Use `aspectRatio` instead.',
|
|
127
|
+
});
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// Use supportsSize as a proxy for whether the model does not support
|
|
131
|
+
// aspectRatio. This invariant holds for the current set of models.
|
|
132
|
+
if (backendConfig?.supportsSize && aspectRatio != null) {
|
|
133
|
+
warnings.push({
|
|
134
|
+
type: 'unsupported',
|
|
135
|
+
feature: 'aspectRatio',
|
|
136
|
+
details: 'This model does not support the `aspectRatio` option.',
|
|
137
|
+
});
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
// Handle files for image editing
|
|
141
|
+
const hasInputImage = files != null && files.length > 0;
|
|
142
|
+
let inputImage: string | undefined;
|
|
143
|
+
|
|
144
|
+
if (hasInputImage) {
|
|
145
|
+
inputImage = convertImageModelFileToDataUri(files[0]);
|
|
146
|
+
|
|
147
|
+
if (files.length > 1) {
|
|
148
|
+
warnings.push({
|
|
149
|
+
type: 'other',
|
|
150
|
+
message:
|
|
151
|
+
'Fireworks only supports a single input image. Additional images are ignored.',
|
|
152
|
+
});
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
// Warn about mask - Fireworks Kontext models don't support explicit masks
|
|
157
|
+
if (mask != null) {
|
|
158
|
+
warnings.push({
|
|
159
|
+
type: 'unsupported',
|
|
160
|
+
feature: 'mask',
|
|
161
|
+
details:
|
|
162
|
+
'Fireworks Kontext models do not support explicit masks. Use the prompt to describe the areas to edit.',
|
|
163
|
+
});
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
const splitSize = size?.split('x');
|
|
167
|
+
const currentDate = this.config._internal?.currentDate?.() ?? new Date();
|
|
168
|
+
const { value: response, responseHeaders } = await postJsonToApi({
|
|
169
|
+
url: getUrlForModel(this.config.baseURL, this.modelId, hasInputImage),
|
|
170
|
+
headers: combineHeaders(this.config.headers(), headers),
|
|
171
|
+
body: {
|
|
172
|
+
prompt,
|
|
173
|
+
aspect_ratio: aspectRatio,
|
|
174
|
+
seed,
|
|
175
|
+
samples: n,
|
|
176
|
+
...(inputImage && { input_image: inputImage }),
|
|
177
|
+
...(splitSize && { width: splitSize[0], height: splitSize[1] }),
|
|
178
|
+
...(providerOptions.fireworks ?? {}),
|
|
179
|
+
},
|
|
180
|
+
failedResponseHandler: createStatusCodeErrorResponseHandler(),
|
|
181
|
+
successfulResponseHandler: createBinaryResponseHandler(),
|
|
182
|
+
abortSignal,
|
|
183
|
+
fetch: this.config.fetch,
|
|
184
|
+
});
|
|
185
|
+
|
|
186
|
+
return {
|
|
187
|
+
images: [response],
|
|
188
|
+
warnings,
|
|
189
|
+
response: {
|
|
190
|
+
timestamp: currentDate,
|
|
191
|
+
modelId: this.modelId,
|
|
192
|
+
headers: responseHeaders,
|
|
193
|
+
},
|
|
194
|
+
};
|
|
195
|
+
}
|
|
196
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
// https://fireworks.ai/models?type=image
|
|
2
|
+
export type FireworksImageModelId =
|
|
3
|
+
| 'accounts/fireworks/models/flux-1-dev-fp8'
|
|
4
|
+
| 'accounts/fireworks/models/flux-1-schnell-fp8'
|
|
5
|
+
| 'accounts/fireworks/models/flux-kontext-pro'
|
|
6
|
+
| 'accounts/fireworks/models/flux-kontext-max'
|
|
7
|
+
| 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic'
|
|
8
|
+
| 'accounts/fireworks/models/japanese-stable-diffusion-xl'
|
|
9
|
+
| 'accounts/fireworks/models/playground-v2-1024px-aesthetic'
|
|
10
|
+
| 'accounts/fireworks/models/SSD-1B'
|
|
11
|
+
| 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0'
|
|
12
|
+
| (string & {});
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
|
|
2
|
+
import { createFireworks } from './fireworks-provider';
|
|
3
|
+
import { LanguageModelV3, EmbeddingModelV3 } from '@ai-sdk/provider';
|
|
4
|
+
import { loadApiKey } from '@ai-sdk/provider-utils';
|
|
5
|
+
import {
|
|
6
|
+
OpenAICompatibleChatLanguageModel,
|
|
7
|
+
OpenAICompatibleCompletionLanguageModel,
|
|
8
|
+
OpenAICompatibleEmbeddingModel,
|
|
9
|
+
} from '@ai-sdk/openai-compatible';
|
|
10
|
+
import { FireworksImageModel } from './fireworks-image-model';
|
|
11
|
+
|
|
12
|
+
// Add type assertion for the mocked class
|
|
13
|
+
const OpenAICompatibleChatLanguageModelMock =
|
|
14
|
+
OpenAICompatibleChatLanguageModel as unknown as Mock;
|
|
15
|
+
|
|
16
|
+
vi.mock('@ai-sdk/openai-compatible', () => {
|
|
17
|
+
// Create mock constructor functions that behave like classes
|
|
18
|
+
const createMockConstructor = (providerName: string) => {
|
|
19
|
+
const mockConstructor = vi.fn().mockImplementation(function (
|
|
20
|
+
this: any,
|
|
21
|
+
modelId: string,
|
|
22
|
+
settings: any,
|
|
23
|
+
) {
|
|
24
|
+
this.provider = providerName;
|
|
25
|
+
this.modelId = modelId;
|
|
26
|
+
this.settings = settings;
|
|
27
|
+
});
|
|
28
|
+
return mockConstructor;
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
return {
|
|
32
|
+
OpenAICompatibleChatLanguageModel: createMockConstructor('fireworks.chat'),
|
|
33
|
+
OpenAICompatibleCompletionLanguageModel: createMockConstructor(
|
|
34
|
+
'fireworks.completion',
|
|
35
|
+
),
|
|
36
|
+
OpenAICompatibleEmbeddingModel: createMockConstructor(
|
|
37
|
+
'fireworks.embedding',
|
|
38
|
+
),
|
|
39
|
+
};
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
vi.mock('@ai-sdk/provider-utils', async () => {
|
|
43
|
+
const actual = await vi.importActual('@ai-sdk/provider-utils');
|
|
44
|
+
return {
|
|
45
|
+
...actual,
|
|
46
|
+
loadApiKey: vi.fn().mockReturnValue('mock-api-key'),
|
|
47
|
+
withoutTrailingSlash: vi.fn(url => url),
|
|
48
|
+
};
|
|
49
|
+
});
|
|
50
|
+
|
|
51
|
+
vi.mock('./fireworks-image-model', () => ({
|
|
52
|
+
FireworksImageModel: vi.fn(),
|
|
53
|
+
}));
|
|
54
|
+
|
|
55
|
+
describe('FireworksProvider', () => {
|
|
56
|
+
let mockLanguageModel: LanguageModelV3;
|
|
57
|
+
let mockEmbeddingModel: EmbeddingModelV3;
|
|
58
|
+
|
|
59
|
+
beforeEach(() => {
|
|
60
|
+
// Mock implementations of models
|
|
61
|
+
mockLanguageModel = {
|
|
62
|
+
// Add any required methods for LanguageModelV3
|
|
63
|
+
} as LanguageModelV3;
|
|
64
|
+
mockEmbeddingModel = {
|
|
65
|
+
// Add any required methods for EmbeddingModelV3
|
|
66
|
+
} as EmbeddingModelV3;
|
|
67
|
+
|
|
68
|
+
// Reset mocks
|
|
69
|
+
vi.clearAllMocks();
|
|
70
|
+
});
|
|
71
|
+
|
|
72
|
+
describe('createFireworks', () => {
|
|
73
|
+
it('should create a FireworksProvider instance with default options', () => {
|
|
74
|
+
const provider = createFireworks();
|
|
75
|
+
const model = provider('model-id');
|
|
76
|
+
|
|
77
|
+
// Use the mocked version
|
|
78
|
+
const constructorCall =
|
|
79
|
+
OpenAICompatibleChatLanguageModelMock.mock.calls[0];
|
|
80
|
+
const config = constructorCall[1];
|
|
81
|
+
config.headers();
|
|
82
|
+
|
|
83
|
+
expect(loadApiKey).toHaveBeenCalledWith({
|
|
84
|
+
apiKey: undefined,
|
|
85
|
+
environmentVariableName: 'FIREWORKS_API_KEY',
|
|
86
|
+
description: 'Fireworks API key',
|
|
87
|
+
});
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
it('should create a FireworksProvider instance with custom options', () => {
|
|
91
|
+
const options = {
|
|
92
|
+
apiKey: 'custom-key',
|
|
93
|
+
baseURL: 'https://custom.url',
|
|
94
|
+
headers: { 'Custom-Header': 'value' },
|
|
95
|
+
};
|
|
96
|
+
const provider = createFireworks(options);
|
|
97
|
+
const model = provider('model-id');
|
|
98
|
+
|
|
99
|
+
const constructorCall =
|
|
100
|
+
OpenAICompatibleChatLanguageModelMock.mock.calls[0];
|
|
101
|
+
const config = constructorCall[1];
|
|
102
|
+
config.headers();
|
|
103
|
+
|
|
104
|
+
expect(loadApiKey).toHaveBeenCalledWith({
|
|
105
|
+
apiKey: 'custom-key',
|
|
106
|
+
environmentVariableName: 'FIREWORKS_API_KEY',
|
|
107
|
+
description: 'Fireworks API key',
|
|
108
|
+
});
|
|
109
|
+
});
|
|
110
|
+
|
|
111
|
+
it('should return a chat model when called as a function', () => {
|
|
112
|
+
const provider = createFireworks();
|
|
113
|
+
const modelId = 'foo-model-id';
|
|
114
|
+
|
|
115
|
+
const model = provider(modelId);
|
|
116
|
+
expect(model).toBeInstanceOf(OpenAICompatibleChatLanguageModel);
|
|
117
|
+
});
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
describe('chatModel', () => {
|
|
121
|
+
it('should construct a chat model with correct configuration', () => {
|
|
122
|
+
const provider = createFireworks();
|
|
123
|
+
const modelId = 'fireworks-chat-model';
|
|
124
|
+
|
|
125
|
+
const model = provider.chatModel(modelId);
|
|
126
|
+
|
|
127
|
+
expect(model).toBeInstanceOf(OpenAICompatibleChatLanguageModel);
|
|
128
|
+
});
|
|
129
|
+
});
|
|
130
|
+
|
|
131
|
+
describe('completionModel', () => {
|
|
132
|
+
it('should construct a completion model with correct configuration', () => {
|
|
133
|
+
const provider = createFireworks();
|
|
134
|
+
const modelId = 'fireworks-completion-model';
|
|
135
|
+
|
|
136
|
+
const model = provider.completionModel(modelId);
|
|
137
|
+
|
|
138
|
+
expect(model).toBeInstanceOf(OpenAICompatibleCompletionLanguageModel);
|
|
139
|
+
});
|
|
140
|
+
});
|
|
141
|
+
|
|
142
|
+
describe('embeddingModel', () => {
|
|
143
|
+
it('should construct a text embedding model with correct configuration', () => {
|
|
144
|
+
const provider = createFireworks();
|
|
145
|
+
const modelId = 'fireworks-embedding-model';
|
|
146
|
+
|
|
147
|
+
const model = provider.embeddingModel(modelId);
|
|
148
|
+
|
|
149
|
+
expect(model).toBeInstanceOf(OpenAICompatibleEmbeddingModel);
|
|
150
|
+
});
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
describe('image', () => {
|
|
154
|
+
it('should construct an image model with correct configuration', () => {
|
|
155
|
+
const provider = createFireworks();
|
|
156
|
+
const modelId = 'accounts/fireworks/models/flux-1-dev-fp8';
|
|
157
|
+
|
|
158
|
+
const model = provider.image(modelId);
|
|
159
|
+
|
|
160
|
+
expect(model).toBeInstanceOf(FireworksImageModel);
|
|
161
|
+
expect(FireworksImageModel).toHaveBeenCalledWith(
|
|
162
|
+
modelId,
|
|
163
|
+
expect.objectContaining({
|
|
164
|
+
provider: 'fireworks.image',
|
|
165
|
+
baseURL: 'https://api.fireworks.ai/inference/v1',
|
|
166
|
+
}),
|
|
167
|
+
);
|
|
168
|
+
});
|
|
169
|
+
|
|
170
|
+
it('should use default settings when none provided', () => {
|
|
171
|
+
const provider = createFireworks();
|
|
172
|
+
const modelId = 'accounts/fireworks/models/flux-1-dev-fp8';
|
|
173
|
+
|
|
174
|
+
const model = provider.image(modelId);
|
|
175
|
+
|
|
176
|
+
expect(model).toBeInstanceOf(FireworksImageModel);
|
|
177
|
+
expect(FireworksImageModel).toHaveBeenCalledWith(
|
|
178
|
+
modelId,
|
|
179
|
+
expect.any(Object),
|
|
180
|
+
);
|
|
181
|
+
});
|
|
182
|
+
|
|
183
|
+
it('should respect custom baseURL', () => {
|
|
184
|
+
const customBaseURL = 'https://custom.api.fireworks.ai';
|
|
185
|
+
const provider = createFireworks({ baseURL: customBaseURL });
|
|
186
|
+
const modelId = 'accounts/fireworks/models/flux-1-dev-fp8';
|
|
187
|
+
|
|
188
|
+
provider.image(modelId);
|
|
189
|
+
|
|
190
|
+
expect(FireworksImageModel).toHaveBeenCalledWith(
|
|
191
|
+
modelId,
|
|
192
|
+
expect.objectContaining({
|
|
193
|
+
baseURL: customBaseURL,
|
|
194
|
+
}),
|
|
195
|
+
);
|
|
196
|
+
});
|
|
197
|
+
});
|
|
198
|
+
});
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import {
|
|
2
|
+
OpenAICompatibleChatLanguageModel,
|
|
3
|
+
OpenAICompatibleCompletionLanguageModel,
|
|
4
|
+
OpenAICompatibleEmbeddingModel,
|
|
5
|
+
ProviderErrorStructure,
|
|
6
|
+
} from '@ai-sdk/openai-compatible';
|
|
7
|
+
import {
|
|
8
|
+
EmbeddingModelV3,
|
|
9
|
+
ImageModelV3,
|
|
10
|
+
LanguageModelV3,
|
|
11
|
+
ProviderV3,
|
|
12
|
+
} from '@ai-sdk/provider';
|
|
13
|
+
import {
|
|
14
|
+
FetchFunction,
|
|
15
|
+
loadApiKey,
|
|
16
|
+
withoutTrailingSlash,
|
|
17
|
+
withUserAgentSuffix,
|
|
18
|
+
} from '@ai-sdk/provider-utils';
|
|
19
|
+
import { z } from 'zod/v4';
|
|
20
|
+
import { FireworksChatModelId } from './fireworks-chat-options';
|
|
21
|
+
import { FireworksCompletionModelId } from './fireworks-completion-options';
|
|
22
|
+
import { FireworksEmbeddingModelId } from './fireworks-embedding-options';
|
|
23
|
+
import { FireworksImageModel } from './fireworks-image-model';
|
|
24
|
+
import { FireworksImageModelId } from './fireworks-image-options';
|
|
25
|
+
import { VERSION } from './version';
|
|
26
|
+
|
|
27
|
+
export type FireworksErrorData = z.infer<typeof fireworksErrorSchema>;
|
|
28
|
+
|
|
29
|
+
const fireworksErrorSchema = z.object({
|
|
30
|
+
error: z.string(),
|
|
31
|
+
});
|
|
32
|
+
|
|
33
|
+
const fireworksErrorStructure: ProviderErrorStructure<FireworksErrorData> = {
|
|
34
|
+
errorSchema: fireworksErrorSchema,
|
|
35
|
+
errorToMessage: data => data.error,
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
export interface FireworksProviderSettings {
|
|
39
|
+
/**
|
|
40
|
+
Fireworks API key. Default value is taken from the `FIREWORKS_API_KEY`
|
|
41
|
+
environment variable.
|
|
42
|
+
*/
|
|
43
|
+
apiKey?: string;
|
|
44
|
+
/**
|
|
45
|
+
Base URL for the API calls.
|
|
46
|
+
*/
|
|
47
|
+
baseURL?: string;
|
|
48
|
+
/**
|
|
49
|
+
Custom headers to include in the requests.
|
|
50
|
+
*/
|
|
51
|
+
headers?: Record<string, string>;
|
|
52
|
+
/**
|
|
53
|
+
Custom fetch implementation. You can use it as a middleware to intercept requests,
|
|
54
|
+
or to provide a custom fetch implementation for e.g. testing.
|
|
55
|
+
*/
|
|
56
|
+
fetch?: FetchFunction;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
export interface FireworksProvider extends ProviderV3 {
|
|
60
|
+
/**
|
|
61
|
+
Creates a model for text generation.
|
|
62
|
+
*/
|
|
63
|
+
(modelId: FireworksChatModelId): LanguageModelV3;
|
|
64
|
+
|
|
65
|
+
/**
|
|
66
|
+
Creates a chat model for text generation.
|
|
67
|
+
*/
|
|
68
|
+
chatModel(modelId: FireworksChatModelId): LanguageModelV3;
|
|
69
|
+
|
|
70
|
+
/**
|
|
71
|
+
Creates a completion model for text generation.
|
|
72
|
+
*/
|
|
73
|
+
completionModel(modelId: FireworksCompletionModelId): LanguageModelV3;
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
Creates a chat model for text generation.
|
|
77
|
+
*/
|
|
78
|
+
languageModel(modelId: FireworksChatModelId): LanguageModelV3;
|
|
79
|
+
|
|
80
|
+
/**
|
|
81
|
+
Creates a text embedding model for text generation.
|
|
82
|
+
*/
|
|
83
|
+
embeddingModel(modelId: FireworksEmbeddingModelId): EmbeddingModelV3;
|
|
84
|
+
|
|
85
|
+
/**
|
|
86
|
+
* @deprecated Use `embeddingModel` instead.
|
|
87
|
+
*/
|
|
88
|
+
textEmbeddingModel(modelId: FireworksEmbeddingModelId): EmbeddingModelV3;
|
|
89
|
+
|
|
90
|
+
/**
|
|
91
|
+
Creates a model for image generation.
|
|
92
|
+
*/
|
|
93
|
+
image(modelId: FireworksImageModelId): ImageModelV3;
|
|
94
|
+
|
|
95
|
+
/**
|
|
96
|
+
Creates a model for image generation.
|
|
97
|
+
*/
|
|
98
|
+
imageModel(modelId: FireworksImageModelId): ImageModelV3;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
const defaultBaseURL = 'https://api.fireworks.ai/inference/v1';
|
|
102
|
+
|
|
103
|
+
export function createFireworks(
|
|
104
|
+
options: FireworksProviderSettings = {},
|
|
105
|
+
): FireworksProvider {
|
|
106
|
+
const baseURL = withoutTrailingSlash(options.baseURL ?? defaultBaseURL);
|
|
107
|
+
const getHeaders = () =>
|
|
108
|
+
withUserAgentSuffix(
|
|
109
|
+
{
|
|
110
|
+
Authorization: `Bearer ${loadApiKey({
|
|
111
|
+
apiKey: options.apiKey,
|
|
112
|
+
environmentVariableName: 'FIREWORKS_API_KEY',
|
|
113
|
+
description: 'Fireworks API key',
|
|
114
|
+
})}`,
|
|
115
|
+
...options.headers,
|
|
116
|
+
},
|
|
117
|
+
`ai-sdk/fireworks/${VERSION}`,
|
|
118
|
+
);
|
|
119
|
+
|
|
120
|
+
interface CommonModelConfig {
|
|
121
|
+
provider: string;
|
|
122
|
+
url: ({ path }: { path: string }) => string;
|
|
123
|
+
headers: () => Record<string, string>;
|
|
124
|
+
fetch?: FetchFunction;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
const getCommonModelConfig = (modelType: string): CommonModelConfig => ({
|
|
128
|
+
provider: `fireworks.${modelType}`,
|
|
129
|
+
url: ({ path }) => `${baseURL}${path}`,
|
|
130
|
+
headers: getHeaders,
|
|
131
|
+
fetch: options.fetch,
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
const createChatModel = (modelId: FireworksChatModelId) => {
|
|
135
|
+
return new OpenAICompatibleChatLanguageModel(modelId, {
|
|
136
|
+
...getCommonModelConfig('chat'),
|
|
137
|
+
errorStructure: fireworksErrorStructure,
|
|
138
|
+
});
|
|
139
|
+
};
|
|
140
|
+
|
|
141
|
+
const createCompletionModel = (modelId: FireworksCompletionModelId) =>
|
|
142
|
+
new OpenAICompatibleCompletionLanguageModel(modelId, {
|
|
143
|
+
...getCommonModelConfig('completion'),
|
|
144
|
+
errorStructure: fireworksErrorStructure,
|
|
145
|
+
});
|
|
146
|
+
|
|
147
|
+
const createEmbeddingModel = (modelId: FireworksEmbeddingModelId) =>
|
|
148
|
+
new OpenAICompatibleEmbeddingModel(modelId, {
|
|
149
|
+
...getCommonModelConfig('embedding'),
|
|
150
|
+
errorStructure: fireworksErrorStructure,
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
const createImageModel = (modelId: FireworksImageModelId) =>
|
|
154
|
+
new FireworksImageModel(modelId, {
|
|
155
|
+
...getCommonModelConfig('image'),
|
|
156
|
+
baseURL: baseURL ?? defaultBaseURL,
|
|
157
|
+
});
|
|
158
|
+
|
|
159
|
+
const provider = (modelId: FireworksChatModelId) => createChatModel(modelId);
|
|
160
|
+
|
|
161
|
+
provider.specificationVersion = 'v3' as const;
|
|
162
|
+
provider.completionModel = createCompletionModel;
|
|
163
|
+
provider.chatModel = createChatModel;
|
|
164
|
+
provider.languageModel = createChatModel;
|
|
165
|
+
provider.embeddingModel = createEmbeddingModel;
|
|
166
|
+
provider.textEmbeddingModel = createEmbeddingModel;
|
|
167
|
+
provider.image = createImageModel;
|
|
168
|
+
provider.imageModel = createImageModel;
|
|
169
|
+
return provider;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
export const fireworks = createFireworks();
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
export type {
|
|
2
|
+
FireworksEmbeddingModelId,
|
|
3
|
+
FireworksEmbeddingProviderOptions,
|
|
4
|
+
} from './fireworks-embedding-options';
|
|
5
|
+
export { FireworksImageModel } from './fireworks-image-model';
|
|
6
|
+
export type { FireworksImageModelId } from './fireworks-image-options';
|
|
7
|
+
export { fireworks, createFireworks } from './fireworks-provider';
|
|
8
|
+
export type {
|
|
9
|
+
FireworksProvider,
|
|
10
|
+
FireworksProviderSettings,
|
|
11
|
+
FireworksErrorData,
|
|
12
|
+
} from './fireworks-provider';
|
|
13
|
+
export { VERSION } from './version';
|