@ai-sdk/prodia 1.0.5 → 1.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/CHANGELOG.md CHANGED
@@ -1,5 +1,11 @@
1
1
  # @ai-sdk/prodia
2
2
 
3
+ ## 1.0.6
4
+
5
+ ### Patch Changes
6
+
7
+ - 8dc54db: chore: add src folders to package bundle
8
+
3
9
  ## 1.0.5
4
10
 
5
11
  ### Patch Changes
package/dist/index.js CHANGED
@@ -376,7 +376,7 @@ var prodiaFailedResponseHandler = (0, import_provider_utils.createJsonErrorRespo
376
376
  });
377
377
 
378
378
  // src/version.ts
379
- var VERSION = true ? "1.0.5" : "0.0.0-test";
379
+ var VERSION = true ? "1.0.6" : "0.0.0-test";
380
380
 
381
381
  // src/prodia-provider.ts
382
382
  var defaultBaseURL = "https://inference.prodia.com/v2";
package/dist/index.mjs CHANGED
@@ -362,7 +362,7 @@ var prodiaFailedResponseHandler = createJsonErrorResponseHandler({
362
362
  });
363
363
 
364
364
  // src/version.ts
365
- var VERSION = true ? "1.0.5" : "0.0.0-test";
365
+ var VERSION = true ? "1.0.6" : "0.0.0-test";
366
366
 
367
367
  // src/prodia-provider.ts
368
368
  var defaultBaseURL = "https://inference.prodia.com/v2";
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@ai-sdk/prodia",
3
- "version": "1.0.5",
3
+ "version": "1.0.6",
4
4
  "license": "Apache-2.0",
5
5
  "sideEffects": false,
6
6
  "main": "./dist/index.js",
@@ -8,6 +8,7 @@
8
8
  "types": "./dist/index.d.ts",
9
9
  "files": [
10
10
  "dist/**/*",
11
+ "src",
11
12
  "CHANGELOG.md"
12
13
  ],
13
14
  "exports": {
@@ -27,7 +28,7 @@
27
28
  "tsup": "^8",
28
29
  "typescript": "5.8.3",
29
30
  "zod": "3.25.76",
30
- "@ai-sdk/test-server": "1.0.1",
31
+ "@ai-sdk/test-server": "1.0.2",
31
32
  "@vercel/ai-tsconfig": "0.0.0"
32
33
  },
