@mistralai/mistralai 0.0.8 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/.eslintrc.yml CHANGED
@@ -2,11 +2,12 @@ env:
2
2
  browser: true
3
3
  es2021: true
4
4
  extends: google
5
- ignorePatterns:
5
+ ignorePatterns:
6
6
  - examples/chat-react/
7
7
  parserOptions:
8
8
  ecmaVersion: latest
9
9
  sourceType: module
10
- rules:
10
+ rules:
11
11
  indent: ["error", 2]
12
12
  space-before-function-paren: ["error", "never"]
13
+ quotes: ["error", "single"]
@@ -3,7 +3,7 @@ name: Build and Publish
3
3
  on:
4
4
  push:
5
5
  branches: ["main"]
6
-
6
+
7
7
  # We only deploy on tags and main branch
8
8
  tags:
9
9
  # Only run on tags that match the following regex
@@ -14,10 +14,13 @@ on:
14
14
  pull_request:
15
15
 
16
16
  jobs:
17
-
18
- lint:
17
+ lint_and_test:
19
18
  runs-on: ubuntu-latest
20
19
 
20
+ strategy:
21
+ matrix:
22
+ node-version: [18, 20]
23
+
21
24
  steps:
22
25
  # Checkout the repository
23
26
  - name: Checkout
@@ -27,20 +30,25 @@ jobs:
27
30
  - name: set node version
28
31
  uses: actions/setup-node@v4
29
32
  with:
30
- node-version: 18
33
+ node-version: ${{ matrix.node-version }}
31
34
 
32
35
  # Install Build stuff
33
36
  - name: Install Dependencies
34
37
  run: |
35
38
  npm install
36
39
 
37
- # Ruff
40
+ # Eslint
38
41
  - name: ESlint check
39
42
  run: |
40
- ./node_modules/.bin/eslint .
43
+ npm run lint
44
+
45
+ # Run tests
46
+ - name: Run tests
47
+ run: |
48
+ npm run test
41
49
 
42
50
  publish:
43
- needs: lint
51
+ needs: lint_and_test
44
52
  runs-on: ubuntu-latest
45
53
  if: startsWith(github.ref, 'refs/tags')
46
54
 
@@ -60,6 +68,67 @@ jobs:
60
68
  run: |
61
69
  echo "//registry.npmjs.org/:_authToken=${{ secrets.NPM_TOKEN }}" >> .npmrc
62
70
  npm version ${{ github.ref_name }}
63
- npm publish
71
+ sed -i 's/VERSION = '\''0.0.1'\''/VERSION = '\''${{ github.ref_name }}'\''/g' src/client.js
72
+ npm publish
73
+
74
+ create_pr_on_public:
75
+ if: startsWith(github.ref, 'refs/tags')
76
+ runs-on: ubuntu-latest
77
+ needs: lint_and_test
78
+ steps:
79
+ - name: Checkout
80
+ uses: actions/checkout@v4
81
+ with:
82
+ token: ${{ secrets.PUBLIC_CLIENT_WRITE_TOKEN }}
83
+
84
+ - name: Pull public updates
85
+ env: # We cannot use the github bot token to push to the public repo, we have to use one with more permissions
86
+ GITHUB_TOKEN: ${{ secrets.PUBLIC_CLIENT_WRITE_TOKEN }}
87
+ run: |
88
+
89
+ set -x
90
+ git config --global user.name "GitHub Actions"
91
+ git config --global user.email "mayo@mistral.ai"
92
+
93
+ git remote add public https://github.com/mistralai/client-js.git
94
+ git remote update
95
+
96
+ # Create a diff of the changes, ignoring the ci workflow
97
+ git merge public/main --no-commit --no-ff --no-edit --allow-unrelated-histories --strategy-option ours
98
+
99
+ # If there are changes, commit them
100
+ if ! git diff --quiet; then
101
+ git commit -m "Update from public repo"
102
+ git push origin ${{github.ref}}
103
+ else
104
+ echo "No changes to apply"
105
+ fi
106
+
107
+ - name: Push to public repo
108
+ env:
109
+ GITHUB_TOKEN: ${{ secrets.PUBLIC_CLIENT_WRITE_TOKEN }}
110
+ run: |
111
+ git checkout public/main
112
+ git checkout -b update/${{github.ref_name}}
113
+
114
+ # write version number to version file
115
+ echo ${{github.ref_name}} > version.txt
116
+
117
+ git add .
118
+ git commit -m "Bump version file"
119
+
120
+ # create a diff of this ref and the public repo
121
+ git diff update/${{github.ref_name}} ${{github.ref_name}} --binary -- . ':!.github' > changes.diff
122
+
123
+ # apply the diff to the current branch
124
+ git apply changes.diff
125
+
126
+ # commit the changes
127
+ git add .
128
+ git commit -m "Update version to ${{github.ref_name}}"
129
+
130
+ # push the changes
131
+ git push public update/${{github.ref_name}}
64
132
 
65
-
133
+ # Create a PR from this branch to the public repo
134
+ gh pr create --title "Update client to ${{github.ref_name}}" --body "This PR was automatically created by a GitHub Action" --base main --head update/${{github.ref_name}} --repo mistralai/client-js
package/README.md CHANGED
@@ -11,16 +11,19 @@ You can install the library in your project using:
11
11
  `npm install @mistralai/mistralai`
12
12
 
13
13
  ## Usage
14
+
14
15
  ### Set up
16
+
15
17
  ```typescript
16
18
  import MistralClient from '@mistralai/mistralai';
17
19
 
18
- const apiKey = "Your API key";
20
+ const apiKey = process.env.MISTRAL_API_KEY || 'your_api_key';
19
21
 
20
22
  const client = new MistralClient(apiKey);
21
23
  ```
