@ai-sdk/deepinfra 2.0.16 → 2.0.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +8 -0
- package/dist/index.js +1 -1
- package/dist/index.mjs +1 -1
- package/package.json +4 -3
- package/src/deepinfra-chat-options.ts +69 -0
- package/src/deepinfra-completion-options.ts +4 -0
- package/src/deepinfra-embedding-options.ts +19 -0
- package/src/deepinfra-image-model.test.ts +431 -0
- package/src/deepinfra-image-model.ts +194 -0
- package/src/deepinfra-image-settings.ts +12 -0
- package/src/deepinfra-provider.test.ts +184 -0
- package/src/deepinfra-provider.ts +161 -0
- package/src/index.ts +7 -0
- package/src/version.ts +6 -0
package/CHANGELOG.md
CHANGED
package/dist/index.js
CHANGED
package/dist/index.mjs
CHANGED
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@ai-sdk/deepinfra",
|
|
3
|
-
"version": "2.0.
|
|
3
|
+
"version": "2.0.17",
|
|
4
4
|
"license": "Apache-2.0",
|
|
5
5
|
"sideEffects": false,
|
|
6
6
|
"main": "./dist/index.js",
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
"types": "./dist/index.d.ts",
|
|
9
9
|
"files": [
|
|
10
10
|
"dist/**/*",
|
|
11
|
+
"src",
|
|
11
12
|
"CHANGELOG.md",
|
|
12
13
|
"README.md"
|
|
13
14
|
],
|
|
@@ -20,7 +21,7 @@
|
|
|
20
21
|
}
|
|
21
22
|
},
|
|
22
23
|
"dependencies": {
|
|
23
|
-
"@ai-sdk/openai-compatible": "2.0.
|
|
24
|
+
"@ai-sdk/openai-compatible": "2.0.17",
|
|
24
25
|
"@ai-sdk/provider": "3.0.4",
|
|
25
26
|
"@ai-sdk/provider-utils": "4.0.8"
|
|
26
27
|
},
|
|
@@ -29,7 +30,7 @@
|
|
|
29
30
|
"tsup": "^8",
|
|
30
31
|
"typescript": "5.8.3",
|
|
31
32
|
"zod": "3.25.76",
|
|
32
|
-
"@ai-sdk/test-server": "1.0.
|
|
33
|
+
"@ai-sdk/test-server": "1.0.2",
|
|
33
34
|
"@vercel/ai-tsconfig": "0.0.0"
|
|
34
35
|
},
|
|
35
36
|
"peerDependencies": {
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
// https://deepinfra.com/models/text-generation
|
|
2
|
+
export type DeepInfraChatModelId =
|
|
3
|
+
| '01-ai/Yi-34B-Chat'
|
|
4
|
+
| 'Austism/chronos-hermes-13b-v2'
|
|
5
|
+
| 'bigcode/starcoder2-15b-instruct-v0.1'
|
|
6
|
+
| 'bigcode/starcoder2-15b'
|
|
7
|
+
| 'codellama/CodeLlama-34b-Instruct-hf'
|
|
8
|
+
| 'codellama/CodeLlama-70b-Instruct-hf'
|
|
9
|
+
| 'cognitivecomputations/dolphin-2.6-mixtral-8x7b'
|
|
10
|
+
| 'cognitivecomputations/dolphin-2.9.1-llama-3-70b'
|
|
11
|
+
| 'databricks/dbrx-instruct'
|
|
12
|
+
| 'deepinfra/airoboros-70b'
|
|
13
|
+
| 'deepseek-ai/DeepSeek-V3'
|
|
14
|
+
| 'google/codegemma-7b-it'
|
|
15
|
+
| 'google/gemma-1.1-7b-it'
|
|
16
|
+
| 'google/gemma-2-27b-it'
|
|
17
|
+
| 'google/gemma-2-9b-it'
|
|
18
|
+
| 'Gryphe/MythoMax-L2-13b-turbo'
|
|
19
|
+
| 'Gryphe/MythoMax-L2-13b'
|
|
20
|
+
| 'HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1'
|
|
21
|
+
| 'KoboldAI/LLaMA2-13B-Tiefighter'
|
|
22
|
+
| 'lizpreciatior/lzlv_70b_fp16_hf'
|
|
23
|
+
| 'mattshumer/Reflection-Llama-3.1-70B'
|
|
24
|
+
| 'meta-llama/Llama-2-13b-chat-hf'
|
|
25
|
+
| 'meta-llama/Llama-2-70b-chat-hf'
|
|
26
|
+
| 'meta-llama/Llama-2-7b-chat-hf'
|
|
27
|
+
| 'meta-llama/Llama-3.2-11B-Vision-Instruct'
|
|
28
|
+
| 'meta-llama/Llama-3.2-1B-Instruct'
|
|
29
|
+
| 'meta-llama/Llama-3.2-3B-Instruct'
|
|
30
|
+
| 'meta-llama/Llama-3.2-90B-Vision-Instruct'
|
|
31
|
+
| 'meta-llama/Llama-3.3-70B-Instruct-Turbo'
|
|
32
|
+
| 'meta-llama/Llama-3.3-70B-Instruct'
|
|
33
|
+
| 'meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8'
|
|
34
|
+
| 'meta-llama/Llama-4-Scout-17B-16E-Instruct'
|
|
35
|
+
| 'meta-llama/Meta-Llama-3-70B-Instruct'
|
|
36
|
+
| 'meta-llama/Meta-Llama-3-8B-Instruct'
|
|
37
|
+
| 'meta-llama/Meta-Llama-3.1-405B-Instruct'
|
|
38
|
+
| 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'
|
|
39
|
+
| 'meta-llama/Meta-Llama-3.1-70B-Instruct'
|
|
40
|
+
| 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
|
|
41
|
+
| 'meta-llama/Meta-Llama-3.1-8B-Instruct'
|
|
42
|
+
| 'microsoft/Phi-3-medium-4k-instruct'
|
|
43
|
+
| 'microsoft/WizardLM-2-7B'
|
|
44
|
+
| 'microsoft/WizardLM-2-8x22B'
|
|
45
|
+
| 'mistralai/Mistral-7B-Instruct-v0.1'
|
|
46
|
+
| 'mistralai/Mistral-7B-Instruct-v0.2'
|
|
47
|
+
| 'mistralai/Mistral-7B-Instruct-v0.3'
|
|
48
|
+
| 'mistralai/Mistral-Nemo-Instruct-2407'
|
|
49
|
+
| 'mistralai/Mixtral-8x22B-Instruct-v0.1'
|
|
50
|
+
| 'mistralai/Mixtral-8x22B-v0.1'
|
|
51
|
+
| 'mistralai/Mixtral-8x7B-Instruct-v0.1'
|
|
52
|
+
| 'NousResearch/Hermes-3-Llama-3.1-405B'
|
|
53
|
+
| 'nvidia/Llama-3.1-Nemotron-70B-Instruct'
|
|
54
|
+
| 'nvidia/Nemotron-4-340B-Instruct'
|
|
55
|
+
| 'openbmb/MiniCPM-Llama3-V-2_5'
|
|
56
|
+
| 'openchat/openchat_3.5'
|
|
57
|
+
| 'openchat/openchat-3.6-8b'
|
|
58
|
+
| 'Phind/Phind-CodeLlama-34B-v2'
|
|
59
|
+
| 'Qwen/Qwen2-72B-Instruct'
|
|
60
|
+
| 'Qwen/Qwen2-7B-Instruct'
|
|
61
|
+
| 'Qwen/Qwen2.5-72B-Instruct'
|
|
62
|
+
| 'Qwen/Qwen2.5-7B-Instruct'
|
|
63
|
+
| 'Qwen/Qwen2.5-Coder-32B-Instruct'
|
|
64
|
+
| 'Qwen/Qwen2.5-Coder-7B'
|
|
65
|
+
| 'Qwen/QwQ-32B-Preview'
|
|
66
|
+
| 'Sao10K/L3-70B-Euryale-v2.1'
|
|
67
|
+
| 'Sao10K/L3-8B-Lunaris-v1'
|
|
68
|
+
| 'Sao10K/L3.1-70B-Euryale-v2.2'
|
|
69
|
+
| (string & {});
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
// https://deepinfra.com/models/embeddings
|
|
2
|
+
export type DeepInfraEmbeddingModelId =
|
|
3
|
+
| 'BAAI/bge-base-en-v1.5'
|
|
4
|
+
| 'BAAI/bge-large-en-v1.5'
|
|
5
|
+
| 'BAAI/bge-m3'
|
|
6
|
+
| 'intfloat/e5-base-v2'
|
|
7
|
+
| 'intfloat/e5-large-v2'
|
|
8
|
+
| 'intfloat/multilingual-e5-large'
|
|
9
|
+
| 'sentence-transformers/all-MiniLM-L12-v2'
|
|
10
|
+
| 'sentence-transformers/all-MiniLM-L6-v2'
|
|
11
|
+
| 'sentence-transformers/all-mpnet-base-v2'
|
|
12
|
+
| 'sentence-transformers/clip-ViT-B-32'
|
|
13
|
+
| 'sentence-transformers/clip-ViT-B-32-multilingual-v1'
|
|
14
|
+
| 'sentence-transformers/multi-qa-mpnet-base-dot-v1'
|
|
15
|
+
| 'sentence-transformers/paraphrase-MiniLM-L6-v2'
|
|
16
|
+
| 'shibing624/text2vec-base-chinese'
|
|
17
|
+
| 'thenlper/gte-base'
|
|
18
|
+
| 'thenlper/gte-large'
|
|
19
|
+
| (string & {});
|
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
2
|
+
import { describe, expect, it } from 'vitest';
|
|
3
|
+
import { DeepInfraImageModel } from './deepinfra-image-model';
|
|
4
|
+
import { FetchFunction } from '@ai-sdk/provider-utils';
|
|
5
|
+
|
|
6
|
+
const prompt = 'A cute baby sea otter';
|
|
7
|
+
|
|
8
|
+
function createBasicModel({
|
|
9
|
+
headers,
|
|
10
|
+
fetch,
|
|
11
|
+
currentDate,
|
|
12
|
+
}: {
|
|
13
|
+
headers?: () => Record<string, string>;
|
|
14
|
+
fetch?: FetchFunction;
|
|
15
|
+
currentDate?: () => Date;
|
|
16
|
+
} = {}) {
|
|
17
|
+
return new DeepInfraImageModel('stability-ai/sdxl', {
|
|
18
|
+
provider: 'deepinfra',
|
|
19
|
+
baseURL: 'https://api.example.com',
|
|
20
|
+
headers: headers ?? (() => ({ 'api-key': 'test-key' })),
|
|
21
|
+
fetch,
|
|
22
|
+
_internal: {
|
|
23
|
+
currentDate,
|
|
24
|
+
},
|
|
25
|
+
});
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
describe('DeepInfraImageModel', () => {
|
|
29
|
+
const testDate = new Date('2024-01-01T00:00:00Z');
|
|
30
|
+
const server = createTestServer({
|
|
31
|
+
'https://api.example.com/*': {
|
|
32
|
+
response: {
|
|
33
|
+
type: 'json-value',
|
|
34
|
+
body: {
|
|
35
|
+
images: ['data:image/png;base64,test-image-data'],
|
|
36
|
+
},
|
|
37
|
+
},
|
|
38
|
+
},
|
|
39
|
+
});
|
|
40
|
+
|
|
41
|
+
describe('doGenerate', () => {
|
|
42
|
+
it('should pass the correct parameters including aspect ratio and seed', async () => {
|
|
43
|
+
const model = createBasicModel();
|
|
44
|
+
|
|
45
|
+
await model.doGenerate({
|
|
46
|
+
prompt,
|
|
47
|
+
files: undefined,
|
|
48
|
+
mask: undefined,
|
|
49
|
+
n: 1,
|
|
50
|
+
size: undefined,
|
|
51
|
+
aspectRatio: '16:9',
|
|
52
|
+
seed: 42,
|
|
53
|
+
providerOptions: { deepinfra: { additional_param: 'value' } },
|
|
54
|
+
});
|
|
55
|
+
|
|
56
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
57
|
+
prompt,
|
|
58
|
+
aspect_ratio: '16:9',
|
|
59
|
+
seed: 42,
|
|
60
|
+
num_images: 1,
|
|
61
|
+
additional_param: 'value',
|
|
62
|
+
});
|
|
63
|
+
});
|
|
64
|
+
|
|
65
|
+
it('should call the correct url', async () => {
|
|
66
|
+
const model = createBasicModel();
|
|
67
|
+
|
|
68
|
+
await model.doGenerate({
|
|
69
|
+
prompt,
|
|
70
|
+
files: undefined,
|
|
71
|
+
mask: undefined,
|
|
72
|
+
n: 1,
|
|
73
|
+
size: undefined,
|
|
74
|
+
aspectRatio: undefined,
|
|
75
|
+
seed: undefined,
|
|
76
|
+
providerOptions: {},
|
|
77
|
+
});
|
|
78
|
+
|
|
79
|
+
expect(server.calls[0].requestMethod).toStrictEqual('POST');
|
|
80
|
+
expect(server.calls[0].requestUrl).toStrictEqual(
|
|
81
|
+
'https://api.example.com/stability-ai/sdxl',
|
|
82
|
+
);
|
|
83
|
+
});
|
|
84
|
+
|
|
85
|
+
it('should pass headers', async () => {
|
|
86
|
+
const modelWithHeaders = createBasicModel({
|
|
87
|
+
headers: () => ({
|
|
88
|
+
'Custom-Provider-Header': 'provider-header-value',
|
|
89
|
+
}),
|
|
90
|
+
});
|
|
91
|
+
|
|
92
|
+
await modelWithHeaders.doGenerate({
|
|
93
|
+
prompt,
|
|
94
|
+
files: undefined,
|
|
95
|
+
mask: undefined,
|
|
96
|
+
n: 1,
|
|
97
|
+
size: undefined,
|
|
98
|
+
aspectRatio: undefined,
|
|
99
|
+
seed: undefined,
|
|
100
|
+
providerOptions: {},
|
|
101
|
+
headers: {
|
|
102
|
+
'Custom-Request-Header': 'request-header-value',
|
|
103
|
+
},
|
|
104
|
+
});
|
|
105
|
+
|
|
106
|
+
expect(server.calls[0].requestHeaders).toStrictEqual({
|
|
107
|
+
'content-type': 'application/json',
|
|
108
|
+
'custom-provider-header': 'provider-header-value',
|
|
109
|
+
'custom-request-header': 'request-header-value',
|
|
110
|
+
});
|
|
111
|
+
});
|
|
112
|
+
|
|
113
|
+
it('should handle API errors', async () => {
|
|
114
|
+
server.urls['https://api.example.com/*'].response = {
|
|
115
|
+
type: 'error',
|
|
116
|
+
status: 400,
|
|
117
|
+
body: JSON.stringify({
|
|
118
|
+
error: {
|
|
119
|
+
message: 'Bad Request',
|
|
120
|
+
},
|
|
121
|
+
}),
|
|
122
|
+
};
|
|
123
|
+
|
|
124
|
+
const model = createBasicModel();
|
|
125
|
+
await expect(
|
|
126
|
+
model.doGenerate({
|
|
127
|
+
prompt,
|
|
128
|
+
files: undefined,
|
|
129
|
+
mask: undefined,
|
|
130
|
+
n: 1,
|
|
131
|
+
size: undefined,
|
|
132
|
+
aspectRatio: undefined,
|
|
133
|
+
seed: undefined,
|
|
134
|
+
providerOptions: {},
|
|
135
|
+
}),
|
|
136
|
+
).rejects.toThrow('Bad Request');
|
|
137
|
+
});
|
|
138
|
+
|
|
139
|
+
it('should handle size parameter', async () => {
|
|
140
|
+
const model = createBasicModel();
|
|
141
|
+
|
|
142
|
+
await model.doGenerate({
|
|
143
|
+
prompt,
|
|
144
|
+
files: undefined,
|
|
145
|
+
mask: undefined,
|
|
146
|
+
n: 1,
|
|
147
|
+
size: '1024x768',
|
|
148
|
+
aspectRatio: undefined,
|
|
149
|
+
seed: 42,
|
|
150
|
+
providerOptions: {},
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
154
|
+
prompt,
|
|
155
|
+
width: '1024',
|
|
156
|
+
height: '768',
|
|
157
|
+
seed: 42,
|
|
158
|
+
num_images: 1,
|
|
159
|
+
});
|
|
160
|
+
});
|
|
161
|
+
|
|
162
|
+
it('should respect the abort signal', async () => {
|
|
163
|
+
const model = createBasicModel();
|
|
164
|
+
const controller = new AbortController();
|
|
165
|
+
|
|
166
|
+
const generatePromise = model.doGenerate({
|
|
167
|
+
prompt,
|
|
168
|
+
files: undefined,
|
|
169
|
+
mask: undefined,
|
|
170
|
+
n: 1,
|
|
171
|
+
size: undefined,
|
|
172
|
+
aspectRatio: undefined,
|
|
173
|
+
seed: undefined,
|
|
174
|
+
providerOptions: {},
|
|
175
|
+
abortSignal: controller.signal,
|
|
176
|
+
});
|
|
177
|
+
|
|
178
|
+
controller.abort();
|
|
179
|
+
|
|
180
|
+
await expect(generatePromise).rejects.toThrow(
|
|
181
|
+
'This operation was aborted',
|
|
182
|
+
);
|
|
183
|
+
});
|
|
184
|
+
|
|
185
|
+
describe('response metadata', () => {
|
|
186
|
+
it('should include timestamp, headers and modelId in response', async () => {
|
|
187
|
+
const model = createBasicModel({
|
|
188
|
+
currentDate: () => testDate,
|
|
189
|
+
});
|
|
190
|
+
|
|
191
|
+
const result = await model.doGenerate({
|
|
192
|
+
prompt,
|
|
193
|
+
files: undefined,
|
|
194
|
+
mask: undefined,
|
|
195
|
+
n: 1,
|
|
196
|
+
size: undefined,
|
|
197
|
+
aspectRatio: undefined,
|
|
198
|
+
seed: undefined,
|
|
199
|
+
providerOptions: {},
|
|
200
|
+
});
|
|
201
|
+
|
|
202
|
+
expect(result.response).toStrictEqual({
|
|
203
|
+
timestamp: testDate,
|
|
204
|
+
modelId: 'stability-ai/sdxl',
|
|
205
|
+
headers: expect.any(Object),
|
|
206
|
+
});
|
|
207
|
+
});
|
|
208
|
+
|
|
209
|
+
it('should include response headers from API call', async () => {
|
|
210
|
+
server.urls['https://api.example.com/*'].response = {
|
|
211
|
+
type: 'json-value',
|
|
212
|
+
headers: {
|
|
213
|
+
'x-request-id': 'test-request-id',
|
|
214
|
+
},
|
|
215
|
+
body: {
|
|
216
|
+
images: ['data:image/png;base64,test-image-data'],
|
|
217
|
+
},
|
|
218
|
+
};
|
|
219
|
+
|
|
220
|
+
const model = createBasicModel();
|
|
221
|
+
const result = await model.doGenerate({
|
|
222
|
+
prompt,
|
|
223
|
+
files: undefined,
|
|
224
|
+
mask: undefined,
|
|
225
|
+
n: 1,
|
|
226
|
+
size: undefined,
|
|
227
|
+
aspectRatio: undefined,
|
|
228
|
+
seed: undefined,
|
|
229
|
+
providerOptions: {},
|
|
230
|
+
});
|
|
231
|
+
|
|
232
|
+
expect(result.response.headers).toStrictEqual({
|
|
233
|
+
'content-length': '52',
|
|
234
|
+
'x-request-id': 'test-request-id',
|
|
235
|
+
'content-type': 'application/json',
|
|
236
|
+
});
|
|
237
|
+
});
|
|
238
|
+
});
|
|
239
|
+
});
|
|
240
|
+
|
|
241
|
+
describe('constructor', () => {
|
|
242
|
+
it('should expose correct provider and model information', () => {
|
|
243
|
+
const model = createBasicModel();
|
|
244
|
+
|
|
245
|
+
expect(model.provider).toBe('deepinfra');
|
|
246
|
+
expect(model.modelId).toBe('stability-ai/sdxl');
|
|
247
|
+
expect(model.specificationVersion).toBe('v3');
|
|
248
|
+
expect(model.maxImagesPerCall).toBe(1);
|
|
249
|
+
});
|
|
250
|
+
});
|
|
251
|
+
|
|
252
|
+
describe('Image Editing', () => {
|
|
253
|
+
const editServer = createTestServer({
|
|
254
|
+
'https://edit.example.com/openai/images/edits': {
|
|
255
|
+
response: {
|
|
256
|
+
type: 'json-value',
|
|
257
|
+
body: {
|
|
258
|
+
created: 1234567890,
|
|
259
|
+
data: [{ b64_json: 'edited-image-base64' }],
|
|
260
|
+
},
|
|
261
|
+
},
|
|
262
|
+
},
|
|
263
|
+
});
|
|
264
|
+
|
|
265
|
+
// Model with baseURL that will resolve to edit endpoint
|
|
266
|
+
const editModel = new DeepInfraImageModel(
|
|
267
|
+
'black-forest-labs/FLUX.1-Kontext-dev',
|
|
268
|
+
{
|
|
269
|
+
provider: 'deepinfra',
|
|
270
|
+
baseURL: 'https://edit.example.com/inference',
|
|
271
|
+
headers: () => ({ 'api-key': 'test-key' }),
|
|
272
|
+
},
|
|
273
|
+
);
|
|
274
|
+
|
|
275
|
+
it('should send edit request with files', async () => {
|
|
276
|
+
const imageData = new Uint8Array([137, 80, 78, 71]); // PNG magic bytes
|
|
277
|
+
|
|
278
|
+
const result = await editModel.doGenerate({
|
|
279
|
+
prompt: 'Turn the cat into a dog',
|
|
280
|
+
files: [
|
|
281
|
+
{
|
|
282
|
+
type: 'file',
|
|
283
|
+
data: imageData,
|
|
284
|
+
mediaType: 'image/png',
|
|
285
|
+
},
|
|
286
|
+
],
|
|
287
|
+
mask: undefined,
|
|
288
|
+
n: 1,
|
|
289
|
+
size: '1024x1024',
|
|
290
|
+
aspectRatio: undefined,
|
|
291
|
+
seed: undefined,
|
|
292
|
+
providerOptions: {},
|
|
293
|
+
});
|
|
294
|
+
|
|
295
|
+
expect(result.images).toStrictEqual(['edited-image-base64']);
|
|
296
|
+
expect(editServer.calls[0].requestUrl).toBe(
|
|
297
|
+
'https://edit.example.com/openai/images/edits',
|
|
298
|
+
);
|
|
299
|
+
});
|
|
300
|
+
|
|
301
|
+
it('should send edit request with files and mask', async () => {
|
|
302
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
303
|
+
const maskData = new Uint8Array([255, 255, 255, 0]);
|
|
304
|
+
|
|
305
|
+
const result = await editModel.doGenerate({
|
|
306
|
+
prompt: 'Add a flamingo to the pool',
|
|
307
|
+
files: [
|
|
308
|
+
{
|
|
309
|
+
type: 'file',
|
|
310
|
+
data: imageData,
|
|
311
|
+
mediaType: 'image/png',
|
|
312
|
+
},
|
|
313
|
+
],
|
|
314
|
+
mask: {
|
|
315
|
+
type: 'file',
|
|
316
|
+
data: maskData,
|
|
317
|
+
mediaType: 'image/png',
|
|
318
|
+
},
|
|
319
|
+
n: 1,
|
|
320
|
+
size: undefined,
|
|
321
|
+
aspectRatio: undefined,
|
|
322
|
+
seed: undefined,
|
|
323
|
+
providerOptions: {},
|
|
324
|
+
});
|
|
325
|
+
|
|
326
|
+
expect(result.images).toStrictEqual(['edited-image-base64']);
|
|
327
|
+
expect(editServer.calls[0].requestUrl).toBe(
|
|
328
|
+
'https://edit.example.com/openai/images/edits',
|
|
329
|
+
);
|
|
330
|
+
});
|
|
331
|
+
|
|
332
|
+
it('should send edit request with multiple images', async () => {
|
|
333
|
+
const image1 = new Uint8Array([137, 80, 78, 71]);
|
|
334
|
+
const image2 = new Uint8Array([137, 80, 78, 71]);
|
|
335
|
+
|
|
336
|
+
const result = await editModel.doGenerate({
|
|
337
|
+
prompt: 'Combine these images',
|
|
338
|
+
files: [
|
|
339
|
+
{
|
|
340
|
+
type: 'file',
|
|
341
|
+
data: image1,
|
|
342
|
+
mediaType: 'image/png',
|
|
343
|
+
},
|
|
344
|
+
{
|
|
345
|
+
type: 'file',
|
|
346
|
+
data: image2,
|
|
347
|
+
mediaType: 'image/png',
|
|
348
|
+
},
|
|
349
|
+
],
|
|
350
|
+
mask: undefined,
|
|
351
|
+
n: 1,
|
|
352
|
+
size: undefined,
|
|
353
|
+
aspectRatio: undefined,
|
|
354
|
+
seed: undefined,
|
|
355
|
+
providerOptions: {},
|
|
356
|
+
});
|
|
357
|
+
|
|
358
|
+
expect(result.images).toStrictEqual(['edited-image-base64']);
|
|
359
|
+
});
|
|
360
|
+
|
|
361
|
+
it('should include response metadata for edit requests', async () => {
|
|
362
|
+
const testDate = new Date('2024-01-01T00:00:00Z');
|
|
363
|
+
const modelWithDate = new DeepInfraImageModel(
|
|
364
|
+
'black-forest-labs/FLUX.1-Kontext-dev',
|
|
365
|
+
{
|
|
366
|
+
provider: 'deepinfra',
|
|
367
|
+
baseURL: 'https://edit.example.com/inference',
|
|
368
|
+
headers: () => ({ 'api-key': 'test-key' }),
|
|
369
|
+
_internal: {
|
|
370
|
+
currentDate: () => testDate,
|
|
371
|
+
},
|
|
372
|
+
},
|
|
373
|
+
);
|
|
374
|
+
|
|
375
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
376
|
+
|
|
377
|
+
const result = await modelWithDate.doGenerate({
|
|
378
|
+
prompt: 'Edit this image',
|
|
379
|
+
files: [
|
|
380
|
+
{
|
|
381
|
+
type: 'file',
|
|
382
|
+
data: imageData,
|
|
383
|
+
mediaType: 'image/png',
|
|
384
|
+
},
|
|
385
|
+
],
|
|
386
|
+
mask: undefined,
|
|
387
|
+
n: 1,
|
|
388
|
+
size: undefined,
|
|
389
|
+
aspectRatio: undefined,
|
|
390
|
+
seed: undefined,
|
|
391
|
+
providerOptions: {},
|
|
392
|
+
});
|
|
393
|
+
|
|
394
|
+
expect(result.response).toStrictEqual({
|
|
395
|
+
timestamp: testDate,
|
|
396
|
+
modelId: 'black-forest-labs/FLUX.1-Kontext-dev',
|
|
397
|
+
headers: expect.any(Object),
|
|
398
|
+
});
|
|
399
|
+
});
|
|
400
|
+
|
|
401
|
+
it('should pass provider options in edit request', async () => {
|
|
402
|
+
const imageData = new Uint8Array([137, 80, 78, 71]);
|
|
403
|
+
|
|
404
|
+
await editModel.doGenerate({
|
|
405
|
+
prompt: 'Edit with custom options',
|
|
406
|
+
files: [
|
|
407
|
+
{
|
|
408
|
+
type: 'file',
|
|
409
|
+
data: imageData,
|
|
410
|
+
mediaType: 'image/png',
|
|
411
|
+
},
|
|
412
|
+
],
|
|
413
|
+
mask: undefined,
|
|
414
|
+
n: 1,
|
|
415
|
+
size: undefined,
|
|
416
|
+
aspectRatio: undefined,
|
|
417
|
+
seed: undefined,
|
|
418
|
+
providerOptions: {
|
|
419
|
+
deepinfra: {
|
|
420
|
+
guidance: 7.5,
|
|
421
|
+
},
|
|
422
|
+
},
|
|
423
|
+
});
|
|
424
|
+
|
|
425
|
+
// The request should have been made to the edit endpoint
|
|
426
|
+
expect(editServer.calls[0].requestUrl).toBe(
|
|
427
|
+
'https://edit.example.com/openai/images/edits',
|
|
428
|
+
);
|
|
429
|
+
});
|
|
430
|
+
});
|
|
431
|
+
});
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import {
|
|
2
|
+
ImageModelV3,
|
|
3
|
+
ImageModelV3File,
|
|
4
|
+
SharedV3Warning,
|
|
5
|
+
} from '@ai-sdk/provider';
|
|
6
|
+
import {
|
|
7
|
+
combineHeaders,
|
|
8
|
+
convertBase64ToUint8Array,
|
|
9
|
+
convertToFormData,
|
|
10
|
+
createJsonErrorResponseHandler,
|
|
11
|
+
createJsonResponseHandler,
|
|
12
|
+
downloadBlob,
|
|
13
|
+
FetchFunction,
|
|
14
|
+
postFormDataToApi,
|
|
15
|
+
postJsonToApi,
|
|
16
|
+
} from '@ai-sdk/provider-utils';
|
|
17
|
+
import { DeepInfraImageModelId } from './deepinfra-image-settings';
|
|
18
|
+
import { z } from 'zod/v4';
|
|
19
|
+
|
|
20
|
+
interface DeepInfraImageModelConfig {
|
|
21
|
+
provider: string;
|
|
22
|
+
baseURL: string;
|
|
23
|
+
headers: () => Record<string, string>;
|
|
24
|
+
fetch?: FetchFunction;
|
|
25
|
+
_internal?: {
|
|
26
|
+
currentDate?: () => Date;
|
|
27
|
+
};
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
export class DeepInfraImageModel implements ImageModelV3 {
|
|
31
|
+
readonly specificationVersion = 'v3';
|
|
32
|
+
readonly maxImagesPerCall = 1;
|
|
33
|
+
|
|
34
|
+
get provider(): string {
|
|
35
|
+
return this.config.provider;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
constructor(
|
|
39
|
+
readonly modelId: DeepInfraImageModelId,
|
|
40
|
+
private config: DeepInfraImageModelConfig,
|
|
41
|
+
) {}
|
|
42
|
+
|
|
43
|
+
async doGenerate({
|
|
44
|
+
prompt,
|
|
45
|
+
n,
|
|
46
|
+
size,
|
|
47
|
+
aspectRatio,
|
|
48
|
+
seed,
|
|
49
|
+
providerOptions,
|
|
50
|
+
headers,
|
|
51
|
+
abortSignal,
|
|
52
|
+
files,
|
|
53
|
+
mask,
|
|
54
|
+
}: Parameters<ImageModelV3['doGenerate']>[0]): Promise<
|
|
55
|
+
Awaited<ReturnType<ImageModelV3['doGenerate']>>
|
|
56
|
+
> {
|
|
57
|
+
const warnings: Array<SharedV3Warning> = [];
|
|
58
|
+
const currentDate = this.config._internal?.currentDate?.() ?? new Date();
|
|
59
|
+
|
|
60
|
+
// Image editing mode - use OpenAI-compatible /images/edits endpoint
|
|
61
|
+
if (files != null && files.length > 0) {
|
|
62
|
+
const { value: response, responseHeaders } = await postFormDataToApi({
|
|
63
|
+
url: this.getEditUrl(),
|
|
64
|
+
headers: combineHeaders(this.config.headers(), headers),
|
|
65
|
+
formData: convertToFormData<DeepInfraFormDataInput>(
|
|
66
|
+
{
|
|
67
|
+
model: this.modelId,
|
|
68
|
+
prompt,
|
|
69
|
+
image: await Promise.all(files.map(file => fileToBlob(file))),
|
|
70
|
+
mask: mask != null ? await fileToBlob(mask) : undefined,
|
|
71
|
+
n,
|
|
72
|
+
size,
|
|
73
|
+
...(providerOptions.deepinfra ?? {}),
|
|
74
|
+
},
|
|
75
|
+
{ useArrayBrackets: false },
|
|
76
|
+
),
|
|
77
|
+
failedResponseHandler: createJsonErrorResponseHandler({
|
|
78
|
+
errorSchema: deepInfraEditErrorSchema,
|
|
79
|
+
errorToMessage: error => error.error?.message ?? 'Unknown error',
|
|
80
|
+
}),
|
|
81
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
82
|
+
deepInfraEditResponseSchema,
|
|
83
|
+
),
|
|
84
|
+
abortSignal,
|
|
85
|
+
fetch: this.config.fetch,
|
|
86
|
+
});
|
|
87
|
+
|
|
88
|
+
return {
|
|
89
|
+
images: response.data.map(item => item.b64_json),
|
|
90
|
+
warnings,
|
|
91
|
+
response: {
|
|
92
|
+
timestamp: currentDate,
|
|
93
|
+
modelId: this.modelId,
|
|
94
|
+
headers: responseHeaders,
|
|
95
|
+
},
|
|
96
|
+
};
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// Standard image generation mode
|
|
100
|
+
// Some deepinfra models support size while others support aspect ratio.
|
|
101
|
+
// Allow passing either and leave it up to the server to validate.
|
|
102
|
+
const splitSize = size?.split('x');
|
|
103
|
+
const { value: response, responseHeaders } = await postJsonToApi({
|
|
104
|
+
url: `${this.config.baseURL}/${this.modelId}`,
|
|
105
|
+
headers: combineHeaders(this.config.headers(), headers),
|
|
106
|
+
body: {
|
|
107
|
+
prompt,
|
|
108
|
+
num_images: n,
|
|
109
|
+
...(aspectRatio && { aspect_ratio: aspectRatio }),
|
|
110
|
+
...(splitSize && { width: splitSize[0], height: splitSize[1] }),
|
|
111
|
+
...(seed != null && { seed }),
|
|
112
|
+
...(providerOptions.deepinfra ?? {}),
|
|
113
|
+
},
|
|
114
|
+
failedResponseHandler: createJsonErrorResponseHandler({
|
|
115
|
+
errorSchema: deepInfraErrorSchema,
|
|
116
|
+
errorToMessage: error => error.detail.error,
|
|
117
|
+
}),
|
|
118
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
119
|
+
deepInfraImageResponseSchema,
|
|
120
|
+
),
|
|
121
|
+
abortSignal,
|
|
122
|
+
fetch: this.config.fetch,
|
|
123
|
+
});
|
|
124
|
+
|
|
125
|
+
return {
|
|
126
|
+
images: response.images.map(image =>
|
|
127
|
+
image.replace(/^data:image\/\w+;base64,/, ''),
|
|
128
|
+
),
|
|
129
|
+
warnings,
|
|
130
|
+
response: {
|
|
131
|
+
timestamp: currentDate,
|
|
132
|
+
modelId: this.modelId,
|
|
133
|
+
headers: responseHeaders,
|
|
134
|
+
},
|
|
135
|
+
};
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
private getEditUrl(): string {
|
|
139
|
+
// Use OpenAI-compatible endpoint for image editing
|
|
140
|
+
// baseURL is typically https://api.deepinfra.com/v1/inference
|
|
141
|
+
// We need to use https://api.deepinfra.com/v1/openai/images/edits
|
|
142
|
+
const baseUrl = this.config.baseURL.replace('/inference', '/openai');
|
|
143
|
+
return `${baseUrl}/images/edits`;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
export const deepInfraErrorSchema = z.object({
|
|
148
|
+
detail: z.object({
|
|
149
|
+
error: z.string(),
|
|
150
|
+
}),
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
// limited version of the schema, focussed on what is needed for the implementation
|
|
154
|
+
// this approach limits breakages when the API changes and increases efficiency
|
|
155
|
+
export const deepInfraImageResponseSchema = z.object({
|
|
156
|
+
images: z.array(z.string()),
|
|
157
|
+
});
|
|
158
|
+
|
|
159
|
+
// Schema for OpenAI-compatible image edit endpoint errors
|
|
160
|
+
export const deepInfraEditErrorSchema = z.object({
|
|
161
|
+
error: z
|
|
162
|
+
.object({
|
|
163
|
+
message: z.string(),
|
|
164
|
+
})
|
|
165
|
+
.optional(),
|
|
166
|
+
});
|
|
167
|
+
|
|
168
|
+
// Schema for OpenAI-compatible image edit endpoint response
|
|
169
|
+
export const deepInfraEditResponseSchema = z.object({
|
|
170
|
+
data: z.array(z.object({ b64_json: z.string() })),
|
|
171
|
+
});
|
|
172
|
+
|
|
173
|
+
type DeepInfraFormDataInput = {
|
|
174
|
+
model: string;
|
|
175
|
+
prompt: string | undefined;
|
|
176
|
+
image: Blob | Blob[];
|
|
177
|
+
mask?: Blob;
|
|
178
|
+
n: number;
|
|
179
|
+
size: `${number}x${number}` | undefined;
|
|
180
|
+
[key: string]: unknown;
|
|
181
|
+
};
|
|
182
|
+
|
|
183
|
+
async function fileToBlob(file: ImageModelV3File): Promise<Blob> {
|
|
184
|
+
if (file.type === 'url') {
|
|
185
|
+
return downloadBlob(file.url);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
const data =
|
|
189
|
+
file.data instanceof Uint8Array
|
|
190
|
+
? file.data
|
|
191
|
+
: convertBase64ToUint8Array(file.data);
|
|
192
|
+
|
|
193
|
+
return new Blob([data as BlobPart], { type: file.mediaType });
|
|
194
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
// https://deepinfra.com/models/text-to-image
|
|
2
|
+
export type DeepInfraImageModelId =
|
|
3
|
+
| 'stabilityai/sd3.5'
|
|
4
|
+
| 'black-forest-labs/FLUX-1.1-pro'
|
|
5
|
+
| 'black-forest-labs/FLUX-1-schnell'
|
|
6
|
+
| 'black-forest-labs/FLUX-1-dev'
|
|
7
|
+
| 'black-forest-labs/FLUX-pro'
|
|
8
|
+
| 'black-forest-labs/FLUX.1-Kontext-dev'
|
|
9
|
+
| 'black-forest-labs/FLUX.1-Kontext-pro'
|
|
10
|
+
| 'stabilityai/sd3.5-medium'
|
|
11
|
+
| 'stabilityai/sdxl-turbo'
|
|
12
|
+
| (string & {});
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import { DeepInfraImageModel } from './deepinfra-image-model';
|
|
2
|
+
import { createDeepInfra } from './deepinfra-provider';
|
|
3
|
+
import {
|
|
4
|
+
OpenAICompatibleChatLanguageModel,
|
|
5
|
+
OpenAICompatibleCompletionLanguageModel,
|
|
6
|
+
OpenAICompatibleEmbeddingModel,
|
|
7
|
+
} from '@ai-sdk/openai-compatible';
|
|
8
|
+
import { LanguageModelV3, EmbeddingModelV3 } from '@ai-sdk/provider';
|
|
9
|
+
import { loadApiKey } from '@ai-sdk/provider-utils';
|
|
10
|
+
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
|
|
11
|
+
|
|
12
|
+
// Add type assertion for the mocked class
|
|
13
|
+
const OpenAICompatibleChatLanguageModelMock =
|
|
14
|
+
OpenAICompatibleChatLanguageModel as unknown as Mock;
|
|
15
|
+
|
|
16
|
+
vi.mock('@ai-sdk/openai-compatible', () => ({
|
|
17
|
+
OpenAICompatibleChatLanguageModel: vi.fn(),
|
|
18
|
+
OpenAICompatibleCompletionLanguageModel: vi.fn(),
|
|
19
|
+
OpenAICompatibleEmbeddingModel: vi.fn(),
|
|
20
|
+
}));
|
|
21
|
+
|
|
22
|
+
vi.mock('@ai-sdk/provider-utils', async () => {
|
|
23
|
+
const actual = await vi.importActual('@ai-sdk/provider-utils');
|
|
24
|
+
return {
|
|
25
|
+
...actual,
|
|
26
|
+
loadApiKey: vi.fn().mockReturnValue('mock-api-key'),
|
|
27
|
+
withoutTrailingSlash: vi.fn(url => url),
|
|
28
|
+
};
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
vi.mock('./deepinfra-image-model', () => ({
|
|
32
|
+
DeepInfraImageModel: vi.fn(),
|
|
33
|
+
}));
|
|
34
|
+
|
|
35
|
+
describe('DeepInfraProvider', () => {
|
|
36
|
+
let mockLanguageModel: LanguageModelV3;
|
|
37
|
+
let mockEmbeddingModel: EmbeddingModelV3;
|
|
38
|
+
|
|
39
|
+
beforeEach(() => {
|
|
40
|
+
// Mock implementations of models
|
|
41
|
+
mockLanguageModel = {
|
|
42
|
+
// Add any required methods for LanguageModelV3
|
|
43
|
+
} as LanguageModelV3;
|
|
44
|
+
mockEmbeddingModel = {
|
|
45
|
+
// Add any required methods for EmbeddingModelV3
|
|
46
|
+
} as EmbeddingModelV3;
|
|
47
|
+
|
|
48
|
+
// Reset mocks
|
|
49
|
+
vi.clearAllMocks();
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
describe('createDeepInfra', () => {
|
|
53
|
+
it('should create a DeepInfraProvider instance with default options', () => {
|
|
54
|
+
const provider = createDeepInfra();
|
|
55
|
+
const model = provider('model-id');
|
|
56
|
+
|
|
57
|
+
// Use the mocked version
|
|
58
|
+
const constructorCall =
|
|
59
|
+
OpenAICompatibleChatLanguageModelMock.mock.calls[0];
|
|
60
|
+
const config = constructorCall[1];
|
|
61
|
+
config.headers();
|
|
62
|
+
|
|
63
|
+
expect(loadApiKey).toHaveBeenCalledWith({
|
|
64
|
+
apiKey: undefined,
|
|
65
|
+
environmentVariableName: 'DEEPINFRA_API_KEY',
|
|
66
|
+
description: "DeepInfra's API key",
|
|
67
|
+
});
|
|
68
|
+
});
|
|
69
|
+
|
|
70
|
+
it('should create a DeepInfraProvider instance with custom options', () => {
|
|
71
|
+
const options = {
|
|
72
|
+
apiKey: 'custom-key',
|
|
73
|
+
baseURL: 'https://custom.url',
|
|
74
|
+
headers: { 'Custom-Header': 'value' },
|
|
75
|
+
};
|
|
76
|
+
const provider = createDeepInfra(options);
|
|
77
|
+
const model = provider('model-id');
|
|
78
|
+
|
|
79
|
+
const constructorCall =
|
|
80
|
+
OpenAICompatibleChatLanguageModelMock.mock.calls[0];
|
|
81
|
+
const config = constructorCall[1];
|
|
82
|
+
config.headers();
|
|
83
|
+
|
|
84
|
+
expect(loadApiKey).toHaveBeenCalledWith({
|
|
85
|
+
apiKey: 'custom-key',
|
|
86
|
+
environmentVariableName: 'DEEPINFRA_API_KEY',
|
|
87
|
+
description: "DeepInfra's API key",
|
|
88
|
+
});
|
|
89
|
+
});
|
|
90
|
+
|
|
91
|
+
it('should return a chat model when called as a function', () => {
|
|
92
|
+
const provider = createDeepInfra();
|
|
93
|
+
const modelId = 'foo-model-id';
|
|
94
|
+
|
|
95
|
+
const model = provider(modelId);
|
|
96
|
+
expect(model).toBeInstanceOf(OpenAICompatibleChatLanguageModel);
|
|
97
|
+
});
|
|
98
|
+
});
|
|
99
|
+
|
|
100
|
+
describe('chatModel', () => {
|
|
101
|
+
it('should construct a chat model with correct configuration', () => {
|
|
102
|
+
const provider = createDeepInfra();
|
|
103
|
+
const modelId = 'deepinfra-chat-model';
|
|
104
|
+
|
|
105
|
+
const model = provider.chatModel(modelId);
|
|
106
|
+
|
|
107
|
+
expect(model).toBeInstanceOf(OpenAICompatibleChatLanguageModel);
|
|
108
|
+
expect(OpenAICompatibleChatLanguageModelMock).toHaveBeenCalledWith(
|
|
109
|
+
modelId,
|
|
110
|
+
expect.objectContaining({
|
|
111
|
+
provider: 'deepinfra.chat',
|
|
112
|
+
}),
|
|
113
|
+
);
|
|
114
|
+
});
|
|
115
|
+
});
|
|
116
|
+
|
|
117
|
+
describe('completionModel', () => {
|
|
118
|
+
it('should construct a completion model with correct configuration', () => {
|
|
119
|
+
const provider = createDeepInfra();
|
|
120
|
+
const modelId = 'deepinfra-completion-model';
|
|
121
|
+
|
|
122
|
+
const model = provider.completionModel(modelId);
|
|
123
|
+
|
|
124
|
+
expect(model).toBeInstanceOf(OpenAICompatibleCompletionLanguageModel);
|
|
125
|
+
});
|
|
126
|
+
});
|
|
127
|
+
|
|
128
|
+
describe('embeddingModel', () => {
|
|
129
|
+
it('should construct a text embedding model with correct configuration', () => {
|
|
130
|
+
const provider = createDeepInfra();
|
|
131
|
+
const modelId = 'deepinfra-embedding-model';
|
|
132
|
+
|
|
133
|
+
const model = provider.embeddingModel(modelId);
|
|
134
|
+
|
|
135
|
+
expect(model).toBeInstanceOf(OpenAICompatibleEmbeddingModel);
|
|
136
|
+
});
|
|
137
|
+
});
|
|
138
|
+
|
|
139
|
+
describe('image', () => {
|
|
140
|
+
it('should construct an image model with correct configuration', () => {
|
|
141
|
+
const provider = createDeepInfra();
|
|
142
|
+
const modelId = 'deepinfra-image-model';
|
|
143
|
+
|
|
144
|
+
const model = provider.image(modelId);
|
|
145
|
+
|
|
146
|
+
expect(model).toBeInstanceOf(DeepInfraImageModel);
|
|
147
|
+
expect(DeepInfraImageModel).toHaveBeenCalledWith(
|
|
148
|
+
modelId,
|
|
149
|
+
expect.objectContaining({
|
|
150
|
+
provider: 'deepinfra.image',
|
|
151
|
+
baseURL: 'https://api.deepinfra.com/v1/inference',
|
|
152
|
+
}),
|
|
153
|
+
);
|
|
154
|
+
});
|
|
155
|
+
|
|
156
|
+
it('should use default settings when none provided', () => {
|
|
157
|
+
const provider = createDeepInfra();
|
|
158
|
+
const modelId = 'deepinfra-image-model';
|
|
159
|
+
|
|
160
|
+
const model = provider.image(modelId);
|
|
161
|
+
|
|
162
|
+
expect(model).toBeInstanceOf(DeepInfraImageModel);
|
|
163
|
+
expect(DeepInfraImageModel).toHaveBeenCalledWith(
|
|
164
|
+
modelId,
|
|
165
|
+
expect.any(Object),
|
|
166
|
+
);
|
|
167
|
+
});
|
|
168
|
+
|
|
169
|
+
it('should respect custom baseURL', () => {
|
|
170
|
+
const customBaseURL = 'https://custom.api.deepinfra.com';
|
|
171
|
+
const provider = createDeepInfra({ baseURL: customBaseURL });
|
|
172
|
+
const modelId = 'deepinfra-image-model';
|
|
173
|
+
|
|
174
|
+
const model = provider.image(modelId);
|
|
175
|
+
|
|
176
|
+
expect(DeepInfraImageModel).toHaveBeenCalledWith(
|
|
177
|
+
modelId,
|
|
178
|
+
expect.objectContaining({
|
|
179
|
+
baseURL: `${customBaseURL}/inference`,
|
|
180
|
+
}),
|
|
181
|
+
);
|
|
182
|
+
});
|
|
183
|
+
});
|
|
184
|
+
});
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import {
|
|
2
|
+
LanguageModelV3,
|
|
3
|
+
EmbeddingModelV3,
|
|
4
|
+
ProviderV3,
|
|
5
|
+
ImageModelV3,
|
|
6
|
+
} from '@ai-sdk/provider';
|
|
7
|
+
import {
|
|
8
|
+
OpenAICompatibleChatLanguageModel,
|
|
9
|
+
OpenAICompatibleCompletionLanguageModel,
|
|
10
|
+
OpenAICompatibleEmbeddingModel,
|
|
11
|
+
} from '@ai-sdk/openai-compatible';
|
|
12
|
+
import {
|
|
13
|
+
FetchFunction,
|
|
14
|
+
loadApiKey,
|
|
15
|
+
withoutTrailingSlash,
|
|
16
|
+
withUserAgentSuffix,
|
|
17
|
+
} from '@ai-sdk/provider-utils';
|
|
18
|
+
import { DeepInfraChatModelId } from './deepinfra-chat-options';
|
|
19
|
+
import { DeepInfraEmbeddingModelId } from './deepinfra-embedding-options';
|
|
20
|
+
import { DeepInfraCompletionModelId } from './deepinfra-completion-options';
|
|
21
|
+
import { DeepInfraImageModelId } from './deepinfra-image-settings';
|
|
22
|
+
import { DeepInfraImageModel } from './deepinfra-image-model';
|
|
23
|
+
import { VERSION } from './version';
|
|
24
|
+
|
|
25
|
+
export interface DeepInfraProviderSettings {
|
|
26
|
+
/**
|
|
27
|
+
DeepInfra API key.
|
|
28
|
+
*/
|
|
29
|
+
apiKey?: string;
|
|
30
|
+
/**
|
|
31
|
+
Base URL for the API calls.
|
|
32
|
+
*/
|
|
33
|
+
baseURL?: string;
|
|
34
|
+
/**
|
|
35
|
+
Custom headers to include in the requests.
|
|
36
|
+
*/
|
|
37
|
+
headers?: Record<string, string>;
|
|
38
|
+
/**
|
|
39
|
+
Custom fetch implementation. You can use it as a middleware to intercept requests,
|
|
40
|
+
or to provide a custom fetch implementation for e.g. testing.
|
|
41
|
+
*/
|
|
42
|
+
fetch?: FetchFunction;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
export interface DeepInfraProvider extends ProviderV3 {
|
|
46
|
+
/**
|
|
47
|
+
Creates a model for text generation.
|
|
48
|
+
*/
|
|
49
|
+
(modelId: DeepInfraChatModelId): LanguageModelV3;
|
|
50
|
+
|
|
51
|
+
/**
|
|
52
|
+
Creates a chat model for text generation.
|
|
53
|
+
*/
|
|
54
|
+
chatModel(modelId: DeepInfraChatModelId): LanguageModelV3;
|
|
55
|
+
|
|
56
|
+
/**
|
|
57
|
+
Creates a model for image generation.
|
|
58
|
+
*/
|
|
59
|
+
image(modelId: DeepInfraImageModelId): ImageModelV3;
|
|
60
|
+
|
|
61
|
+
/**
|
|
62
|
+
Creates a model for image generation.
|
|
63
|
+
*/
|
|
64
|
+
imageModel(modelId: DeepInfraImageModelId): ImageModelV3;
|
|
65
|
+
|
|
66
|
+
/**
|
|
67
|
+
Creates a chat model for text generation.
|
|
68
|
+
*/
|
|
69
|
+
languageModel(modelId: DeepInfraChatModelId): LanguageModelV3;
|
|
70
|
+
|
|
71
|
+
/**
|
|
72
|
+
Creates a completion model for text generation.
|
|
73
|
+
*/
|
|
74
|
+
completionModel(modelId: DeepInfraCompletionModelId): LanguageModelV3;
|
|
75
|
+
|
|
76
|
+
/**
|
|
77
|
+
Creates a embedding model for text generation.
|
|
78
|
+
*/
|
|
79
|
+
embeddingModel(modelId: DeepInfraEmbeddingModelId): EmbeddingModelV3;
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* @deprecated Use `embeddingModel` instead.
|
|
83
|
+
*/
|
|
84
|
+
textEmbeddingModel(modelId: DeepInfraEmbeddingModelId): EmbeddingModelV3;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
export function createDeepInfra(
|
|
88
|
+
options: DeepInfraProviderSettings = {},
|
|
89
|
+
): DeepInfraProvider {
|
|
90
|
+
const baseURL = withoutTrailingSlash(
|
|
91
|
+
options.baseURL ?? 'https://api.deepinfra.com/v1',
|
|
92
|
+
);
|
|
93
|
+
const getHeaders = () =>
|
|
94
|
+
withUserAgentSuffix(
|
|
95
|
+
{
|
|
96
|
+
Authorization: `Bearer ${loadApiKey({
|
|
97
|
+
apiKey: options.apiKey,
|
|
98
|
+
environmentVariableName: 'DEEPINFRA_API_KEY',
|
|
99
|
+
description: "DeepInfra's API key",
|
|
100
|
+
})}`,
|
|
101
|
+
...options.headers,
|
|
102
|
+
},
|
|
103
|
+
`ai-sdk/deepinfra/${VERSION}`,
|
|
104
|
+
);
|
|
105
|
+
|
|
106
|
+
interface CommonModelConfig {
|
|
107
|
+
provider: string;
|
|
108
|
+
url: ({ path }: { path: string }) => string;
|
|
109
|
+
headers: () => Record<string, string>;
|
|
110
|
+
fetch?: FetchFunction;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
const getCommonModelConfig = (modelType: string): CommonModelConfig => ({
|
|
114
|
+
provider: `deepinfra.${modelType}`,
|
|
115
|
+
url: ({ path }) => `${baseURL}/openai${path}`,
|
|
116
|
+
headers: getHeaders,
|
|
117
|
+
fetch: options.fetch,
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
const createChatModel = (modelId: DeepInfraChatModelId) => {
|
|
121
|
+
return new OpenAICompatibleChatLanguageModel(
|
|
122
|
+
modelId,
|
|
123
|
+
getCommonModelConfig('chat'),
|
|
124
|
+
);
|
|
125
|
+
};
|
|
126
|
+
|
|
127
|
+
const createCompletionModel = (modelId: DeepInfraCompletionModelId) =>
|
|
128
|
+
new OpenAICompatibleCompletionLanguageModel(
|
|
129
|
+
modelId,
|
|
130
|
+
getCommonModelConfig('completion'),
|
|
131
|
+
);
|
|
132
|
+
|
|
133
|
+
const createEmbeddingModel = (modelId: DeepInfraEmbeddingModelId) =>
|
|
134
|
+
new OpenAICompatibleEmbeddingModel(
|
|
135
|
+
modelId,
|
|
136
|
+
getCommonModelConfig('embedding'),
|
|
137
|
+
);
|
|
138
|
+
|
|
139
|
+
const createImageModel = (modelId: DeepInfraImageModelId) =>
|
|
140
|
+
new DeepInfraImageModel(modelId, {
|
|
141
|
+
...getCommonModelConfig('image'),
|
|
142
|
+
baseURL: baseURL
|
|
143
|
+
? `${baseURL}/inference`
|
|
144
|
+
: 'https://api.deepinfra.com/v1/inference',
|
|
145
|
+
});
|
|
146
|
+
|
|
147
|
+
const provider = (modelId: DeepInfraChatModelId) => createChatModel(modelId);
|
|
148
|
+
|
|
149
|
+
provider.specificationVersion = 'v3' as const;
|
|
150
|
+
provider.completionModel = createCompletionModel;
|
|
151
|
+
provider.chatModel = createChatModel;
|
|
152
|
+
provider.image = createImageModel;
|
|
153
|
+
provider.imageModel = createImageModel;
|
|
154
|
+
provider.languageModel = createChatModel;
|
|
155
|
+
provider.embeddingModel = createEmbeddingModel;
|
|
156
|
+
provider.textEmbeddingModel = createEmbeddingModel;
|
|
157
|
+
|
|
158
|
+
return provider;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
export const deepinfra = createDeepInfra();
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
export { createDeepInfra, deepinfra } from './deepinfra-provider';
|
|
2
|
+
export type {
|
|
3
|
+
DeepInfraProvider,
|
|
4
|
+
DeepInfraProviderSettings,
|
|
5
|
+
} from './deepinfra-provider';
|
|
6
|
+
export type { OpenAICompatibleErrorData as DeepInfraErrorData } from '@ai-sdk/openai-compatible';
|
|
7
|
+
export { VERSION } from './version';
|