33
34
  "peerDependencies": {
package/src/index.ts ADDED
@@ -0,0 +1,5 @@
1
+ export type { ProdiaImageProviderOptions } from './prodia-image-model';
2
+ export type { ProdiaImageModelId } from './prodia-image-settings';
3
+ export type { ProdiaProvider, ProdiaProviderSettings } from './prodia-provider';
4
+ export { createProdia, prodia } from './prodia-provider';
5
+ export { VERSION } from './version';
@@ -0,0 +1,502 @@
1
+ import type { FetchFunction } from '@ai-sdk/provider-utils';
2
+ import { createTestServer } from '@ai-sdk/test-server/with-vitest';
3
+ import { describe, expect, it } from 'vitest';
4
+ import { ProdiaImageModel } from './prodia-image-model';
5
+
6
+ const prompt = 'A cute baby sea otter';
7
+
8
+ function createBasicModel({
9
+ headers,
10
+ fetch,
11
+ currentDate,
12
+ }: {
13
+ headers?: () => Record<string, string | undefined>;
14
+ fetch?: FetchFunction;
15
+ currentDate?: () => Date;
16
+ } = {}) {
17
+ return new ProdiaImageModel('inference.flux-fast.schnell.txt2img.v2', {
18
+ provider: 'prodia.image',
19
+ baseURL: 'https://api.example.com/v2',
20
+ headers: headers ?? (() => ({ Authorization: 'Bearer test-key' })),
21
+ fetch,
22
+ _internal: {
23
+ currentDate,
24
+ },
25
+ });
26
+ }
27
+
28
+ function createMultipartResponse(
29
+ jobResult: Record<string, unknown>,
30
+ imageContent: string = 'test-binary-content',
31
+ ): { body: Buffer; contentType: string } {
32
+ const boundary = 'test-boundary-12345';
33
+ const jobJson = JSON.stringify(jobResult);
34
+ const imageBuffer = Buffer.from(imageContent);
35
+
36
+ const parts = [
37
+ `--${boundary}\r\n`,
38
+ 'Content-Disposition: form-data; name="job"; filename="job.json"\r\n',
39
+ 'Content-Type: application/json\r\n',
40
+ '\r\n',
41
+ jobJson,
42
+ '\r\n',
43
+ `--${boundary}\r\n`,
44
+ 'Content-Disposition: form-data; name="output"; filename="output.png"\r\n',
45
+ 'Content-Type: image/png\r\n',
46
+ '\r\n',
47
+ ];
48
+
49
+ const headerPart = Buffer.from(parts.join(''));
50
+ const endPart = Buffer.from(`\r\n--${boundary}--\r\n`);
51
+
52
+ const body = Buffer.concat([headerPart, imageBuffer, endPart]);
53
+
54
+ return {
55
+ body,
56
+ contentType: `multipart/form-data; boundary=${boundary}`,
57
+ };
58
+ }
59
+
60
+ const defaultJobResult = {
61
+ id: 'job-123',
62
+ created_at: '2025-01-01T00:00:00Z',
63
+ updated_at: '2025-01-01T00:00:05Z',
64
+ state: { current: 'completed' },
65
+ config: { prompt, seed: 42 },
66
+ metrics: { elapsed: 2.5, ips: 10.5 },
67
+ };
68
+
69
+ describe('ProdiaImageModel', () => {
70
+ const multipartResponse = createMultipartResponse(defaultJobResult);
71
+
72
+ const server = createTestServer({
73
+ 'https://api.example.com/v2/job': {
74
+ response: {
75
+ type: 'binary',
76
+ body: multipartResponse.body,
77
+ headers: {
78
+ 'content-type': multipartResponse.contentType,
79
+ },
80
+ },
81
+ },
82
+ });
83
+
84
+ describe('doGenerate', () => {
85
+ it('passes the correct parameters including providerOptions', async () => {
86
+ const model = createBasicModel();
87
+
88
+ await model.doGenerate({
89
+ prompt,
90
+ files: undefined,
91
+ mask: undefined,
92
+ n: 1,
93
+ size: undefined,
94
+ aspectRatio: undefined,
95
+ seed: 12345,
96
+ providerOptions: {
97
+ prodia: {
98
+ steps: 4,
99
+ },
100
+ },
101
+ });
102
+
103
+ expect(await server.calls[0].requestBodyJson).toStrictEqual({
104
+ type: 'inference.flux-fast.schnell.txt2img.v2',
105
+ config: {
106
+ prompt,
107
+ seed: 12345,
108
+ steps: 4,
109
+ },
110
+ });
111
+ });
112
+
113
+ it('includes width and height when size is provided', async () => {
114
+ const model = createBasicModel();
115
+
116
+ await model.doGenerate({
117
+ prompt,
118
+ files: undefined,
119
+ mask: undefined,
120
+ n: 1,
121
+ size: '1024x768',
122
+ aspectRatio: undefined,
123
+ seed: undefined,
124
+ providerOptions: {},
125
+ });
126
+
127
+ expect(await server.calls[0].requestBodyJson).toStrictEqual({
128
+ type: 'inference.flux-fast.schnell.txt2img.v2',
129
+ config: {
130
+ prompt,
131
+ width: 1024,
132
+ height: 768,
133
+ },
134
+ });
135
+ });
136
+
137
+ it('provider options width/height take precedence over size', async () => {
138
+ const model = createBasicModel();
139
+
140
+ await model.doGenerate({
141
+ prompt,
142
+ files: undefined,
143
+ mask: undefined,
144
+ n: 1,
145
+ size: '1024x768',
146
+ aspectRatio: undefined,
147
+ seed: undefined,
148
+ providerOptions: {
149
+ prodia: {
150
+ width: 512,
151
+ height: 512,
152
+ },
153
+ },
154
+ });
155
+
156
+ expect(await server.calls[0].requestBodyJson).toStrictEqual({
157
+ type: 'inference.flux-fast.schnell.txt2img.v2',
158
+ config: {
159
+ prompt,
160
+ width: 512,
161
+ height: 512,
162
+ },
163
+ });
164
+ });
165
+
166
+ it('includes style_preset when stylePreset is provided', async () => {
167
+ const model = createBasicModel();
168
+
169
+ await model.doGenerate({
170
+ prompt,
171
+ files: undefined,
172
+ mask: undefined,
173
+ n: 1,
174
+ size: undefined,
175
+ aspectRatio: undefined,
176
+ seed: undefined,
177
+ providerOptions: {
178
+ prodia: {
179
+ stylePreset: 'anime',
180
+ },
181
+ },
182
+ });
183
+
184
+ expect(await server.calls[0].requestBodyJson).toStrictEqual({
185
+ type: 'inference.flux-fast.schnell.txt2img.v2',
186
+ config: {
187
+ prompt,
188
+ style_preset: 'anime',
189
+ },
190
+ });
191
+ });
192
+
193
+ it('includes loras when provided', async () => {
194
+ const model = createBasicModel();
195
+
196
+ await model.doGenerate({
197
+ prompt,
198
+ files: undefined,
199
+ mask: undefined,
200
+ n: 1,
201
+ size: undefined,
202
+ aspectRatio: undefined,
203
+ seed: undefined,
204
+ providerOptions: {
205
+ prodia: {
206
+ loras: ['prodia/lora/flux/anime@v1', 'prodia/lora/flux/realism@v1'],
207
+ },
208
+ },
209
+ });
210
+
211
+ expect(await server.calls[0].requestBodyJson).toStrictEqual({
212
+ type: 'inference.flux-fast.schnell.txt2img.v2',
213
+ config: {
214
+ prompt,
215
+ loras: ['prodia/lora/flux/anime@v1', 'prodia/lora/flux/realism@v1'],
216
+ },
217
+ });
218
+ });
219
+
220
+ it('includes progressive when provided', async () => {
221
+ const model = createBasicModel();
222
+
223
+ await model.doGenerate({
224
+ prompt,
225
+ files: undefined,
226
+ mask: undefined,
227
+ n: 1,
228
+ size: undefined,
229
+ aspectRatio: undefined,
230
+ seed: undefined,
231
+ providerOptions: {
232
+ prodia: {
233
+ progressive: true,
234
+ },
235
+ },
236
+ });
237
+
238
+ expect(await server.calls[0].requestBodyJson).toStrictEqual({
239
+ type: 'inference.flux-fast.schnell.txt2img.v2',
240
+ config: {
241
+ prompt,
242
+ progressive: true,
243
+ },
244
+ });
245
+ });
246
+
247
+ it('calls the correct endpoint', async () => {
248
+ const model = createBasicModel();
249
+
250
+ await model.doGenerate({
251
+ prompt,
252
+ files: undefined,
253
+ mask: undefined,
254
+ n: 1,
255
+ size: undefined,
256
+ aspectRatio: undefined,
257
+ seed: undefined,
258
+ providerOptions: {},
259
+ });
260
+
261
+ expect(server.calls[0].requestMethod).toBe('POST');
262
+ expect(server.calls[0].requestUrl).toBe('https://api.example.com/v2/job');
263
+ });
264
+
265
+ it('sends Accept: multipart/form-data header', async () => {
266
+ const model = createBasicModel();
267
+
268
+ await model.doGenerate({
269
+ prompt,
270
+ files: undefined,
271
+ mask: undefined,
272
+ n: 1,
273
+ size: undefined,
274
+ aspectRatio: undefined,
275
+ seed: undefined,
276
+ providerOptions: {},
277
+ });
278
+
279
+ expect(server.calls[0].requestHeaders.accept).toBe(
280
+ 'multipart/form-data; image/png',
281
+ );
282
+ });
283
+
284
+ it('merges provider and request headers', async () => {
285
+ const modelWithHeaders = createBasicModel({
286
+ headers: () => ({
287
+ 'Custom-Provider-Header': 'provider-header-value',
288
+ Authorization: 'Bearer test-key',
289
+ }),
290
+ });
291
+
292
+ await modelWithHeaders.doGenerate({
293
+ prompt,
294
+ files: undefined,
295
+ mask: undefined,
296
+ n: 1,
297
+ providerOptions: {},
298
+ headers: {
299
+ 'Custom-Request-Header': 'request-header-value',
300
+ },
301
+ size: undefined,
302
+ seed: undefined,
303
+ aspectRatio: undefined,
304
+ });
305
+
306
+ expect(server.calls[0].requestHeaders).toMatchObject({
307
+ 'content-type': 'application/json',
308
+ 'custom-provider-header': 'provider-header-value',
309
+ 'custom-request-header': 'request-header-value',
310
+ authorization: 'Bearer test-key',
311
+ accept: 'multipart/form-data; image/png',
312
+ });
313
+ });
314
+
315
+ it('returns image bytes from multipart response', async () => {
316
+ const model = createBasicModel();
317
+
318
+ const result = await model.doGenerate({
319
+ prompt,
320
+ files: undefined,
321
+ mask: undefined,
322
+ n: 1,
323
+ size: undefined,
324
+ seed: undefined,
325
+ aspectRatio: undefined,
326
+ providerOptions: {},
327
+ });
328
+
329
+ expect(result.images).toHaveLength(1);
330
+ const image = result.images[0];
331
+ expect(image).toBeInstanceOf(Uint8Array);
332
+ expect(Buffer.from(image as Uint8Array<ArrayBufferLike>).toString()).toBe(
333
+ 'test-binary-content',
334
+ );
335
+ });
336
+
337
+ it('returns provider metadata from job result', async () => {
338
+ const model = createBasicModel();
339
+
340
+ const result = await model.doGenerate({
341
+ prompt,
342
+ files: undefined,
343
+ mask: undefined,
344
+ n: 1,
345
+ size: undefined,
346
+ seed: undefined,
347
+ aspectRatio: undefined,
348
+ providerOptions: {},
349
+ });
350
+
351
+ expect(result.providerMetadata?.prodia).toStrictEqual({
352
+ images: [
353
+ {
354
+ jobId: 'job-123',
355
+ seed: 42,
356
+ elapsed: 2.5,
357
+ iterationsPerSecond: 10.5,
358
+ createdAt: '2025-01-01T00:00:00Z',
359
+ updatedAt: '2025-01-01T00:00:05Z',
360
+ },
361
+ ],
362
+ });
363
+ });
364
+
365
+ it('omits optional metadata fields when not present in job result', async () => {
366
+ const minimalJobResult = {
367
+ id: 'job-456',
368
+ state: { current: 'completed' },
369
+ config: { prompt },
370
+ };
371
+ const response = createMultipartResponse(minimalJobResult);
372
+
373
+ server.urls['https://api.example.com/v2/job'].response = {
374
+ type: 'binary',
375
+ body: response.body,
376
+ headers: {
377
+ 'content-type': response.contentType,
378
+ },
379
+ };
380
+
381
+ const model = createBasicModel();
382
+
383
+ const result = await model.doGenerate({
384
+ prompt,
385
+ files: undefined,
386
+ mask: undefined,
387
+ n: 1,
388
+ size: undefined,
389
+ seed: undefined,
390
+ aspectRatio: undefined,
391
+ providerOptions: {},
392
+ });
393
+
394
+ expect(result.providerMetadata?.prodia).toStrictEqual({
395
+ images: [
396
+ {
397
+ jobId: 'job-456',
398
+ },
399
+ ],
400
+ });
401
+ });
402
+
403
+ it('warns on invalid size format', async () => {
404
+ const model = createBasicModel();
405
+
406
+ const result = await model.doGenerate({
407
+ prompt,
408
+ files: undefined,
409
+ mask: undefined,
410
+ n: 1,
411
+ size: 'invalid' as `${number}x${number}`,
412
+ seed: undefined,
413
+ aspectRatio: undefined,
414
+ providerOptions: {},
415
+ });
416
+
417
+ expect(result.warnings).toMatchInlineSnapshot(`
418
+ [
419
+ {
420
+ "details": "Invalid size format: invalid. Expected format: WIDTHxHEIGHT (e.g., 1024x1024)",
421
+ "feature": "size",
422
+ "type": "unsupported",
423
+ },
424
+ ]
425
+ `);
426
+ });
427
+
428
+ it('handles API errors', async () => {
429
+ server.urls['https://api.example.com/v2/job'].response = {
430
+ type: 'error',
431
+ status: 400,
432
+ body: JSON.stringify({
433
+ message: 'Invalid prompt',
434
+ detail: 'Prompt cannot be empty',
435
+ }),
436
+ };
437
+
438
+ const model = createBasicModel();
439
+
440
+ await expect(
441
+ model.doGenerate({
442
+ prompt,
443
+ files: undefined,
444
+ mask: undefined,
445
+ n: 1,
446
+ providerOptions: {},
447
+ size: undefined,
448
+ seed: undefined,
449
+ aspectRatio: undefined,
450
+ }),
451
+ ).rejects.toMatchObject({
452
+ message: 'Prompt cannot be empty',
453
+ statusCode: 400,
454
+ url: 'https://api.example.com/v2/job',
455
+ });
456
+ });
457
+
458
+ it('includes timestamp, headers, and modelId in response metadata', async () => {
459
+ const response = createMultipartResponse(defaultJobResult);
460
+ server.urls['https://api.example.com/v2/job'].response = {
461
+ type: 'binary',
462
+ body: response.body,
463
+ headers: {
464
+ 'content-type': response.contentType,
465
+ },
466
+ };
467
+
468
+ const testDate = new Date('2025-01-01T00:00:00Z');
469
+ const model = createBasicModel({
470
+ currentDate: () => testDate,
471
+ });
472
+
473
+ const result = await model.doGenerate({
474
+ prompt,
475
+ files: undefined,
476
+ mask: undefined,
477
+ n: 1,
478
+ providerOptions: {},
479
+ size: undefined,
480
+ seed: undefined,
481
+ aspectRatio: undefined,
482
+ });
483
+
484
+ expect(result.response).toStrictEqual({
485
+ timestamp: testDate,
486
+ modelId: 'inference.flux-fast.schnell.txt2img.v2',
487
+ headers: expect.any(Object),
488
+ });
489
+ });
490
+ });
491
+
492
+ describe('constructor', () => {
493
+ it('exposes correct provider and model information', () => {
494
+ const model = createBasicModel();
495
+
496
+ expect(model.provider).toBe('prodia.image');
497
+ expect(model.modelId).toBe('inference.flux-fast.schnell.txt2img.v2');
498
+ expect(model.specificationVersion).toBe('v3');
499
+ expect(model.maxImagesPerCall).toBe(1);
500
+ });
501
+ });
502
+ });
@@ -0,0 +1,450 @@
1
+ import type { ImageModelV3, SharedV3Warning } from '@ai-sdk/provider';
2
+ import type { InferSchema, Resolvable } from '@ai-sdk/provider-utils';
3
+ import {
4
+ combineHeaders,
5
+ createJsonErrorResponseHandler,
6
+ type FetchFunction,
7
+ lazySchema,
8
+ parseProviderOptions,
9
+ postToApi,
10
+ resolve,
11
+ zodSchema,
12
+ } from '@ai-sdk/provider-utils';
13
+ import { z } from 'zod/v4';
14
+ import type { ProdiaImageModelId } from './prodia-image-settings';
15
+
16
+ export class ProdiaImageModel implements ImageModelV3 {
17
+ readonly specificationVersion = 'v3';
18
+ readonly maxImagesPerCall = 1;
19
+
20
+ get provider(): string {
21
+ return this.config.provider;
22
+ }
23
+
24
+ constructor(
25
+ readonly modelId: ProdiaImageModelId,
26
+ private readonly config: ProdiaImageModelConfig,
27
+ ) {}
28
+
29
+ private async getArgs({
30
+ prompt,
31
+ size,
32
+ seed,
33
+ providerOptions,
34
+ }: Parameters<ImageModelV3['doGenerate']>[0]) {
35
+ const warnings: Array<SharedV3Warning> = [];
36
+
37
+ const prodiaOptions = await parseProviderOptions({
38
+ provider: 'prodia',
39
+ providerOptions,
40
+ schema: prodiaImageProviderOptionsSchema,
41
+ });
42
+
43
+ let width: number | undefined;
44
+ let height: number | undefined;
45
+ if (size) {
46
+ const [widthStr, heightStr] = size.split('x');
47
+ width = Number(widthStr);
48
+ height = Number(heightStr);
49
+ if (
50
+ !widthStr ||
51
+ !heightStr ||
52
+ !Number.isFinite(width) ||
53
+ !Number.isFinite(height)
54
+ ) {
55
+ warnings.push({
56
+ type: 'unsupported',
57
+ feature: 'size',
58
+ details: `Invalid size format: ${size}. Expected format: WIDTHxHEIGHT (e.g., 1024x1024)`,
59
+ });
60
+ width = undefined;
61
+ height = undefined;
62
+ }
63
+ }
64
+
65
+ const jobConfig: Record<string, unknown> = {
66
+ prompt,
67
+ };
68
+
69
+ if (prodiaOptions?.width !== undefined) {
70
+ jobConfig.width = prodiaOptions.width;
71
+ } else if (width !== undefined) {
72
+ jobConfig.width = width;
73
+ }
74
+
75
+ if (prodiaOptions?.height !== undefined) {
76
+ jobConfig.height = prodiaOptions.height;
77
+ } else if (height !== undefined) {
78
+ jobConfig.height = height;
79
+ }
80
+
81
+ if (seed !== undefined) {
82
+ jobConfig.seed = seed;
83
+ }
84
+ if (prodiaOptions?.steps !== undefined) {
85
+ jobConfig.steps = prodiaOptions.steps;
86
+ }
87
+ if (prodiaOptions?.stylePreset !== undefined) {
88
+ jobConfig.style_preset = prodiaOptions.stylePreset;
89
+ }
90
+ if (prodiaOptions?.loras !== undefined && prodiaOptions.loras.length > 0) {
91
+ jobConfig.loras = prodiaOptions.loras;
92
+ }
93
+ if (prodiaOptions?.progressive !== undefined) {
94
+ jobConfig.progressive = prodiaOptions.progressive;
95
+ }
96
+
97
+ const body = {
98
+ type: this.modelId,
99
+ config: jobConfig,
100
+ };
101
+
102
+ return { body, warnings };
103
+ }
104
+
105
+ async doGenerate(
106
+ options: Parameters<ImageModelV3['doGenerate']>[0],
107
+ ): Promise<Awaited<ReturnType<ImageModelV3['doGenerate']>>> {
108
+ const { body, warnings } = await this.getArgs(options);
109
+
110
+ const currentDate = this.config._internal?.currentDate?.() ?? new Date();
111
+ const combinedHeaders = combineHeaders(
112
+ await resolve(this.config.headers),
113
+ options.headers,
114
+ );
115
+
116
+ const { value: multipartResult, responseHeaders } = await postToApi({
117
+ url: `${this.config.baseURL}/job`,
118
+ headers: {
119
+ ...combinedHeaders,
120
+ Accept: 'multipart/form-data; image/png',
121
+ 'Content-Type': 'application/json',
122
+ },
123
+ body: {
124
+ content: JSON.stringify(body),
125
+ values: body,
126
+ },
127
+ failedResponseHandler: prodiaFailedResponseHandler,
128
+ successfulResponseHandler: createMultipartResponseHandler(),
129
+ abortSignal: options.abortSignal,
130
+ fetch: this.config.fetch,
131
+ });
132
+
133
+ const { jobResult, imageBytes } = multipartResult;
134
+
135
+ return {
136
+ images: [imageBytes],
137
+ warnings,
138
+ providerMetadata: {
139
+ prodia: {
140
+ images: [
141
+ {
142
+ jobId: jobResult.id,
143
+ ...(jobResult.config?.seed != null && {
144
+ seed: jobResult.config.seed,
145
+ }),
146
+ ...(jobResult.metrics?.elapsed != null && {
147
+ elapsed: jobResult.metrics.elapsed,
148
+ }),
149
+ ...(jobResult.metrics?.ips != null && {
150
+ iterationsPerSecond: jobResult.metrics.ips,
151
+ }),
152
+ ...(jobResult.created_at != null && {
153
+ createdAt: jobResult.created_at,
154
+ }),
155
+ ...(jobResult.updated_at != null && {
156
+ updatedAt: jobResult.updated_at,
157
+ }),
158
+ },
159
+ ],
160
+ },
161
+ },
162
+ response: {
163
+ modelId: this.modelId,
164
+ timestamp: currentDate,
165
+ headers: responseHeaders,
166
+ },
167
+ };
168
+ }
169
+ }
170
+
171
+ const stylePresets = [
172
+ '3d-model',
173
+ 'analog-film',
174
+ 'anime',
175
+ 'cinematic',
176
+ 'comic-book',
177
+ 'digital-art',
178
+ 'enhance',
179
+ 'fantasy-art',
180
+ 'isometric',
181
+ 'line-art',
182
+ 'low-poly',
183
+ 'neon-punk',
184
+ 'origami',
185
+ 'photographic',
186
+ 'pixel-art',
187
+ 'texture',
188
+ 'craft-clay',
189
+ ] as const;
190
+
191
+ export const prodiaImageProviderOptionsSchema = lazySchema(() =>
192
+ zodSchema(
193
+ z.object({
194
+ /**
195
+ * Amount of computational iterations to run. More is typically higher quality.
196
+ */
197
+ steps: z.number().int().min(1).max(4).optional(),
198
+ /**
199
+ * Width of the output image in pixels.
200
+ */
201
+ width: z.number().int().min(256).max(1920).optional(),
202
+ /**
203
+ * Height of the output image in pixels.
204
+ */
205
+ height: z.number().int().min(256).max(1920).optional(),
206
+ /**
207
+ * Apply a visual theme to your output image.
208
+ */
209
+ stylePreset: z.enum(stylePresets).optional(),
210
+ /**
211
+ * Augment the output with a LoRa model.
212
+ */
213
+ loras: z.array(z.string()).max(3).optional(),
214
+ /**
215
+ * When using JPEG output, return a progressive JPEG.
216
+ */
217
+ progressive: z.boolean().optional(),
218
+ }),
219
+ ),
220
+ );
221
+
222
+ export type ProdiaImageProviderOptions = InferSchema<
223
+ typeof prodiaImageProviderOptionsSchema
224
+ >;
225
+
226
+ interface ProdiaImageModelConfig {
227
+ provider: string;
228
+ baseURL: string;
229
+ headers?: Resolvable<Record<string, string | undefined>>;
230
+ fetch?: FetchFunction;
231
+ _internal?: {
232
+ currentDate?: () => Date;
233
+ };
234
+ }
235
+
236
+ const prodiaJobResultSchema = z.object({
237
+ id: z.string(),
238
+ created_at: z.string().optional(),
239
+ updated_at: z.string().optional(),
240
+ expires_at: z.string().optional(),
241
+ state: z
242
+ .object({
243
+ current: z.string(),
244
+ })
245
+ .optional(),
246
+ config: z
247
+ .object({
248
+ seed: z.number().optional(),
249
+ })
250
+ .passthrough()
251
+ .optional(),
252
+ metrics: z
253
+ .object({
254
+ elapsed: z.number().optional(),
255
+ ips: z.number().optional(),
256
+ })
257
+ .optional(),
258
+ });
259
+
260
+ type ProdiaJobResult = z.infer<typeof prodiaJobResultSchema>;
261
+
262
+ interface MultipartResult {
263
+ jobResult: ProdiaJobResult;
264
+ imageBytes: Uint8Array;
265
+ }
266
+
267
+ function createMultipartResponseHandler() {
268
+ return async ({
269
+ response,
270
+ }: {
271
+ response: Response;
272
+ }): Promise<{
273
+ value: MultipartResult;
274
+ responseHeaders: Record<string, string>;
275
+ }> => {
276
+ const contentType = response.headers.get('content-type') ?? '';
277
+ const responseHeaders: Record<string, string> = {};
278
+ response.headers.forEach((value, key) => {
279
+ responseHeaders[key] = value;
280
+ });
281
+
282
+ const boundaryMatch = contentType.match(/boundary=([^\s;]+)/);
283
+ if (!boundaryMatch) {
284
+ throw new Error(
285
+ `Prodia response missing multipart boundary in content-type: ${contentType}`,
286
+ );
287
+ }
288
+ const boundary = boundaryMatch[1];
289
+
290
+ const arrayBuffer = await response.arrayBuffer();
291
+ const bytes = new Uint8Array(arrayBuffer);
292
+
293
+ const parts = parseMultipart(bytes, boundary);
294
+
295
+ let jobResult: ProdiaJobResult | undefined;
296
+ let imageBytes: Uint8Array | undefined;
297
+
298
+ for (const part of parts) {
299
+ const contentDisposition = part.headers['content-disposition'] ?? '';
300
+ const partContentType = part.headers['content-type'] ?? '';
301
+
302
+ if (contentDisposition.includes('name="job"')) {
303
+ const jsonStr = new TextDecoder().decode(part.body);
304
+ jobResult = prodiaJobResultSchema.parse(JSON.parse(jsonStr));
305
+ } else if (contentDisposition.includes('name="output"')) {
306
+ imageBytes = part.body;
307
+ } else if (partContentType.startsWith('image/')) {
308
+ imageBytes = part.body;
309
+ }
310
+ }
311
+
312
+ if (!jobResult) {
313
+ throw new Error('Prodia multipart response missing job part');
314
+ }
315
+ if (!imageBytes) {
316
+ throw new Error('Prodia multipart response missing output image');
317
+ }
318
+
319
+ return {
320
+ value: { jobResult, imageBytes },
321
+ responseHeaders,
322
+ };
323
+ };
324
+ }
325
+
326
+ interface MultipartPart {
327
+ headers: Record<string, string>;
328
+ body: Uint8Array;
329
+ }
330
+
331
+ function parseMultipart(data: Uint8Array, boundary: string): MultipartPart[] {
332
+ const parts: MultipartPart[] = [];
333
+ const boundaryBytes = new TextEncoder().encode(`--${boundary}`);
334
+ const endBoundaryBytes = new TextEncoder().encode(`--${boundary}--`);
335
+
336
+ const positions: number[] = [];
337
+ for (let i = 0; i <= data.length - boundaryBytes.length; i++) {
338
+ let match = true;
339
+ for (let j = 0; j < boundaryBytes.length; j++) {
340
+ if (data[i + j] !== boundaryBytes[j]) {
341
+ match = false;
342
+ break;
343
+ }
344
+ }
345
+ if (match) {
346
+ positions.push(i);
347
+ }
348
+ }
349
+
350
+ for (let i = 0; i < positions.length - 1; i++) {
351
+ const start = positions[i] + boundaryBytes.length;
352
+ const end = positions[i + 1];
353
+
354
+ let isEndBoundary = true;
355
+ for (let j = 0; j < endBoundaryBytes.length && isEndBoundary; j++) {
356
+ if (data[positions[i] + j] !== endBoundaryBytes[j]) {
357
+ isEndBoundary = false;
358
+ }
359
+ }
360
+ if (
361
+ isEndBoundary &&
362
+ positions[i] + endBoundaryBytes.length <= data.length
363
+ ) {
364
+ continue;
365
+ }
366
+
367
+ let partStart = start;
368
+ if (data[partStart] === 0x0d && data[partStart + 1] === 0x0a) {
369
+ partStart += 2;
370
+ } else if (data[partStart] === 0x0a) {
371
+ partStart += 1;
372
+ }
373
+
374
+ let partEnd = end;
375
+ if (data[partEnd - 2] === 0x0d && data[partEnd - 1] === 0x0a) {
376
+ partEnd -= 2;
377
+ } else if (data[partEnd - 1] === 0x0a) {
378
+ partEnd -= 1;
379
+ }
380
+
381
+ const partData = data.slice(partStart, partEnd);
382
+
383
+ let headerEnd = -1;
384
+ for (let j = 0; j < partData.length - 3; j++) {
385
+ if (
386
+ partData[j] === 0x0d &&
387
+ partData[j + 1] === 0x0a &&
388
+ partData[j + 2] === 0x0d &&
389
+ partData[j + 3] === 0x0a
390
+ ) {
391
+ headerEnd = j;
392
+ break;
393
+ }
394
+ if (partData[j] === 0x0a && partData[j + 1] === 0x0a) {
395
+ headerEnd = j;
396
+ break;
397
+ }
398
+ }
399
+
400
+ if (headerEnd === -1) {
401
+ continue;
402
+ }
403
+
404
+ const headerBytes = partData.slice(0, headerEnd);
405
+ const headerStr = new TextDecoder().decode(headerBytes);
406
+ const headers: Record<string, string> = {};
407
+ for (const line of headerStr.split(/\r?\n/)) {
408
+ const colonIdx = line.indexOf(':');
409
+ if (colonIdx > 0) {
410
+ const key = line.slice(0, colonIdx).trim().toLowerCase();
411
+ const value = line.slice(colonIdx + 1).trim();
412
+ headers[key] = value;
413
+ }
414
+ }
415
+
416
+ let bodyStart = headerEnd + 2;
417
+ if (partData[headerEnd] === 0x0d) {
418
+ bodyStart = headerEnd + 4;
419
+ }
420
+ const body = partData.slice(bodyStart);
421
+
422
+ parts.push({ headers, body });
423
+ }
424
+
425
+ return parts;
426
+ }
427
+
428
+ const prodiaErrorSchema = z.object({
429
+ message: z.string().optional(),
430
+ detail: z.unknown().optional(),
431
+ error: z.string().optional(),
432
+ });
433
+
434
+ const prodiaFailedResponseHandler = createJsonErrorResponseHandler({
435
+ errorSchema: prodiaErrorSchema,
436
+ errorToMessage: error => {
437
+ const parsed = prodiaErrorSchema.safeParse(error);
438
+ if (!parsed.success) return 'Unknown Prodia error';
439
+ const { message, detail, error: errorField } = parsed.data;
440
+ if (typeof detail === 'string') return detail;
441
+ if (detail != null) {
442
+ try {
443
+ return JSON.stringify(detail);
444
+ } catch {
445
+ // ignore
446
+ }
447
+ }
448
+ return errorField ?? message ?? 'Unknown Prodia error';
449
+ },
450
+ });
@@ -0,0 +1,7 @@
1
+ /**
2
+ * Prodia job types for image generation.
3
+ */
4
+ export type ProdiaImageModelId =
5
+ | 'inference.flux-fast.schnell.txt2img.v2'
6
+ | 'inference.flux.schnell.txt2img.v2'
7
+ | (string & {});
@@ -0,0 +1,123 @@
1
+ import { createTestServer } from '@ai-sdk/test-server/with-vitest';
2
+ import { describe, expect, it } from 'vitest';
3
+ import { createProdia } from './prodia-provider';
4
+
5
+ function createMultipartResponse(
6
+ jobResult: Record<string, unknown>,
7
+ imageContent: string = 'test-image',
8
+ ): { body: Buffer; contentType: string } {
9
+ const boundary = 'test-boundary-12345';
10
+ const jobJson = JSON.stringify(jobResult);
11
+ const imageBuffer = Buffer.from(imageContent);
12
+
13
+ const parts = [
14
+ `--${boundary}\r\n`,
15
+ 'Content-Disposition: form-data; name="job"; filename="job.json"\r\n',
16
+ 'Content-Type: application/json\r\n',
17
+ '\r\n',
18
+ jobJson,
19
+ '\r\n',
20
+ `--${boundary}\r\n`,
21
+ 'Content-Disposition: form-data; name="output"; filename="output.png"\r\n',
22
+ 'Content-Type: image/png\r\n',
23
+ '\r\n',
24
+ ];
25
+
26
+ const headerPart = Buffer.from(parts.join(''));
27
+ const endPart = Buffer.from(`\r\n--${boundary}--\r\n`);
28
+
29
+ const body = Buffer.concat([headerPart, imageBuffer, endPart]);
30
+
31
+ return {
32
+ body,
33
+ contentType: `multipart/form-data; boundary=${boundary}`,
34
+ };
35
+ }
36
+
37
+ const defaultJobResult = {
38
+ id: 'job-123',
39
+ state: { current: 'completed' },
40
+ config: { prompt: 'test' },
41
+ };
42
+
43
+ const multipartResponse = createMultipartResponse(defaultJobResult);
44
+
45
+ const server = createTestServer({
46
+ 'https://api.example.com/v2/job': {
47
+ response: {
48
+ type: 'binary',
49
+ body: multipartResponse.body,
50
+ headers: {
51
+ 'content-type': multipartResponse.contentType,
52
+ },
53
+ },
54
+ },
55
+ });
56
+
57
+ describe('Prodia provider', () => {
58
+ it('creates image models via .image and .imageModel', () => {
59
+ const provider = createProdia();
60
+
61
+ const imageModel = provider.image('inference.flux-fast.schnell.txt2img.v2');
62
+ const imageModel2 = provider.imageModel(
63
+ 'inference.flux.schnell.txt2img.v2',
64
+ );
65
+
66
+ expect(imageModel.provider).toBe('prodia.image');
67
+ expect(imageModel.modelId).toBe('inference.flux-fast.schnell.txt2img.v2');
68
+ expect(imageModel2.modelId).toBe('inference.flux.schnell.txt2img.v2');
69
+ expect(imageModel.specificationVersion).toBe('v3');
70
+ });
71
+
72
+ it('configures baseURL and headers correctly', async () => {
73
+ const provider = createProdia({
74
+ apiKey: 'test-api-key',
75
+ baseURL: 'https://api.example.com/v2',
76
+ headers: {
77
+ 'x-extra-header': 'extra',
78
+ },
79
+ });
80
+
81
+ const model = provider.image('inference.flux-fast.schnell.txt2img.v2');
82
+
83
+ await model.doGenerate({
84
+ prompt: 'A serene mountain landscape at sunset',
85
+ files: undefined,
86
+ mask: undefined,
87
+ n: 1,
88
+ size: undefined,
89
+ seed: undefined,
90
+ aspectRatio: undefined,
91
+ providerOptions: {},
92
+ });
93
+
94
+ expect(server.calls[0].requestUrl).toBe('https://api.example.com/v2/job');
95
+ expect(server.calls[0].requestMethod).toBe('POST');
96
+ expect(server.calls[0].requestHeaders.authorization).toBe(
97
+ 'Bearer test-api-key',
98
+ );
99
+ expect(server.calls[0].requestHeaders['x-extra-header']).toBe('extra');
100
+ expect(server.calls[0].requestHeaders.accept).toBe(
101
+ 'multipart/form-data; image/png',
102
+ );
103
+ expect(await server.calls[0].requestBodyJson).toMatchObject({
104
+ type: 'inference.flux-fast.schnell.txt2img.v2',
105
+ config: {
106
+ prompt: 'A serene mountain landscape at sunset',
107
+ },
108
+ });
109
+
110
+ expect(server.calls[0].requestUserAgent).toContain('ai-sdk/prodia/');
111
+ });
112
+
113
+ it('throws NoSuchModelError for unsupported model types', () => {
114
+ const provider = createProdia();
115
+
116
+ expect(() => provider.languageModel('some-id')).toThrowError(
117
+ 'No such languageModel',
118
+ );
119
+ expect(() => provider.embeddingModel('some-id')).toThrowError(
120
+ 'No such embeddingModel',
121
+ );
122
+ });
123
+ });
@@ -0,0 +1,107 @@
1
+ import {
2
+ type ImageModelV3,
3
+ NoSuchModelError,
4
+ type ProviderV3,
5
+ } from '@ai-sdk/provider';
6
+ import type { FetchFunction } from '@ai-sdk/provider-utils';
7
+ import {
8
+ loadApiKey,
9
+ withoutTrailingSlash,
10
+ withUserAgentSuffix,
11
+ } from '@ai-sdk/provider-utils';
12
+ import { ProdiaImageModel } from './prodia-image-model';
13
+ import type { ProdiaImageModelId } from './prodia-image-settings';
14
+ import { VERSION } from './version';
15
+
16
+ export interface ProdiaProviderSettings {
17
+ /**
18
+ * Prodia API key. Default value is taken from the `PRODIA_TOKEN` environment variable.
19
+ */
20
+ apiKey?: string;
21
+
22
+ /**
23
+ * Base URL for the API calls. Defaults to `https://inference.prodia.com/v2`.
24
+ */
25
+ baseURL?: string;
26
+
27
+ /**
28
+ * Custom headers to include in the requests.
29
+ */
30
+ headers?: Record<string, string>;
31
+
32
+ /**
33
+ * Custom fetch implementation. You can use it as a middleware to intercept
34
+ * requests, or to provide a custom fetch implementation for e.g. testing.
35
+ */
36
+ fetch?: FetchFunction;
37
+ }
38
+
39
+ export interface ProdiaProvider extends ProviderV3 {
40
+ /**
41
+ * Creates a model for image generation.
42
+ */
43
+ image(modelId: ProdiaImageModelId): ImageModelV3;
44
+
45
+ /**
46
+ * Creates a model for image generation.
47
+ */
48
+ imageModel(modelId: ProdiaImageModelId): ImageModelV3;
49
+
50
+ /**
51
+ * @deprecated Use `embeddingModel` instead.
52
+ */
53
+ textEmbeddingModel(modelId: string): never;
54
+ }
55
+
56
+ const defaultBaseURL = 'https://inference.prodia.com/v2';
57
+
58
+ export function createProdia(
59
+ options: ProdiaProviderSettings = {},
60
+ ): ProdiaProvider {
61
+ const baseURL = withoutTrailingSlash(options.baseURL ?? defaultBaseURL);
62
+ const getHeaders = () =>
63
+ withUserAgentSuffix(
64
+ {
65
+ Authorization: `Bearer ${loadApiKey({
66
+ apiKey: options.apiKey,
67
+ environmentVariableName: 'PRODIA_TOKEN',
68
+ description: 'Prodia',
69
+ })}`,
70
+ ...options.headers,
71
+ },
72
+ `ai-sdk/prodia/${VERSION}`,
73
+ );
74
+
75
+ const createImageModel = (modelId: ProdiaImageModelId) =>
76
+ new ProdiaImageModel(modelId, {
77
+ provider: 'prodia.image',
78
+ baseURL: baseURL ?? defaultBaseURL,
79
+ headers: getHeaders,
80
+ fetch: options.fetch,
81
+ });
82
+
83
+ const embeddingModel = (modelId: string) => {
84
+ throw new NoSuchModelError({
85
+ modelId,
86
+ modelType: 'embeddingModel',
87
+ });
88
+ };
89
+
90
+ const languageModel = (modelId: string) => {
91
+ throw new NoSuchModelError({
92
+ modelId,
93
+ modelType: 'languageModel',
94
+ });
95
+ };
96
+
97
+ return {
98
+ specificationVersion: 'v3',
99
+ imageModel: createImageModel,
100
+ image: createImageModel,
101
+ languageModel,
102
+ embeddingModel,
103
+ textEmbeddingModel: embeddingModel,
104
+ };
105
+ }
106
+
107
+ export const prodia = createProdia();
package/src/version.ts ADDED
@@ -0,0 +1,6 @@
1
+ // Version string of this package injected at build time.
2
+ declare const __PACKAGE_VERSION__: string | undefined;
3
+ export const VERSION: string =
4
+ typeof __PACKAGE_VERSION__ !== 'undefined'
5
+ ? __PACKAGE_VERSION__
6
+ : '0.0.0-test';