22
24
 
23
25
  ### List models
26
+
24
27
  ```typescript
25
28
  const listModelsResponse = await client.listModels();
26
29
  const listModels = listModelsResponse.data;
@@ -30,6 +33,7 @@ listModels.forEach((model) => {
30
33
  ```
31
34
 
32
35
  ### Chat with streaming
36
+
33
37
  ```typescript
34
38
  const chatStreamResponse = await client.chatStream({
35
39
  model: 'mistral-tiny',
@@ -44,7 +48,9 @@ for await (const chunk of chatStreamResponse) {
44
48
  }
45
49
  }
46
50
  ```
51
+
47
52
  ### Chat without streaming
53
+
48
54
  ```typescript
49
55
  const chatResponse = await client.chat({
50
56
  model: 'mistral-tiny',
@@ -53,7 +59,9 @@ const chatResponse = await client.chat({
53
59
 
54
60
  console.log('Chat:', chatResponse.choices[0].message.content);
55
61
  ```
56
- ###Embeddings
62
+
63
+ ### Embeddings
64
+
57
65
  ```typescript
58
66
  const input = [];
59
67
  for (let i = 0; i < 1; i++) {
@@ -67,6 +75,7 @@ const embeddingsBatchResponse = await client.embeddings({
67
75
 
68
76
  console.log('Embeddings Batch:', embeddingsBatchResponse.data);
69
77
  ```
78
+
70
79
  ## Run examples
71
80
 
72
81
  You can run the examples in the examples directory by installing them locally:
@@ -76,23 +85,39 @@ cd examples
76
85
  npm install .
77
86
  ```
78
87
 
79
- ### API Key Setup
88
+ ### API key setup
80
89
 
81
90
  Running the examples requires a Mistral AI API key.
82
91
 
83
- 1. Get your own Mistral API Key: <https://docs.mistral.ai/#api-access>
84
- 2. Set your Mistral API Key as an environment variable. You only need to do this once.
92
+ Get your own Mistral API Key: <https://docs.mistral.ai/#api-access>
93
+
94
+ ### Run the examples
95
+
96
+ ```bash
97
+ MISTRAL_API_KEY='your_api_key' node chat_with_streaming.js
98
+ ```
99
+
100
+ ### Persisting the API key in environment
101
+
102
+ Set your Mistral API Key as an environment variable. You only need to do this once.
85
103
 
86
104
  ```bash
87
105
  # set Mistral API Key (using zsh for example)
88
- $ echo 'export MISTRAL_API_KEY=[your_key_here]' >> ~/.zshenv
106
+ $ echo 'export MISTRAL_API_KEY=[your_api_key]' >> ~/.zshenv
89
107
 
90
108
  # reload the environment (or just quit and open a new terminal)
91
109
  $ source ~/.zshenv
92
110
  ```
93
111
 
94
- You can then run the examples using node:
112
+ You can then run the examples without appending the API key:
95
113
 
96
114
  ```bash
