@ai-sdk/mistral 3.0.9 → 3.0.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +6 -0
- package/dist/index.js +1 -1
- package/dist/index.js.map +1 -1
- package/dist/index.mjs +1 -1
- package/dist/index.mjs.map +1 -1
- package/package.json +3 -2
- package/src/__fixtures__/mistral-generate-text.1.json +22 -0
- package/src/__snapshots__/convert-to-mistral-chat-messages.test.ts.snap +57 -0
- package/src/__snapshots__/mistral-embedding-model.test.ts.snap +44 -0
- package/src/convert-mistral-usage.ts +46 -0
- package/src/convert-to-mistral-chat-messages.test.ts +372 -0
- package/src/convert-to-mistral-chat-messages.ts +163 -0
- package/src/get-response-metadata.ts +15 -0
- package/src/index.ts +7 -0
- package/src/map-mistral-finish-reason.ts +17 -0
- package/src/mistral-chat-language-model.test.ts +1755 -0
- package/src/mistral-chat-language-model.ts +580 -0
- package/src/mistral-chat-options.ts +63 -0
- package/src/mistral-chat-prompt.ts +46 -0
- package/src/mistral-embedding-model.test.ts +127 -0
- package/src/mistral-embedding-model.ts +94 -0
- package/src/mistral-embedding-options.ts +1 -0
- package/src/mistral-error.ts +17 -0
- package/src/mistral-prepare-tools.test.ts +178 -0
- package/src/mistral-prepare-tools.ts +97 -0
- package/src/mistral-provider.ts +147 -0
- package/src/version.ts +6 -0
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import { EmbeddingModelV3Embedding } from '@ai-sdk/provider';
|
|
2
|
+
import { createTestServer } from '@ai-sdk/test-server/with-vitest';
|
|
3
|
+
import { createMistral } from './mistral-provider';
|
|
4
|
+
import { describe, it, expect, vi } from 'vitest';
|
|
5
|
+
|
|
6
|
+
vi.mock('./version', () => ({
|
|
7
|
+
VERSION: '0.0.0-test',
|
|
8
|
+
}));
|
|
9
|
+
|
|
10
|
+
const dummyEmbeddings = [
|
|
11
|
+
[0.1, 0.2, 0.3, 0.4, 0.5],
|
|
12
|
+
[0.6, 0.7, 0.8, 0.9, 1.0],
|
|
13
|
+
];
|
|
14
|
+
const testValues = ['sunny day at the beach', 'rainy day in the city'];
|
|
15
|
+
|
|
16
|
+
const provider = createMistral({ apiKey: 'test-api-key' });
|
|
17
|
+
const model = provider.embeddingModel('mistral-embed');
|
|
18
|
+
|
|
19
|
+
const server = createTestServer({
|
|
20
|
+
'https://api.mistral.ai/v1/embeddings': {},
|
|
21
|
+
});
|
|
22
|
+
|
|
23
|
+
describe('doEmbed', () => {
|
|
24
|
+
function prepareJsonResponse({
|
|
25
|
+
embeddings = dummyEmbeddings,
|
|
26
|
+
usage = { prompt_tokens: 8, total_tokens: 8 },
|
|
27
|
+
headers,
|
|
28
|
+
}: {
|
|
29
|
+
embeddings?: EmbeddingModelV3Embedding[];
|
|
30
|
+
usage?: { prompt_tokens: number; total_tokens: number };
|
|
31
|
+
headers?: Record<string, string>;
|
|
32
|
+
} = {}) {
|
|
33
|
+
server.urls['https://api.mistral.ai/v1/embeddings'].response = {
|
|
34
|
+
type: 'json-value',
|
|
35
|
+
headers,
|
|
36
|
+
body: {
|
|
37
|
+
id: 'b322cfc2b9d34e2f8e14fc99874faee5',
|
|
38
|
+
object: 'list',
|
|
39
|
+
data: embeddings.map((embedding, i) => ({
|
|
40
|
+
object: 'embedding',
|
|
41
|
+
embedding,
|
|
42
|
+
index: i,
|
|
43
|
+
})),
|
|
44
|
+
model: 'mistral-embed',
|
|
45
|
+
usage,
|
|
46
|
+
},
|
|
47
|
+
};
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
it('should extract embedding', async () => {
|
|
51
|
+
prepareJsonResponse();
|
|
52
|
+
|
|
53
|
+
const { embeddings } = await model.doEmbed({ values: testValues });
|
|
54
|
+
|
|
55
|
+
expect(embeddings).toStrictEqual(dummyEmbeddings);
|
|
56
|
+
});
|
|
57
|
+
|
|
58
|
+
it('should extract usage', async () => {
|
|
59
|
+
prepareJsonResponse({
|
|
60
|
+
usage: { prompt_tokens: 20, total_tokens: 20 },
|
|
61
|
+
});
|
|
62
|
+
|
|
63
|
+
const { usage } = await model.doEmbed({ values: testValues });
|
|
64
|
+
|
|
65
|
+
expect(usage).toStrictEqual({ tokens: 20 });
|
|
66
|
+
});
|
|
67
|
+
|
|
68
|
+
it('should expose the raw response', async () => {
|
|
69
|
+
prepareJsonResponse({
|
|
70
|
+
headers: { 'test-header': 'test-value' },
|
|
71
|
+
});
|
|
72
|
+
|
|
73
|
+
const { response } = await model.doEmbed({ values: testValues });
|
|
74
|
+
|
|
75
|
+
expect(response?.headers).toStrictEqual({
|
|
76
|
+
// default headers:
|
|
77
|
+
'content-length': '267',
|
|
78
|
+
'content-type': 'application/json',
|
|
79
|
+
|
|
80
|
+
// custom header
|
|
81
|
+
'test-header': 'test-value',
|
|
82
|
+
});
|
|
83
|
+
expect(response).toMatchSnapshot();
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
it('should pass the model and the values', async () => {
|
|
87
|
+
prepareJsonResponse();
|
|
88
|
+
|
|
89
|
+
await model.doEmbed({ values: testValues });
|
|
90
|
+
|
|
91
|
+
expect(await server.calls[0].requestBodyJson).toStrictEqual({
|
|
92
|
+
model: 'mistral-embed',
|
|
93
|
+
input: testValues,
|
|
94
|
+
encoding_format: 'float',
|
|
95
|
+
});
|
|
96
|
+
});
|
|
97
|
+
|
|
98
|
+
it('should pass headers', async () => {
|
|
99
|
+
prepareJsonResponse();
|
|
100
|
+
|
|
101
|
+
const provider = createMistral({
|
|
102
|
+
apiKey: 'test-api-key',
|
|
103
|
+
headers: {
|
|
104
|
+
'Custom-Provider-Header': 'provider-header-value',
|
|
105
|
+
},
|
|
106
|
+
});
|
|
107
|
+
|
|
108
|
+
await provider.embedding('mistral-embed').doEmbed({
|
|
109
|
+
values: testValues,
|
|
110
|
+
headers: {
|
|
111
|
+
'Custom-Request-Header': 'request-header-value',
|
|
112
|
+
},
|
|
113
|
+
});
|
|
114
|
+
|
|
115
|
+
const requestHeaders = server.calls[0].requestHeaders;
|
|
116
|
+
|
|
117
|
+
expect(requestHeaders).toStrictEqual({
|
|
118
|
+
authorization: 'Bearer test-api-key',
|
|
119
|
+
'content-type': 'application/json',
|
|
120
|
+
'custom-provider-header': 'provider-header-value',
|
|
121
|
+
'custom-request-header': 'request-header-value',
|
|
122
|
+
});
|
|
123
|
+
expect(server.calls[0].requestUserAgent).toContain(
|
|
124
|
+
`ai-sdk/mistral/0.0.0-test`,
|
|
125
|
+
);
|
|
126
|
+
});
|
|
127
|
+
});
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import {
|
|
2
|
+
EmbeddingModelV3,
|
|
3
|
+
TooManyEmbeddingValuesForCallError,
|
|
4
|
+
} from '@ai-sdk/provider';
|
|
5
|
+
import {
|
|
6
|
+
combineHeaders,
|
|
7
|
+
createJsonResponseHandler,
|
|
8
|
+
FetchFunction,
|
|
9
|
+
postJsonToApi,
|
|
10
|
+
} from '@ai-sdk/provider-utils';
|
|
11
|
+
import { z } from 'zod/v4';
|
|
12
|
+
import { MistralEmbeddingModelId } from './mistral-embedding-options';
|
|
13
|
+
import { mistralFailedResponseHandler } from './mistral-error';
|
|
14
|
+
|
|
15
|
+
type MistralEmbeddingConfig = {
|
|
16
|
+
provider: string;
|
|
17
|
+
baseURL: string;
|
|
18
|
+
headers: () => Record<string, string | undefined>;
|
|
19
|
+
fetch?: FetchFunction;
|
|
20
|
+
};
|
|
21
|
+
|
|
22
|
+
export class MistralEmbeddingModel implements EmbeddingModelV3 {
|
|
23
|
+
readonly specificationVersion = 'v3';
|
|
24
|
+
readonly modelId: MistralEmbeddingModelId;
|
|
25
|
+
readonly maxEmbeddingsPerCall = 32;
|
|
26
|
+
readonly supportsParallelCalls = false;
|
|
27
|
+
|
|
28
|
+
private readonly config: MistralEmbeddingConfig;
|
|
29
|
+
|
|
30
|
+
get provider(): string {
|
|
31
|
+
return this.config.provider;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
constructor(
|
|
35
|
+
modelId: MistralEmbeddingModelId,
|
|
36
|
+
config: MistralEmbeddingConfig,
|
|
37
|
+
) {
|
|
38
|
+
this.modelId = modelId;
|
|
39
|
+
this.config = config;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
async doEmbed({
|
|
43
|
+
values,
|
|
44
|
+
abortSignal,
|
|
45
|
+
headers,
|
|
46
|
+
}: Parameters<EmbeddingModelV3['doEmbed']>[0]): Promise<
|
|
47
|
+
Awaited<ReturnType<EmbeddingModelV3['doEmbed']>>
|
|
48
|
+
> {
|
|
49
|
+
if (values.length > this.maxEmbeddingsPerCall) {
|
|
50
|
+
throw new TooManyEmbeddingValuesForCallError({
|
|
51
|
+
provider: this.provider,
|
|
52
|
+
modelId: this.modelId,
|
|
53
|
+
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
|
|
54
|
+
values,
|
|
55
|
+
});
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
const {
|
|
59
|
+
responseHeaders,
|
|
60
|
+
value: response,
|
|
61
|
+
rawValue,
|
|
62
|
+
} = await postJsonToApi({
|
|
63
|
+
url: `${this.config.baseURL}/embeddings`,
|
|
64
|
+
headers: combineHeaders(this.config.headers(), headers),
|
|
65
|
+
body: {
|
|
66
|
+
model: this.modelId,
|
|
67
|
+
input: values,
|
|
68
|
+
encoding_format: 'float',
|
|
69
|
+
},
|
|
70
|
+
failedResponseHandler: mistralFailedResponseHandler,
|
|
71
|
+
successfulResponseHandler: createJsonResponseHandler(
|
|
72
|
+
MistralTextEmbeddingResponseSchema,
|
|
73
|
+
),
|
|
74
|
+
abortSignal,
|
|
75
|
+
fetch: this.config.fetch,
|
|
76
|
+
});
|
|
77
|
+
|
|
78
|
+
return {
|
|
79
|
+
warnings: [],
|
|
80
|
+
embeddings: response.data.map(item => item.embedding),
|
|
81
|
+
usage: response.usage
|
|
82
|
+
? { tokens: response.usage.prompt_tokens }
|
|
83
|
+
: undefined,
|
|
84
|
+
response: { headers: responseHeaders, body: rawValue },
|
|
85
|
+
};
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
// minimal version of the schema, focussed on what is needed for the implementation
|
|
90
|
+
// this approach limits breakages when the API changes and increases efficiency
|
|
91
|
+
const MistralTextEmbeddingResponseSchema = z.object({
|
|
92
|
+
data: z.array(z.object({ embedding: z.array(z.number()) })),
|
|
93
|
+
usage: z.object({ prompt_tokens: z.number() }).nullish(),
|
|
94
|
+
});
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export type MistralEmbeddingModelId = 'mistral-embed' | (string & {});
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import { createJsonErrorResponseHandler } from '@ai-sdk/provider-utils';
|
|
2
|
+
import { z } from 'zod/v4';
|
|
3
|
+
|
|
4
|
+
const mistralErrorDataSchema = z.object({
|
|
5
|
+
object: z.literal('error'),
|
|
6
|
+
message: z.string(),
|
|
7
|
+
type: z.string(),
|
|
8
|
+
param: z.string().nullable(),
|
|
9
|
+
code: z.string().nullable(),
|
|
10
|
+
});
|
|
11
|
+
|
|
12
|
+
export type MistralErrorData = z.infer<typeof mistralErrorDataSchema>;
|
|
13
|
+
|
|
14
|
+
export const mistralFailedResponseHandler = createJsonErrorResponseHandler({
|
|
15
|
+
errorSchema: mistralErrorDataSchema,
|
|
16
|
+
errorToMessage: data => data.message,
|
|
17
|
+
});
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import { describe, it, expect } from 'vitest';
|
|
2
|
+
import { prepareTools } from './mistral-prepare-tools';
|
|
3
|
+
|
|
4
|
+
describe('prepareTools', () => {
|
|
5
|
+
it('should pass through strict mode when strict is true', () => {
|
|
6
|
+
const result = prepareTools({
|
|
7
|
+
tools: [
|
|
8
|
+
{
|
|
9
|
+
type: 'function',
|
|
10
|
+
name: 'testFunction',
|
|
11
|
+
description: 'A test function',
|
|
12
|
+
inputSchema: { type: 'object', properties: {} },
|
|
13
|
+
strict: true,
|
|
14
|
+
},
|
|
15
|
+
],
|
|
16
|
+
});
|
|
17
|
+
|
|
18
|
+
expect(result).toMatchInlineSnapshot(`
|
|
19
|
+
{
|
|
20
|
+
"toolChoice": undefined,
|
|
21
|
+
"toolWarnings": [],
|
|
22
|
+
"tools": [
|
|
23
|
+
{
|
|
24
|
+
"function": {
|
|
25
|
+
"description": "A test function",
|
|
26
|
+
"name": "testFunction",
|
|
27
|
+
"parameters": {
|
|
28
|
+
"properties": {},
|
|
29
|
+
"type": "object",
|
|
30
|
+
},
|
|
31
|
+
"strict": true,
|
|
32
|
+
},
|
|
33
|
+
"type": "function",
|
|
34
|
+
},
|
|
35
|
+
],
|
|
36
|
+
}
|
|
37
|
+
`);
|
|
38
|
+
});
|
|
39
|
+
|
|
40
|
+
it('should pass through strict mode when strict is false', () => {
|
|
41
|
+
const result = prepareTools({
|
|
42
|
+
tools: [
|
|
43
|
+
{
|
|
44
|
+
type: 'function',
|
|
45
|
+
name: 'testFunction',
|
|
46
|
+
description: 'A test function',
|
|
47
|
+
inputSchema: { type: 'object', properties: {} },
|
|
48
|
+
strict: false,
|
|
49
|
+
},
|
|
50
|
+
],
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
expect(result).toMatchInlineSnapshot(`
|
|
54
|
+
{
|
|
55
|
+
"toolChoice": undefined,
|
|
56
|
+
"toolWarnings": [],
|
|
57
|
+
"tools": [
|
|
58
|
+
{
|
|
59
|
+
"function": {
|
|
60
|
+
"description": "A test function",
|
|
61
|
+
"name": "testFunction",
|
|
62
|
+
"parameters": {
|
|
63
|
+
"properties": {},
|
|
64
|
+
"type": "object",
|
|
65
|
+
},
|
|
66
|
+
"strict": false,
|
|
67
|
+
},
|
|
68
|
+
"type": "function",
|
|
69
|
+
},
|
|
70
|
+
],
|
|
71
|
+
}
|
|
72
|
+
`);
|
|
73
|
+
});
|
|
74
|
+
|
|
75
|
+
it('should not include strict mode when strict is undefined', () => {
|
|
76
|
+
const result = prepareTools({
|
|
77
|
+
tools: [
|
|
78
|
+
{
|
|
79
|
+
type: 'function',
|
|
80
|
+
name: 'testFunction',
|
|
81
|
+
description: 'A test function',
|
|
82
|
+
inputSchema: { type: 'object', properties: {} },
|
|
83
|
+
},
|
|
84
|
+
],
|
|
85
|
+
});
|
|
86
|
+
|
|
87
|
+
expect(result).toMatchInlineSnapshot(`
|
|
88
|
+
{
|
|
89
|
+
"toolChoice": undefined,
|
|
90
|
+
"toolWarnings": [],
|
|
91
|
+
"tools": [
|
|
92
|
+
{
|
|
93
|
+
"function": {
|
|
94
|
+
"description": "A test function",
|
|
95
|
+
"name": "testFunction",
|
|
96
|
+
"parameters": {
|
|
97
|
+
"properties": {},
|
|
98
|
+
"type": "object",
|
|
99
|
+
},
|
|
100
|
+
},
|
|
101
|
+
"type": "function",
|
|
102
|
+
},
|
|
103
|
+
],
|
|
104
|
+
}
|
|
105
|
+
`);
|
|
106
|
+
});
|
|
107
|
+
|
|
108
|
+
it('should pass through strict mode for multiple tools with different strict settings', () => {
|
|
109
|
+
const result = prepareTools({
|
|
110
|
+
tools: [
|
|
111
|
+
{
|
|
112
|
+
type: 'function',
|
|
113
|
+
name: 'strictTool',
|
|
114
|
+
description: 'A strict tool',
|
|
115
|
+
inputSchema: { type: 'object', properties: {} },
|
|
116
|
+
strict: true,
|
|
117
|
+
},
|
|
118
|
+
{
|
|
119
|
+
type: 'function',
|
|
120
|
+
name: 'nonStrictTool',
|
|
121
|
+
description: 'A non-strict tool',
|
|
122
|
+
inputSchema: { type: 'object', properties: {} },
|
|
123
|
+
strict: false,
|
|
124
|
+
},
|
|
125
|
+
{
|
|
126
|
+
type: 'function',
|
|
127
|
+
name: 'defaultTool',
|
|
128
|
+
description: 'A tool without strict setting',
|
|
129
|
+
inputSchema: { type: 'object', properties: {} },
|
|
130
|
+
},
|
|
131
|
+
],
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
expect(result).toMatchInlineSnapshot(`
|
|
135
|
+
{
|
|
136
|
+
"toolChoice": undefined,
|
|
137
|
+
"toolWarnings": [],
|
|
138
|
+
"tools": [
|
|
139
|
+
{
|
|
140
|
+
"function": {
|
|
141
|
+
"description": "A strict tool",
|
|
142
|
+
"name": "strictTool",
|
|
143
|
+
"parameters": {
|
|
144
|
+
"properties": {},
|
|
145
|
+
"type": "object",
|
|
146
|
+
},
|
|
147
|
+
"strict": true,
|
|
148
|
+
},
|
|
149
|
+
"type": "function",
|
|
150
|
+
},
|
|
151
|
+
{
|
|
152
|
+
"function": {
|
|
153
|
+
"description": "A non-strict tool",
|
|
154
|
+
"name": "nonStrictTool",
|
|
155
|
+
"parameters": {
|
|
156
|
+
"properties": {},
|
|
157
|
+
"type": "object",
|
|
158
|
+
},
|
|
159
|
+
"strict": false,
|
|
160
|
+
},
|
|
161
|
+
"type": "function",
|
|
162
|
+
},
|
|
163
|
+
{
|
|
164
|
+
"function": {
|
|
165
|
+
"description": "A tool without strict setting",
|
|
166
|
+
"name": "defaultTool",
|
|
167
|
+
"parameters": {
|
|
168
|
+
"properties": {},
|
|
169
|
+
"type": "object",
|
|
170
|
+
},
|
|
171
|
+
},
|
|
172
|
+
"type": "function",
|
|
173
|
+
},
|
|
174
|
+
],
|
|
175
|
+
}
|
|
176
|
+
`);
|
|
177
|
+
});
|
|
178
|
+
});
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import {
|
|
2
|
+
LanguageModelV3CallOptions,
|
|
3
|
+
SharedV3Warning,
|
|
4
|
+
UnsupportedFunctionalityError,
|
|
5
|
+
} from '@ai-sdk/provider';
|
|
6
|
+
import { MistralToolChoice } from './mistral-chat-prompt';
|
|
7
|
+
|
|
8
|
+
export function prepareTools({
|
|
9
|
+
tools,
|
|
10
|
+
toolChoice,
|
|
11
|
+
}: {
|
|
12
|
+
tools: LanguageModelV3CallOptions['tools'];
|
|
13
|
+
toolChoice?: LanguageModelV3CallOptions['toolChoice'];
|
|
14
|
+
}): {
|
|
15
|
+
tools:
|
|
16
|
+
| Array<{
|
|
17
|
+
type: 'function';
|
|
18
|
+
function: {
|
|
19
|
+
name: string;
|
|
20
|
+
description: string | undefined;
|
|
21
|
+
parameters: unknown;
|
|
22
|
+
strict?: boolean;
|
|
23
|
+
};
|
|
24
|
+
}>
|
|
25
|
+
| undefined;
|
|
26
|
+
toolChoice: MistralToolChoice | undefined;
|
|
27
|
+
toolWarnings: SharedV3Warning[];
|
|
28
|
+
} {
|
|
29
|
+
// when the tools array is empty, change it to undefined to prevent errors:
|
|
30
|
+
tools = tools?.length ? tools : undefined;
|
|
31
|
+
|
|
32
|
+
const toolWarnings: SharedV3Warning[] = [];
|
|
33
|
+
|
|
34
|
+
if (tools == null) {
|
|
35
|
+
return { tools: undefined, toolChoice: undefined, toolWarnings };
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
const mistralTools: Array<{
|
|
39
|
+
type: 'function';
|
|
40
|
+
function: {
|
|
41
|
+
name: string;
|
|
42
|
+
description: string | undefined;
|
|
43
|
+
parameters: unknown;
|
|
44
|
+
strict?: boolean;
|
|
45
|
+
};
|
|
46
|
+
}> = [];
|
|
47
|
+
|
|
48
|
+
for (const tool of tools) {
|
|
49
|
+
if (tool.type === 'provider') {
|
|
50
|
+
toolWarnings.push({
|
|
51
|
+
type: 'unsupported',
|
|
52
|
+
feature: `provider-defined tool ${tool.id}`,
|
|
53
|
+
});
|
|
54
|
+
} else {
|
|
55
|
+
mistralTools.push({
|
|
56
|
+
type: 'function',
|
|
57
|
+
function: {
|
|
58
|
+
name: tool.name,
|
|
59
|
+
description: tool.description,
|
|
60
|
+
parameters: tool.inputSchema,
|
|
61
|
+
...(tool.strict != null ? { strict: tool.strict } : {}),
|
|
62
|
+
},
|
|
63
|
+
});
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
if (toolChoice == null) {
|
|
68
|
+
return { tools: mistralTools, toolChoice: undefined, toolWarnings };
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
const type = toolChoice.type;
|
|
72
|
+
|
|
73
|
+
switch (type) {
|
|
74
|
+
case 'auto':
|
|
75
|
+
case 'none':
|
|
76
|
+
return { tools: mistralTools, toolChoice: type, toolWarnings };
|
|
77
|
+
case 'required':
|
|
78
|
+
return { tools: mistralTools, toolChoice: 'any', toolWarnings };
|
|
79
|
+
|
|
80
|
+
// mistral does not support tool mode directly,
|
|
81
|
+
// so we filter the tools and force the tool choice through 'any'
|
|
82
|
+
case 'tool':
|
|
83
|
+
return {
|
|
84
|
+
tools: mistralTools.filter(
|
|
85
|
+
tool => tool.function.name === toolChoice.toolName,
|
|
86
|
+
),
|
|
87
|
+
toolChoice: 'any',
|
|
88
|
+
toolWarnings,
|
|
89
|
+
};
|
|
90
|
+
default: {
|
|
91
|
+
const _exhaustiveCheck: never = type;
|
|
92
|
+
throw new UnsupportedFunctionalityError({
|
|
93
|
+
functionality: `tool choice type: ${_exhaustiveCheck}`,
|
|
94
|
+
});
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
}
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import {
|
|
2
|
+
EmbeddingModelV3,
|
|
3
|
+
LanguageModelV3,
|
|
4
|
+
NoSuchModelError,
|
|
5
|
+
ProviderV3,
|
|
6
|
+
} from '@ai-sdk/provider';
|
|
7
|
+
import {
|
|
8
|
+
FetchFunction,
|
|
9
|
+
loadApiKey,
|
|
10
|
+
withoutTrailingSlash,
|
|
11
|
+
withUserAgentSuffix,
|
|
12
|
+
} from '@ai-sdk/provider-utils';
|
|
13
|
+
import { MistralChatLanguageModel } from './mistral-chat-language-model';
|
|
14
|
+
import { MistralChatModelId } from './mistral-chat-options';
|
|
15
|
+
import { MistralEmbeddingModel } from './mistral-embedding-model';
|
|
16
|
+
import { MistralEmbeddingModelId } from './mistral-embedding-options';
|
|
17
|
+
import { VERSION } from './version';
|
|
18
|
+
|
|
19
|
+
export interface MistralProvider extends ProviderV3 {
|
|
20
|
+
(modelId: MistralChatModelId): LanguageModelV3;
|
|
21
|
+
|
|
22
|
+
/**
|
|
23
|
+
Creates a model for text generation.
|
|
24
|
+
*/
|
|
25
|
+
languageModel(modelId: MistralChatModelId): LanguageModelV3;
|
|
26
|
+
|
|
27
|
+
/**
|
|
28
|
+
Creates a model for text generation.
|
|
29
|
+
*/
|
|
30
|
+
chat(modelId: MistralChatModelId): LanguageModelV3;
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* Creates a model for text embeddings.
|
|
34
|
+
*/
|
|
35
|
+
embedding(modelId: MistralEmbeddingModelId): EmbeddingModelV3;
|
|
36
|
+
|
|
37
|
+
/**
|
|
38
|
+
* Creates a model for text embeddings.
|
|
39
|
+
*/
|
|
40
|
+
embeddingModel: (modelId: MistralEmbeddingModelId) => EmbeddingModelV3;
|
|
41
|
+
|
|
42
|
+
/**
|
|
43
|
+
* @deprecated Use `embedding` instead.
|
|
44
|
+
*/
|
|
45
|
+
textEmbedding(modelId: MistralEmbeddingModelId): EmbeddingModelV3;
|
|
46
|
+
|
|
47
|
+
/**
|
|
48
|
+
* @deprecated Use `embeddingModel` instead.
|
|
49
|
+
*/
|
|
50
|
+
textEmbeddingModel(modelId: MistralEmbeddingModelId): EmbeddingModelV3;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
export interface MistralProviderSettings {
|
|
54
|
+
/**
|
|
55
|
+
Use a different URL prefix for API calls, e.g. to use proxy servers.
|
|
56
|
+
The default prefix is `https://api.mistral.ai/v1`.
|
|
57
|
+
*/
|
|
58
|
+
baseURL?: string;
|
|
59
|
+
|
|
60
|
+
/**
|
|
61
|
+
API key that is being send using the `Authorization` header.
|
|
62
|
+
It defaults to the `MISTRAL_API_KEY` environment variable.
|
|
63
|
+
*/
|
|
64
|
+
apiKey?: string;
|
|
65
|
+
|
|
66
|
+
/**
|
|
67
|
+
Custom headers to include in the requests.
|
|
68
|
+
*/
|
|
69
|
+
headers?: Record<string, string>;
|
|
70
|
+
|
|
71
|
+
/**
|
|
72
|
+
Custom fetch implementation. You can use it as a middleware to intercept requests,
|
|
73
|
+
or to provide a custom fetch implementation for e.g. testing.
|
|
74
|
+
*/
|
|
75
|
+
fetch?: FetchFunction;
|
|
76
|
+
|
|
77
|
+
generateId?: () => string;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
/**
|
|
81
|
+
Create a Mistral AI provider instance.
|
|
82
|
+
*/
|
|
83
|
+
export function createMistral(
|
|
84
|
+
options: MistralProviderSettings = {},
|
|
85
|
+
): MistralProvider {
|
|
86
|
+
const baseURL =
|
|
87
|
+
withoutTrailingSlash(options.baseURL) ?? 'https://api.mistral.ai/v1';
|
|
88
|
+
|
|
89
|
+
const getHeaders = () =>
|
|
90
|
+
withUserAgentSuffix(
|
|
91
|
+
{
|
|
92
|
+
Authorization: `Bearer ${loadApiKey({
|
|
93
|
+
apiKey: options.apiKey,
|
|
94
|
+
environmentVariableName: 'MISTRAL_API_KEY',
|
|
95
|
+
description: 'Mistral',
|
|
96
|
+
})}`,
|
|
97
|
+
...options.headers,
|
|
98
|
+
},
|
|
99
|
+
`ai-sdk/mistral/${VERSION}`,
|
|
100
|
+
);
|
|
101
|
+
|
|
102
|
+
const createChatModel = (modelId: MistralChatModelId) =>
|
|
103
|
+
new MistralChatLanguageModel(modelId, {
|
|
104
|
+
provider: 'mistral.chat',
|
|
105
|
+
baseURL,
|
|
106
|
+
headers: getHeaders,
|
|
107
|
+
fetch: options.fetch,
|
|
108
|
+
generateId: options.generateId,
|
|
109
|
+
});
|
|
110
|
+
|
|
111
|
+
const createEmbeddingModel = (modelId: MistralEmbeddingModelId) =>
|
|
112
|
+
new MistralEmbeddingModel(modelId, {
|
|
113
|
+
provider: 'mistral.embedding',
|
|
114
|
+
baseURL,
|
|
115
|
+
headers: getHeaders,
|
|
116
|
+
fetch: options.fetch,
|
|
117
|
+
});
|
|
118
|
+
|
|
119
|
+
const provider = function (modelId: MistralChatModelId) {
|
|
120
|
+
if (new.target) {
|
|
121
|
+
throw new Error(
|
|
122
|
+
'The Mistral model function cannot be called with the new keyword.',
|
|
123
|
+
);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
return createChatModel(modelId);
|
|
127
|
+
};
|
|
128
|
+
|
|
129
|
+
provider.specificationVersion = 'v3' as const;
|
|
130
|
+
provider.languageModel = createChatModel;
|
|
131
|
+
provider.chat = createChatModel;
|
|
132
|
+
provider.embedding = createEmbeddingModel;
|
|
133
|
+
provider.embeddingModel = createEmbeddingModel;
|
|
134
|
+
provider.textEmbedding = createEmbeddingModel;
|
|
135
|
+
provider.textEmbeddingModel = createEmbeddingModel;
|
|
136
|
+
|
|
137
|
+
provider.imageModel = (modelId: string) => {
|
|
138
|
+
throw new NoSuchModelError({ modelId, modelType: 'imageModel' });
|
|
139
|
+
};
|
|
140
|
+
|
|
141
|
+
return provider;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
/**
|
|
145
|
+
Default Mistral provider instance.
|
|
146
|
+
*/
|
|
147
|
+
export const mistral = createMistral();
|