@mistralai/mistralai 0.0.8 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintrc.yml +3 -2
- package/.github/workflows/build_publish.yaml +78 -9
- package/README.md +33 -8
- package/examples/chat-react/package-lock.json +9 -7
- package/examples/chat-react/package.json +0 -5
- package/examples/function_calling.js +122 -0
- package/examples/json_format.js +13 -0
- package/examples/package-lock.json +1 -0
- package/package.json +13 -3
- package/src/client.d.ts +62 -4
- package/src/client.js +80 -19
- package/tests/client.test.js +179 -0
- package/tests/utils.js +257 -0
package/.eslintrc.yml
CHANGED
|
@@ -2,11 +2,12 @@ env:
|
|
|
2
2
|
browser: true
|
|
3
3
|
es2021: true
|
|
4
4
|
extends: google
|
|
5
|
-
ignorePatterns:
|
|
5
|
+
ignorePatterns:
|
|
6
6
|
- examples/chat-react/
|
|
7
7
|
parserOptions:
|
|
8
8
|
ecmaVersion: latest
|
|
9
9
|
sourceType: module
|
|
10
|
-
rules:
|
|
10
|
+
rules:
|
|
11
11
|
indent: ["error", 2]
|
|
12
12
|
space-before-function-paren: ["error", "never"]
|
|
13
|
+
quotes: ["error", "single"]
|
|
@@ -3,7 +3,7 @@ name: Build and Publish
|
|
|
3
3
|
on:
|
|
4
4
|
push:
|
|
5
5
|
branches: ["main"]
|
|
6
|
-
|
|
6
|
+
|
|
7
7
|
# We only deploy on tags and main branch
|
|
8
8
|
tags:
|
|
9
9
|
# Only run on tags that match the following regex
|
|
@@ -14,10 +14,13 @@ on:
|
|
|
14
14
|
pull_request:
|
|
15
15
|
|
|
16
16
|
jobs:
|
|
17
|
-
|
|
18
|
-
lint:
|
|
17
|
+
lint_and_test:
|
|
19
18
|
runs-on: ubuntu-latest
|
|
20
19
|
|
|
20
|
+
strategy:
|
|
21
|
+
matrix:
|
|
22
|
+
node-version: [18, 20]
|
|
23
|
+
|
|
21
24
|
steps:
|
|
22
25
|
# Checkout the repository
|
|
23
26
|
- name: Checkout
|
|
@@ -27,20 +30,25 @@ jobs:
|
|
|
27
30
|
- name: set node version
|
|
28
31
|
uses: actions/setup-node@v4
|
|
29
32
|
with:
|
|
30
|
-
node-version:
|
|
33
|
+
node-version: ${{ matrix.node-version }}
|
|
31
34
|
|
|
32
35
|
# Install Build stuff
|
|
33
36
|
- name: Install Dependencies
|
|
34
37
|
run: |
|
|
35
38
|
npm install
|
|
36
39
|
|
|
37
|
-
#
|
|
40
|
+
# Eslint
|
|
38
41
|
- name: ESlint check
|
|
39
42
|
run: |
|
|
40
|
-
|
|
43
|
+
npm run lint
|
|
44
|
+
|
|
45
|
+
# Run tests
|
|
46
|
+
- name: Run tests
|
|
47
|
+
run: |
|
|
48
|
+
npm run test
|
|
41
49
|
|
|
42
50
|
publish:
|
|
43
|
-
needs:
|
|
51
|
+
needs: lint_and_test
|
|
44
52
|
runs-on: ubuntu-latest
|
|
45
53
|
if: startsWith(github.ref, 'refs/tags')
|
|
46
54
|
|
|
@@ -60,6 +68,67 @@ jobs:
|
|
|
60
68
|
run: |
|
|
61
69
|
echo "//registry.npmjs.org/:_authToken=${{ secrets.NPM_TOKEN }}" >> .npmrc
|
|
62
70
|
npm version ${{ github.ref_name }}
|
|
63
|
-
|
|
71
|
+
sed -i 's/VERSION = '\''0.0.1'\''/VERSION = '\''${{ github.ref_name }}'\''/g' src/client.js
|
|
72
|
+
npm publish
|
|
73
|
+
|
|
74
|
+
create_pr_on_public:
|
|
75
|
+
if: startsWith(github.ref, 'refs/tags')
|
|
76
|
+
runs-on: ubuntu-latest
|
|
77
|
+
needs: lint_and_test
|
|
78
|
+
steps:
|
|
79
|
+
- name: Checkout
|
|
80
|
+
uses: actions/checkout@v4
|
|
81
|
+
with:
|
|
82
|
+
token: ${{ secrets.PUBLIC_CLIENT_WRITE_TOKEN }}
|
|
83
|
+
|
|
84
|
+
- name: Pull public updates
|
|
85
|
+
env: # We cannot use the github bot token to push to the public repo, we have to use one with more permissions
|
|
86
|
+
GITHUB_TOKEN: ${{ secrets.PUBLIC_CLIENT_WRITE_TOKEN }}
|
|
87
|
+
run: |
|
|
88
|
+
|
|
89
|
+
set -x
|
|
90
|
+
git config --global user.name "GitHub Actions"
|
|
91
|
+
git config --global user.email "mayo@mistral.ai"
|
|
92
|
+
|
|
93
|
+
git remote add public https://github.com/mistralai/client-js.git
|
|
94
|
+
git remote update
|
|
95
|
+
|
|
96
|
+
# Create a diff of the changes, ignoring the ci workflow
|
|
97
|
+
git merge public/main --no-commit --no-ff --no-edit --allow-unrelated-histories --strategy-option ours
|
|
98
|
+
|
|
99
|
+
# If there are changes, commit them
|
|
100
|
+
if ! git diff --quiet; then
|
|
101
|
+
git commit -m "Update from public repo"
|
|
102
|
+
git push origin ${{github.ref}}
|
|
103
|
+
else
|
|
104
|
+
echo "No changes to apply"
|
|
105
|
+
fi
|
|
106
|
+
|
|
107
|
+
- name: Push to public repo
|
|
108
|
+
env:
|
|
109
|
+
GITHUB_TOKEN: ${{ secrets.PUBLIC_CLIENT_WRITE_TOKEN }}
|
|
110
|
+
run: |
|
|
111
|
+
git checkout public/main
|
|
112
|
+
git checkout -b update/${{github.ref_name}}
|
|
113
|
+
|
|
114
|
+
# write version number to version file
|
|
115
|
+
echo ${{github.ref_name}} > version.txt
|
|
116
|
+
|
|
117
|
+
git add .
|
|
118
|
+
git commit -m "Bump version file"
|
|
119
|
+
|
|
120
|
+
# create a diff of this ref and the public repo
|
|
121
|
+
git diff update/${{github.ref_name}} ${{github.ref_name}} --binary -- . ':!.github' > changes.diff
|
|
122
|
+
|
|
123
|
+
# apply the diff to the current branch
|
|
124
|
+
git apply changes.diff
|
|
125
|
+
|
|
126
|
+
# commit the changes
|
|
127
|
+
git add .
|
|
128
|
+
git commit -m "Update version to ${{github.ref_name}}"
|
|
129
|
+
|
|
130
|
+
# push the changes
|
|
131
|
+
git push public update/${{github.ref_name}}
|
|
64
132
|
|
|
65
|
-
|
|
133
|
+
# Create a PR from this branch to the public repo
|
|
134
|
+
gh pr create --title "Update client to ${{github.ref_name}}" --body "This PR was automatically created by a GitHub Action" --base main --head update/${{github.ref_name}} --repo mistralai/client-js
|
package/README.md
CHANGED
|
@@ -11,16 +11,19 @@ You can install the library in your project using:
|
|
|
11
11
|
`npm install @mistralai/mistralai`
|
|
12
12
|
|
|
13
13
|
## Usage
|
|
14
|
+
|
|
14
15
|
### Set up
|
|
16
|
+
|
|
15
17
|
```typescript
|
|
16
18
|
import MistralClient from '@mistralai/mistralai';
|
|
17
19
|
|
|
18
|
-
const apiKey =
|
|
20
|
+
const apiKey = process.env.MISTRAL_API_KEY || 'your_api_key';
|
|
19
21
|
|
|
20
22
|
const client = new MistralClient(apiKey);
|
|
21
23
|
```
|
|
22
24
|
|
|
23
25
|
### List models
|
|
26
|
+
|
|
24
27
|
```typescript
|
|
25
28
|
const listModelsResponse = await client.listModels();
|
|
26
29
|
const listModels = listModelsResponse.data;
|
|
@@ -30,6 +33,7 @@ listModels.forEach((model) => {
|
|
|
30
33
|
```
|
|
31
34
|
|
|
32
35
|
### Chat with streaming
|
|
36
|
+
|
|
33
37
|
```typescript
|
|
34
38
|
const chatStreamResponse = await client.chatStream({
|
|
35
39
|
model: 'mistral-tiny',
|
|
@@ -44,7 +48,9 @@ for await (const chunk of chatStreamResponse) {
|
|
|
44
48
|
}
|
|
45
49
|
}
|
|
46
50
|
```
|
|
51
|
+
|
|
47
52
|
### Chat without streaming
|
|
53
|
+
|
|
48
54
|
```typescript
|
|
49
55
|
const chatResponse = await client.chat({
|
|
50
56
|
model: 'mistral-tiny',
|
|
@@ -53,7 +59,9 @@ const chatResponse = await client.chat({
|
|
|
53
59
|
|
|
54
60
|
console.log('Chat:', chatResponse.choices[0].message.content);
|
|
55
61
|
```
|
|
56
|
-
|
|
62
|
+
|
|
63
|
+
### Embeddings
|
|
64
|
+
|
|
57
65
|
```typescript
|
|
58
66
|
const input = [];
|
|
59
67
|
for (let i = 0; i < 1; i++) {
|
|
@@ -67,6 +75,7 @@ const embeddingsBatchResponse = await client.embeddings({
|
|
|
67
75
|
|
|
68
76
|
console.log('Embeddings Batch:', embeddingsBatchResponse.data);
|
|
69
77
|
```
|
|
78
|
+
|
|
70
79
|
## Run examples
|
|
71
80
|
|
|
72
81
|
You can run the examples in the examples directory by installing them locally:
|
|
@@ -76,23 +85,39 @@ cd examples
|
|
|
76
85
|
npm install .
|
|
77
86
|
```
|
|
78
87
|
|
|
79
|
-
### API
|
|
88
|
+
### API key setup
|
|
80
89
|
|
|
81
90
|
Running the examples requires a Mistral AI API key.
|
|
82
91
|
|
|
83
|
-
|
|
84
|
-
|
|
92
|
+
Get your own Mistral API Key: <https://docs.mistral.ai/#api-access>
|
|
93
|
+
|
|
94
|
+
### Run the examples
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
MISTRAL_API_KEY='your_api_key' node chat_with_streaming.js
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
### Persisting the API key in environment
|
|
101
|
+
|
|
102
|
+
Set your Mistral API Key as an environment variable. You only need to do this once.
|
|
85
103
|
|
|
86
104
|
```bash
|
|
87
105
|
# set Mistral API Key (using zsh for example)
|
|
88
|
-
$ echo 'export MISTRAL_API_KEY=[
|
|
106
|
+
$ echo 'export MISTRAL_API_KEY=[your_api_key]' >> ~/.zshenv
|
|
89
107
|
|
|
90
108
|
# reload the environment (or just quit and open a new terminal)
|
|
91
109
|
$ source ~/.zshenv
|
|
92
110
|
```
|
|
93
111
|
|
|
94
|
-
You can then run the examples
|
|
112
|
+
You can then run the examples without appending the API key:
|
|
95
113
|
|
|
96
114
|
```bash
|
|
97
|
-
|
|
115
|
+
node chat_with_streaming.js
|
|
116
|
+
```
|
|
117
|
+
After the env variable setup the client will find the `MISTRAL_API_KEY` by itself
|
|
118
|
+
|
|
119
|
+
```typescript
|
|
120
|
+
import MistralClient from '@mistralai/mistralai';
|
|
121
|
+
|
|
122
|
+
const client = new MistralClient();
|
|
98
123
|
```
|
|
@@ -26,6 +26,7 @@
|
|
|
26
26
|
"devDependencies": {
|
|
27
27
|
"eslint": "^8.55.0",
|
|
28
28
|
"eslint-config-google": "^0.14.0",
|
|
29
|
+
"jest": "^29.7.0",
|
|
29
30
|
"prettier": "2.8.8"
|
|
30
31
|
}
|
|
31
32
|
},
|
|
@@ -16023,16 +16024,16 @@
|
|
|
16023
16024
|
}
|
|
16024
16025
|
},
|
|
16025
16026
|
"node_modules/typescript": {
|
|
16026
|
-
"version": "
|
|
16027
|
-
"resolved": "https://registry.npmjs.org/typescript/-/typescript-
|
|
16028
|
-
"integrity": "sha512-
|
|
16027
|
+
"version": "4.9.5",
|
|
16028
|
+
"resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
|
|
16029
|
+
"integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
|
|
16029
16030
|
"peer": true,
|
|
16030
16031
|
"bin": {
|
|
16031
16032
|
"tsc": "bin/tsc",
|
|
16032
16033
|
"tsserver": "bin/tsserver"
|
|
16033
16034
|
},
|
|
16034
16035
|
"engines": {
|
|
16035
|
-
"node": ">=
|
|
16036
|
+
"node": ">=4.2.0"
|
|
16036
16037
|
}
|
|
16037
16038
|
},
|
|
16038
16039
|
"node_modules/unbox-primitive": {
|
|
@@ -19300,6 +19301,7 @@
|
|
|
19300
19301
|
"requires": {
|
|
19301
19302
|
"eslint": "^8.55.0",
|
|
19302
19303
|
"eslint-config-google": "^0.14.0",
|
|
19304
|
+
"jest": "^29.7.0",
|
|
19303
19305
|
"node-fetch": "^2.6.7",
|
|
19304
19306
|
"prettier": "2.8.8"
|
|
19305
19307
|
}
|
|
@@ -28464,9 +28466,9 @@
|
|
|
28464
28466
|
}
|
|
28465
28467
|
},
|
|
28466
28468
|
"typescript": {
|
|
28467
|
-
"version": "
|
|
28468
|
-
"resolved": "https://registry.npmjs.org/typescript/-/typescript-
|
|
28469
|
-
"integrity": "sha512-
|
|
28469
|
+
"version": "4.9.5",
|
|
28470
|
+
"resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
|
|
28471
|
+
"integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
|
|
28470
28472
|
"peer": true
|
|
28471
28473
|
},
|
|
28472
28474
|
"unbox-primitive": {
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import MistralClient from '@mistralai/mistralai';
|
|
2
|
+
|
|
3
|
+
const apiKey = process.env.MISTRAL_API_KEY;
|
|
4
|
+
|
|
5
|
+
// Assuming we have the following data
|
|
6
|
+
const data = {
|
|
7
|
+
transactionId: ['T1001', 'T1002', 'T1003', 'T1004', 'T1005'],
|
|
8
|
+
customerId: ['C001', 'C002', 'C003', 'C002', 'C001'],
|
|
9
|
+
paymentAmount: [125.50, 89.99, 120.00, 54.30, 210.20],
|
|
10
|
+
paymentDate: [
|
|
11
|
+
'2021-10-05', '2021-10-06', '2021-10-07', '2021-10-05', '2021-10-08',
|
|
12
|
+
],
|
|
13
|
+
paymentStatus: ['Paid', 'Unpaid', 'Paid', 'Paid', 'Pending'],
|
|
14
|
+
};
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* This function retrieves the payment status of a transaction id.
|
|
18
|
+
* @param {object} data - The data object.
|
|
19
|
+
* @param {string} transactionId - The transaction id.
|
|
20
|
+
* @return {string} - The payment status.
|
|
21
|
+
*/
|
|
22
|
+
function retrievePaymentStatus({data, transactionId}) {
|
|
23
|
+
const transactionIndex = data.transactionId.indexOf(transactionId);
|
|
24
|
+
if (transactionIndex != -1) {
|
|
25
|
+
return JSON.stringify({status: data.payment_status[transactionIndex]});
|
|
26
|
+
} else {
|
|
27
|
+
return JSON.stringify({status: 'error - transaction id not found.'});
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* This function retrieves the payment date of a transaction id.
|
|
33
|
+
* @param {object} data - The data object.
|
|
34
|
+
* @param {string} transactionId - The transaction id.
|
|
35
|
+
* @return {string} - The payment date.
|
|
36
|
+
*
|
|
37
|
+
*/
|
|
38
|
+
function retrievePaymentDate({data, transactionId}) {
|
|
39
|
+
const transactionIndex = data.transactionId.indexOf(transactionId);
|
|
40
|
+
if (transactionIndex != -1) {
|
|
41
|
+
return JSON.stringify({status: data.payment_date[transactionIndex]});
|
|
42
|
+
} else {
|
|
43
|
+
return JSON.stringify({status: 'error - transaction id not found.'});
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
const namesToFunctions = {
|
|
48
|
+
retrievePaymentStatus: (transactionId) =>
|
|
49
|
+
retrievePaymentStatus({data, ...transactionId}),
|
|
50
|
+
retrievePaymentDate: (transactionId) =>
|
|
51
|
+
retrievePaymentDate({data, ...transactionId}),
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
const tools = [
|
|
55
|
+
{
|
|
56
|
+
type: 'function',
|
|
57
|
+
function: {
|
|
58
|
+
name: 'retrievePaymentStatus',
|
|
59
|
+
description: 'Get payment status of a transaction id',
|
|
60
|
+
parameters: {
|
|
61
|
+
type: 'object',
|
|
62
|
+
required: ['transactionId'],
|
|
63
|
+
properties: {transactionId:
|
|
64
|
+
{type: 'string', description: 'The transaction id.'},
|
|
65
|
+
},
|
|
66
|
+
},
|
|
67
|
+
},
|
|
68
|
+
},
|
|
69
|
+
{
|
|
70
|
+
type: 'function',
|
|
71
|
+
function: {
|
|
72
|
+
name: 'retrievePaymentDate',
|
|
73
|
+
description: 'Get payment date of a transaction id',
|
|
74
|
+
parameters: {
|
|
75
|
+
type: 'object',
|
|
76
|
+
required: ['transactionId'],
|
|
77
|
+
properties: {transactionId:
|
|
78
|
+
{type: 'string', description: 'The transaction id.'},
|
|
79
|
+
},
|
|
80
|
+
},
|
|
81
|
+
},
|
|
82
|
+
},
|
|
83
|
+
];
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
const model = 'mistral-large';
|
|
87
|
+
|
|
88
|
+
const client = new MistralClient(apiKey, 'https://api-2.aurocloud.net');
|
|
89
|
+
|
|
90
|
+
const messages = [
|
|
91
|
+
{role: 'user', content: 'What\'s the status of my transaction?'},
|
|
92
|
+
];
|
|
93
|
+
|
|
94
|
+
let response = await client.chat({
|
|
95
|
+
model: model, messages: messages, tools: tools,
|
|
96
|
+
});
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
console.log(response.choices[0].message.content);
|
|
100
|
+
|
|
101
|
+
messages.push(
|
|
102
|
+
{role: 'assistant', content: response.choices[0].message.content},
|
|
103
|
+
);
|
|
104
|
+
messages.push({role: 'user', content: 'My transaction ID is T1001.'});
|
|
105
|
+
|
|
106
|
+
response = await client.chat({model: model, messages: messages, tools: tools});
|
|
107
|
+
|
|
108
|
+
const toolCall = response.choices[0].message.toolCalls[0];
|
|
109
|
+
const functionName = toolCall.function.name;
|
|
110
|
+
const functionParams = JSON.parse(toolCall.function.arguments);
|
|
111
|
+
|
|
112
|
+
console.log(`calling functionName: ${functionName}`);
|
|
113
|
+
console.log(`functionParams: ${toolCall.function.arguments}`);
|
|
114
|
+
|
|
115
|
+
const functionResult = namesToFunctions[functionName](functionParams);
|
|
116
|
+
|
|
117
|
+
messages.push(response.choices[0].message);
|
|
118
|
+
messages.push({role: 'tool', name: functionName, content: functionResult});
|
|
119
|
+
|
|
120
|
+
response = await client.chat({model: model, messages: messages, tools: tools});
|
|
121
|
+
|
|
122
|
+
console.log(response.choices[0].message.content);
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import MistralClient from '@mistralai/mistralai';
|
|
2
|
+
|
|
3
|
+
const apiKey = process.env.MISTRAL_API_KEY;
|
|
4
|
+
|
|
5
|
+
const client = new MistralClient(apiKey);
|
|
6
|
+
|
|
7
|
+
const chatResponse = await client.chat({
|
|
8
|
+
model: 'mistral-large',
|
|
9
|
+
messages: [{role: 'user', content: 'What is the best French cheese?'}],
|
|
10
|
+
responseFormat: {type: 'json_object'},
|
|
11
|
+
});
|
|
12
|
+
|
|
13
|
+
console.log('Chat:', chatResponse.choices[0].message.content);
|
package/package.json
CHANGED
|
@@ -1,22 +1,32 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@mistralai/mistralai",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.1.3",
|
|
4
4
|
"description": "",
|
|
5
5
|
"author": "bam4d@mistral.ai",
|
|
6
6
|
"license": "ISC",
|
|
7
7
|
"type": "module",
|
|
8
8
|
"main": "src/client.js",
|
|
9
|
+
"scripts": {
|
|
10
|
+
"lint": "./node_modules/.bin/eslint .",
|
|
11
|
+
"test": "node --experimental-vm-modules node_modules/.bin/jest"
|
|
12
|
+
},
|
|
13
|
+
"jest": {
|
|
14
|
+
"testPathIgnorePatterns": [
|
|
15
|
+
"examples"
|
|
16
|
+
]
|
|
17
|
+
},
|
|
9
18
|
"repository": {
|
|
10
19
|
"type": "git",
|
|
11
20
|
"url": "https://github.com/mistralai/client-js"
|
|
12
21
|
},
|
|
13
|
-
"types": "src/
|
|
22
|
+
"types": "src/client.d.ts",
|
|
14
23
|
"dependencies": {
|
|
15
24
|
"node-fetch": "^2.6.7"
|
|
16
25
|
},
|
|
17
26
|
"devDependencies": {
|
|
18
27
|
"eslint": "^8.55.0",
|
|
19
28
|
"eslint-config-google": "^0.14.0",
|
|
20
|
-
"prettier": "2.8.8"
|
|
29
|
+
"prettier": "2.8.8",
|
|
30
|
+
"jest": "^29.7.0"
|
|
21
31
|
}
|
|
22
32
|
}
|
package/src/client.d.ts
CHANGED
|
@@ -29,6 +29,42 @@ declare module '@mistralai/mistralai' {
|
|
|
29
29
|
data: Model[];
|
|
30
30
|
}
|
|
31
31
|
|
|
32
|
+
export interface Function {
|
|
33
|
+
name: string;
|
|
34
|
+
description: string;
|
|
35
|
+
parameters: object;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
export enum ToolType {
|
|
39
|
+
function = 'function',
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
export interface FunctionCall {
|
|
43
|
+
name: string;
|
|
44
|
+
arguments: string;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
export interface ToolCalls {
|
|
48
|
+
id: 'null';
|
|
49
|
+
type: ToolType = ToolType.function;
|
|
50
|
+
function: FunctionCall;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
export enum ResponseFormats {
|
|
54
|
+
text = 'text',
|
|
55
|
+
json_object = 'json_object',
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
export enum ToolChoice {
|
|
59
|
+
auto = 'auto',
|
|
60
|
+
any = 'any',
|
|
61
|
+
none = 'none',
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
export interface ResponseFormat {
|
|
65
|
+
type: ResponseFormats = ResponseFormats.text;
|
|
66
|
+
}
|
|
67
|
+
|
|
32
68
|
export interface TokenUsage {
|
|
33
69
|
prompt_tokens: number;
|
|
34
70
|
completion_tokens: number;
|
|
@@ -49,6 +85,7 @@ declare module '@mistralai/mistralai' {
|
|
|
49
85
|
delta: {
|
|
50
86
|
role?: string;
|
|
51
87
|
content?: string;
|
|
88
|
+
tool_calls?: ToolCalls[];
|
|
52
89
|
};
|
|
53
90
|
finish_reason: string;
|
|
54
91
|
}
|
|
@@ -95,35 +132,56 @@ declare module '@mistralai/mistralai' {
|
|
|
95
132
|
|
|
96
133
|
private _makeChatCompletionRequest(
|
|
97
134
|
model: string,
|
|
98
|
-
messages: Array<{ role: string; content: string }>,
|
|
135
|
+
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>,
|
|
136
|
+
tools?: Array<{ type: string; function:Function; }>,
|
|
99
137
|
temperature?: number,
|
|
100
138
|
maxTokens?: number,
|
|
101
139
|
topP?: number,
|
|
102
140
|
randomSeed?: number,
|
|
103
141
|
stream?: boolean,
|
|
104
|
-
|
|
142
|
+
/**
|
|
143
|
+
* @deprecated use safePrompt instead
|
|
144
|
+
*/
|
|
145
|
+
safeMode?: boolean,
|
|
146
|
+
safePrompt?: boolean,
|
|
147
|
+
toolChoice?: ToolChoice,
|
|
148
|
+
responseFormat?: ResponseFormat
|
|
105
149
|
): object;
|
|
106
150
|
|
|
107
151
|
listModels(): Promise<ListModelsResponse>;
|
|
108
152
|
|
|
109
153
|
chat(options: {
|
|
110
154
|
model: string;
|
|
111
|
-
messages: Array<{ role: string; content: string }>;
|
|
155
|
+
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>;
|
|
156
|
+
tools?: Array<{ type: string; function:Function; }>;
|
|
112
157
|
temperature?: number;
|
|
113
158
|
maxTokens?: number;
|
|
114
159
|
topP?: number;
|
|
115
160
|
randomSeed?: number;
|
|
161
|
+
/**
|
|
162
|
+
* @deprecated use safePrompt instead
|
|
163
|
+
*/
|
|
116
164
|
safeMode?: boolean;
|
|
165
|
+
safePrompt?: boolean;
|
|
166
|
+
toolChoice?: ToolChoice;
|
|
167
|
+
responseFormat?: ResponseFormat;
|
|
117
168
|
}): Promise<ChatCompletionResponse>;
|
|
118
169
|
|
|
119
170
|
chatStream(options: {
|
|
120
171
|
model: string;
|
|
121
|
-
messages: Array<{ role: string; content: string }>;
|
|
172
|
+
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>;
|
|
173
|
+
tools?: Array<{ type: string; function:Function; }>;
|
|
122
174
|
temperature?: number;
|
|
123
175
|
maxTokens?: number;
|
|
124
176
|
topP?: number;
|
|
125
177
|
randomSeed?: number;
|
|
178
|
+
/**
|
|
179
|
+
* @deprecated use safePrompt instead
|
|
180
|
+
*/
|
|
126
181
|
safeMode?: boolean;
|
|
182
|
+
safePrompt?: boolean;
|
|
183
|
+
toolChoice?: ToolChoice;
|
|
184
|
+
responseFormat?: ResponseFormat;
|
|
127
185
|
}): AsyncGenerator<ChatCompletionResponseChunk, void, unknown>;
|
|
128
186
|
|
|
129
187
|
embeddings(options: {
|
package/src/client.js
CHANGED
|
@@ -1,16 +1,26 @@
|
|
|
1
|
-
let fetch;
|
|
2
1
|
let isNode = false;
|
|
3
|
-
if (typeof globalThis.fetch === 'undefined') {
|
|
4
|
-
fetch = (await import('node-fetch')).default;
|
|
5
|
-
isNode = true;
|
|
6
|
-
} else {
|
|
7
|
-
fetch = globalThis.fetch;
|
|
8
|
-
}
|
|
9
|
-
|
|
10
2
|
|
|
3
|
+
const VERSION = '0.0.3';
|
|
11
4
|
const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
|
|
12
5
|
const ENDPOINT = 'https://api.mistral.ai';
|
|
13
6
|
|
|
7
|
+
/**
|
|
8
|
+
* Initialize fetch
|
|
9
|
+
* @return {Promise<void>}
|
|
10
|
+
*/
|
|
11
|
+
async function initializeFetch() {
|
|
12
|
+
if (typeof window === 'undefined' ||
|
|
13
|
+
typeof globalThis.fetch === 'undefined') {
|
|
14
|
+
const nodeFetch = await import('node-fetch');
|
|
15
|
+
fetch = nodeFetch.default;
|
|
16
|
+
isNode = true;
|
|
17
|
+
} else {
|
|
18
|
+
fetch = globalThis.fetch;
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
initializeFetch();
|
|
23
|
+
|
|
14
24
|
/**
|
|
15
25
|
* MistralAPIError
|
|
16
26
|
* @return {MistralAPIError}
|
|
@@ -51,6 +61,10 @@ class MistralClient {
|
|
|
51
61
|
|
|
52
62
|
this.maxRetries = maxRetries;
|
|
53
63
|
this.timeout = timeout;
|
|
64
|
+
|
|
65
|
+
if (this.endpoint.indexOf('inference.azure.com')) {
|
|
66
|
+
this.modelDefault = 'mistral';
|
|
67
|
+
}
|
|
54
68
|
}
|
|
55
69
|
|
|
56
70
|
/**
|
|
@@ -65,6 +79,8 @@ class MistralClient {
|
|
|
65
79
|
const options = {
|
|
66
80
|
method: method,
|
|
67
81
|
headers: {
|
|
82
|
+
'User-Agent': `mistral-client-js/${VERSION}`,
|
|
83
|
+
'Accept': request?.stream ? 'text/event-stream' : 'application/json',
|
|
68
84
|
'Content-Type': 'application/json',
|
|
69
85
|
'Authorization': `Bearer ${this.apiKey}`,
|
|
70
86
|
},
|
|
@@ -85,14 +101,13 @@ class MistralClient {
|
|
|
85
101
|
// Chrome does not support async iterators yet, so polyfill it
|
|
86
102
|
const asyncIterator = async function* () {
|
|
87
103
|
try {
|
|
88
|
-
const decoder = new TextDecoder();
|
|
89
104
|
while (true) {
|
|
90
105
|
// Read from the stream
|
|
91
106
|
const {done, value} = await reader.read();
|
|
92
107
|
// Exit if we're done
|
|
93
108
|
if (done) return;
|
|
94
109
|
// Else yield the chunk
|
|
95
|
-
yield
|
|
110
|
+
yield value;
|
|
96
111
|
}
|
|
97
112
|
} finally {
|
|
98
113
|
reader.releaseLock();
|
|
@@ -104,14 +119,19 @@ class MistralClient {
|
|
|
104
119
|
}
|
|
105
120
|
return await response.json();
|
|
106
121
|
} else if (RETRY_STATUS_CODES.includes(response.status)) {
|
|
107
|
-
console.debug(
|
|
122
|
+
console.debug(
|
|
123
|
+
`Retrying request on response status: ${response.status}`,
|
|
124
|
+
`Response: ${await response.text()}`,
|
|
125
|
+
`Attempt: ${attempts + 1}`,
|
|
126
|
+
);
|
|
108
127
|
// eslint-disable-next-line max-len
|
|
109
128
|
await new Promise((resolve) =>
|
|
110
129
|
setTimeout(resolve, Math.pow(2, (attempts + 1)) * 500),
|
|
111
130
|
);
|
|
112
131
|
} else {
|
|
113
132
|
throw new MistralAPIError(
|
|
114
|
-
`HTTP error! status: ${response.status}
|
|
133
|
+
`HTTP error! status: ${response.status} ` +
|
|
134
|
+
`Response: \n${await response.text()}`,
|
|
115
135
|
);
|
|
116
136
|
}
|
|
117
137
|
} catch (error) {
|
|
@@ -133,33 +153,50 @@ class MistralClient {
|
|
|
133
153
|
* Creates a chat completion request
|
|
134
154
|
* @param {*} model
|
|
135
155
|
* @param {*} messages
|
|
156
|
+
* @param {*} tools
|
|
136
157
|
* @param {*} temperature
|
|
137
158
|
* @param {*} maxTokens
|
|
138
159
|
* @param {*} topP
|
|
139
160
|
* @param {*} randomSeed
|
|
140
161
|
* @param {*} stream
|
|
141
|
-
* @param {*} safeMode
|
|
162
|
+
* @param {*} safeMode deprecated use safePrompt instead
|
|
163
|
+
* @param {*} safePrompt
|
|
164
|
+
* @param {*} toolChoice
|
|
165
|
+
* @param {*} responseFormat
|
|
142
166
|
* @return {Promise<Object>}
|
|
143
167
|
*/
|
|
144
168
|
_makeChatCompletionRequest = function(
|
|
145
169
|
model,
|
|
146
170
|
messages,
|
|
171
|
+
tools,
|
|
147
172
|
temperature,
|
|
148
173
|
maxTokens,
|
|
149
174
|
topP,
|
|
150
175
|
randomSeed,
|
|
151
176
|
stream,
|
|
152
177
|
safeMode,
|
|
178
|
+
safePrompt,
|
|
179
|
+
toolChoice,
|
|
180
|
+
responseFormat,
|
|
153
181
|
) {
|
|
182
|
+
// if modelDefault and model are undefined, throw an error
|
|
183
|
+
if (!model && !this.modelDefault) {
|
|
184
|
+
throw new MistralAPIError(
|
|
185
|
+
'You must provide a model name',
|
|
186
|
+
);
|
|
187
|
+
}
|
|
154
188
|
return {
|
|
155
|
-
model: model,
|
|
189
|
+
model: model ?? this.modelDefault,
|
|
156
190
|
messages: messages,
|
|
191
|
+
tools: tools ?? undefined,
|
|
157
192
|
temperature: temperature ?? undefined,
|
|
158
193
|
max_tokens: maxTokens ?? undefined,
|
|
159
194
|
top_p: topP ?? undefined,
|
|
160
195
|
random_seed: randomSeed ?? undefined,
|
|
161
196
|
stream: stream ?? undefined,
|
|
162
|
-
safe_prompt: safeMode ?? undefined,
|
|
197
|
+
safe_prompt: (safeMode || safePrompt) ?? undefined,
|
|
198
|
+
tool_choice: toolChoice ?? undefined,
|
|
199
|
+
response_format: responseFormat ?? undefined,
|
|
163
200
|
};
|
|
164
201
|
};
|
|
165
202
|
|
|
@@ -177,31 +214,43 @@ class MistralClient {
|
|
|
177
214
|
* @param {*} model the name of the model to chat with, e.g. mistral-tiny
|
|
178
215
|
* @param {*} messages an array of messages to chat with, e.g.
|
|
179
216
|
* [{role: 'user', content: 'What is the best French cheese?'}]
|
|
217
|
+
* @param {*} tools a list of tools to use.
|
|
180
218
|
* @param {*} temperature the temperature to use for sampling, e.g. 0.5
|
|
181
219
|
* @param {*} maxTokens the maximum number of tokens to generate, e.g. 100
|
|
182
220
|
* @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9
|
|
183
221
|
* @param {*} randomSeed the random seed to use for sampling, e.g. 42
|
|
184
|
-
* @param {*} safeMode
|
|
222
|
+
* @param {*} safeMode deprecated use safePrompt instead
|
|
223
|
+
* @param {*} safePrompt whether to use safe mode, e.g. true
|
|
224
|
+
* @param {*} toolChoice the tool to use, e.g. 'auto'
|
|
225
|
+
* @param {*} responseFormat the format of the response, e.g. 'json_format'
|
|
185
226
|
* @return {Promise<Object>}
|
|
186
227
|
*/
|
|
187
228
|
chat = async function({
|
|
188
229
|
model,
|
|
189
230
|
messages,
|
|
231
|
+
tools,
|
|
190
232
|
temperature,
|
|
191
233
|
maxTokens,
|
|
192
234
|
topP,
|
|
193
235
|
randomSeed,
|
|
194
236
|
safeMode,
|
|
237
|
+
safePrompt,
|
|
238
|
+
toolChoice,
|
|
239
|
+
responseFormat,
|
|
195
240
|
}) {
|
|
196
241
|
const request = this._makeChatCompletionRequest(
|
|
197
242
|
model,
|
|
198
243
|
messages,
|
|
244
|
+
tools,
|
|
199
245
|
temperature,
|
|
200
246
|
maxTokens,
|
|
201
247
|
topP,
|
|
202
248
|
randomSeed,
|
|
203
249
|
false,
|
|
204
250
|
safeMode,
|
|
251
|
+
safePrompt,
|
|
252
|
+
toolChoice,
|
|
253
|
+
responseFormat,
|
|
205
254
|
);
|
|
206
255
|
const response = await this._request(
|
|
207
256
|
'post',
|
|
@@ -216,31 +265,43 @@ class MistralClient {
|
|
|
216
265
|
* @param {*} model the name of the model to chat with, e.g. mistral-tiny
|
|
217
266
|
* @param {*} messages an array of messages to chat with, e.g.
|
|
218
267
|
* [{role: 'user', content: 'What is the best French cheese?'}]
|
|
268
|
+
* @param {*} tools a list of tools to use.
|
|
219
269
|
* @param {*} temperature the temperature to use for sampling, e.g. 0.5
|
|
220
270
|
* @param {*} maxTokens the maximum number of tokens to generate, e.g. 100
|
|
221
271
|
* @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9
|
|
222
272
|
* @param {*} randomSeed the random seed to use for sampling, e.g. 42
|
|
223
|
-
* @param {*} safeMode
|
|
273
|
+
* @param {*} safeMode deprecated use safePrompt instead
|
|
274
|
+
* @param {*} safePrompt whether to use safe mode, e.g. true
|
|
275
|
+
* @param {*} toolChoice the tool to use, e.g. 'auto'
|
|
276
|
+
* @param {*} responseFormat the format of the response, e.g. 'json_format'
|
|
224
277
|
* @return {Promise<Object>}
|
|
225
278
|
*/
|
|
226
279
|
chatStream = async function* ({
|
|
227
280
|
model,
|
|
228
281
|
messages,
|
|
282
|
+
tools,
|
|
229
283
|
temperature,
|
|
230
284
|
maxTokens,
|
|
231
285
|
topP,
|
|
232
286
|
randomSeed,
|
|
233
287
|
safeMode,
|
|
288
|
+
safePrompt,
|
|
289
|
+
toolChoice,
|
|
290
|
+
responseFormat,
|
|
234
291
|
}) {
|
|
235
292
|
const request = this._makeChatCompletionRequest(
|
|
236
293
|
model,
|
|
237
294
|
messages,
|
|
295
|
+
tools,
|
|
238
296
|
temperature,
|
|
239
297
|
maxTokens,
|
|
240
298
|
topP,
|
|
241
299
|
randomSeed,
|
|
242
300
|
true,
|
|
243
301
|
safeMode,
|
|
302
|
+
safePrompt,
|
|
303
|
+
toolChoice,
|
|
304
|
+
responseFormat,
|
|
244
305
|
);
|
|
245
306
|
const response = await this._request(
|
|
246
307
|
'post',
|
|
@@ -249,9 +310,9 @@ class MistralClient {
|
|
|
249
310
|
);
|
|
250
311
|
|
|
251
312
|
let buffer = '';
|
|
252
|
-
|
|
313
|
+
const decoder = new TextDecoder();
|
|
253
314
|
for await (const chunk of response) {
|
|
254
|
-
buffer += chunk;
|
|
315
|
+
buffer += decoder.decode(chunk, {stream: true});
|
|
255
316
|
let firstNewline;
|
|
256
317
|
while ((firstNewline = buffer.indexOf('\n')) !== -1) {
|
|
257
318
|
const chunkLine = buffer.substring(0, firstNewline);
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import MistralClient from '../src/client';
|
|
2
|
+
import {
|
|
3
|
+
mockListModels,
|
|
4
|
+
mockFetch,
|
|
5
|
+
mockChatResponseStreamingPayload,
|
|
6
|
+
mockEmbeddingRequest,
|
|
7
|
+
mockEmbeddingResponsePayload,
|
|
8
|
+
mockChatResponsePayload,
|
|
9
|
+
mockFetchStream,
|
|
10
|
+
} from './utils';
|
|
11
|
+
|
|
12
|
+
// Test the list models endpoint
|
|
13
|
+
describe('Mistral Client', () => {
|
|
14
|
+
let client;
|
|
15
|
+
beforeEach(() => {
|
|
16
|
+
client = new MistralClient();
|
|
17
|
+
});
|
|
18
|
+
|
|
19
|
+
describe('chat()', () => {
|
|
20
|
+
it('should return a chat response object', async() => {
|
|
21
|
+
// Mock the fetch function
|
|
22
|
+
const mockResponse = mockChatResponsePayload();
|
|
23
|
+
globalThis.fetch = mockFetch(200, mockResponse);
|
|
24
|
+
|
|
25
|
+
const response = await client.chat({
|
|
26
|
+
model: 'mistral-small',
|
|
27
|
+
messages: [
|
|
28
|
+
{
|
|
29
|
+
role: 'user',
|
|
30
|
+
content: 'What is the best French cheese?',
|
|
31
|
+
},
|
|
32
|
+
],
|
|
33
|
+
});
|
|
34
|
+
expect(response).toEqual(mockResponse);
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
it('should return a chat response object if safeMode is set', async() => {
|
|
38
|
+
// Mock the fetch function
|
|
39
|
+
const mockResponse = mockChatResponsePayload();
|
|
40
|
+
globalThis.fetch = mockFetch(200, mockResponse);
|
|
41
|
+
|
|
42
|
+
const response = await client.chat({
|
|
43
|
+
model: 'mistral-small',
|
|
44
|
+
messages: [
|
|
45
|
+
{
|
|
46
|
+
role: 'user',
|
|
47
|
+
content: 'What is the best French cheese?',
|
|
48
|
+
},
|
|
49
|
+
],
|
|
50
|
+
safeMode: true,
|
|
51
|
+
});
|
|
52
|
+
expect(response).toEqual(mockResponse);
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
it('should return a chat response object if safePrompt is set', async() => {
|
|
56
|
+
// Mock the fetch function
|
|
57
|
+
const mockResponse = mockChatResponsePayload();
|
|
58
|
+
globalThis.fetch = mockFetch(200, mockResponse);
|
|
59
|
+
|
|
60
|
+
const response = await client.chat({
|
|
61
|
+
model: 'mistral-small',
|
|
62
|
+
messages: [
|
|
63
|
+
{
|
|
64
|
+
role: 'user',
|
|
65
|
+
content: 'What is the best French cheese?',
|
|
66
|
+
},
|
|
67
|
+
],
|
|
68
|
+
safePrompt: true,
|
|
69
|
+
});
|
|
70
|
+
expect(response).toEqual(mockResponse);
|
|
71
|
+
});
|
|
72
|
+
});
|
|
73
|
+
|
|
74
|
+
describe('chatStream()', () => {
|
|
75
|
+
it('should return parsed, streamed response', async() => {
|
|
76
|
+
// Mock the fetch function
|
|
77
|
+
const mockResponse = mockChatResponseStreamingPayload();
|
|
78
|
+
globalThis.fetch = mockFetchStream(200, mockResponse);
|
|
79
|
+
|
|
80
|
+
const response = await client.chatStream({
|
|
81
|
+
model: 'mistral-small',
|
|
82
|
+
messages: [
|
|
83
|
+
{
|
|
84
|
+
role: 'user',
|
|
85
|
+
content: 'What is the best French cheese?',
|
|
86
|
+
},
|
|
87
|
+
],
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
const parsedResponse = [];
|
|
91
|
+
for await (const r of response) {
|
|
92
|
+
parsedResponse.push(r);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
expect(parsedResponse.length).toEqual(11);
|
|
96
|
+
});
|
|
97
|
+
|
|
98
|
+
it('should return parsed, streamed response with safeMode', async() => {
|
|
99
|
+
// Mock the fetch function
|
|
100
|
+
const mockResponse = mockChatResponseStreamingPayload();
|
|
101
|
+
globalThis.fetch = mockFetchStream(200, mockResponse);
|
|
102
|
+
|
|
103
|
+
const response = await client.chatStream({
|
|
104
|
+
model: 'mistral-small',
|
|
105
|
+
messages: [
|
|
106
|
+
{
|
|
107
|
+
role: 'user',
|
|
108
|
+
content: 'What is the best French cheese?',
|
|
109
|
+
},
|
|
110
|
+
],
|
|
111
|
+
safeMode: true,
|
|
112
|
+
});
|
|
113
|
+
|
|
114
|
+
const parsedResponse = [];
|
|
115
|
+
for await (const r of response) {
|
|
116
|
+
parsedResponse.push(r);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
expect(parsedResponse.length).toEqual(11);
|
|
120
|
+
});
|
|
121
|
+
|
|
122
|
+
it('should return parsed, streamed response with safePrompt', async() => {
|
|
123
|
+
// Mock the fetch function
|
|
124
|
+
const mockResponse = mockChatResponseStreamingPayload();
|
|
125
|
+
globalThis.fetch = mockFetchStream(200, mockResponse);
|
|
126
|
+
|
|
127
|
+
const response = await client.chatStream({
|
|
128
|
+
model: 'mistral-small',
|
|
129
|
+
messages: [
|
|
130
|
+
{
|
|
131
|
+
role: 'user',
|
|
132
|
+
content: 'What is the best French cheese?',
|
|
133
|
+
},
|
|
134
|
+
],
|
|
135
|
+
safePrompt: true,
|
|
136
|
+
});
|
|
137
|
+
|
|
138
|
+
const parsedResponse = [];
|
|
139
|
+
for await (const r of response) {
|
|
140
|
+
parsedResponse.push(r);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
expect(parsedResponse.length).toEqual(11);
|
|
144
|
+
});
|
|
145
|
+
});
|
|
146
|
+
|
|
147
|
+
describe('embeddings()', () => {
|
|
148
|
+
it('should return embeddings', async() => {
|
|
149
|
+
// Mock the fetch function
|
|
150
|
+
const mockResponse = mockEmbeddingResponsePayload();
|
|
151
|
+
globalThis.fetch = mockFetch(200, mockResponse);
|
|
152
|
+
|
|
153
|
+
const response = await client.embeddings(mockEmbeddingRequest);
|
|
154
|
+
expect(response).toEqual(mockResponse);
|
|
155
|
+
});
|
|
156
|
+
});
|
|
157
|
+
|
|
158
|
+
describe('embeddings() batched', () => {
|
|
159
|
+
it('should return batched embeddings', async() => {
|
|
160
|
+
// Mock the fetch function
|
|
161
|
+
const mockResponse = mockEmbeddingResponsePayload(10);
|
|
162
|
+
globalThis.fetch = mockFetch(200, mockResponse);
|
|
163
|
+
|
|
164
|
+
const response = await client.embeddings(mockEmbeddingRequest);
|
|
165
|
+
expect(response).toEqual(mockResponse);
|
|
166
|
+
});
|
|
167
|
+
});
|
|
168
|
+
|
|
169
|
+
describe('listModels()', () => {
|
|
170
|
+
it('should return a list of models', async() => {
|
|
171
|
+
// Mock the fetch function
|
|
172
|
+
const mockResponse = mockListModels();
|
|
173
|
+
globalThis.fetch = mockFetch(200, mockResponse);
|
|
174
|
+
|
|
175
|
+
const response = await client.listModels();
|
|
176
|
+
expect(response).toEqual(mockResponse);
|
|
177
|
+
});
|
|
178
|
+
});
|
|
179
|
+
});
|
package/tests/utils.js
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
import jest from 'jest-mock';
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* Mock the fetch function
|
|
5
|
+
* @param {*} status
|
|
6
|
+
* @param {*} payload
|
|
7
|
+
* @return {Object}
|
|
8
|
+
*/
|
|
9
|
+
export function mockFetch(status, payload) {
|
|
10
|
+
return jest.fn(() =>
|
|
11
|
+
Promise.resolve({
|
|
12
|
+
json: () => Promise.resolve(payload),
|
|
13
|
+
text: () => Promise.resolve(JSON.stringify(payload)),
|
|
14
|
+
status,
|
|
15
|
+
ok: status >= 200 && status < 300,
|
|
16
|
+
}),
|
|
17
|
+
);
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
/**
|
|
21
|
+
* Mock fetch stream
|
|
22
|
+
* @param {*} status
|
|
23
|
+
* @param {*} payload
|
|
24
|
+
* @return {Object}
|
|
25
|
+
*/
|
|
26
|
+
export function mockFetchStream(status, payload) {
|
|
27
|
+
const asyncIterator = async function* () {
|
|
28
|
+
while (true) {
|
|
29
|
+
// Read from the stream
|
|
30
|
+
const value = payload.shift();
|
|
31
|
+
// Exit if we're done
|
|
32
|
+
if (!value) return;
|
|
33
|
+
// Else yield the chunk
|
|
34
|
+
yield value;
|
|
35
|
+
}
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
return jest.fn(() =>
|
|
39
|
+
Promise.resolve({
|
|
40
|
+
// body is a ReadableStream of the objects in payload list
|
|
41
|
+
body: asyncIterator(),
|
|
42
|
+
status,
|
|
43
|
+
ok: status >= 200 && status < 300,
|
|
44
|
+
}),
|
|
45
|
+
);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* Mock models list
|
|
50
|
+
* @return {Object}
|
|
51
|
+
*/
|
|
52
|
+
export function mockListModels() {
|
|
53
|
+
return {
|
|
54
|
+
object: 'list',
|
|
55
|
+
data: [
|
|
56
|
+
{
|
|
57
|
+
id: 'mistral-medium',
|
|
58
|
+
object: 'model',
|
|
59
|
+
created: 1703186988,
|
|
60
|
+
owned_by: 'mistralai',
|
|
61
|
+
root: null,
|
|
62
|
+
parent: null,
|
|
63
|
+
permission: [
|
|
64
|
+
{
|
|
65
|
+
id: 'modelperm-15bebaf316264adb84b891bf06a84933',
|
|
66
|
+
object: 'model_permission',
|
|
67
|
+
created: 1703186988,
|
|
68
|
+
allow_create_engine: false,
|
|
69
|
+
allow_sampling: true,
|
|
70
|
+
allow_logprobs: false,
|
|
71
|
+
allow_search_indices: false,
|
|
72
|
+
allow_view: true,
|
|
73
|
+
allow_fine_tuning: false,
|
|
74
|
+
organization: '*',
|
|
75
|
+
group: null,
|
|
76
|
+
is_blocking: false,
|
|
77
|
+
},
|
|
78
|
+
],
|
|
79
|
+
},
|
|
80
|
+
{
|
|
81
|
+
id: 'mistral-small',
|
|
82
|
+
object: 'model',
|
|
83
|
+
created: 1703186988,
|
|
84
|
+
owned_by: 'mistralai',
|
|
85
|
+
root: null,
|
|
86
|
+
parent: null,
|
|
87
|
+
permission: [
|
|
88
|
+
{
|
|
89
|
+
id: 'modelperm-d0dced5c703242fa862f4ca3f241c00e',
|
|
90
|
+
object: 'model_permission',
|
|
91
|
+
created: 1703186988,
|
|
92
|
+
allow_create_engine: false,
|
|
93
|
+
allow_sampling: true,
|
|
94
|
+
allow_logprobs: false,
|
|
95
|
+
allow_search_indices: false,
|
|
96
|
+
allow_view: true,
|
|
97
|
+
allow_fine_tuning: false,
|
|
98
|
+
organization: '*',
|
|
99
|
+
group: null,
|
|
100
|
+
is_blocking: false,
|
|
101
|
+
},
|
|
102
|
+
],
|
|
103
|
+
},
|
|
104
|
+
{
|
|
105
|
+
id: 'mistral-tiny',
|
|
106
|
+
object: 'model',
|
|
107
|
+
created: 1703186988,
|
|
108
|
+
owned_by: 'mistralai',
|
|
109
|
+
root: null,
|
|
110
|
+
parent: null,
|
|
111
|
+
permission: [
|
|
112
|
+
{
|
|
113
|
+
id: 'modelperm-0e64e727c3a94f17b29f8895d4be2910',
|
|
114
|
+
object: 'model_permission',
|
|
115
|
+
created: 1703186988,
|
|
116
|
+
allow_create_engine: false,
|
|
117
|
+
allow_sampling: true,
|
|
118
|
+
allow_logprobs: false,
|
|
119
|
+
allow_search_indices: false,
|
|
120
|
+
allow_view: true,
|
|
121
|
+
allow_fine_tuning: false,
|
|
122
|
+
organization: '*',
|
|
123
|
+
group: null,
|
|
124
|
+
is_blocking: false,
|
|
125
|
+
},
|
|
126
|
+
],
|
|
127
|
+
},
|
|
128
|
+
{
|
|
129
|
+
id: 'mistral-embed',
|
|
130
|
+
object: 'model',
|
|
131
|
+
created: 1703186988,
|
|
132
|
+
owned_by: 'mistralai',
|
|
133
|
+
root: null,
|
|
134
|
+
parent: null,
|
|
135
|
+
permission: [
|
|
136
|
+
{
|
|
137
|
+
id: 'modelperm-ebdff9046f524e628059447b5932e3ad',
|
|
138
|
+
object: 'model_permission',
|
|
139
|
+
created: 1703186988,
|
|
140
|
+
allow_create_engine: false,
|
|
141
|
+
allow_sampling: true,
|
|
142
|
+
allow_logprobs: false,
|
|
143
|
+
allow_search_indices: false,
|
|
144
|
+
allow_view: true,
|
|
145
|
+
allow_fine_tuning: false,
|
|
146
|
+
organization: '*',
|
|
147
|
+
group: null,
|
|
148
|
+
is_blocking: false,
|
|
149
|
+
},
|
|
150
|
+
],
|
|
151
|
+
},
|
|
152
|
+
],
|
|
153
|
+
};
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
/**
|
|
157
|
+
* Mock chat completion object
|
|
158
|
+
* @return {Object}
|
|
159
|
+
*/
|
|
160
|
+
export function mockChatResponsePayload() {
|
|
161
|
+
return {
|
|
162
|
+
id: 'chat-98c8c60e3fbf4fc49658eddaf447357c',
|
|
163
|
+
object: 'chat.completion',
|
|
164
|
+
created: 1703165682,
|
|
165
|
+
choices: [
|
|
166
|
+
{
|
|
167
|
+
finish_reason: 'stop',
|
|
168
|
+
message: {
|
|
169
|
+
role: 'assistant',
|
|
170
|
+
content: 'What is the best French cheese?',
|
|
171
|
+
},
|
|
172
|
+
index: 0,
|
|
173
|
+
},
|
|
174
|
+
],
|
|
175
|
+
model: 'mistral-small',
|
|
176
|
+
usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0},
|
|
177
|
+
};
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
/**
|
|
181
|
+
* Mock chat completion stream
|
|
182
|
+
* @return {Object}
|
|
183
|
+
*/
|
|
184
|
+
export function mockChatResponseStreamingPayload() {
|
|
185
|
+
const encoder = new TextEncoder();
|
|
186
|
+
const firstMessage =
|
|
187
|
+
[encoder.encode('data: ' +
|
|
188
|
+
JSON.stringify({
|
|
189
|
+
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
|
|
190
|
+
model: 'mistral-small',
|
|
191
|
+
choices: [
|
|
192
|
+
{
|
|
193
|
+
index: 0,
|
|
194
|
+
delta: {role: 'assistant'},
|
|
195
|
+
finish_reason: null,
|
|
196
|
+
},
|
|
197
|
+
],
|
|
198
|
+
}) +
|
|
199
|
+
'\n\n')];
|
|
200
|
+
const lastMessage = [encoder.encode('data: [DONE]\n\n')];
|
|
201
|
+
|
|
202
|
+
const dataMessages = [];
|
|
203
|
+
for (let i = 0; i < 10; i++) {
|
|
204
|
+
dataMessages.push(encoder.encode(
|
|
205
|
+
'data: ' +
|
|
206
|
+
JSON.stringify({
|
|
207
|
+
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
|
|
208
|
+
object: 'chat.completion.chunk',
|
|
209
|
+
created: 1703168544,
|
|
210
|
+
model: 'mistral-small',
|
|
211
|
+
choices: [
|
|
212
|
+
{
|
|
213
|
+
index: i,
|
|
214
|
+
delta: {content: `stream response ${i}`},
|
|
215
|
+
finish_reason: null,
|
|
216
|
+
},
|
|
217
|
+
],
|
|
218
|
+
}) +
|
|
219
|
+
'\n\n'),
|
|
220
|
+
);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
return firstMessage.concat(dataMessages).concat(lastMessage);
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
/**
|
|
227
|
+
* Mock embeddings response
|
|
228
|
+
* @param {number} batchSize
|
|
229
|
+
* @return {Object}
|
|
230
|
+
*/
|
|
231
|
+
export function mockEmbeddingResponsePayload(batchSize = 1) {
|
|
232
|
+
return {
|
|
233
|
+
id: 'embd-98c8c60e3fbf4fc49658eddaf447357c',
|
|
234
|
+
object: 'list',
|
|
235
|
+
data:
|
|
236
|
+
[
|
|
237
|
+
{
|
|
238
|
+
object: 'embedding',
|
|
239
|
+
embedding: [-0.018585205078125, 0.027099609375, 0.02587890625],
|
|
240
|
+
index: 0,
|
|
241
|
+
},
|
|
242
|
+
] * batchSize,
|
|
243
|
+
model: 'mistral-embed',
|
|
244
|
+
usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0},
|
|
245
|
+
};
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
/**
|
|
249
|
+
* Mock embeddings request payload
|
|
250
|
+
* @return {Object}
|
|
251
|
+
*/
|
|
252
|
+
export function mockEmbeddingRequest() {
|
|
253
|
+
return {
|
|
254
|
+
model: 'mistral-embed',
|
|
255
|
+
input: 'embed',
|
|
256
|
+
};
|
|
257
|
+
}
|