97
- MISTRAL_API_KEY=XXXX node chat_with_streaming.js
115
+ node chat_with_streaming.js
116
+ ```
117
+ After the env variable setup the client will find the `MISTRAL_API_KEY` by itself
118
+
119
+ ```typescript
120
+ import MistralClient from '@mistralai/mistralai';
121
+
122
+ const client = new MistralClient();
98
123
  ```
@@ -26,6 +26,7 @@
26
26
  "devDependencies": {
27
27
  "eslint": "^8.55.0",
28
28
  "eslint-config-google": "^0.14.0",
29
+ "jest": "^29.7.0",
29
30
  "prettier": "2.8.8"
30
31
  }
31
32
  },
@@ -16023,16 +16024,16 @@
16023
16024
  }
16024
16025
  },
16025
16026
  "node_modules/typescript": {
16026
- "version": "5.3.3",
16027
- "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.3.3.tgz",
16028
- "integrity": "sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==",
16027
+ "version": "4.9.5",
16028
+ "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
16029
+ "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
16029
16030
  "peer": true,
16030
16031
  "bin": {
16031
16032
  "tsc": "bin/tsc",
16032
16033
  "tsserver": "bin/tsserver"
16033
16034
  },
16034
16035
  "engines": {
16035
- "node": ">=14.17"
16036
+ "node": ">=4.2.0"
16036
16037
  }
16037
16038
  },
16038
16039
  "node_modules/unbox-primitive": {
@@ -19300,6 +19301,7 @@
19300
19301
  "requires": {
19301
19302
  "eslint": "^8.55.0",
19302
19303
  "eslint-config-google": "^0.14.0",
19304
+ "jest": "^29.7.0",
19303
19305
  "node-fetch": "^2.6.7",
19304
19306
  "prettier": "2.8.8"
19305
19307
  }
@@ -28464,9 +28466,9 @@
28464
28466
  }
28465
28467
  },
28466
28468
  "typescript": {
28467
- "version": "5.3.3",
28468
- "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.3.3.tgz",
28469
- "integrity": "sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==",
28469
+ "version": "4.9.5",
28470
+ "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
28471
+ "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
28470
28472
  "peer": true
28471
28473
  },
28472
28474
  "unbox-primitive": {
@@ -13,11 +13,6 @@
13
13
  "scripts": {
14
14
  "start": "react-scripts start"
15
15
  },
16
- // "eslintConfig": {
17
- // "extends": [
18
- // "react-app"
19
- // ]
20
- // },
21
16
  "browserslist": {
22
17
  "production": [
23
18
  ">0.2%",
@@ -0,0 +1,122 @@
1
+ import MistralClient from '@mistralai/mistralai';
2
+
3
+ const apiKey = process.env.MISTRAL_API_KEY;
4
+
5
+ // Assuming we have the following data
6
+ const data = {
7
+ transactionId: ['T1001', 'T1002', 'T1003', 'T1004', 'T1005'],
8
+ customerId: ['C001', 'C002', 'C003', 'C002', 'C001'],
9
+ paymentAmount: [125.50, 89.99, 120.00, 54.30, 210.20],
10
+ paymentDate: [
11
+ '2021-10-05', '2021-10-06', '2021-10-07', '2021-10-05', '2021-10-08',
12
+ ],
13
+ paymentStatus: ['Paid', 'Unpaid', 'Paid', 'Paid', 'Pending'],
14
+ };
15
+
16
+ /**
17
+ * This function retrieves the payment status of a transaction id.
18
+ * @param {object} data - The data object.
19
+ * @param {string} transactionId - The transaction id.
20
+ * @return {string} - The payment status.
21
+ */
22
+ function retrievePaymentStatus({data, transactionId}) {
23
+ const transactionIndex = data.transactionId.indexOf(transactionId);
24
+ if (transactionIndex != -1) {
25
+ return JSON.stringify({status: data.payment_status[transactionIndex]});
26
+ } else {
27
+ return JSON.stringify({status: 'error - transaction id not found.'});
28
+ }
29
+ }
30
+
31
+ /**
32
+ * This function retrieves the payment date of a transaction id.
33
+ * @param {object} data - The data object.
34
+ * @param {string} transactionId - The transaction id.
35
+ * @return {string} - The payment date.
36
+ *
37
+ */
38
+ function retrievePaymentDate({data, transactionId}) {
39
+ const transactionIndex = data.transactionId.indexOf(transactionId);
40
+ if (transactionIndex != -1) {
41
+ return JSON.stringify({status: data.payment_date[transactionIndex]});
42
+ } else {
43
+ return JSON.stringify({status: 'error - transaction id not found.'});
44
+ }
45
+ }
46
+
47
+ const namesToFunctions = {
48
+ retrievePaymentStatus: (transactionId) =>
49
+ retrievePaymentStatus({data, ...transactionId}),
50
+ retrievePaymentDate: (transactionId) =>
51
+ retrievePaymentDate({data, ...transactionId}),
52
+ };
53
+
54
+ const tools = [
55
+ {
56
+ type: 'function',
57
+ function: {
58
+ name: 'retrievePaymentStatus',
59
+ description: 'Get payment status of a transaction id',
60
+ parameters: {
61
+ type: 'object',
62
+ required: ['transactionId'],
63
+ properties: {transactionId:
64
+ {type: 'string', description: 'The transaction id.'},
65
+ },
66
+ },
67
+ },
68
+ },
69
+ {
70
+ type: 'function',
71
+ function: {
72
+ name: 'retrievePaymentDate',
73
+ description: 'Get payment date of a transaction id',
74
+ parameters: {
75
+ type: 'object',
76
+ required: ['transactionId'],
77
+ properties: {transactionId:
78
+ {type: 'string', description: 'The transaction id.'},
79
+ },
80
+ },
81
+ },
82
+ },
83
+ ];
84
+
85
+
86
+ const model = 'mistral-large';
87
+
88
+ const client = new MistralClient(apiKey, 'https://api-2.aurocloud.net');
89
+
90
+ const messages = [
91
+ {role: 'user', content: 'What\'s the status of my transaction?'},
92
+ ];
93
+
94
+ let response = await client.chat({
95
+ model: model, messages: messages, tools: tools,
96
+ });
97
+
98
+
99
+ console.log(response.choices[0].message.content);
100
+
101
+ messages.push(
102
+ {role: 'assistant', content: response.choices[0].message.content},
103
+ );
104
+ messages.push({role: 'user', content: 'My transaction ID is T1001.'});
105
+
106
+ response = await client.chat({model: model, messages: messages, tools: tools});
107
+
108
+ const toolCall = response.choices[0].message.toolCalls[0];
109
+ const functionName = toolCall.function.name;
110
+ const functionParams = JSON.parse(toolCall.function.arguments);
111
+
112
+ console.log(`calling functionName: ${functionName}`);
113
+ console.log(`functionParams: ${toolCall.function.arguments}`);
114
+
115
+ const functionResult = namesToFunctions[functionName](functionParams);
116
+
117
+ messages.push(response.choices[0].message);
118
+ messages.push({role: 'tool', name: functionName, content: functionResult});
119
+
120
+ response = await client.chat({model: model, messages: messages, tools: tools});
121
+
122
+ console.log(response.choices[0].message.content);
@@ -0,0 +1,13 @@
1
+ import MistralClient from '@mistralai/mistralai';
2
+
3
+ const apiKey = process.env.MISTRAL_API_KEY;
4
+
5
+ const client = new MistralClient(apiKey);
6
+
7
+ const chatResponse = await client.chat({
8
+ model: 'mistral-large',
9
+ messages: [{role: 'user', content: 'What is the best French cheese?'}],
10
+ responseFormat: {type: 'json_object'},
11
+ });
12
+
13
+ console.log('Chat:', chatResponse.choices[0].message.content);
@@ -21,6 +21,7 @@
21
21
  "devDependencies": {
22
22
  "eslint": "^8.55.0",
23
23
  "eslint-config-google": "^0.14.0",
24
+ "jest": "^29.7.0",
24
25
  "prettier": "2.8.8"
25
26
  }
26
27
  },
package/package.json CHANGED
@@ -1,22 +1,32 @@
1
1
  {
2
2
  "name": "@mistralai/mistralai",
3
- "version": "0.0.8",
3
+ "version": "0.1.3",
4
4
  "description": "",
5
5
  "author": "bam4d@mistral.ai",
6
6
  "license": "ISC",
7
7
  "type": "module",
8
8
  "main": "src/client.js",
9
+ "scripts": {
10
+ "lint": "./node_modules/.bin/eslint .",
11
+ "test": "node --experimental-vm-modules node_modules/.bin/jest"
12
+ },
13
+ "jest": {
14
+ "testPathIgnorePatterns": [
15
+ "examples"
16
+ ]
17
+ },
9
18
  "repository": {
10
19
  "type": "git",
11
20
  "url": "https://github.com/mistralai/client-js"
12
21
  },
13
- "types": "src/mistralai.d.ts",
22
+ "types": "src/client.d.ts",
14
23
  "dependencies": {
15
24
  "node-fetch": "^2.6.7"
16
25
  },
17
26
  "devDependencies": {
18
27
  "eslint": "^8.55.0",
19
28
  "eslint-config-google": "^0.14.0",
20
- "prettier": "2.8.8"
29
+ "prettier": "2.8.8",
30
+ "jest": "^29.7.0"
21
31
  }
22
32
  }
package/src/client.d.ts CHANGED
@@ -29,6 +29,42 @@ declare module '@mistralai/mistralai' {
29
29
  data: Model[];
30
30
  }
31
31
 
32
+ export interface Function {
33
+ name: string;
34
+ description: string;
35
+ parameters: object;
36
+ }
37
+
38
+ export enum ToolType {
39
+ function = 'function',
40
+ }
41
+
42
+ export interface FunctionCall {
43
+ name: string;
44
+ arguments: string;
45
+ }
46
+
47
+ export interface ToolCalls {
48
+ id: 'null';
49
+ type: ToolType = ToolType.function;
50
+ function: FunctionCall;
51
+ }
52
+
53
+ export enum ResponseFormats {
54
+ text = 'text',
55
+ json_object = 'json_object',
56
+ }
57
+
58
+ export enum ToolChoice {
59
+ auto = 'auto',
60
+ any = 'any',
61
+ none = 'none',
62
+ }
63
+
64
+ export interface ResponseFormat {
65
+ type: ResponseFormats = ResponseFormats.text;
66
+ }
67
+
32
68
  export interface TokenUsage {
33
69
  prompt_tokens: number;
34
70
  completion_tokens: number;
@@ -49,6 +85,7 @@ declare module '@mistralai/mistralai' {
49
85
  delta: {
50
86
  role?: string;
51
87
  content?: string;
88
+ tool_calls?: ToolCalls[];
52
89
  };
53
90
  finish_reason: string;
54
91
  }
@@ -95,35 +132,56 @@ declare module '@mistralai/mistralai' {
95
132
 
96
133
  private _makeChatCompletionRequest(
97
134
  model: string,
98
- messages: Array<{ role: string; content: string }>,
135
+ messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>,
136
+ tools?: Array<{ type: string; function:Function; }>,
99
137
  temperature?: number,
100
138
  maxTokens?: number,
101
139
  topP?: number,
102
140
  randomSeed?: number,
103
141
  stream?: boolean,
104
- safeMode?: boolean
142
+ /**
143
+ * @deprecated use safePrompt instead
144
+ */
145
+ safeMode?: boolean,
146
+ safePrompt?: boolean,
147
+ toolChoice?: ToolChoice,
148
+ responseFormat?: ResponseFormat
105
149
  ): object;
106
150
 
107
151
  listModels(): Promise<ListModelsResponse>;
108
152
 
109
153
  chat(options: {
110
154
  model: string;
111
- messages: Array<{ role: string; content: string }>;
155
+ messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>;
156
+ tools?: Array<{ type: string; function:Function; }>;
112
157
  temperature?: number;
113
158
  maxTokens?: number;
114
159
  topP?: number;
115
160
  randomSeed?: number;
161
+ /**
162
+ * @deprecated use safePrompt instead
163
+ */
116
164
  safeMode?: boolean;
165
+ safePrompt?: boolean;
166
+ toolChoice?: ToolChoice;
167
+ responseFormat?: ResponseFormat;
117
168
  }): Promise<ChatCompletionResponse>;
118
169
 
119
170
  chatStream(options: {
120
171
  model: string;
121
- messages: Array<{ role: string; content: string }>;
172
+ messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>;
173
+ tools?: Array<{ type: string; function:Function; }>;
122
174
  temperature?: number;
123
175
  maxTokens?: number;
124
176
  topP?: number;
125
177
  randomSeed?: number;
178
+ /**
179
+ * @deprecated use safePrompt instead
180
+ */
126
181
  safeMode?: boolean;
182
+ safePrompt?: boolean;
183
+ toolChoice?: ToolChoice;
184
+ responseFormat?: ResponseFormat;
127
185
  }): AsyncGenerator<ChatCompletionResponseChunk, void, unknown>;
128
186
 
129
187
  embeddings(options: {
package/src/client.js CHANGED
@@ -1,16 +1,26 @@
1
- let fetch;
2
1
  let isNode = false;
3
- if (typeof globalThis.fetch === 'undefined') {
4
- fetch = (await import('node-fetch')).default;
5
- isNode = true;
6
- } else {
7
- fetch = globalThis.fetch;
8
- }
9
-
10
2
 
3
+ const VERSION = '0.0.3';
11
4
  const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
12
5
  const ENDPOINT = 'https://api.mistral.ai';
13
6
 
7
+ /**
8
+ * Initialize fetch
9
+ * @return {Promise<void>}
10
+ */
11
+ async function initializeFetch() {
12
+ if (typeof window === 'undefined' ||
13
+ typeof globalThis.fetch === 'undefined') {
14
+ const nodeFetch = await import('node-fetch');
15
+ fetch = nodeFetch.default;
16
+ isNode = true;
17
+ } else {
18
+ fetch = globalThis.fetch;
19
+ }
20
+ }
21
+
22
+ initializeFetch();
23
+
14
24
  /**
15
25
  * MistralAPIError
16
26
  * @return {MistralAPIError}
@@ -51,6 +61,10 @@ class MistralClient {
51
61
 
52
62
  this.maxRetries = maxRetries;
53
63
  this.timeout = timeout;
64
+
65
+ if (this.endpoint.indexOf('inference.azure.com')) {
66
+ this.modelDefault = 'mistral';
67
+ }
54
68
  }
55
69
 
56
70
  /**
@@ -65,6 +79,8 @@ class MistralClient {
65
79
  const options = {
66
80
  method: method,
67
81
  headers: {
82
+ 'User-Agent': `mistral-client-js/${VERSION}`,
83
+ 'Accept': request?.stream ? 'text/event-stream' : 'application/json',
68
84
  'Content-Type': 'application/json',
69
85
  'Authorization': `Bearer ${this.apiKey}`,
70
86
  },
@@ -85,14 +101,13 @@ class MistralClient {
85
101
  // Chrome does not support async iterators yet, so polyfill it
86
102
  const asyncIterator = async function* () {
87
103
  try {
88
- const decoder = new TextDecoder();
89
104
  while (true) {
90
105
  // Read from the stream
91
106
  const {done, value} = await reader.read();
92
107
  // Exit if we're done
93
108
  if (done) return;
94
109
  // Else yield the chunk
95
- yield decoder.decode(value, {stream: true});
110
+ yield value;
96
111
  }
97
112
  } finally {
98
113
  reader.releaseLock();
@@ -104,14 +119,19 @@ class MistralClient {
104
119
  }
105
120
  return await response.json();
106
121
  } else if (RETRY_STATUS_CODES.includes(response.status)) {
107
- console.debug(`Retrying request, attempt: ${attempts + 1}`);
122
+ console.debug(
123
+ `Retrying request on response status: ${response.status}`,
124
+ `Response: ${await response.text()}`,
125
+ `Attempt: ${attempts + 1}`,
126
+ );
108
127
  // eslint-disable-next-line max-len
109
128
  await new Promise((resolve) =>
110
129
  setTimeout(resolve, Math.pow(2, (attempts + 1)) * 500),
111
130
  );
112
131
  } else {
113
132
  throw new MistralAPIError(
114
- `HTTP error! status: ${response.status}`,
133
+ `HTTP error! status: ${response.status} ` +
134
+ `Response: \n${await response.text()}`,
115
135
  );
116
136
  }
117
137
  } catch (error) {
@@ -133,33 +153,50 @@ class MistralClient {
133
153
  * Creates a chat completion request
134
154
  * @param {*} model
135
155
  * @param {*} messages
156
+ * @param {*} tools
136
157
  * @param {*} temperature
137
158
  * @param {*} maxTokens
138
159
  * @param {*} topP
139
160
  * @param {*} randomSeed
140
161
  * @param {*} stream
141
- * @param {*} safeMode
162
+ * @param {*} safeMode deprecated use safePrompt instead
163
+ * @param {*} safePrompt
164
+ * @param {*} toolChoice
165
+ * @param {*} responseFormat
142
166
  * @return {Promise<Object>}
143
167
  */
144
168
  _makeChatCompletionRequest = function(
145
169
  model,
146
170
  messages,
171
+ tools,
147
172
  temperature,
148
173
  maxTokens,
149
174
  topP,
150
175
  randomSeed,
151
176
  stream,
152
177
  safeMode,
178
+ safePrompt,
179
+ toolChoice,
180
+ responseFormat,
153
181
  ) {
182
+ // if modelDefault and model are undefined, throw an error
183
+ if (!model && !this.modelDefault) {
184
+ throw new MistralAPIError(
185
+ 'You must provide a model name',
186
+ );
187
+ }
154
188
  return {
155
- model: model,
189
+ model: model ?? this.modelDefault,
156
190
  messages: messages,
191
+ tools: tools ?? undefined,
157
192
  temperature: temperature ?? undefined,
158
193
  max_tokens: maxTokens ?? undefined,
159
194
  top_p: topP ?? undefined,
160
195
  random_seed: randomSeed ?? undefined,
161
196
  stream: stream ?? undefined,
162
- safe_prompt: safeMode ?? undefined,
197
+ safe_prompt: (safeMode || safePrompt) ?? undefined,
198
+ tool_choice: toolChoice ?? undefined,
199
+ response_format: responseFormat ?? undefined,
163
200
  };
164
201
  };
165
202
 
@@ -177,31 +214,43 @@ class MistralClient {
177
214
  * @param {*} model the name of the model to chat with, e.g. mistral-tiny
178
215
  * @param {*} messages an array of messages to chat with, e.g.
179
216
  * [{role: 'user', content: 'What is the best French cheese?'}]
217
+ * @param {*} tools a list of tools to use.
180
218
  * @param {*} temperature the temperature to use for sampling, e.g. 0.5
181
219
  * @param {*} maxTokens the maximum number of tokens to generate, e.g. 100
182
220
  * @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9
183
221
  * @param {*} randomSeed the random seed to use for sampling, e.g. 42
184
- * @param {*} safeMode whether to use safe mode, e.g. true
222
+ * @param {*} safeMode deprecated use safePrompt instead
223
+ * @param {*} safePrompt whether to use safe mode, e.g. true
224
+ * @param {*} toolChoice the tool to use, e.g. 'auto'
225
+ * @param {*} responseFormat the format of the response, e.g. 'json_format'
185
226
  * @return {Promise<Object>}
186
227
  */
187
228
  chat = async function({
188
229
  model,
189
230
  messages,
231
+ tools,
190
232
  temperature,
191
233
  maxTokens,
192
234
  topP,
193
235
  randomSeed,
194
236
  safeMode,
237
+ safePrompt,
238
+ toolChoice,
239
+ responseFormat,
195
240
  }) {
196
241
  const request = this._makeChatCompletionRequest(
197
242
  model,
198
243
  messages,
244
+ tools,
199
245
  temperature,
200
246
  maxTokens,
201
247
  topP,
202
248
  randomSeed,
203
249
  false,
204
250
  safeMode,
251
+ safePrompt,
252
+ toolChoice,
253
+ responseFormat,
205
254
  );
206
255
  const response = await this._request(
207
256
  'post',
@@ -216,31 +265,43 @@ class MistralClient {
216
265
  * @param {*} model the name of the model to chat with, e.g. mistral-tiny
217
266
  * @param {*} messages an array of messages to chat with, e.g.
218
267
  * [{role: 'user', content: 'What is the best French cheese?'}]
268
+ * @param {*} tools a list of tools to use.
219
269
  * @param {*} temperature the temperature to use for sampling, e.g. 0.5
220
270
  * @param {*} maxTokens the maximum number of tokens to generate, e.g. 100
221
271
  * @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9
222
272
  * @param {*} randomSeed the random seed to use for sampling, e.g. 42
223
- * @param {*} safeMode whether to use safe mode, e.g. true
273
+ * @param {*} safeMode deprecated use safePrompt instead
274
+ * @param {*} safePrompt whether to use safe mode, e.g. true
275
+ * @param {*} toolChoice the tool to use, e.g. 'auto'
276
+ * @param {*} responseFormat the format of the response, e.g. 'json_format'
224
277
  * @return {Promise<Object>}
225
278
  */
226
279
  chatStream = async function* ({
227
280
  model,
228
281
  messages,
282
+ tools,
229
283
  temperature,
230
284
  maxTokens,
231
285
  topP,
232
286
  randomSeed,
233
287
  safeMode,
288
+ safePrompt,
289
+ toolChoice,
290
+ responseFormat,
234
291
  }) {
235
292
  const request = this._makeChatCompletionRequest(
236
293
  model,
237
294
  messages,
295
+ tools,
238
296
  temperature,
239
297
  maxTokens,
240
298
  topP,
241
299
  randomSeed,
242
300
  true,
243
301
  safeMode,
302
+ safePrompt,
303
+ toolChoice,
304
+ responseFormat,
244
305
  );
245
306
  const response = await this._request(
246
307
  'post',
@@ -249,9 +310,9 @@ class MistralClient {
249
310
  );
250
311
 
251
312
  let buffer = '';
252
-
313
+ const decoder = new TextDecoder();
253
314
  for await (const chunk of response) {
254
- buffer += chunk;
315
+ buffer += decoder.decode(chunk, {stream: true});
255
316
  let firstNewline;
256
317
  while ((firstNewline = buffer.indexOf('\n')) !== -1) {
257
318
  const chunkLine = buffer.substring(0, firstNewline);
@@ -0,0 +1,179 @@
1
+ import MistralClient from '../src/client';
2
+ import {
3
+ mockListModels,
4
+ mockFetch,
5
+ mockChatResponseStreamingPayload,
6
+ mockEmbeddingRequest,
7
+ mockEmbeddingResponsePayload,
8
+ mockChatResponsePayload,
9
+ mockFetchStream,
10
+ } from './utils';
11
+
12
+ // Test the list models endpoint
13
+ describe('Mistral Client', () => {
14
+ let client;
15
+ beforeEach(() => {
16
+ client = new MistralClient();
17
+ });
18
+
19
+ describe('chat()', () => {
20
+ it('should return a chat response object', async() => {
21
+ // Mock the fetch function
22
+ const mockResponse = mockChatResponsePayload();
23
+ globalThis.fetch = mockFetch(200, mockResponse);
24
+
25
+ const response = await client.chat({
26
+ model: 'mistral-small',
27
+ messages: [
28
+ {
29
+ role: 'user',
30
+ content: 'What is the best French cheese?',
31
+ },
32
+ ],
33
+ });
34
+ expect(response).toEqual(mockResponse);
35
+ });
36
+
37
+ it('should return a chat response object if safeMode is set', async() => {
38
+ // Mock the fetch function
39
+ const mockResponse = mockChatResponsePayload();
40
+ globalThis.fetch = mockFetch(200, mockResponse);
41
+
42
+ const response = await client.chat({
43
+ model: 'mistral-small',
44
+ messages: [
45
+ {
46
+ role: 'user',
47
+ content: 'What is the best French cheese?',
48
+ },
49
+ ],
50
+ safeMode: true,
51
+ });
52
+ expect(response).toEqual(mockResponse);
53
+ });
54
+
55
+ it('should return a chat response object if safePrompt is set', async() => {
56
+ // Mock the fetch function
57
+ const mockResponse = mockChatResponsePayload();
58
+ globalThis.fetch = mockFetch(200, mockResponse);
59
+
60
+ const response = await client.chat({
61
+ model: 'mistral-small',
62
+ messages: [
63
+ {
64
+ role: 'user',
65
+ content: 'What is the best French cheese?',
66
+ },
67
+ ],
68
+ safePrompt: true,
69
+ });
70
+ expect(response).toEqual(mockResponse);
71
+ });
72
+ });
73
+
74
+ describe('chatStream()', () => {
75
+ it('should return parsed, streamed response', async() => {
76
+ // Mock the fetch function
77
+ const mockResponse = mockChatResponseStreamingPayload();
78
+ globalThis.fetch = mockFetchStream(200, mockResponse);
79
+
80
+ const response = await client.chatStream({
81
+ model: 'mistral-small',
82
+ messages: [
83
+ {
84
+ role: 'user',
85
+ content: 'What is the best French cheese?',
86
+ },
87
+ ],
88
+ });
89
+
90
+ const parsedResponse = [];
91
+ for await (const r of response) {
92
+ parsedResponse.push(r);
93
+ }
94
+
95
+ expect(parsedResponse.length).toEqual(11);
96
+ });
97
+
98
+ it('should return parsed, streamed response with safeMode', async() => {
99
+ // Mock the fetch function
100
+ const mockResponse = mockChatResponseStreamingPayload();
101
+ globalThis.fetch = mockFetchStream(200, mockResponse);
102
+
103
+ const response = await client.chatStream({
104
+ model: 'mistral-small',
105
+ messages: [
106
+ {
107
+ role: 'user',
108
+ content: 'What is the best French cheese?',
109
+ },
110
+ ],
111
+ safeMode: true,
112
+ });
113
+
114
+ const parsedResponse = [];
115
+ for await (const r of response) {
116
+ parsedResponse.push(r);
117
+ }
118
+
119
+ expect(parsedResponse.length).toEqual(11);
120
+ });
121
+
122
+ it('should return parsed, streamed response with safePrompt', async() => {
123
+ // Mock the fetch function
124
+ const mockResponse = mockChatResponseStreamingPayload();
125
+ globalThis.fetch = mockFetchStream(200, mockResponse);
126
+
127
+ const response = await client.chatStream({
128
+ model: 'mistral-small',
129
+ messages: [
130
+ {
131
+ role: 'user',
132
+ content: 'What is the best French cheese?',
133
+ },
134
+ ],
135
+ safePrompt: true,
136
+ });
137
+
138
+ const parsedResponse = [];
139
+ for await (const r of response) {
140
+ parsedResponse.push(r);
141
+ }
142
+
143
+ expect(parsedResponse.length).toEqual(11);
144
+ });
145
+ });
146
+
147
+ describe('embeddings()', () => {
148
+ it('should return embeddings', async() => {
149
+ // Mock the fetch function
150
+ const mockResponse = mockEmbeddingResponsePayload();
151
+ globalThis.fetch = mockFetch(200, mockResponse);
152
+
153
+ const response = await client.embeddings(mockEmbeddingRequest);
154
+ expect(response).toEqual(mockResponse);
155
+ });
156
+ });
157
+
158
+ describe('embeddings() batched', () => {
159
+ it('should return batched embeddings', async() => {
160
+ // Mock the fetch function
161
+ const mockResponse = mockEmbeddingResponsePayload(10);
162
+ globalThis.fetch = mockFetch(200, mockResponse);
163
+
164
+ const response = await client.embeddings(mockEmbeddingRequest);
165
+ expect(response).toEqual(mockResponse);
166
+ });
167
+ });
168
+
169
+ describe('listModels()', () => {
170
+ it('should return a list of models', async() => {
171
+ // Mock the fetch function
172
+ const mockResponse = mockListModels();
173
+ globalThis.fetch = mockFetch(200, mockResponse);
174
+
175
+ const response = await client.listModels();
176
+ expect(response).toEqual(mockResponse);
177
+ });
178
+ });
179
+ });
package/tests/utils.js ADDED
@@ -0,0 +1,257 @@
1
+ import jest from 'jest-mock';
2
+
3
+ /**
4
+ * Mock the fetch function
5
+ * @param {*} status
6
+ * @param {*} payload
7
+ * @return {Object}
8
+ */
9
+ export function mockFetch(status, payload) {
10
+ return jest.fn(() =>
11
+ Promise.resolve({
12
+ json: () => Promise.resolve(payload),
13
+ text: () => Promise.resolve(JSON.stringify(payload)),
14
+ status,
15
+ ok: status >= 200 && status < 300,
16
+ }),
17
+ );
18
+ }
19
+
20
+ /**
21
+ * Mock fetch stream
22
+ * @param {*} status
23
+ * @param {*} payload
24
+ * @return {Object}
25
+ */
26
+ export function mockFetchStream(status, payload) {
27
+ const asyncIterator = async function* () {
28
+ while (true) {
29
+ // Read from the stream
30
+ const value = payload.shift();
31
+ // Exit if we're done
32
+ if (!value) return;
33
+ // Else yield the chunk
34
+ yield value;
35
+ }
36
+ };
37
+
38
+ return jest.fn(() =>
39
+ Promise.resolve({
40
+ // body is a ReadableStream of the objects in payload list
41
+ body: asyncIterator(),
42
+ status,
43
+ ok: status >= 200 && status < 300,
44
+ }),
45
+ );
46
+ }
47
+
48
+ /**
49
+ * Mock models list
50
+ * @return {Object}
51
+ */
52
+ export function mockListModels() {
53
+ return {
54
+ object: 'list',
55
+ data: [
56
+ {
57
+ id: 'mistral-medium',
58
+ object: 'model',
59
+ created: 1703186988,
60
+ owned_by: 'mistralai',
61
+ root: null,
62
+ parent: null,
63
+ permission: [
64
+ {
65
+ id: 'modelperm-15bebaf316264adb84b891bf06a84933',
66
+ object: 'model_permission',
67
+ created: 1703186988,
68
+ allow_create_engine: false,
69
+ allow_sampling: true,
70
+ allow_logprobs: false,
71
+ allow_search_indices: false,
72
+ allow_view: true,
73
+ allow_fine_tuning: false,
74
+ organization: '*',
75
+ group: null,
76
+ is_blocking: false,
77
+ },
78
+ ],
79
+ },
80
+ {
81
+ id: 'mistral-small',
82
+ object: 'model',
83
+ created: 1703186988,
84
+ owned_by: 'mistralai',
85
+ root: null,
86
+ parent: null,
87
+ permission: [
88
+ {
89
+ id: 'modelperm-d0dced5c703242fa862f4ca3f241c00e',
90
+ object: 'model_permission',
91
+ created: 1703186988,
92
+ allow_create_engine: false,
93
+ allow_sampling: true,
94
+ allow_logprobs: false,
95
+ allow_search_indices: false,
96
+ allow_view: true,
97
+ allow_fine_tuning: false,
98
+ organization: '*',
99
+ group: null,
100
+ is_blocking: false,
101
+ },
102
+ ],
103
+ },
104
+ {
105
+ id: 'mistral-tiny',
106
+ object: 'model',
107
+ created: 1703186988,
108
+ owned_by: 'mistralai',
109
+ root: null,
110
+ parent: null,
111
+ permission: [
112
+ {
113
+ id: 'modelperm-0e64e727c3a94f17b29f8895d4be2910',
114
+ object: 'model_permission',
115
+ created: 1703186988,
116
+ allow_create_engine: false,
117
+ allow_sampling: true,
118
+ allow_logprobs: false,
119
+ allow_search_indices: false,
120
+ allow_view: true,
121
+ allow_fine_tuning: false,
122
+ organization: '*',
123
+ group: null,
124
+ is_blocking: false,
125
+ },
126
+ ],
127
+ },
128
+ {
129
+ id: 'mistral-embed',
130
+ object: 'model',
131
+ created: 1703186988,
132
+ owned_by: 'mistralai',
133
+ root: null,
134
+ parent: null,
135
+ permission: [
136
+ {
137
+ id: 'modelperm-ebdff9046f524e628059447b5932e3ad',
138
+ object: 'model_permission',
139
+ created: 1703186988,
140
+ allow_create_engine: false,
141
+ allow_sampling: true,
142
+ allow_logprobs: false,
143
+ allow_search_indices: false,
144
+ allow_view: true,
145
+ allow_fine_tuning: false,
146
+ organization: '*',
147
+ group: null,
148
+ is_blocking: false,
149
+ },
150
+ ],
151
+ },
152
+ ],
153
+ };
154
+ }
155
+
156
+ /**
157
+ * Mock chat completion object
158
+ * @return {Object}
159
+ */
160
+ export function mockChatResponsePayload() {
161
+ return {
162
+ id: 'chat-98c8c60e3fbf4fc49658eddaf447357c',
163
+ object: 'chat.completion',
164
+ created: 1703165682,
165
+ choices: [
166
+ {
167
+ finish_reason: 'stop',
168
+ message: {
169
+ role: 'assistant',
170
+ content: 'What is the best French cheese?',
171
+ },
172
+ index: 0,
173
+ },
174
+ ],
175
+ model: 'mistral-small',
176
+ usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0},
177
+ };
178
+ }
179
+
180
+ /**
181
+ * Mock chat completion stream
182
+ * @return {Object}
183
+ */
184
+ export function mockChatResponseStreamingPayload() {
185
+ const encoder = new TextEncoder();
186
+ const firstMessage =
187
+ [encoder.encode('data: ' +
188
+ JSON.stringify({
189
+ id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
190
+ model: 'mistral-small',
191
+ choices: [
192
+ {
193
+ index: 0,
194
+ delta: {role: 'assistant'},
195
+ finish_reason: null,
196
+ },
197
+ ],
198
+ }) +
199
+ '\n\n')];
200
+ const lastMessage = [encoder.encode('data: [DONE]\n\n')];
201
+
202
+ const dataMessages = [];
203
+ for (let i = 0; i < 10; i++) {
204
+ dataMessages.push(encoder.encode(
205
+ 'data: ' +
206
+ JSON.stringify({
207
+ id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
208
+ object: 'chat.completion.chunk',
209
+ created: 1703168544,
210
+ model: 'mistral-small',
211
+ choices: [
212
+ {
213
+ index: i,
214
+ delta: {content: `stream response ${i}`},
215
+ finish_reason: null,
216
+ },
217
+ ],
218
+ }) +
219
+ '\n\n'),
220
+ );
221
+ }
222
+
223
+ return firstMessage.concat(dataMessages).concat(lastMessage);
224
+ }
225
+
226
+ /**
227
+ * Mock embeddings response
228
+ * @param {number} batchSize
229
+ * @return {Object}
230
+ */
231
+ export function mockEmbeddingResponsePayload(batchSize = 1) {
232
+ return {
233
+ id: 'embd-98c8c60e3fbf4fc49658eddaf447357c',
234
+ object: 'list',
235
+ data:
236
+ [
237
+ {
238
+ object: 'embedding',
239
+ embedding: [-0.018585205078125, 0.027099609375, 0.02587890625],
240
+ index: 0,
241
+ },
242
+ ] * batchSize,
243
+ model: 'mistral-embed',
244
+ usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0},
245
+ };
246
+ }
247
+
248
+ /**
249
+ * Mock embeddings request payload
250
+ * @return {Object}
251
+ */
252
+ export function mockEmbeddingRequest() {
253
+ return {
254
+ model: 'mistral-embed',
255
+ input: 'embed',
256
+ };
257
+ }