@ai-sdk/openai-compatible 2.0.15 → 2.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +12 -0
- package/dist/index.d.mts +5 -0
- package/dist/index.d.ts +5 -0
- package/dist/index.js +23 -6
- package/dist/index.js.map +1 -1
- package/dist/index.mjs +23 -6
- package/dist/index.mjs.map +1 -1
- package/package.json +3 -2
- package/src/chat/convert-openai-compatible-chat-usage.ts +55 -0
- package/src/chat/convert-to-openai-compatible-chat-messages.test.ts +1238 -0
- package/src/chat/convert-to-openai-compatible-chat-messages.ts +246 -0
- package/src/chat/get-response-metadata.ts +15 -0
- package/src/chat/map-openai-compatible-finish-reason.ts +19 -0
- package/src/chat/openai-compatible-api-types.ts +86 -0
- package/src/chat/openai-compatible-chat-language-model.test.ts +3292 -0
- package/src/chat/openai-compatible-chat-language-model.ts +830 -0
- package/src/chat/openai-compatible-chat-options.ts +34 -0
- package/src/chat/openai-compatible-metadata-extractor.ts +48 -0
- package/src/chat/openai-compatible-prepare-tools.test.ts +336 -0
- package/src/chat/openai-compatible-prepare-tools.ts +98 -0
- package/src/completion/convert-openai-compatible-completion-usage.ts +46 -0
- package/src/completion/convert-to-openai-compatible-completion-prompt.ts +93 -0
- package/src/completion/get-response-metadata.ts +15 -0
- package/src/completion/map-openai-compatible-finish-reason.ts +19 -0
- package/src/completion/openai-compatible-completion-language-model.test.ts +773 -0
- package/src/completion/openai-compatible-completion-language-model.ts +390 -0
- package/src/completion/openai-compatible-completion-options.ts +33 -0
- package/src/embedding/openai-compatible-embedding-model.test.ts +171 -0
- package/src/embedding/openai-compatible-embedding-model.ts +166 -0
- package/src/embedding/openai-compatible-embedding-options.ts +21 -0
- package/src/image/openai-compatible-image-model.test.ts +494 -0
- package/src/image/openai-compatible-image-model.ts +205 -0
- package/src/image/openai-compatible-image-settings.ts +1 -0
- package/src/index.ts +27 -0
- package/src/internal/index.ts +4 -0
- package/src/openai-compatible-error.ts +30 -0
- package/src/openai-compatible-provider.test.ts +329 -0
- package/src/openai-compatible-provider.ts +189 -0
- package/src/version.ts +5 -0
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import {
|
|
2
|
+
EmbeddingModelV3,
|
|
3
|
+
SharedV3Warning,
|
|
4
|
+
TooManyEmbeddingValuesForCallError,
|
|
5
|
+
} from '@ai-sdk/provider';
|
|
6
|
+
import {
|
|
7
|
+
combineHeaders,
|
|
8
|
+
createJsonErrorResponseHandler,
|
|
9
|
+
createJsonResponseHandler,
|
|
10
|
+
FetchFunction,
|
|
11
|
+
parseProviderOptions,
|
|
12
|
+
postJsonToApi,
|
|
13
|
+
} from '@ai-sdk/provider-utils';
|
|
14
|
+
import { z } from 'zod/v4';
|
|
15
|
+
import {
|
|
16
|
+
OpenAICompatibleEmbeddingModelId,
|
|
17
|
+
openaiCompatibleEmbeddingProviderOptions,
|
|
18
|
+
} from './openai-compatible-embedding-options';
|
|
19
|
+
import {
|
|
20
|
+
defaultOpenAICompatibleErrorStructure,
|
|
21
|
+
ProviderErrorStructure,
|
|
22
|
+
} from '../openai-compatible-error';
|
|
23
|
+
|
|
24
|
+
type OpenAICompatibleEmbeddingConfig = {
|
|
25
|
+
/**
|
|
26
|
+
Override the maximum number of embeddings per call.
|
|
27
|
+
*/
|
|
28
|
+
maxEmbeddingsPerCall?: number;
|
|
29
|
+
|
|
30
|
+
/**
|
|
31
|
+
Override the parallelism of embedding calls.
|
|
32
|
+
*/
|
|
33
|
+
supportsParallelCalls?: boolean;
|
|
34
|
+
|
|
35
|
+
provider: string;
|
|
36
|
+
url: (options: { modelId: string; path: string }) => string;
|
|
37
|
+
headers: () => Record<string, string | undefined>;
|
|
38
|
+
fetch?: FetchFunction;
|
|
39
|
+
errorStructure?: ProviderErrorStructure<any>;
|
|
40
|
+
};
|
|
41
|
+
|
|
42
|
+
export class OpenAICompatibleEmbeddingModel implements EmbeddingModelV3 {
|
|
43
|
+
readonly specificationVersion = 'v3';
|
|
44
|
+
readonly modelId: OpenAICompatibleEmbeddingModelId;
|
|
45
|
+
|
|
46
|
+
private readonly config: OpenAICompatibleEmbeddingConfig;
|
|
47
|
+
|
|
48
|
+
get provider(): string {
|
|
49
|
+
return this.config.provider;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
get maxEmbeddingsPerCall(): number {
|
|
53
|
+
return this.config.maxEmbeddingsPerCall ?? 2048;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
get supportsParallelCalls(): boolean {
|
|
57
|
+
return this.config.supportsParallelCalls ?? true;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
constructor(
|
|
61
|
+
modelId: OpenAICompatibleEmbeddingModelId,
|
|
62
|
+
config: OpenAICompatibleEmbeddingConfig,
|
|
63
|
+
) {
|
|
64
|
+
this.modelId = modelId;
|
|
65
|
+
this.config = config;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
private get providerOptionsName(): string {
|
|
69
|
+
return this.config.provider.split('.')[0].trim();
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
async doEmbed({
|
|
73
|
+
values,
|
|
74
|
+
headers,
|
|
75
|
+
abortSignal,
|
|
76
|
+
providerOptions,
|
|
77
|
+
}: Parameters<EmbeddingModelV3['doEmbed']>[0]): Promise<
|
|
78
|
+
Awaited<ReturnType<EmbeddingModelV3['doEmbed']>>
|
|
79
|
+
> {
|
|
80
|
+
const warnings: SharedV3Warning[] = [];
|
|
81
|
+
|
|
82
|
+
// Parse provider options - check for deprecated 'openai-compatible' key
|
|
83
|
+
const deprecatedOptions = await parseProviderOptions({
|
|
84
|
+
provider: 'openai-compatible',
|
|
85
|
+
providerOptions,
|
|
86
|
+
schema: openaiCompatibleEmbeddingProviderOptions,
|
|
87
|
+
});
|
|
88
|
+
|
|
89
|
+
if (deprecatedOptions != null) {
|
|
90
|
+
warnings.push({
|
|
91
|
+
type: 'other',
|
|
92
|
+
message: `The 'openai-compatible' key in providerOptions is deprecated. Use 'openaiCompatible' instead.`,
|
|
93
|
+
});
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
const compatibleOptions = Object.assign(
|
|
97
|
+
deprecatedOptions ?? {},
|
|
98
|
+
(await parseProviderOptions({
|
|
99
|
+
provider: 'openaiCompatible',
|
|
100
|
+
providerOptions,
|
|
101
|
+
schema: openaiCompatibleEmbeddingProviderOptions,
|
|
102
|
+
})) ?? {},
|
|
103
|
+
(await parseProviderOptions({
|
|
104
|
+
provider: this.providerOptionsName,
|
|
105
|
+
providerOptions,
|
|
106
|
+
schema: openaiCompatibleEmbeddingProviderOptions,
|
|
107
|
+
})) ?? {},
|
|
108
|
+
);
|
|
109
|
+
|
|
110
|
+
if (values.length > this.maxEmbeddingsPerCall) {
|
|
111
|
+
throw new TooManyEmbeddingValuesForCallError({
|
|
112
|
+
provider: this.provider,
|
|
113
|
+
modelId: this.modelId,
|
|
114
|
+
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
|
|
115
|
+
values,
|
|
116
|
+
});
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
const {
|
|
120
|
+
responseHeaders,
|
|
121
|
+
value: response,
|
|
122
|
+
rawValue,
|
|
123
|
+
} = await postJsonToApi({
|
|
124
|
+
url: this.config.url({
|
|
125
|
+
path: '/embeddings',
|
|
126
|
+
modelId: this.modelId,
|
|
127
|
+
}),
|
|
128
|
+
headers: combineHeaders(this.config.headers(), headers),
|
|
129
|
+
body: {
|
|
130
|
+
model: this.modelId,
|
|
131
|
+
input: values,
|
|
132
|
+
encoding_format: 'float',
|
|
133
|
+
dimensions: compatibleOptions.dimensions,
|
|
134
|
+
user: compatibleOptions.user,
|
|
135
|
+
},
|
|
136
|
+
failedResponseHandler: createJsonErrorResponseHandler(
|
|
137
|
+
this.config.errorStructure ?? defaultOpenAICompatibleErrorStructure,
|
|
138
|
+
),
|
|
139
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
140
|
+
openaiTextEmbeddingResponseSchema,
|
|
141
|
+
),
|
|
142
|
+
abortSignal,
|
|
143
|
+
fetch: this.config.fetch,
|
|
144
|
+
});
|
|
145
|
+
|
|
146
|
+
return {
|
|
147
|
+
warnings,
|
|
148
|
+
embeddings: response.data.map(item => item.embedding),
|
|
149
|
+
usage: response.usage
|
|
150
|
+
? { tokens: response.usage.prompt_tokens }
|
|
151
|
+
: undefined,
|
|
152
|
+
providerMetadata: response.providerMetadata,
|
|
153
|
+
response: { headers: responseHeaders, body: rawValue },
|
|
154
|
+
};
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
// minimal version of the schema, focussed on what is needed for the implementation
|
|
159
|
+
// this approach limits breakages when the API changes and increases efficiency
|
|
160
|
+
const openaiTextEmbeddingResponseSchema = z.object({
|
|
161
|
+
data: z.array(z.object({ embedding: z.array(z.number()) })),
|
|
162
|
+
usage: z.object({ prompt_tokens: z.number() }).nullish(),
|
|
163
|
+
providerMetadata: z
|
|
164
|
+
.record(z.string(), z.record(z.string(), z.any()))
|
|
165
|
+
.optional(),
|
|
166
|
+
});
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import { z } from 'zod/v4';
|
|
2
|
+
|
|
3
|
+
export type OpenAICompatibleEmbeddingModelId = string;
|
|
4
|
+
|
|
5
|
+
export const openaiCompatibleEmbeddingProviderOptions = z.object({
|
|
6
|
+
/**
|
|
7
|
+
* The number of dimensions the resulting output embeddings should have.
|
|
8
|
+
* Only supported in text-embedding-3 and later models.
|
|
9
|
+
*/
|
|
10
|
+
dimensions: z.number().optional(),
|
|
11
|
+
|
|
12
|
+
/**
|
|
13
|
+
* A unique identifier representing your end-user, which can help providers to
|
|
14
|
+
* monitor and detect abuse.
|
|
15
|
+
*/
|
|
16
|
+
user: z.string().optional(),
|
|
17
|
+
});
|
|
18
|
+
|
|
19
|
+
export type OpenAICompatibleEmbeddingProviderOptions = z.infer<
|
|
20
|
+
typeof openaiCompatibleEmbeddingProviderOptions
|
|
21
|
+
>;
|
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
import { describe, it, expect } from 'vitest';
|
|
2
|
+
import { FetchFunction } from '@ai-sdk/provider-utils';
|
|
3
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
4
|
+
import { OpenAICompatibleImageModel } from './openai-compatible-image-model';
|
|
5
|
+
import { z } from 'zod/v4';
|
|
6
|
+
import { ProviderErrorStructure } from '../openai-compatible-error';
|
|
7
|
+
import { ImageModelV3CallOptions } from '@ai-sdk/provider';
|
|
8
|
+
|
|
9
|
+
const prompt = 'A photorealistic astronaut riding a horse';
|
|
10
|
+
|
|
11
|
+
function createBasicModel({
|
|
12
|
+
headers,
|
|
13
|
+
fetch,
|
|
14
|
+
currentDate,
|
|
15
|
+
errorStructure,
|
|
16
|
+
}: {
|
|
17
|
+
headers?: () => Record<string, string | undefined>;
|
|
18
|
+
fetch?: FetchFunction;
|
|
19
|
+
currentDate?: () => Date;
|
|
20
|
+
errorStructure?: ProviderErrorStructure<any>;
|
|
21
|
+
} = {}) {
|
|
22
|
+
return new OpenAICompatibleImageModel('dall-e-3', {
|
|
23
|
+
provider: 'openai-compatible',
|
|
24
|
+
headers: headers ?? (() => ({ Authorization: 'Bearer test-key' })),
|
|
25
|
+
url: ({ modelId, path }) => `https://api.example.com/${modelId}${path}`,
|
|
26
|
+
fetch,
|
|
27
|
+
errorStructure,
|
|
28
|
+
_internal: {
|
|
29
|
+
currentDate,
|
|
30
|
+
},
|
|
31
|
+
});
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
function createDefaultGenerateParams(overrides = {}): ImageModelV3CallOptions {
|
|
35
|
+
return {
|
|
36
|
+
prompt: 'A photorealistic astronaut riding a horse',
|
|
37
|
+
files: undefined,
|
|
38
|
+
mask: undefined,
|
|
39
|
+
n: 1,
|
|
40
|
+
size: '1024x1024',
|
|
41
|
+
aspectRatio: undefined,
|
|
42
|
+
seed: undefined,
|
|
43
|
+
providerOptions: {},
|
|
44
|
+
headers: {},
|
|
45
|
+
abortSignal: undefined,
|
|
46
|
+
...overrides,
|
|
47
|
+
};
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
describe('OpenAICompatibleImageModel', () => {
|
|
51
|
+
const server = createTestServer({
|
|
52
|
+
'https://api.example.com/dall-e-3/images/generations': {
|
|
53
|
+
response: {
|
|
54
|
+
type: 'json-value',
|
|
55
|
+
body: {
|
|
56
|
+
data: [
|
|
57
|
+
{
|
|
58
|
+
b64_json: 'test1234',
|
|
59
|
+
},
|
|
60
|
+
{
|
|
61
|
+
b64_json: 'test5678',
|
|
62
|
+
},
|
|
63
|
+
],
|
|
64
|
+
},
|
|
65
|
+
},
|
|
66
|
+
},
|
|
67
|
+
'https://external.api.recraft.ai/v1/images/generations': {
|
|
68
|
+
response: {
|
|
69
|
+
type: 'json-value',
|
|
70
|
+
body: {
|
|
71
|
+
data: [
|
|
72
|
+
{
|
|
73
|
+
b64_json: 'recraft-test-image',
|
|
74
|
+
},
|
|
75
|
+
],
|
|
76
|
+
},
|
|
77
|
+
},
|
|
78
|
+
},
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
describe('constructor', () => {
|
|
82
|
+
it('should expose correct provider and model information', () => {
|
|
83
|
+
const model = createBasicModel();
|
|
84
|
+
|
|
85
|
+
expect(model.provider).toBe('openai-compatible');
|
|
86
|
+
expect(model.modelId).toBe('dall-e-3');
|
|
87
|
+
expect(model.specificationVersion).toBe('v3');
|
|
88
|
+
expect(model.maxImagesPerCall).toBe(10);
|
|
89
|
+
});
|
|
90
|
+
});
|
|
91
|
+
|
|
92
|
+
describe('doGenerate', () => {
|
|
93
|
+
it('should pass the correct parameters', async () => {
|
|
94
|
+
const model = createBasicModel();
|
|
95
|
+
|
|
96
|
+
await model.doGenerate(
|
|
97
|
+
createDefaultGenerateParams({
|
|
98
|
+
n: 2,
|
|
99
|
+
providerOptions: { openaiCompatible: { quality: 'hd' } },
|
|
100
|
+
}),
|
|
101
|
+
);
|
|
102
|
+
|
|
103
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
104
|
+
model: 'dall-e-3',
|
|
105
|
+
prompt,
|
|
106
|
+
n: 2,
|
|
107
|
+
size: '1024x1024',
|
|
108
|
+
quality: 'hd',
|
|
109
|
+
response_format: 'b64_json',
|
|
110
|
+
});
|
|
111
|
+
});
|
|
112
|
+
|
|
113
|
+
it('should use provider name from config for providerOptions key', async () => {
|
|
114
|
+
const recraftModel = new OpenAICompatibleImageModel('recraft-v3', {
|
|
115
|
+
provider: 'recraft.image',
|
|
116
|
+
headers: () => ({ Authorization: 'Bearer test-key' }),
|
|
117
|
+
url: ({ modelId, path }) => `https://external.api.recraft.ai/v1${path}`,
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
await recraftModel.doGenerate(
|
|
121
|
+
createDefaultGenerateParams({
|
|
122
|
+
prompt: 'A beautiful sunset',
|
|
123
|
+
providerOptions: { recraft: { style: 'vector_illustration' } },
|
|
124
|
+
}),
|
|
125
|
+
);
|
|
126
|
+
|
|
127
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
128
|
+
model: 'recraft-v3',
|
|
129
|
+
prompt: 'A beautiful sunset',
|
|
130
|
+
n: 1,
|
|
131
|
+
size: '1024x1024',
|
|
132
|
+
style: 'vector_illustration',
|
|
133
|
+
response_format: 'b64_json',
|
|
134
|
+
});
|
|
135
|
+
});
|
|
136
|
+
|
|
137
|
+
it('should add warnings for unsupported settings', async () => {
|
|
138
|
+
const model = createBasicModel();
|
|
139
|
+
|
|
140
|
+
const result = await model.doGenerate(
|
|
141
|
+
createDefaultGenerateParams({
|
|
142
|
+
aspectRatio: '16:9',
|
|
143
|
+
seed: 123,
|
|
144
|
+
}),
|
|
145
|
+
);
|
|
146
|
+
|
|
147
|
+
expect(result.warnings).toMatchInlineSnapshot(`
|
|
148
|
+
[
|
|
149
|
+
{
|
|
150
|
+
"details": "This model does not support aspect ratio. Use \`size\` instead.",
|
|
151
|
+
"feature": "aspectRatio",
|
|
152
|
+
"type": "unsupported",
|
|
153
|
+
},
|
|
154
|
+
{
|
|
155
|
+
"feature": "seed",
|
|
156
|
+
"type": "unsupported",
|
|
157
|
+
},
|
|
158
|
+
]
|
|
159
|
+
`);
|
|
160
|
+
});
|
|
161
|
+
|
|
162
|
+
it('should pass headers', async () => {
|
|
163
|
+
const modelWithHeaders = createBasicModel({
|
|
164
|
+
headers: () => ({
|
|
165
|
+
'Custom-Provider-Header': 'provider-header-value',
|
|
166
|
+
}),
|
|
167
|
+
});
|
|
168
|
+
|
|
169
|
+
await modelWithHeaders.doGenerate(
|
|
170
|
+
createDefaultGenerateParams({
|
|
171
|
+
headers: {
|
|
172
|
+
'Custom-Request-Header': 'request-header-value',
|
|
173
|
+
},
|
|
174
|
+
}),
|
|
175
|
+
);
|
|
176
|
+
|
|
177
|
+
expect(server.calls[0].requestHeaders).toStrictEqual({
|
|
178
|
+
'content-type': 'application/json',
|
|
179
|
+
'custom-provider-header': 'provider-header-value',
|
|
180
|
+
'custom-request-header': 'request-header-value',
|
|
181
|
+
});
|
|
182
|
+
});
|
|
183
|
+
|
|
184
|
+
it('should handle API errors with custom error structure', async () => {
|
|
185
|
+
// Define a custom error schema different from OpenAI's format
|
|
186
|
+
const customErrorSchema = z.object({
|
|
187
|
+
status: z.string(),
|
|
188
|
+
details: z.object({
|
|
189
|
+
errorMessage: z.string(),
|
|
190
|
+
errorCode: z.number(),
|
|
191
|
+
}),
|
|
192
|
+
});
|
|
193
|
+
|
|
194
|
+
server.urls[
|
|
195
|
+
'https://api.example.com/dall-e-3/images/generations'
|
|
196
|
+
].response = {
|
|
197
|
+
type: 'error',
|
|
198
|
+
status: 400,
|
|
199
|
+
body: JSON.stringify({
|
|
200
|
+
status: 'error',
|
|
201
|
+
details: {
|
|
202
|
+
errorMessage: 'Custom provider error format',
|
|
203
|
+
errorCode: 1234,
|
|
204
|
+
},
|
|
205
|
+
}),
|
|
206
|
+
};
|
|
207
|
+
|
|
208
|
+
const model = createBasicModel({
|
|
209
|
+
errorStructure: {
|
|
210
|
+
errorSchema: customErrorSchema,
|
|
211
|
+
errorToMessage: data =>
|
|
212
|
+
`Error ${data.details.errorCode}: ${data.details.errorMessage}`,
|
|
213
|
+
},
|
|
214
|
+
});
|
|
215
|
+
|
|
216
|
+
await expect(
|
|
217
|
+
model.doGenerate(createDefaultGenerateParams()),
|
|
218
|
+
).rejects.toMatchObject({
|
|
219
|
+
message: 'Error 1234: Custom provider error format',
|
|
220
|
+
statusCode: 400,
|
|
221
|
+
url: 'https://api.example.com/dall-e-3/images/generations',
|
|
222
|
+
});
|
|
223
|
+
});
|
|
224
|
+
|
|
225
|
+
it('should handle API errors with default error structure', async () => {
|
|
226
|
+
server.urls[
|
|
227
|
+
'https://api.example.com/dall-e-3/images/generations'
|
|
228
|
+
].response = {
|
|
229
|
+
type: 'error',
|
|
230
|
+
status: 400,
|
|
231
|
+
body: JSON.stringify({
|
|
232
|
+
error: {
|
|
233
|
+
message: 'Invalid prompt content',
|
|
234
|
+
type: 'invalid_request_error',
|
|
235
|
+
param: null,
|
|
236
|
+
code: null,
|
|
237
|
+
},
|
|
238
|
+
}),
|
|
239
|
+
};
|
|
240
|
+
|
|
241
|
+
const model = createBasicModel();
|
|
242
|
+
|
|
243
|
+
await expect(
|
|
244
|
+
model.doGenerate(createDefaultGenerateParams()),
|
|
245
|
+
).rejects.toMatchObject({
|
|
246
|
+
message: 'Invalid prompt content',
|
|
247
|
+
statusCode: 400,
|
|
248
|
+
url: 'https://api.example.com/dall-e-3/images/generations',
|
|
249
|
+
});
|
|
250
|
+
});
|
|
251
|
+
|
|
252
|
+
it('should return the raw b64_json content', async () => {
|
|
253
|
+
const model = createBasicModel();
|
|
254
|
+
const result = await model.doGenerate(
|
|
255
|
+
createDefaultGenerateParams({
|
|
256
|
+
n: 2,
|
|
257
|
+
}),
|
|
258
|
+
);
|
|
259
|
+
|
|
260
|
+
expect(result.images).toHaveLength(2);
|
|
261
|
+
expect(result.images[0]).toBe('test1234');
|
|
262
|
+
expect(result.images[1]).toBe('test5678');
|
|
263
|
+
});
|
|
264
|
+
|
|
265
|
+
describe('response metadata', () => {
|
|
266
|
+
it('should include timestamp, headers and modelId in response', async () => {
|
|
267
|
+
const testDate = new Date('2024-01-01T00:00:00Z');
|
|
268
|
+
const model = createBasicModel({
|
|
269
|
+
currentDate: () => testDate,
|
|
270
|
+
});
|
|
271
|
+
|
|
272
|
+
const result = await model.doGenerate(createDefaultGenerateParams());
|
|
273
|
+
|
|
274
|
+
expect(result.response).toStrictEqual({
|
|
275
|
+
timestamp: testDate,
|
|
276
|
+
modelId: 'dall-e-3',
|
|
277
|
+
headers: expect.any(Object),
|
|
278
|
+
});
|
|
279
|
+
});
|
|
280
|
+
});
|
|
281
|
+
|
|
282
|
+
it('should use real date when no custom date provider is specified', async () => {
|
|
283
|
+
const beforeDate = new Date();
|
|
284
|
+
|
|
285
|
+
const model = new OpenAICompatibleImageModel('dall-e-3', {
|
|
286
|
+
provider: 'openai-compatible',
|
|
287
|
+
headers: () => ({ Authorization: 'Bearer test-key' }),
|
|
288
|
+
url: ({ modelId, path }) => `https://api.example.com/${modelId}${path}`,
|
|
289
|
+
});
|
|
290
|
+
|
|
291
|
+
const result = await model.doGenerate(createDefaultGenerateParams());
|
|
292
|
+
|
|
293
|
+
const afterDate = new Date();
|
|
294
|
+
|
|
295
|
+
expect(result.response.timestamp.getTime()).toBeGreaterThanOrEqual(
|
|
296
|
+
beforeDate.getTime(),
|
|
297
|
+
);
|
|
298
|
+
expect(result.response.timestamp.getTime()).toBeLessThanOrEqual(
|
|
299
|
+
afterDate.getTime(),
|
|
300
|
+
);
|
|
301
|
+
expect(result.response.modelId).toBe('dall-e-3');
|
|
302
|
+
});
|
|
303
|
+
|
|
304
|
+
it('should pass the user setting in the request', async () => {
|
|
305
|
+
const model = createBasicModel();
|
|
306
|
+
|
|
307
|
+
await model.doGenerate(
|
|
308
|
+
createDefaultGenerateParams({
|
|
309
|
+
providerOptions: {
|
|
310
|
+
openaiCompatible: {
|
|
311
|
+
user: 'test-user-id',
|
|
312
|
+
},
|
|
313
|
+
},
|
|
314
|
+
}),
|
|
315
|
+
);
|
|
316
|
+
|
|
317
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
318
|
+
model: 'dall-e-3',
|
|
319
|
+
prompt,
|
|
320
|
+
n: 1,
|
|
321
|
+
size: '1024x1024',
|
|
322
|
+
user: 'test-user-id',
|
|
323
|
+
response_format: 'b64_json',
|
|
324
|
+
});
|
|
325
|
+
});
|
|
326
|
+
|
|
327
|
+
it('should not include user field in request when not set via provider options', async () => {
|
|
328
|
+
const model = createBasicModel();
|
|
329
|
+
|
|
330
|
+
await model.doGenerate(
|
|
331
|
+
createDefaultGenerateParams({
|
|
332
|
+
providerOptions: {
|
|
333
|
+
openaiCompatible: {},
|
|
334
|
+
},
|
|
335
|
+
}),
|
|
336
|
+
);
|
|
337
|
+
|
|
338
|
+
const requestBody = await server.calls[0].requestBodyJson;
|
|
339
|
+
expect(requestBody).toStrictEqual({
|
|
340
|
+
model: 'dall-e-3',
|
|
341
|
+
prompt,
|
|
342
|
+
n: 1,
|
|
343
|
+
size: '1024x1024',
|
|
344
|
+
response_format: 'b64_json',
|
|
345
|
+
});
|
|
346
|
+
expect(requestBody).not.toHaveProperty('user');
|
|
347
|
+
});
|
|
348
|
+
});
|
|
349
|
+
|
|
350
|
+
describe('Image Editing', () => {
|
|
351
|
+
const editServer = createTestServer({
|
|
352
|
+
'https://api.example.com/dall-e-3/images/edits': {
|
|
353
|
+
response: {
|
|
354
|
+
type: 'json-value',
|
|
355
|
+
body: {
|
|
356
|
+
data: [{ b64_json: 'edited-image-base64' }],
|
|
357
|
+
},
|
|
358
|
+
},
|
|
359
|
+
},
|
|
360
|
+
});
|
|
361
|
+
|
|
362
|
+
it('should send edit request with files', async () => {
|
|
363
|
+
const model = createBasicModel();
|
|
364
|
+
|
|
365
|
+
// Use Uint8Array for test data to avoid base64 decoding issues
|
|
366
|
+
const imageData = new Uint8Array([137, 80, 78, 71]); // PNG magic bytes
|
|
367
|
+
|
|
368
|
+
const result = await model.doGenerate(
|
|
369
|
+
createDefaultGenerateParams({
|
|
370
|
+
prompt: 'Turn the cat into a dog',
|
|
371
|
+
files: [
|
|
372
|
+
{
|
|
373
|
+
type: 'file',
|
|
374
|
+
data: imageData,
|
|
375
|
+
mediaType: 'image/png',
|
|
376
|
+
},
|
|
377
|
+
],
|
|
378
|
+
}),
|
|
379
|
+
);
|
|
380
|
+
|
|
381
|
+
expect(result.images).toStrictEqual(['edited-image-base64']);
|
|
382
|
+
expect(editServer.calls[0].requestUrl).toBe(
|
|
383
|
+
'https://api.example.com/dall-e-3/images/edits',
|
|
384
|
+
);
|
|
385
|
+
});
|
|
386
|
+
|
|
387
|
+
it('should send edit request with files and mask', async () => {
|
|
388
|
+
const model = createBasicModel();
|
|
389
|
+
|
|
390
|
+
// Use Uint8Array for test data
|
|
391
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
392
|
+
const maskData = new Uint8Array([137, 80, 78, 71]);
|
|
393
|
+
|
|
394
|
+
const result = await model.doGenerate(
|
|
395
|
+
createDefaultGenerateParams({
|
|
396
|
+
prompt: 'Add a flamingo to the pool',
|
|
397
|
+
files: [
|
|
398
|
+
{
|
|
399
|
+
type: 'file',
|
|
400
|
+
data: imageData,
|
|
401
|
+
mediaType: 'image/png',
|
|
402
|
+
},
|
|
403
|
+
],
|
|
404
|
+
mask: {
|
|
405
|
+
type: 'file',
|
|
406
|
+
data: maskData,
|
|
407
|
+
mediaType: 'image/png',
|
|
408
|
+
},
|
|
409
|
+
}),
|
|
410
|
+
);
|
|
411
|
+
|
|
412
|
+
expect(result.images).toStrictEqual(['edited-image-base64']);
|
|
413
|
+
expect(editServer.calls[0].requestUrl).toBe(
|
|
414
|
+
'https://api.example.com/dall-e-3/images/edits',
|
|
415
|
+
);
|
|
416
|
+
});
|
|
417
|
+
|
|
418
|
+
it('should send edit request with Uint8Array data', async () => {
|
|
419
|
+
const model = createBasicModel();
|
|
420
|
+
|
|
421
|
+
const imageUint8Array = new Uint8Array([104, 101, 108, 108, 111]);
|
|
422
|
+
|
|
423
|
+
const result = await model.doGenerate(
|
|
424
|
+
createDefaultGenerateParams({
|
|
425
|
+
prompt: 'Edit this image',
|
|
426
|
+
files: [
|
|
427
|
+
{
|
|
428
|
+
type: 'file',
|
|
429
|
+
data: imageUint8Array,
|
|
430
|
+
mediaType: 'image/png',
|
|
431
|
+
},
|
|
432
|
+
],
|
|
433
|
+
}),
|
|
434
|
+
);
|
|
435
|
+
|
|
436
|
+
expect(result.images).toStrictEqual(['edited-image-base64']);
|
|
437
|
+
});
|
|
438
|
+
|
|
439
|
+
it('should send edit request with multiple images', async () => {
|
|
440
|
+
const model = createBasicModel();
|
|
441
|
+
|
|
442
|
+
const image1 = new Uint8Array([137, 80, 78, 71]);
|
|
443
|
+
const image2 = new Uint8Array([137, 80, 78, 71]);
|
|
444
|
+
|
|
445
|
+
const result = await model.doGenerate(
|
|
446
|
+
createDefaultGenerateParams({
|
|
447
|
+
prompt: 'Combine these images',
|
|
448
|
+
files: [
|
|
449
|
+
{
|
|
450
|
+
type: 'file',
|
|
451
|
+
data: image1,
|
|
452
|
+
mediaType: 'image/png',
|
|
453
|
+
},
|
|
454
|
+
{
|
|
455
|
+
type: 'file',
|
|
456
|
+
data: image2,
|
|
457
|
+
mediaType: 'image/png',
|
|
458
|
+
},
|
|
459
|
+
],
|
|
460
|
+
}),
|
|
461
|
+
);
|
|
462
|
+
|
|
463
|
+
expect(result.images).toStrictEqual(['edited-image-base64']);
|
|
464
|
+
});
|
|
465
|
+
|
|
466
|
+
it('should include response metadata for edit requests', async () => {
|
|
467
|
+
const testDate = new Date('2024-01-01T00:00:00Z');
|
|
468
|
+
const model = createBasicModel({
|
|
469
|
+
currentDate: () => testDate,
|
|
470
|
+
});
|
|
471
|
+
|
|
472
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
473
|
+
|
|
474
|
+
const result = await model.doGenerate(
|
|
475
|
+
createDefaultGenerateParams({
|
|
476
|
+
prompt: 'Edit this image',
|
|
477
|
+
files: [
|
|
478
|
+
{
|
|
479
|
+
type: 'file',
|
|
480
|
+
data: imageData,
|
|
481
|
+
mediaType: 'image/png',
|
|
482
|
+
},
|
|
483
|
+
],
|
|
484
|
+
}),
|
|
485
|
+
);
|
|
486
|
+
|
|
487
|
+
expect(result.response).toStrictEqual({
|
|
488
|
+
timestamp: testDate,
|
|
489
|
+
modelId: 'dall-e-3',
|
|
490
|
+
headers: expect.any(Object),
|
|
491
|
+
});
|
|
492
|
+
});
|
|
493
|
+
});
|
|
494
|
+
});
|