@outputai/llm 0.6.1-next.5d7e612.0 → 0.6.1-next.fc6a93e.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +2 -2
- package/src/ai_sdk.js +16 -10
- package/src/ai_sdk.spec.js +142 -10
- package/src/ai_sdk_options.js +6 -2
- package/src/ai_sdk_options.spec.js +10 -4
- package/src/index.d.ts +7 -0
- package/src/utils/error_handler.js +104 -0
- package/src/utils/error_handler.spec.js +285 -0
- package/src/validations.js +43 -0
- package/src/validations.spec.js +77 -1
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@outputai/llm",
|
|
3
|
-
"version": "0.6.1-next.
|
|
3
|
+
"version": "0.6.1-next.fc6a93e.0",
|
|
4
4
|
"description": "Framework abstraction to interact with LLM models",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "src/index.js",
|
|
@@ -23,7 +23,7 @@
|
|
|
23
23
|
"gray-matter": "4.0.3",
|
|
24
24
|
"liquidjs": "10.25.7",
|
|
25
25
|
"undici": "8.1.0",
|
|
26
|
-
"@outputai/core": "0.6.1-next.
|
|
26
|
+
"@outputai/core": "0.6.1-next.fc6a93e.0"
|
|
27
27
|
},
|
|
28
28
|
"license": "Apache-2.0",
|
|
29
29
|
"publishConfig": {
|
package/src/ai_sdk.js
CHANGED
|
@@ -2,12 +2,13 @@ import { types as utilTypes } from 'node:util';
|
|
|
2
2
|
import * as AI from 'ai';
|
|
3
3
|
import { stepCountIs } from 'ai';
|
|
4
4
|
import { ValidationError } from '@outputai/core';
|
|
5
|
-
import { validateGenerateTextArgs, validateStreamTextArgs } from './validations.js';
|
|
5
|
+
import { validateGenerateTextArgs, validateStreamTextArgs, validateGenerateImageArgs } from './validations.js';
|
|
6
6
|
import { loadPrompt } from './prompt/loader.js';
|
|
7
7
|
import { startTrace, endTraceWithError } from './utils/trace.js';
|
|
8
8
|
import { wrapTextResponse, wrapStreamOnFinishResponse, wrapImageResponse } from './utils/response_wrappers.js';
|
|
9
9
|
import { loadAiSdkTextOptions, loadAiSdkImageOptions } from './ai_sdk_options.js';
|
|
10
10
|
import { prepareTextPrompt } from './prompt/prepare_text.js';
|
|
11
|
+
import { mapAiError } from './utils/error_handler.js';
|
|
11
12
|
|
|
12
13
|
export async function generateText( { prompt, variables, promptDir, skills = [], maxSteps = 10, ...aiSdkArgs } ) {
|
|
13
14
|
validateGenerateTextArgs( { prompt, variables, promptDir, skills, maxSteps } );
|
|
@@ -27,13 +28,14 @@ export async function generateText( { prompt, variables, promptDir, skills = [],
|
|
|
27
28
|
...( tools && !aiSdkArgs.stopWhen ? { stopWhen: stepCountIs( maxSteps ) } : {} )
|
|
28
29
|
} );
|
|
29
30
|
return wrapTextResponse( { traceId, modelId, response } );
|
|
30
|
-
} catch (
|
|
31
|
+
} catch ( originalError ) {
|
|
32
|
+
const error = mapAiError( originalError );
|
|
31
33
|
endTraceWithError( { traceId, error } );
|
|
32
34
|
throw error;
|
|
33
35
|
}
|
|
34
36
|
}
|
|
35
37
|
|
|
36
|
-
export function streamText( { prompt, variables, promptDir, skills = [], maxSteps = 10, onFinish, onError, ...aiSdkArgs } ) {
|
|
38
|
+
export function streamText( { prompt, variables, promptDir, skills = [], maxSteps = 10, onFinish, onError: _onError, ...aiSdkArgs } ) {
|
|
37
39
|
validateStreamTextArgs( { prompt, variables, promptDir, skills, maxSteps } );
|
|
38
40
|
|
|
39
41
|
const parsedSkills = typeof skills === 'function' ? skills( variables ) : skills;
|
|
@@ -54,30 +56,34 @@ export function streamText( { prompt, variables, promptDir, skills = [], maxStep
|
|
|
54
56
|
...( tools && !aiSdkArgs.stopWhen ? { stopWhen: stepCountIs( maxSteps ) } : {} ),
|
|
55
57
|
...wrapStreamOnFinishResponse( { traceId, modelId, onFinish } ),
|
|
56
58
|
onError( event ) {
|
|
57
|
-
|
|
58
|
-
|
|
59
|
+
const error = mapAiError( event.error );
|
|
60
|
+
endTraceWithError( { traceId, error } );
|
|
61
|
+
_onError?.( { ...event, error } );
|
|
59
62
|
}
|
|
60
63
|
} );
|
|
61
|
-
} catch (
|
|
64
|
+
} catch ( originalError ) {
|
|
65
|
+
const error = mapAiError( originalError );
|
|
62
66
|
endTraceWithError( { traceId, error } );
|
|
63
67
|
throw error;
|
|
64
68
|
}
|
|
65
69
|
}
|
|
66
70
|
|
|
67
|
-
export async function generateImage( { prompt, variables, promptDir, ...aiSdkArgs } ) {
|
|
68
|
-
|
|
71
|
+
export async function generateImage( { prompt, variables, promptDir, images, mask, ...aiSdkArgs } ) {
|
|
72
|
+
validateGenerateImageArgs( { prompt, variables, promptDir, images, mask } );
|
|
69
73
|
|
|
74
|
+
const loadedPrompt = loadPrompt( prompt, variables, promptDir );
|
|
70
75
|
const traceId = startTrace( { name: 'generateImage', prompt, variables, loadedPrompt } );
|
|
71
76
|
const { model: modelId } = loadedPrompt.config;
|
|
72
77
|
|
|
73
78
|
try {
|
|
74
79
|
const response = await AI.generateImage( {
|
|
75
|
-
...loadAiSdkImageOptions( loadedPrompt ),
|
|
80
|
+
...loadAiSdkImageOptions( { prompt: loadedPrompt, images, mask } ),
|
|
76
81
|
maxRetries: 0,
|
|
77
82
|
...aiSdkArgs
|
|
78
83
|
} );
|
|
79
84
|
return wrapImageResponse( { traceId, modelId, response } );
|
|
80
|
-
} catch (
|
|
85
|
+
} catch ( originalError ) {
|
|
86
|
+
const error = mapAiError( originalError );
|
|
81
87
|
endTraceWithError( { traceId, error } );
|
|
82
88
|
throw error;
|
|
83
89
|
}
|
package/src/ai_sdk.spec.js
CHANGED
|
@@ -9,7 +9,8 @@ const aiFns = vi.hoisted( () => ( {
|
|
|
9
9
|
|
|
10
10
|
const validators = vi.hoisted( () => ( {
|
|
11
11
|
validateGenerateTextArgs: vi.fn(),
|
|
12
|
-
validateStreamTextArgs: vi.fn()
|
|
12
|
+
validateStreamTextArgs: vi.fn(),
|
|
13
|
+
validateGenerateImageArgs: vi.fn()
|
|
13
14
|
} ) );
|
|
14
15
|
|
|
15
16
|
const promptMocks = vi.hoisted( () => ( {
|
|
@@ -33,6 +34,10 @@ const wrapMocks = vi.hoisted( () => ( {
|
|
|
33
34
|
wrapImageResponse: vi.fn()
|
|
34
35
|
} ) );
|
|
35
36
|
|
|
37
|
+
const errorMocks = vi.hoisted( () => ( {
|
|
38
|
+
mapAiError: vi.fn( error => error )
|
|
39
|
+
} ) );
|
|
40
|
+
|
|
36
41
|
vi.mock( 'ai', () => aiFns );
|
|
37
42
|
|
|
38
43
|
vi.mock( './validations.js', () => validators );
|
|
@@ -61,6 +66,10 @@ vi.mock( './utils/response_wrappers.js', () => ( {
|
|
|
61
66
|
wrapImageResponse: ( ...args ) => wrapMocks.wrapImageResponse( ...args )
|
|
62
67
|
} ) );
|
|
63
68
|
|
|
69
|
+
vi.mock( './utils/error_handler.js', () => ( {
|
|
70
|
+
mapAiError: ( ...args ) => errorMocks.mapAiError( ...args )
|
|
71
|
+
} ) );
|
|
72
|
+
|
|
64
73
|
const importSut = async () => import( './ai_sdk.js' );
|
|
65
74
|
|
|
66
75
|
const loadedPrompt = {
|
|
@@ -86,15 +95,29 @@ const streamResult = {
|
|
|
86
95
|
fullStream: 'FULL_STREAM'
|
|
87
96
|
};
|
|
88
97
|
|
|
98
|
+
const imageOptions = {
|
|
99
|
+
model: 'IMAGE_MODEL',
|
|
100
|
+
prompt: {
|
|
101
|
+
text: 'Generate an image'
|
|
102
|
+
},
|
|
103
|
+
providerOptions: { openai: { quality: 'high' } }
|
|
104
|
+
};
|
|
105
|
+
|
|
106
|
+
const imageResponse = {
|
|
107
|
+
images: [ { mediaType: 'image/png', base64: 'aW1hZ2U=' } ],
|
|
108
|
+
usage: { inputTokens: 1, outputTokens: 2 }
|
|
109
|
+
};
|
|
110
|
+
|
|
89
111
|
describe( 'ai_sdk', () => {
|
|
90
112
|
beforeEach( () => {
|
|
91
113
|
aiFns.generateText.mockReset().mockResolvedValue( textResponse );
|
|
92
114
|
aiFns.streamText.mockReset().mockReturnValue( streamResult );
|
|
93
|
-
aiFns.generateImage.mockReset();
|
|
115
|
+
aiFns.generateImage.mockReset().mockResolvedValue( imageResponse );
|
|
94
116
|
aiFns.stepCountIs.mockReset().mockImplementation( count => ( { type: 'step-count', count } ) );
|
|
95
117
|
|
|
96
118
|
validators.validateGenerateTextArgs.mockReset();
|
|
97
119
|
validators.validateStreamTextArgs.mockReset();
|
|
120
|
+
validators.validateGenerateImageArgs.mockReset();
|
|
98
121
|
|
|
99
122
|
promptMocks.loadPrompt.mockReset().mockReturnValue( loadedPrompt );
|
|
100
123
|
promptMocks.prepareTextPrompt.mockReset().mockReturnValue( {
|
|
@@ -103,7 +126,7 @@ describe( 'ai_sdk', () => {
|
|
|
103
126
|
} );
|
|
104
127
|
|
|
105
128
|
optionMocks.loadAiSdkTextOptions.mockReset().mockReturnValue( textOptions );
|
|
106
|
-
optionMocks.loadAiSdkImageOptions.mockReset();
|
|
129
|
+
optionMocks.loadAiSdkImageOptions.mockReset().mockReturnValue( imageOptions );
|
|
107
130
|
|
|
108
131
|
traceMocks.startTrace.mockReset().mockReturnValue( 'trace-id' );
|
|
109
132
|
traceMocks.endTraceWithError.mockReset();
|
|
@@ -112,7 +135,9 @@ describe( 'ai_sdk', () => {
|
|
|
112
135
|
wrapMocks.wrapStreamOnFinishResponse.mockReset().mockReturnValue( {
|
|
113
136
|
onFinish: vi.fn()
|
|
114
137
|
} );
|
|
115
|
-
wrapMocks.wrapImageResponse.mockReset();
|
|
138
|
+
wrapMocks.wrapImageResponse.mockReset().mockResolvedValue( { wrapped: imageResponse } );
|
|
139
|
+
|
|
140
|
+
errorMocks.mapAiError.mockReset().mockImplementation( error => error );
|
|
116
141
|
} );
|
|
117
142
|
|
|
118
143
|
afterEach( async () => {
|
|
@@ -234,13 +259,16 @@ describe( 'ai_sdk', () => {
|
|
|
234
259
|
|
|
235
260
|
it( 'traces and rethrows AI SDK errors', async () => {
|
|
236
261
|
const error = new Error( 'Provider failed' );
|
|
262
|
+
const mappedError = new Error( 'Mapped provider failed' );
|
|
237
263
|
aiFns.generateText.mockRejectedValueOnce( error );
|
|
264
|
+
errorMocks.mapAiError.mockReturnValueOnce( mappedError );
|
|
238
265
|
const { generateText } = await importSut();
|
|
239
266
|
|
|
240
|
-
await expect( generateText( { prompt: 'test@v1' } ) ).rejects.toThrow(
|
|
267
|
+
await expect( generateText( { prompt: 'test@v1' } ) ).rejects.toThrow( mappedError );
|
|
268
|
+
expect( errorMocks.mapAiError ).toHaveBeenCalledWith( error );
|
|
241
269
|
expect( traceMocks.endTraceWithError ).toHaveBeenCalledWith( {
|
|
242
270
|
traceId: 'trace-id',
|
|
243
|
-
error
|
|
271
|
+
error: mappedError
|
|
244
272
|
} );
|
|
245
273
|
} );
|
|
246
274
|
} );
|
|
@@ -371,16 +399,19 @@ describe( 'ai_sdk', () => {
|
|
|
371
399
|
const { streamText } = await importSut();
|
|
372
400
|
const onError = vi.fn();
|
|
373
401
|
const error = new Error( 'Stream failed' );
|
|
402
|
+
const mappedError = new Error( 'Mapped stream failed' );
|
|
403
|
+
errorMocks.mapAiError.mockReturnValueOnce( mappedError );
|
|
374
404
|
|
|
375
405
|
streamText( { prompt: 'test@v1', onError } );
|
|
376
406
|
const callOptions = aiFns.streamText.mock.calls[0][0];
|
|
377
407
|
callOptions.onError( { error } );
|
|
378
408
|
|
|
409
|
+
expect( errorMocks.mapAiError ).toHaveBeenCalledWith( error );
|
|
379
410
|
expect( traceMocks.endTraceWithError ).toHaveBeenCalledWith( {
|
|
380
411
|
traceId: 'trace-id',
|
|
381
|
-
error
|
|
412
|
+
error: mappedError
|
|
382
413
|
} );
|
|
383
|
-
expect( onError ).toHaveBeenCalledWith( { error } );
|
|
414
|
+
expect( onError ).toHaveBeenCalledWith( { error: mappedError } );
|
|
384
415
|
} );
|
|
385
416
|
|
|
386
417
|
it( 'does not pass the raw onFinish or onError callbacks to AI SDK', async () => {
|
|
@@ -410,15 +441,116 @@ describe( 'ai_sdk', () => {
|
|
|
410
441
|
|
|
411
442
|
it( 'traces and rethrows synchronous AI SDK errors', async () => {
|
|
412
443
|
const error = new Error( 'Invalid model' );
|
|
444
|
+
const mappedError = new Error( 'Mapped invalid model' );
|
|
413
445
|
aiFns.streamText.mockImplementationOnce( () => {
|
|
414
446
|
throw error;
|
|
415
447
|
} );
|
|
448
|
+
errorMocks.mapAiError.mockReturnValueOnce( mappedError );
|
|
416
449
|
const { streamText } = await importSut();
|
|
417
450
|
|
|
418
|
-
expect( () => streamText( { prompt: 'test@v1' } ) ).toThrow(
|
|
451
|
+
expect( () => streamText( { prompt: 'test@v1' } ) ).toThrow( mappedError );
|
|
452
|
+
expect( errorMocks.mapAiError ).toHaveBeenCalledWith( error );
|
|
453
|
+
expect( traceMocks.endTraceWithError ).toHaveBeenCalledWith( {
|
|
454
|
+
traceId: 'trace-id',
|
|
455
|
+
error: mappedError
|
|
456
|
+
} );
|
|
457
|
+
} );
|
|
458
|
+
} );
|
|
459
|
+
|
|
460
|
+
describe( 'generateImage', () => {
|
|
461
|
+
it( 'validates, loads prompt, traces, calls AI SDK, and wraps the response', async () => {
|
|
462
|
+
const { generateImage } = await importSut();
|
|
463
|
+
const variables = { scene: 'race cars' };
|
|
464
|
+
const images = [ Buffer.from( 'image-bytes' ) ];
|
|
465
|
+
const mask = Buffer.from( 'mask-bytes' );
|
|
466
|
+
|
|
467
|
+
const result = await generateImage( {
|
|
468
|
+
prompt: 'image@v1',
|
|
469
|
+
variables,
|
|
470
|
+
promptDir: '/prompts',
|
|
471
|
+
images,
|
|
472
|
+
mask,
|
|
473
|
+
n: 2,
|
|
474
|
+
providerOptions: { openai: { background: 'transparent' } }
|
|
475
|
+
} );
|
|
476
|
+
|
|
477
|
+
expect( validators.validateGenerateImageArgs ).toHaveBeenCalledWith( {
|
|
478
|
+
prompt: 'image@v1',
|
|
479
|
+
variables,
|
|
480
|
+
promptDir: '/prompts',
|
|
481
|
+
images,
|
|
482
|
+
mask
|
|
483
|
+
} );
|
|
484
|
+
expect( promptMocks.loadPrompt ).toHaveBeenCalledWith( 'image@v1', variables, '/prompts' );
|
|
485
|
+
expect( traceMocks.startTrace ).toHaveBeenCalledWith( {
|
|
486
|
+
name: 'generateImage',
|
|
487
|
+
prompt: 'image@v1',
|
|
488
|
+
variables,
|
|
489
|
+
loadedPrompt
|
|
490
|
+
} );
|
|
491
|
+
expect( optionMocks.loadAiSdkImageOptions ).toHaveBeenCalledWith( {
|
|
492
|
+
prompt: loadedPrompt,
|
|
493
|
+
images,
|
|
494
|
+
mask
|
|
495
|
+
} );
|
|
496
|
+
expect( aiFns.generateImage ).toHaveBeenCalledWith( {
|
|
497
|
+
...imageOptions,
|
|
498
|
+
maxRetries: 0,
|
|
499
|
+
n: 2,
|
|
500
|
+
providerOptions: { openai: { background: 'transparent' } }
|
|
501
|
+
} );
|
|
502
|
+
expect( wrapMocks.wrapImageResponse ).toHaveBeenCalledWith( {
|
|
503
|
+
traceId: 'trace-id',
|
|
504
|
+
modelId: 'test-model',
|
|
505
|
+
response: imageResponse
|
|
506
|
+
} );
|
|
507
|
+
expect( result ).toEqual( { wrapped: imageResponse } );
|
|
508
|
+
} );
|
|
509
|
+
|
|
510
|
+
it( 'supports text-to-image calls without images or mask', async () => {
|
|
511
|
+
const { generateImage } = await importSut();
|
|
512
|
+
|
|
513
|
+
await generateImage( { prompt: 'image@v1' } );
|
|
514
|
+
|
|
515
|
+
expect( validators.validateGenerateImageArgs ).toHaveBeenCalledWith( {
|
|
516
|
+
prompt: 'image@v1',
|
|
517
|
+
variables: undefined,
|
|
518
|
+
promptDir: undefined,
|
|
519
|
+
images: undefined,
|
|
520
|
+
mask: undefined
|
|
521
|
+
} );
|
|
522
|
+
expect( optionMocks.loadAiSdkImageOptions ).toHaveBeenCalledWith( {
|
|
523
|
+
prompt: loadedPrompt,
|
|
524
|
+
images: undefined,
|
|
525
|
+
mask: undefined
|
|
526
|
+
} );
|
|
527
|
+
} );
|
|
528
|
+
|
|
529
|
+
it( 'propagates validation errors before loading or tracing', async () => {
|
|
530
|
+
const validationError = new Error( 'Invalid image args' );
|
|
531
|
+
validators.validateGenerateImageArgs.mockImplementationOnce( () => {
|
|
532
|
+
throw validationError;
|
|
533
|
+
} );
|
|
534
|
+
const { generateImage } = await importSut();
|
|
535
|
+
|
|
536
|
+
await expect( generateImage( { prompt: '' } ) ).rejects.toThrow( validationError );
|
|
537
|
+
expect( promptMocks.loadPrompt ).not.toHaveBeenCalled();
|
|
538
|
+
expect( traceMocks.startTrace ).not.toHaveBeenCalled();
|
|
539
|
+
expect( aiFns.generateImage ).not.toHaveBeenCalled();
|
|
540
|
+
} );
|
|
541
|
+
|
|
542
|
+
it( 'traces and rethrows AI SDK errors', async () => {
|
|
543
|
+
const error = new Error( 'Image provider failed' );
|
|
544
|
+
const mappedError = new Error( 'Mapped image provider failed' );
|
|
545
|
+
aiFns.generateImage.mockRejectedValueOnce( error );
|
|
546
|
+
errorMocks.mapAiError.mockReturnValueOnce( mappedError );
|
|
547
|
+
const { generateImage } = await importSut();
|
|
548
|
+
|
|
549
|
+
await expect( generateImage( { prompt: 'image@v1' } ) ).rejects.toThrow( mappedError );
|
|
550
|
+
expect( errorMocks.mapAiError ).toHaveBeenCalledWith( error );
|
|
419
551
|
expect( traceMocks.endTraceWithError ).toHaveBeenCalledWith( {
|
|
420
552
|
traceId: 'trace-id',
|
|
421
|
-
error
|
|
553
|
+
error: mappedError
|
|
422
554
|
} );
|
|
423
555
|
} );
|
|
424
556
|
} );
|
package/src/ai_sdk_options.js
CHANGED
|
@@ -39,13 +39,17 @@ export const loadAiSdkTextOptions = prompt => {
|
|
|
39
39
|
* @param {object} prompt - Loaded prompt object
|
|
40
40
|
* @returns {object} Options for AI SDK image calls
|
|
41
41
|
*/
|
|
42
|
-
export const loadAiSdkImageOptions = prompt => {
|
|
42
|
+
export const loadAiSdkImageOptions = ( { prompt, images, mask } ) => {
|
|
43
43
|
if ( !prompt.instructions ) {
|
|
44
44
|
throw new FatalError( `Prompt "${prompt.name}" has no instructions. Image prompts must use plain instructions.` );
|
|
45
45
|
}
|
|
46
46
|
const options = {
|
|
47
47
|
model: loadImageModel( prompt ),
|
|
48
|
-
prompt:
|
|
48
|
+
prompt: ( images || mask ) ? {
|
|
49
|
+
text: prompt.instructions,
|
|
50
|
+
...( images && { images } ),
|
|
51
|
+
...( mask && { mask } )
|
|
52
|
+
} : prompt.instructions,
|
|
49
53
|
providerOptions: prompt.config.providerOptions
|
|
50
54
|
};
|
|
51
55
|
for ( const key of [ 'n', 'maxImagesPerCall', 'size', 'aspectRatio', 'seed' ] ) {
|
|
@@ -100,6 +100,8 @@ describe( 'ai_sdk_options', () => {
|
|
|
100
100
|
} );
|
|
101
101
|
|
|
102
102
|
it( 'maps loaded prompts to AI SDK image options', async () => {
|
|
103
|
+
const images = [ Buffer.from( 'image-bytes' ) ];
|
|
104
|
+
const mask = Buffer.from( 'mask-bytes' );
|
|
103
105
|
const prompt = makeImagePrompt( {
|
|
104
106
|
n: 2,
|
|
105
107
|
maxImagesPerCall: 1,
|
|
@@ -112,14 +114,18 @@ describe( 'ai_sdk_options', () => {
|
|
|
112
114
|
} );
|
|
113
115
|
|
|
114
116
|
const { loadAiSdkImageOptions } = await importSut();
|
|
115
|
-
const result = loadAiSdkImageOptions( prompt );
|
|
117
|
+
const result = loadAiSdkImageOptions( { prompt, images, mask } );
|
|
116
118
|
|
|
117
119
|
expect( loadImageModelImpl ).toHaveBeenCalledWith( prompt );
|
|
118
120
|
expect( loadModelImpl ).not.toHaveBeenCalled();
|
|
119
121
|
expect( loadToolsImpl ).not.toHaveBeenCalled();
|
|
120
122
|
expect( result ).toEqual( {
|
|
121
123
|
model: 'IMAGE_MODEL',
|
|
122
|
-
prompt:
|
|
124
|
+
prompt: {
|
|
125
|
+
text: 'Generate a cinematic image of a NASCAR race at sunset.',
|
|
126
|
+
images,
|
|
127
|
+
mask
|
|
128
|
+
},
|
|
123
129
|
providerOptions: prompt.config.providerOptions,
|
|
124
130
|
n: 2,
|
|
125
131
|
maxImagesPerCall: 1,
|
|
@@ -135,7 +141,7 @@ describe( 'ai_sdk_options', () => {
|
|
|
135
141
|
const prompt = makeImagePrompt( { seed: 0 } );
|
|
136
142
|
|
|
137
143
|
const { loadAiSdkImageOptions } = await importSut();
|
|
138
|
-
const result = loadAiSdkImageOptions( prompt );
|
|
144
|
+
const result = loadAiSdkImageOptions( { prompt } );
|
|
139
145
|
|
|
140
146
|
expect( result ).toEqual( {
|
|
141
147
|
model: 'IMAGE_MODEL',
|
|
@@ -150,7 +156,7 @@ describe( 'ai_sdk_options', () => {
|
|
|
150
156
|
|
|
151
157
|
const { loadAiSdkImageOptions } = await importSut();
|
|
152
158
|
|
|
153
|
-
expect( () => loadAiSdkImageOptions( prompt ) ).toThrow(
|
|
159
|
+
expect( () => loadAiSdkImageOptions( { prompt } ) ).toThrow(
|
|
154
160
|
'Prompt "test@v1" has no instructions.'
|
|
155
161
|
);
|
|
156
162
|
expect( loadImageModelImpl ).not.toHaveBeenCalled();
|
package/src/index.d.ts
CHANGED
|
@@ -200,6 +200,9 @@ export type StreamTextAiSdkOptions<
|
|
|
200
200
|
* `model` and `prompt` are omitted because Output supplies them from the prompt file.
|
|
201
201
|
*/
|
|
202
202
|
export type GenerateImageAiSdkOptions = Omit<Parameters<typeof aiGenerateImage>[0], 'model' | 'prompt'>;
|
|
203
|
+
type GenerateImagePrompt = Parameters<typeof aiGenerateImage>[0]['prompt'];
|
|
204
|
+
type GenerateImagePromptWithImages = Exclude<GenerateImagePrompt, string>;
|
|
205
|
+
type GenerateImageInput = GenerateImagePromptWithImages['images'][number];
|
|
203
206
|
|
|
204
207
|
/** Agent {@link Agent.stream} options: same as AI SDK plus wrapped `onFinish` (adds `cost`). */
|
|
205
208
|
export type OutputAgentStreamParameters = Omit<AgentStreamParameters<never, ToolSet>, 'onFinish'> & {
|
|
@@ -275,6 +278,10 @@ export type GenerateImageParameters = {
|
|
|
275
278
|
variables?: Record<string, string | number | boolean>;
|
|
276
279
|
/** Override the stack-resolved prompt directory */
|
|
277
280
|
promptDir?: string;
|
|
281
|
+
/** Runtime image inputs for image-to-image generation */
|
|
282
|
+
images?: GenerateImageInput[];
|
|
283
|
+
/** Optional mask for image editing */
|
|
284
|
+
mask?: GenerateImagePromptWithImages['mask'];
|
|
278
285
|
} & GenerateImageAiSdkOptions;
|
|
279
286
|
|
|
280
287
|
/** A source extracted from search tool results during multi-step LLM execution. */
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import {
|
|
2
|
+
APICallError,
|
|
3
|
+
InvalidArgumentError,
|
|
4
|
+
InvalidDataContentError,
|
|
5
|
+
InvalidPromptError,
|
|
6
|
+
LoadAPIKeyError,
|
|
7
|
+
LoadSettingError,
|
|
8
|
+
NoImageGeneratedError,
|
|
9
|
+
NoObjectGeneratedError,
|
|
10
|
+
NoSuchModelError,
|
|
11
|
+
NoSuchProviderError,
|
|
12
|
+
UnsupportedFunctionalityError
|
|
13
|
+
} from 'ai';
|
|
14
|
+
import { FatalError } from '@outputai/core';
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* Recursively search an error cause chain until finds an error which is instance of given prototype.
|
|
18
|
+
*
|
|
19
|
+
* @param {object} error - Error instance.
|
|
20
|
+
* @param {Function|string} _class - Target constructor or constructor name.
|
|
21
|
+
* @param {number} depth - Current depth, search up to 10 causes deep.
|
|
22
|
+
* @returns {object|null} - Error or null if not found.
|
|
23
|
+
*/
|
|
24
|
+
export const findInstanceInCauseChain = ( error, _class, depth = 0 ) => {
|
|
25
|
+
if ( !error || typeof error !== 'object' ) {
|
|
26
|
+
return null;
|
|
27
|
+
}
|
|
28
|
+
if ( typeof _class === 'string' && error.constructor.name === _class ) {
|
|
29
|
+
return error;
|
|
30
|
+
}
|
|
31
|
+
if ( typeof _class === 'function' && error instanceof _class ) {
|
|
32
|
+
return error;
|
|
33
|
+
}
|
|
34
|
+
if ( depth >= 10 ) {
|
|
35
|
+
return null;
|
|
36
|
+
}
|
|
37
|
+
return error.cause ? findInstanceInCauseChain( error.cause, _class, depth + 1 ) : null;
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
const toFatalError = ( error, extraMessage = '' ) => new FatalError(
|
|
41
|
+
`AI-SDK fatal error${extraMessage ? ` (${extraMessage})` : ''}: ${error.message}`,
|
|
42
|
+
{ cause: error }
|
|
43
|
+
);
|
|
44
|
+
|
|
45
|
+
export const mapAiError = error => {
|
|
46
|
+
if ( error instanceof FatalError ) {
|
|
47
|
+
return error;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
// NoObjectGeneratedError can be thrown when the response doesn't match the schema
|
|
51
|
+
// This adds a wrapper to that error serializing the first zod validation in the message, to make it easier to debug.
|
|
52
|
+
if ( NoObjectGeneratedError.isInstance( error ) && error.message.includes( 'No object generated: response did not match schema.' ) ) {
|
|
53
|
+
const zodError = findInstanceInCauseChain( error, 'ZodError' );
|
|
54
|
+
if ( zodError && zodError.issues?.length > 0 ) {
|
|
55
|
+
const { path, message } = zodError.issues[0];
|
|
56
|
+
const wrapper = new Error( `${error.message} First issue is "${message}" at path [${path.join( ', ' )}].`, { cause: error } );
|
|
57
|
+
wrapper.name = 'NoObjectGeneratedError';
|
|
58
|
+
return wrapper;
|
|
59
|
+
}
|
|
60
|
+
return error;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
if ( APICallError.isInstance( error ) && !error.isRetryable ) {
|
|
64
|
+
// Non-retryable API failures are already classified by AI SDK as permanent provider failures.
|
|
65
|
+
return toFatalError( error, error.statusCode ? `HTTP ${error.statusCode}` : '' );
|
|
66
|
+
}
|
|
67
|
+
if ( InvalidArgumentError.isInstance( error ) ) {
|
|
68
|
+
// Invalid call settings are deterministic caller bugs, so retrying the same activity cannot fix them.
|
|
69
|
+
return toFatalError( error );
|
|
70
|
+
}
|
|
71
|
+
if ( InvalidDataContentError.isInstance( error ) ) {
|
|
72
|
+
// Invalid media content has the wrong local shape/encoding and will fail again with the same input.
|
|
73
|
+
return toFatalError( error );
|
|
74
|
+
}
|
|
75
|
+
if ( InvalidPromptError.isInstance( error ) ) {
|
|
76
|
+
// Invalid prompt structure is a deterministic request-construction error.
|
|
77
|
+
return toFatalError( error );
|
|
78
|
+
}
|
|
79
|
+
if ( LoadAPIKeyError.isInstance( error ) ) {
|
|
80
|
+
// Missing or invalid API key configuration will not change during an activity retry.
|
|
81
|
+
return toFatalError( error );
|
|
82
|
+
}
|
|
83
|
+
if ( LoadSettingError.isInstance( error ) ) {
|
|
84
|
+
// Missing or invalid provider settings are deployment/configuration problems.
|
|
85
|
+
return toFatalError( error );
|
|
86
|
+
}
|
|
87
|
+
if ( NoImageGeneratedError.isInstance( error ) ) {
|
|
88
|
+
// Image generation completed provider calls but collected zero images; repeating identical input is not useful.
|
|
89
|
+
return toFatalError( error );
|
|
90
|
+
}
|
|
91
|
+
if ( NoSuchProviderError.isInstance( error ) ) {
|
|
92
|
+
// A missing provider id is a deterministic provider registry/configuration error.
|
|
93
|
+
return toFatalError( error );
|
|
94
|
+
}
|
|
95
|
+
if ( NoSuchModelError.isInstance( error ) ) {
|
|
96
|
+
// A missing model id is a deterministic provider/model configuration error.
|
|
97
|
+
return toFatalError( error );
|
|
98
|
+
}
|
|
99
|
+
if ( UnsupportedFunctionalityError.isInstance( error ) ) {
|
|
100
|
+
// The selected model/output mode does not support the requested feature.
|
|
101
|
+
return toFatalError( error );
|
|
102
|
+
}
|
|
103
|
+
return error;
|
|
104
|
+
};
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import { describe, expect, it } from 'vitest';
|
|
2
|
+
import {
|
|
3
|
+
APICallError,
|
|
4
|
+
InvalidArgumentError,
|
|
5
|
+
InvalidDataContentError,
|
|
6
|
+
InvalidMessageRoleError,
|
|
7
|
+
InvalidPromptError,
|
|
8
|
+
InvalidToolApprovalError,
|
|
9
|
+
InvalidToolInputError,
|
|
10
|
+
LoadAPIKeyError,
|
|
11
|
+
LoadSettingError,
|
|
12
|
+
MessageConversionError,
|
|
13
|
+
NoImageGeneratedError,
|
|
14
|
+
NoOutputGeneratedError,
|
|
15
|
+
NoObjectGeneratedError,
|
|
16
|
+
NoSuchModelError,
|
|
17
|
+
NoSuchProviderError,
|
|
18
|
+
ToolCallNotFoundForApprovalError,
|
|
19
|
+
ToolCallRepairError,
|
|
20
|
+
UnsupportedFunctionalityError
|
|
21
|
+
} from 'ai';
|
|
22
|
+
import { FatalError } from '@outputai/core';
|
|
23
|
+
import { findInstanceInCauseChain, mapAiError } from './error_handler.js';
|
|
24
|
+
|
|
25
|
+
const makeApiCallError = ( input = {} ) => new APICallError( {
|
|
26
|
+
message: 'Provider rejected the request',
|
|
27
|
+
url: 'https://provider.test/v1/generate',
|
|
28
|
+
requestBodyValues: {},
|
|
29
|
+
responseHeaders: {},
|
|
30
|
+
responseBody: '{"error":"bad request"}',
|
|
31
|
+
...input
|
|
32
|
+
} );
|
|
33
|
+
|
|
34
|
+
const fatalAiSdkErrors = [
|
|
35
|
+
[
|
|
36
|
+
'InvalidArgumentError',
|
|
37
|
+
() => new InvalidArgumentError( {
|
|
38
|
+
parameter: 'temperature',
|
|
39
|
+
value: 'hot',
|
|
40
|
+
message: 'temperature must be a number'
|
|
41
|
+
} )
|
|
42
|
+
],
|
|
43
|
+
[
|
|
44
|
+
'InvalidDataContentError',
|
|
45
|
+
() => new InvalidDataContentError( { content: { bad: true } } )
|
|
46
|
+
],
|
|
47
|
+
[
|
|
48
|
+
'InvalidPromptError',
|
|
49
|
+
() => new InvalidPromptError( { prompt: {}, message: 'prompt or messages must be defined' } )
|
|
50
|
+
],
|
|
51
|
+
[
|
|
52
|
+
'LoadAPIKeyError',
|
|
53
|
+
() => new LoadAPIKeyError( { message: 'Missing API key' } )
|
|
54
|
+
],
|
|
55
|
+
[
|
|
56
|
+
'LoadSettingError',
|
|
57
|
+
() => new LoadSettingError( { message: 'Missing setting' } )
|
|
58
|
+
],
|
|
59
|
+
[
|
|
60
|
+
'NoImageGeneratedError',
|
|
61
|
+
() => new NoImageGeneratedError( { responses: [] } )
|
|
62
|
+
],
|
|
63
|
+
[
|
|
64
|
+
'NoSuchModelError',
|
|
65
|
+
() => new NoSuchModelError( { modelId: 'missing-model', modelType: 'languageModel' } )
|
|
66
|
+
],
|
|
67
|
+
[
|
|
68
|
+
'NoSuchProviderError',
|
|
69
|
+
() => new NoSuchProviderError( {
|
|
70
|
+
modelId: 'missing-provider:model',
|
|
71
|
+
modelType: 'languageModel',
|
|
72
|
+
providerId: 'missing-provider',
|
|
73
|
+
availableProviders: [ 'openai' ]
|
|
74
|
+
} )
|
|
75
|
+
],
|
|
76
|
+
[
|
|
77
|
+
'UnsupportedFunctionalityError',
|
|
78
|
+
() => new UnsupportedFunctionalityError( { functionality: 'image masks' } )
|
|
79
|
+
]
|
|
80
|
+
];
|
|
81
|
+
|
|
82
|
+
const preservedAiSdkErrors = [
|
|
83
|
+
[
|
|
84
|
+
'InvalidMessageRoleError',
|
|
85
|
+
() => new InvalidMessageRoleError( { role: 'critic' } )
|
|
86
|
+
],
|
|
87
|
+
[
|
|
88
|
+
'InvalidToolApprovalError',
|
|
89
|
+
() => new InvalidToolApprovalError( { approvalId: 'approval-1' } )
|
|
90
|
+
],
|
|
91
|
+
[
|
|
92
|
+
'InvalidToolInputError',
|
|
93
|
+
() => new InvalidToolInputError( {
|
|
94
|
+
toolName: 'search',
|
|
95
|
+
toolInput: '{bad json',
|
|
96
|
+
cause: new Error( 'parse failed' )
|
|
97
|
+
} )
|
|
98
|
+
],
|
|
99
|
+
[
|
|
100
|
+
'MessageConversionError',
|
|
101
|
+
() => new MessageConversionError( {
|
|
102
|
+
originalMessage: { role: 'critic', content: 'bad role' },
|
|
103
|
+
message: 'Unsupported role'
|
|
104
|
+
} )
|
|
105
|
+
],
|
|
106
|
+
[
|
|
107
|
+
'NoObjectGeneratedError',
|
|
108
|
+
() => new NoObjectGeneratedError( {
|
|
109
|
+
text: 'not json',
|
|
110
|
+
cause: new Error( 'parse failed' )
|
|
111
|
+
} )
|
|
112
|
+
],
|
|
113
|
+
[
|
|
114
|
+
'NoOutputGeneratedError',
|
|
115
|
+
() => new NoOutputGeneratedError()
|
|
116
|
+
],
|
|
117
|
+
[
|
|
118
|
+
'ToolCallNotFoundForApprovalError',
|
|
119
|
+
() => new ToolCallNotFoundForApprovalError( {
|
|
120
|
+
toolCallId: 'tool-call-1',
|
|
121
|
+
approvalId: 'approval-1'
|
|
122
|
+
} )
|
|
123
|
+
],
|
|
124
|
+
[
|
|
125
|
+
'ToolCallRepairError',
|
|
126
|
+
() => new ToolCallRepairError( {
|
|
127
|
+
cause: new Error( 'repair failed' ),
|
|
128
|
+
originalError: new Error( 'invalid tool input' )
|
|
129
|
+
} )
|
|
130
|
+
]
|
|
131
|
+
];
|
|
132
|
+
|
|
133
|
+
describe( 'findInstanceInCauseChain', () => {
|
|
134
|
+
class FirstCustomError extends Error {}
|
|
135
|
+
class SecondCustomError extends Error {}
|
|
136
|
+
|
|
137
|
+
it( 'returns the input error when it matches the target constructor', () => {
|
|
138
|
+
const error = new FirstCustomError( 'first' );
|
|
139
|
+
|
|
140
|
+
expect( findInstanceInCauseChain( error, FirstCustomError ) ).toBe( error );
|
|
141
|
+
} );
|
|
142
|
+
|
|
143
|
+
it( 'returns the input error when it matches the target constructor name', () => {
|
|
144
|
+
const error = new FirstCustomError( 'first' );
|
|
145
|
+
|
|
146
|
+
expect( findInstanceInCauseChain( error, 'FirstCustomError' ) ).toBe( error );
|
|
147
|
+
} );
|
|
148
|
+
|
|
149
|
+
it( 'walks the cause chain to find an error by constructor', () => {
|
|
150
|
+
const target = new SecondCustomError( 'second' );
|
|
151
|
+
const wrapper = new FirstCustomError( 'first', { cause: target } );
|
|
152
|
+
|
|
153
|
+
expect( findInstanceInCauseChain( wrapper, SecondCustomError ) ).toBe( target );
|
|
154
|
+
} );
|
|
155
|
+
|
|
156
|
+
it( 'walks the cause chain to find an error by constructor name', () => {
|
|
157
|
+
const target = new SecondCustomError( 'second' );
|
|
158
|
+
const wrapper = new FirstCustomError( 'first', { cause: target } );
|
|
159
|
+
|
|
160
|
+
expect( findInstanceInCauseChain( wrapper, 'SecondCustomError' ) ).toBe( target );
|
|
161
|
+
} );
|
|
162
|
+
|
|
163
|
+
it( 'returns null when the target is not found', () => {
|
|
164
|
+
const error = new FirstCustomError( 'first', { cause: new Error( 'root' ) } );
|
|
165
|
+
|
|
166
|
+
expect( findInstanceInCauseChain( error, SecondCustomError ) ).toBeNull();
|
|
167
|
+
} );
|
|
168
|
+
|
|
169
|
+
it( 'returns null for empty or non-object inputs', () => {
|
|
170
|
+
expect( findInstanceInCauseChain( null, Error ) ).toBeNull();
|
|
171
|
+
expect( findInstanceInCauseChain( 'not an error', Error ) ).toBeNull();
|
|
172
|
+
} );
|
|
173
|
+
|
|
174
|
+
it( 'stops searching after the depth limit', () => {
|
|
175
|
+
const makeErrorChain = depth => depth === 0 ?
|
|
176
|
+
new SecondCustomError( 'target' ) :
|
|
177
|
+
new FirstCustomError( `level ${depth}`, { cause: makeErrorChain( depth - 1 ) } );
|
|
178
|
+
|
|
179
|
+
expect( findInstanceInCauseChain( makeErrorChain( 11 ), SecondCustomError ) ).toBeNull();
|
|
180
|
+
} );
|
|
181
|
+
} );
|
|
182
|
+
|
|
183
|
+
describe( 'mapAiError', () => {
|
|
184
|
+
it( 'preserves existing FatalError instances', () => {
|
|
185
|
+
const error = new FatalError( 'Already fatal' );
|
|
186
|
+
|
|
187
|
+
expect( mapAiError( error ) ).toBe( error );
|
|
188
|
+
} );
|
|
189
|
+
|
|
190
|
+
it( 'adds first schema issue details to NoObjectGeneratedError schema mismatches', () => {
|
|
191
|
+
class ZodError extends Error {
|
|
192
|
+
constructor( issues ) {
|
|
193
|
+
super( 'schema failed' );
|
|
194
|
+
this.issues = issues;
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
const zodError = new ZodError( [
|
|
198
|
+
{
|
|
199
|
+
path: [ 'items', 0, 'title' ],
|
|
200
|
+
message: 'Expected string'
|
|
201
|
+
}
|
|
202
|
+
] );
|
|
203
|
+
const validationError = new Error( 'validation failed', { cause: zodError } );
|
|
204
|
+
const error = new NoObjectGeneratedError( {
|
|
205
|
+
message: 'No object generated: response did not match schema.',
|
|
206
|
+
text: '{"items":[{}]}',
|
|
207
|
+
cause: validationError
|
|
208
|
+
} );
|
|
209
|
+
|
|
210
|
+
const result = mapAiError( error );
|
|
211
|
+
|
|
212
|
+
expect( result ).not.toBe( error );
|
|
213
|
+
expect( result.name ).toBe( 'NoObjectGeneratedError' );
|
|
214
|
+
expect( result.message ).toBe(
|
|
215
|
+
'No object generated: response did not match schema. First issue is "Expected string" at path [items, 0, title].'
|
|
216
|
+
);
|
|
217
|
+
expect( result.cause ).toBe( error );
|
|
218
|
+
} );
|
|
219
|
+
|
|
220
|
+
it( 'preserves NoObjectGeneratedError schema mismatches when no schema issue is available', () => {
|
|
221
|
+
const error = new NoObjectGeneratedError( {
|
|
222
|
+
message: 'No object generated: response did not match schema.',
|
|
223
|
+
text: '{"items":[{}]}',
|
|
224
|
+
cause: new Error( 'validation failed' )
|
|
225
|
+
} );
|
|
226
|
+
|
|
227
|
+
expect( mapAiError( error ) ).toBe( error );
|
|
228
|
+
} );
|
|
229
|
+
|
|
230
|
+
it( 'maps non-retryable APICallError instances to FatalError', () => {
|
|
231
|
+
const error = makeApiCallError( {
|
|
232
|
+
statusCode: 400,
|
|
233
|
+
isRetryable: false
|
|
234
|
+
} );
|
|
235
|
+
|
|
236
|
+
const result = mapAiError( error );
|
|
237
|
+
|
|
238
|
+
expect( result ).toBeInstanceOf( FatalError );
|
|
239
|
+
expect( result.message ).toBe( 'AI-SDK fatal error (HTTP 400): Provider rejected the request' );
|
|
240
|
+
expect( result.cause ).toBe( error );
|
|
241
|
+
} );
|
|
242
|
+
|
|
243
|
+
it( 'maps non-retryable APICallError instances without status codes to FatalError', () => {
|
|
244
|
+
const error = makeApiCallError( {
|
|
245
|
+
isRetryable: false
|
|
246
|
+
} );
|
|
247
|
+
|
|
248
|
+
const result = mapAiError( error );
|
|
249
|
+
|
|
250
|
+
expect( result ).toBeInstanceOf( FatalError );
|
|
251
|
+
expect( result.message ).toBe( 'AI-SDK fatal error: Provider rejected the request' );
|
|
252
|
+
expect( result.cause ).toBe( error );
|
|
253
|
+
} );
|
|
254
|
+
|
|
255
|
+
it( 'preserves retryable APICallError instances', () => {
|
|
256
|
+
const error = makeApiCallError( {
|
|
257
|
+
statusCode: 429,
|
|
258
|
+
isRetryable: true
|
|
259
|
+
} );
|
|
260
|
+
|
|
261
|
+
expect( mapAiError( error ) ).toBe( error );
|
|
262
|
+
} );
|
|
263
|
+
|
|
264
|
+
it.each( fatalAiSdkErrors )( 'maps %s to FatalError', ( _name, makeError ) => {
|
|
265
|
+
const error = makeError();
|
|
266
|
+
|
|
267
|
+
const result = mapAiError( error );
|
|
268
|
+
|
|
269
|
+
expect( result ).toBeInstanceOf( FatalError );
|
|
270
|
+
expect( result.message ).toBe( `AI-SDK fatal error: ${error.message}` );
|
|
271
|
+
expect( result.cause ).toBe( error );
|
|
272
|
+
} );
|
|
273
|
+
|
|
274
|
+
it.each( preservedAiSdkErrors )( 'preserves %s for now', ( _name, makeError ) => {
|
|
275
|
+
const error = makeError();
|
|
276
|
+
|
|
277
|
+
expect( mapAiError( error ) ).toBe( error );
|
|
278
|
+
} );
|
|
279
|
+
|
|
280
|
+
it( 'preserves ordinary errors', () => {
|
|
281
|
+
const error = new Error( 'Network exploded' );
|
|
282
|
+
|
|
283
|
+
expect( mapAiError( error ) ).toBe( error );
|
|
284
|
+
} );
|
|
285
|
+
} );
|
package/src/validations.js
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import { Buffer } from 'node:buffer';
|
|
1
2
|
import { ValidationError, z } from '@outputai/core';
|
|
2
3
|
|
|
3
4
|
const skillArgSchema = z.object( {
|
|
@@ -14,6 +15,44 @@ const generateTextArgsSchema = z.object( {
|
|
|
14
15
|
maxSteps: z.number().int().positive().optional()
|
|
15
16
|
} );
|
|
16
17
|
|
|
18
|
+
const base64StringSchema = z.string()
|
|
19
|
+
.min( 1 )
|
|
20
|
+
.regex(
|
|
21
|
+
/^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}(?:==)?|[A-Za-z0-9+/]{3}=?)?$/,
|
|
22
|
+
'Image strings must be raw base64 data.'
|
|
23
|
+
);
|
|
24
|
+
|
|
25
|
+
const imageDataSchema = z.union( [
|
|
26
|
+
z.instanceof( Buffer ),
|
|
27
|
+
z.instanceof( Uint8Array ),
|
|
28
|
+
z.instanceof( ArrayBuffer ),
|
|
29
|
+
base64StringSchema
|
|
30
|
+
] );
|
|
31
|
+
|
|
32
|
+
const imageInputSchema = z.union( [
|
|
33
|
+
imageDataSchema,
|
|
34
|
+
z.object( {
|
|
35
|
+
data: imageDataSchema,
|
|
36
|
+
mediaType: z.string().min( 1 ).optional()
|
|
37
|
+
} ).strict()
|
|
38
|
+
] );
|
|
39
|
+
|
|
40
|
+
const generateImageArgsSchema = z.object( {
|
|
41
|
+
prompt: z.string().min( 1 ),
|
|
42
|
+
variables: z.any().optional(),
|
|
43
|
+
promptDir: z.string().min( 1 ).optional(),
|
|
44
|
+
images: z.array( imageInputSchema ).min( 1 ).optional(),
|
|
45
|
+
mask: imageInputSchema.optional()
|
|
46
|
+
} ).superRefine( ( args, ctx ) => {
|
|
47
|
+
if ( args.mask && !args.images ) {
|
|
48
|
+
ctx.addIssue( {
|
|
49
|
+
code: 'custom',
|
|
50
|
+
path: [ 'mask' ],
|
|
51
|
+
message: 'mask requires images.'
|
|
52
|
+
} );
|
|
53
|
+
}
|
|
54
|
+
} );
|
|
55
|
+
|
|
17
56
|
function validateSchema( schema, input, errorPrefix ) {
|
|
18
57
|
const result = schema.safeParse( input );
|
|
19
58
|
if ( !result.success ) {
|
|
@@ -28,3 +67,7 @@ export function validateGenerateTextArgs( args ) {
|
|
|
28
67
|
export function validateStreamTextArgs( args ) {
|
|
29
68
|
validateSchema( generateTextArgsSchema, args, 'Invalid streamText() arguments' );
|
|
30
69
|
}
|
|
70
|
+
|
|
71
|
+
export function validateGenerateImageArgs( args ) {
|
|
72
|
+
validateSchema( generateImageArgsSchema, args, 'Invalid generateImage() arguments' );
|
|
73
|
+
}
|
package/src/validations.spec.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { describe, it, expect } from 'vitest';
|
|
2
2
|
import { ValidationError } from '@outputai/core';
|
|
3
|
-
import { validateGenerateTextArgs, validateStreamTextArgs } from './validations.js';
|
|
3
|
+
import { validateGenerateTextArgs, validateStreamTextArgs, validateGenerateImageArgs } from './validations.js';
|
|
4
4
|
|
|
5
5
|
describe( 'validateGenerateTextArgs', () => {
|
|
6
6
|
it( 'accepts a prompt with optional variables', () => {
|
|
@@ -88,3 +88,79 @@ describe( 'validateStreamTextArgs', () => {
|
|
|
88
88
|
} ) ).toThrow( ValidationError );
|
|
89
89
|
} );
|
|
90
90
|
} );
|
|
91
|
+
|
|
92
|
+
describe( 'validateGenerateImageArgs', () => {
|
|
93
|
+
it( 'accepts text-to-image args without images or mask', () => {
|
|
94
|
+
expect( () => validateGenerateImageArgs( {
|
|
95
|
+
prompt: 'image@v1',
|
|
96
|
+
variables: { topic: 'race cars' },
|
|
97
|
+
promptDir: '/prompts'
|
|
98
|
+
} ) ).not.toThrow();
|
|
99
|
+
} );
|
|
100
|
+
|
|
101
|
+
it( 'accepts all supported image input shapes', () => {
|
|
102
|
+
const buffer = Buffer.from( 'image-bytes' );
|
|
103
|
+
const uint8Array = new Uint8Array( [ 1, 2, 3 ] );
|
|
104
|
+
const arrayBuffer = new ArrayBuffer( 3 );
|
|
105
|
+
const paddedBase64 = 'aW1hZ2UtYnl0ZXM=';
|
|
106
|
+
const unpaddedBase64 = 'aW1hZ2U';
|
|
107
|
+
|
|
108
|
+
expect( () => validateGenerateImageArgs( {
|
|
109
|
+
prompt: 'image@v1',
|
|
110
|
+
images: [
|
|
111
|
+
buffer,
|
|
112
|
+
uint8Array,
|
|
113
|
+
arrayBuffer,
|
|
114
|
+
paddedBase64,
|
|
115
|
+
unpaddedBase64,
|
|
116
|
+
{ data: buffer, mediaType: 'image/png' },
|
|
117
|
+
{ data: uint8Array },
|
|
118
|
+
{ data: arrayBuffer, mediaType: 'image/jpeg' },
|
|
119
|
+
{ data: paddedBase64, mediaType: 'image/webp' }
|
|
120
|
+
],
|
|
121
|
+
mask: { data: Buffer.from( 'mask-bytes' ), mediaType: 'image/png' }
|
|
122
|
+
} ) ).not.toThrow();
|
|
123
|
+
} );
|
|
124
|
+
|
|
125
|
+
it( 'throws ValidationError for invalid image args', () => {
|
|
126
|
+
expect( () => validateGenerateImageArgs( {
|
|
127
|
+
prompt: ''
|
|
128
|
+
} ) ).toThrow( /Invalid generateImage\(\) arguments/ );
|
|
129
|
+
|
|
130
|
+
expect( () => validateGenerateImageArgs( {
|
|
131
|
+
prompt: 'image@v1',
|
|
132
|
+
images: []
|
|
133
|
+
} ) ).toThrow( ValidationError );
|
|
134
|
+
|
|
135
|
+
expect( () => validateGenerateImageArgs( {
|
|
136
|
+
prompt: 'image@v1',
|
|
137
|
+
images: [ { data: null } ]
|
|
138
|
+
} ) ).toThrow( ValidationError );
|
|
139
|
+
|
|
140
|
+
expect( () => validateGenerateImageArgs( {
|
|
141
|
+
prompt: 'image@v1',
|
|
142
|
+
images: [ { data: 'aW1hZ2U=', mediaType: '' } ]
|
|
143
|
+
} ) ).toThrow( ValidationError );
|
|
144
|
+
} );
|
|
145
|
+
|
|
146
|
+
it( 'rejects image strings that are not raw base64 data', () => {
|
|
147
|
+
for ( const image of [
|
|
148
|
+
'https://example.com/image.png',
|
|
149
|
+
'data:image/png;base64,aW1hZ2U=',
|
|
150
|
+
'not base64',
|
|
151
|
+
'abcde'
|
|
152
|
+
] ) {
|
|
153
|
+
expect( () => validateGenerateImageArgs( {
|
|
154
|
+
prompt: 'image@v1',
|
|
155
|
+
images: [ image ]
|
|
156
|
+
} ) ).toThrow( /Image strings must be raw base64 data/ );
|
|
157
|
+
}
|
|
158
|
+
} );
|
|
159
|
+
|
|
160
|
+
it( 'requires images when mask is provided', () => {
|
|
161
|
+
expect( () => validateGenerateImageArgs( {
|
|
162
|
+
prompt: 'image@v1',
|
|
163
|
+
mask: Buffer.from( 'mask-bytes' )
|
|
164
|
+
} ) ).toThrow( /mask requires images/ );
|
|
165
|
+
} );
|
|
166
|
+
} );
|