@mistralai/mistralai 0.4.0 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +41 -0
- package/examples/file.jsonl +3 -0
- package/examples/files.js +27 -0
- package/examples/jobs.js +39 -0
- package/package.json +1 -1
- package/src/client.d.ts +4 -5
- package/src/client.js +18 -24
- package/src/files.d.ts +30 -0
- package/src/files.js +65 -0
- package/src/jobs.d.ts +86 -0
- package/src/jobs.js +82 -0
- package/tests/files.test.js +65 -0
- package/tests/jobs.test.js +71 -0
- package/tests/utils.js +155 -24
package/README.md
CHANGED
|
@@ -78,6 +78,47 @@ const embeddingsBatchResponse = await client.embeddings({
|
|
|
78
78
|
console.log('Embeddings Batch:', embeddingsBatchResponse.data);
|
|
79
79
|
```
|
|
80
80
|
|
|
81
|
+
### Files
|
|
82
|
+
|
|
83
|
+
```typescript
|
|
84
|
+
// Create a new file
|
|
85
|
+
const file = fs.readFileSync('file.jsonl');
|
|
86
|
+
const createdFile = await client.files.create({ file });
|
|
87
|
+
|
|
88
|
+
// List files
|
|
89
|
+
const files = await client.files.list();
|
|
90
|
+
|
|
91
|
+
// Retrieve a file
|
|
92
|
+
const retrievedFile = await client.files.retrieve({ fileId: createdFile.id });
|
|
93
|
+
|
|
94
|
+
// Delete a file
|
|
95
|
+
const deletedFile = await client.files.delete({ fileId: createdFile.id });
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
### Fine-tuning Jobs
|
|
99
|
+
|
|
100
|
+
```typescript
|
|
101
|
+
// Create a new job
|
|
102
|
+
const createdJob = await client.jobs.create({
|
|
103
|
+
model: 'open-mistral-7B',
|
|
104
|
+
trainingFiles: [trainingFile.id],
|
|
105
|
+
validationFiles: [validationFile.id],
|
|
106
|
+
hyperparameters: {
|
|
107
|
+
trainingSteps: 10,
|
|
108
|
+
learningRate: 0.0001,
|
|
109
|
+
},
|
|
110
|
+
});
|
|
111
|
+
|
|
112
|
+
// List jobs
|
|
113
|
+
const jobs = await client.jobs.list();
|
|
114
|
+
|
|
115
|
+
// Retrieve a job
|
|
116
|
+
const retrievedJob = await client.jobs.retrieve({ jobId: createdJob.id });
|
|
117
|
+
|
|
118
|
+
// Cancel a job
|
|
119
|
+
const canceledJob = await client.jobs.cancel({ jobId: createdJob.id });
|
|
120
|
+
```
|
|
121
|
+
|
|
81
122
|
## Run examples
|
|
82
123
|
|
|
83
124
|
You can run the examples in the examples directory by installing them locally:
|
|
@@ -0,0 +1,3 @@
|
|
|
1
|
+
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
|
|
2
|
+
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
|
|
3
|
+
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters.", "weight": 0}]}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import MistralClient from '@mistralai/mistralai';
|
|
2
|
+
import * as fs from 'fs';
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
const apiKey = process.env.MISTRAL_API_KEY;
|
|
6
|
+
|
|
7
|
+
const client = new MistralClient(apiKey);
|
|
8
|
+
|
|
9
|
+
// Create a new file
|
|
10
|
+
const blob = new Blob(
|
|
11
|
+
[fs.readFileSync('file.jsonl')],
|
|
12
|
+
{type: 'application/json'},
|
|
13
|
+
);
|
|
14
|
+
const createdFile = await client.files.create({file: blob});
|
|
15
|
+
console.log(createdFile);
|
|
16
|
+
|
|
17
|
+
// List files
|
|
18
|
+
const files = await client.files.list();
|
|
19
|
+
console.log(files);
|
|
20
|
+
|
|
21
|
+
// Retrieve a file
|
|
22
|
+
const retrievedFile = await client.files.retrieve({fileId: createdFile.id});
|
|
23
|
+
console.log(retrievedFile);
|
|
24
|
+
|
|
25
|
+
// Delete a file
|
|
26
|
+
const deletedFile = await client.files.delete({fileId: createdFile.id});
|
|
27
|
+
console.log(deletedFile);
|
package/examples/jobs.js
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import MistralClient from '@mistralai/mistralai';
|
|
2
|
+
import * as fs from 'fs';
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
const apiKey = process.env.MISTRAL_API_KEY;
|
|
6
|
+
|
|
7
|
+
const client = new MistralClient(apiKey);
|
|
8
|
+
|
|
9
|
+
// Create a new file
|
|
10
|
+
const blob = new Blob(
|
|
11
|
+
[fs.readFileSync('file.jsonl')],
|
|
12
|
+
{type: 'application/json'},
|
|
13
|
+
);
|
|
14
|
+
const createdFile = await client.files.create({file: blob});
|
|
15
|
+
|
|
16
|
+
// Create a new job
|
|
17
|
+
const hyperparameters = {
|
|
18
|
+
training_steps: 10,
|
|
19
|
+
learning_rate: 0.0001,
|
|
20
|
+
};
|
|
21
|
+
const createdJob = await client.jobs.create({
|
|
22
|
+
model: 'open-mistral-7b',
|
|
23
|
+
trainingFiles: [createdFile.id],
|
|
24
|
+
validationFiles: [createdFile.id],
|
|
25
|
+
hyperparameters,
|
|
26
|
+
});
|
|
27
|
+
console.log(createdJob);
|
|
28
|
+
|
|
29
|
+
// List jobs
|
|
30
|
+
const jobs = await client.jobs.list();
|
|
31
|
+
console.log(jobs);
|
|
32
|
+
|
|
33
|
+
// Retrieve a job
|
|
34
|
+
const retrievedJob = await client.jobs.retrieve({jobId: createdJob.id});
|
|
35
|
+
console.log(retrievedJob);
|
|
36
|
+
|
|
37
|
+
// Cancel a job
|
|
38
|
+
const canceledJob = await client.jobs.cancel({jobId: createdJob.id});
|
|
39
|
+
console.log(canceledJob);
|
package/package.json
CHANGED
package/src/client.d.ts
CHANGED
|
@@ -182,16 +182,15 @@ declare module "@mistralai/mistralai" {
|
|
|
182
182
|
): AsyncGenerator<ChatCompletionResponseChunk, void>;
|
|
183
183
|
|
|
184
184
|
completion(
|
|
185
|
-
|
|
186
|
-
|
|
185
|
+
request: CompletionRequest,
|
|
186
|
+
options?: ChatRequestOptions
|
|
187
187
|
): Promise<ChatCompletionResponse>;
|
|
188
188
|
|
|
189
189
|
completionStream(
|
|
190
|
-
|
|
191
|
-
|
|
190
|
+
request: CompletionRequest,
|
|
191
|
+
options?: ChatRequestOptions
|
|
192
192
|
): AsyncGenerator<ChatCompletionResponseChunk, void>;
|
|
193
193
|
|
|
194
|
-
|
|
195
194
|
embeddings(options: {
|
|
196
195
|
model: string;
|
|
197
196
|
input: string | string[];
|
package/src/client.js
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
|
|
1
|
+
import FilesClient from './files.js';
|
|
2
|
+
import JobsClient from './jobs.js';
|
|
3
|
+
|
|
4
|
+
const VERSION = '0.5.0';
|
|
2
5
|
const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
|
|
3
6
|
const ENDPOINT = 'https://api.mistral.ai';
|
|
4
7
|
|
|
@@ -79,6 +82,9 @@ class MistralClient {
|
|
|
79
82
|
if (this.endpoint.indexOf('inference.azure.com')) {
|
|
80
83
|
this.modelDefault = 'mistral';
|
|
81
84
|
}
|
|
85
|
+
|
|
86
|
+
this.files = new FilesClient(this);
|
|
87
|
+
this.jobs = new JobsClient(this);
|
|
82
88
|
}
|
|
83
89
|
|
|
84
90
|
/**
|
|
@@ -98,9 +104,10 @@ class MistralClient {
|
|
|
98
104
|
* @param {*} path
|
|
99
105
|
* @param {*} request
|
|
100
106
|
* @param {*} signal
|
|
107
|
+
* @param {*} formData
|
|
101
108
|
* @return {Promise<*>}
|
|
102
109
|
*/
|
|
103
|
-
_request = async function(method, path, request, signal) {
|
|
110
|
+
_request = async function(method, path, request, signal, formData = null) {
|
|
104
111
|
const url = `${this.endpoint}/${path}`;
|
|
105
112
|
const options = {
|
|
106
113
|
method: method,
|
|
@@ -110,13 +117,18 @@ class MistralClient {
|
|
|
110
117
|
'Content-Type': 'application/json',
|
|
111
118
|
'Authorization': `Bearer ${this.apiKey}`,
|
|
112
119
|
},
|
|
113
|
-
body: method !== 'get' ? JSON.stringify(request) : null,
|
|
114
120
|
signal: combineSignals([
|
|
115
121
|
AbortSignal.timeout(this.timeout * 1000),
|
|
116
122
|
signal,
|
|
117
123
|
]),
|
|
124
|
+
body: method !== 'get' ? formData ?? JSON.stringify(request) : null,
|
|
125
|
+
timeout: this.timeout * 1000,
|
|
118
126
|
};
|
|
119
127
|
|
|
128
|
+
if (formData) {
|
|
129
|
+
delete options.headers['Content-Type'];
|
|
130
|
+
}
|
|
131
|
+
|
|
120
132
|
for (let attempts = 0; attempts < this.maxRetries; attempts++) {
|
|
121
133
|
try {
|
|
122
134
|
const response = await this._fetch(url, options);
|
|
@@ -161,7 +173,7 @@ class MistralClient {
|
|
|
161
173
|
} else {
|
|
162
174
|
throw new MistralAPIError(
|
|
163
175
|
`HTTP error! status: ${response.status} ` +
|
|
164
|
-
|
|
176
|
+
`Response: \n${await response.text()}`,
|
|
165
177
|
);
|
|
166
178
|
}
|
|
167
179
|
} catch (error) {
|
|
@@ -467,16 +479,7 @@ class MistralClient {
|
|
|
467
479
|
* @return {Promise<Object>}
|
|
468
480
|
*/
|
|
469
481
|
completion = async function(
|
|
470
|
-
{
|
|
471
|
-
model,
|
|
472
|
-
prompt,
|
|
473
|
-
suffix,
|
|
474
|
-
temperature,
|
|
475
|
-
maxTokens,
|
|
476
|
-
topP,
|
|
477
|
-
randomSeed,
|
|
478
|
-
stop,
|
|
479
|
-
},
|
|
482
|
+
{model, prompt, suffix, temperature, maxTokens, topP, randomSeed, stop},
|
|
480
483
|
{signal} = {},
|
|
481
484
|
) {
|
|
482
485
|
const request = this._makeCompletionRequest(
|
|
@@ -523,16 +526,7 @@ class MistralClient {
|
|
|
523
526
|
* @return {Promise<Object>}
|
|
524
527
|
*/
|
|
525
528
|
completionStream = async function* (
|
|
526
|
-
{
|
|
527
|
-
model,
|
|
528
|
-
prompt,
|
|
529
|
-
suffix,
|
|
530
|
-
temperature,
|
|
531
|
-
maxTokens,
|
|
532
|
-
topP,
|
|
533
|
-
randomSeed,
|
|
534
|
-
stop,
|
|
535
|
-
},
|
|
529
|
+
{model, prompt, suffix, temperature, maxTokens, topP, randomSeed, stop},
|
|
536
530
|
{signal} = {},
|
|
537
531
|
) {
|
|
538
532
|
const request = this._makeCompletionRequest(
|
package/src/files.d.ts
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
export enum Purpose {
|
|
2
|
+
finetune = 'fine-tune',
|
|
3
|
+
}
|
|
4
|
+
|
|
5
|
+
export interface FileObject {
|
|
6
|
+
id: string;
|
|
7
|
+
object: string;
|
|
8
|
+
bytes: number;
|
|
9
|
+
created_at: number;
|
|
10
|
+
filename: string;
|
|
11
|
+
purpose?: Purpose;
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export interface FileDeleted {
|
|
15
|
+
id: string;
|
|
16
|
+
object: string;
|
|
17
|
+
deleted: boolean;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
export class FilesClient {
|
|
21
|
+
constructor(client: MistralClient);
|
|
22
|
+
|
|
23
|
+
create(options: { file: File; purpose?: string }): Promise<FileObject>;
|
|
24
|
+
|
|
25
|
+
retrieve(options: { fileId: string }): Promise<FileObject>;
|
|
26
|
+
|
|
27
|
+
list(): Promise<FileObject[]>;
|
|
28
|
+
|
|
29
|
+
delete(options: { fileId: string }): Promise<FileDeleted>;
|
|
30
|
+
}
|
package/src/files.js
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Class representing a client for file operations.
|
|
3
|
+
*/
|
|
4
|
+
class FilesClient {
|
|
5
|
+
/**
|
|
6
|
+
* Create a FilesClient object.
|
|
7
|
+
* @param {MistralClient} client - The client object used for making requests.
|
|
8
|
+
*/
|
|
9
|
+
constructor(client) {
|
|
10
|
+
this.client = client;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Create a new file.
|
|
15
|
+
* @param {File} file - The file to be created.
|
|
16
|
+
* @param {string} purpose - The purpose of the file. Default is 'fine-tune'.
|
|
17
|
+
* @return {Promise<*>} A promise that resolves to a FileObject.
|
|
18
|
+
* @throws {MistralAPIError} If no response is received from the server.
|
|
19
|
+
*/
|
|
20
|
+
async create({file, purpose = 'fine-tune'}) {
|
|
21
|
+
const formData = new FormData();
|
|
22
|
+
formData.append('file', file);
|
|
23
|
+
formData.append('purpose', purpose);
|
|
24
|
+
const response = await this.client._request(
|
|
25
|
+
'post',
|
|
26
|
+
'v1/files',
|
|
27
|
+
null,
|
|
28
|
+
undefined,
|
|
29
|
+
formData,
|
|
30
|
+
);
|
|
31
|
+
return response;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
/**
|
|
35
|
+
* Retrieve a file.
|
|
36
|
+
* @param {string} fileId - The ID of the file to retrieve.
|
|
37
|
+
* @return {Promise<*>} A promise that resolves to the file data.
|
|
38
|
+
*/
|
|
39
|
+
async retrieve({fileId}) {
|
|
40
|
+
const response = await this.client._request('get', `v1/files/${fileId}`);
|
|
41
|
+
return response;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* List all files.
|
|
46
|
+
* @return {Promise<Array<FileObject>>} A promise that resolves to
|
|
47
|
+
* an array of FileObject.
|
|
48
|
+
*/
|
|
49
|
+
async list() {
|
|
50
|
+
const response = await this.client._request('get', 'v1/files');
|
|
51
|
+
return response;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Delete a file.
|
|
56
|
+
* @param {string} fileId - The ID of the file to delete.
|
|
57
|
+
* @return {Promise<*>} A promise that resolves to the response.
|
|
58
|
+
*/
|
|
59
|
+
async delete({fileId}) {
|
|
60
|
+
const response = await this.client._request('delete', `v1/files/${fileId}`);
|
|
61
|
+
return response;
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
export default FilesClient;
|
package/src/jobs.d.ts
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
export enum JobStatus {
|
|
2
|
+
QUEUED = 'QUEUED',
|
|
3
|
+
STARTED = 'STARTED',
|
|
4
|
+
RUNNING = 'RUNNING',
|
|
5
|
+
FAILED = 'FAILED',
|
|
6
|
+
SUCCESS = 'SUCCESS',
|
|
7
|
+
CANCELLED = 'CANCELLED',
|
|
8
|
+
CANCELLATION_REQUESTED = 'CANCELLATION_REQUESTED',
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
export interface TrainingParameters {
|
|
12
|
+
training_steps: number;
|
|
13
|
+
learning_rate: number;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
export interface WandbIntegration {
|
|
17
|
+
type: Literal<'wandb'>;
|
|
18
|
+
project: string;
|
|
19
|
+
name: string | null;
|
|
20
|
+
api_key: string | null;
|
|
21
|
+
run_name: string | null;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
export type Integration = WandbIntegration;
|
|
25
|
+
|
|
26
|
+
export interface Job {
|
|
27
|
+
id: string;
|
|
28
|
+
hyperparameters: TrainingParameters;
|
|
29
|
+
fine_tuned_model: string;
|
|
30
|
+
model: string;
|
|
31
|
+
status: JobStatus;
|
|
32
|
+
jobType: string;
|
|
33
|
+
created_at: number;
|
|
34
|
+
modified_at: number;
|
|
35
|
+
training_files: string[];
|
|
36
|
+
validation_files?: string[];
|
|
37
|
+
object: 'job';
|
|
38
|
+
integrations: Integration[];
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
export interface Event {
|
|
42
|
+
name: string;
|
|
43
|
+
data?: Record<string, unknown>;
|
|
44
|
+
created_at: number;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
export interface Metric {
|
|
48
|
+
train_loss: float | null;
|
|
49
|
+
valid_loss: float | null;
|
|
50
|
+
valid_mean_token_accuracy: float | null;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
export interface Checkpoint {
|
|
54
|
+
metrics: Metric;
|
|
55
|
+
step_number: int;
|
|
56
|
+
created_at: int;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
export interface DetailedJob extends Job {
|
|
60
|
+
events: Event[];
|
|
61
|
+
checkpoints: Checkpoint[];
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
export interface Jobs {
|
|
65
|
+
data: Job[];
|
|
66
|
+
object: 'list';
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
export class JobsClient {
|
|
70
|
+
constructor(client: MistralClient);
|
|
71
|
+
|
|
72
|
+
create(options: {
|
|
73
|
+
model: string;
|
|
74
|
+
trainingFiles: string[];
|
|
75
|
+
validationFiles?: string[];
|
|
76
|
+
hyperparameters?: TrainingParameters;
|
|
77
|
+
suffix?: string;
|
|
78
|
+
integrations?: Integration[];
|
|
79
|
+
}): Promise<Job>;
|
|
80
|
+
|
|
81
|
+
retrieve(options: { jobId: string }): Promise<DetailedJob>;
|
|
82
|
+
|
|
83
|
+
list(params?: Record<string, unknown>): Promise<Jobs>;
|
|
84
|
+
|
|
85
|
+
cancel(options: { jobId: string }): Promise<DetailedJob>;
|
|
86
|
+
}
|
package/src/jobs.js
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Class representing a client for job operations.
|
|
3
|
+
*/
|
|
4
|
+
class JobsClient {
|
|
5
|
+
/**
|
|
6
|
+
* Create a JobsClient object.
|
|
7
|
+
* @param {MistralClient} client - The client object used for making requests.
|
|
8
|
+
*/
|
|
9
|
+
constructor(client) {
|
|
10
|
+
this.client = client;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Create a new job.
|
|
15
|
+
* @param {string} model - The model to be used for the job.
|
|
16
|
+
* @param {Array<string>} trainingFiles - The list of training files.
|
|
17
|
+
* @param {Array<string>} validationFiles - The list of validation files.
|
|
18
|
+
* @param {TrainingParameters} hyperparameters - The hyperparameters.
|
|
19
|
+
* @param {string} suffix - The suffix for the job.
|
|
20
|
+
* @param {Array<Integration>} integrations - The integrations for the job.
|
|
21
|
+
* @return {Promise<*>} A promise that resolves to a Job object.
|
|
22
|
+
* @throws {MistralAPIError} If no response is received from the server.
|
|
23
|
+
*/
|
|
24
|
+
async create({
|
|
25
|
+
model,
|
|
26
|
+
trainingFiles,
|
|
27
|
+
validationFiles = [],
|
|
28
|
+
hyperparameters = {
|
|
29
|
+
training_steps: 1800,
|
|
30
|
+
learning_rate: 1.0e-4,
|
|
31
|
+
},
|
|
32
|
+
suffix = null,
|
|
33
|
+
integrations = null,
|
|
34
|
+
}) {
|
|
35
|
+
const response = await this.client._request('post', 'v1/fine_tuning/jobs', {
|
|
36
|
+
model,
|
|
37
|
+
training_files: trainingFiles,
|
|
38
|
+
validation_files: validationFiles,
|
|
39
|
+
hyperparameters,
|
|
40
|
+
suffix,
|
|
41
|
+
integrations,
|
|
42
|
+
});
|
|
43
|
+
return response;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/**
|
|
47
|
+
* Retrieve a job.
|
|
48
|
+
* @param {string} jobId - The ID of the job to retrieve.
|
|
49
|
+
* @return {Promise<*>} A promise that resolves to the job data.
|
|
50
|
+
*/
|
|
51
|
+
async retrieve({jobId}) {
|
|
52
|
+
const response = await this.client._request(
|
|
53
|
+
'get', `v1/fine_tuning/jobs/${jobId}`, {},
|
|
54
|
+
);
|
|
55
|
+
return response;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* List all jobs.
|
|
60
|
+
* @return {Promise<Array<Job>>} A promise that resolves to an array of Job.
|
|
61
|
+
*/
|
|
62
|
+
async list() {
|
|
63
|
+
const response = await this.client._request(
|
|
64
|
+
'get', 'v1/fine_tuning/jobs', {},
|
|
65
|
+
);
|
|
66
|
+
return response;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
/**
|
|
70
|
+
* Cancel a job.
|
|
71
|
+
* @param {string} jobId - The ID of the job to cancel.
|
|
72
|
+
* @return {Promise<*>} A promise that resolves to the response.
|
|
73
|
+
*/
|
|
74
|
+
async cancel({jobId}) {
|
|
75
|
+
const response = await this.client._request(
|
|
76
|
+
'post', `v1/fine_tuning/jobs/${jobId}/cancel`, {},
|
|
77
|
+
);
|
|
78
|
+
return response;
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
export default JobsClient;
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import MistralClient from '../src/client';
|
|
2
|
+
import {
|
|
3
|
+
mockFetch,
|
|
4
|
+
mockFileResponsePayload,
|
|
5
|
+
mockFilesResponsePayload,
|
|
6
|
+
mockDeletedFileResponsePayload,
|
|
7
|
+
} from './utils';
|
|
8
|
+
|
|
9
|
+
// Test the list models endpoint
|
|
10
|
+
describe('Mistral Client', () => {
|
|
11
|
+
let client;
|
|
12
|
+
beforeEach(() => {
|
|
13
|
+
client = new MistralClient();
|
|
14
|
+
});
|
|
15
|
+
|
|
16
|
+
describe('create()', () => {
|
|
17
|
+
it('should return a file response object', async() => {
|
|
18
|
+
// Mock the fetch function
|
|
19
|
+
const mockResponse = mockFileResponsePayload();
|
|
20
|
+
client._fetch = mockFetch(200, mockResponse);
|
|
21
|
+
|
|
22
|
+
const response = await client.files.create({
|
|
23
|
+
file: null,
|
|
24
|
+
});
|
|
25
|
+
expect(response).toEqual(mockResponse);
|
|
26
|
+
});
|
|
27
|
+
});
|
|
28
|
+
|
|
29
|
+
describe('retrieve()', () => {
|
|
30
|
+
it('should return a file response object', async() => {
|
|
31
|
+
// Mock the fetch function
|
|
32
|
+
const mockResponse = mockFileResponsePayload();
|
|
33
|
+
client._fetch = mockFetch(200, mockResponse);
|
|
34
|
+
|
|
35
|
+
const response = await client.files.retrieve({
|
|
36
|
+
fileId: 'fileId',
|
|
37
|
+
});
|
|
38
|
+
expect(response).toEqual(mockResponse);
|
|
39
|
+
});
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
describe('retrieve()', () => {
|
|
43
|
+
it('should return a list of files response object', async() => {
|
|
44
|
+
// Mock the fetch function
|
|
45
|
+
const mockResponse = mockFilesResponsePayload();
|
|
46
|
+
client._fetch = mockFetch(200, mockResponse);
|
|
47
|
+
|
|
48
|
+
const response = await client.files.list();
|
|
49
|
+
expect(response).toEqual(mockResponse);
|
|
50
|
+
});
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
describe('delete()', () => {
|
|
54
|
+
it('should return a deleted file response object', async() => {
|
|
55
|
+
// Mock the fetch function
|
|
56
|
+
const mockResponse = mockDeletedFileResponsePayload();
|
|
57
|
+
client._fetch = mockFetch(200, mockResponse);
|
|
58
|
+
|
|
59
|
+
const response = await client.files.delete({
|
|
60
|
+
fileId: 'fileId',
|
|
61
|
+
});
|
|
62
|
+
expect(response).toEqual(mockResponse);
|
|
63
|
+
});
|
|
64
|
+
});
|
|
65
|
+
});
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import MistralClient from '../src/client';
|
|
2
|
+
import {
|
|
3
|
+
mockFetch,
|
|
4
|
+
mockJobResponsePayload,
|
|
5
|
+
mockJobsResponsePayload,
|
|
6
|
+
mockDeletedJobResponsePayload,
|
|
7
|
+
} from './utils';
|
|
8
|
+
|
|
9
|
+
// Test the jobs endpoint
|
|
10
|
+
describe('Mistral Client', () => {
|
|
11
|
+
let client;
|
|
12
|
+
beforeEach(() => {
|
|
13
|
+
client = new MistralClient();
|
|
14
|
+
});
|
|
15
|
+
|
|
16
|
+
describe('createJob()', () => {
|
|
17
|
+
it('should return a job response object', async() => {
|
|
18
|
+
// Mock the fetch function
|
|
19
|
+
const mockResponse = mockJobResponsePayload();
|
|
20
|
+
client._fetch = mockFetch(200, mockResponse);
|
|
21
|
+
|
|
22
|
+
const response = await client.jobs.create({
|
|
23
|
+
model: 'mistral-medium',
|
|
24
|
+
trainingFiles: [],
|
|
25
|
+
validationFiles: [],
|
|
26
|
+
hyperparameters: {
|
|
27
|
+
training_steps: 1800,
|
|
28
|
+
learning_rate: 1.0e-4,
|
|
29
|
+
},
|
|
30
|
+
});
|
|
31
|
+
expect(response).toEqual(mockResponse);
|
|
32
|
+
});
|
|
33
|
+
});
|
|
34
|
+
|
|
35
|
+
describe('retrieveJob()', () => {
|
|
36
|
+
it('should return a job response object', async() => {
|
|
37
|
+
// Mock the fetch function
|
|
38
|
+
const mockResponse = mockJobResponsePayload();
|
|
39
|
+
client._fetch = mockFetch(200, mockResponse);
|
|
40
|
+
|
|
41
|
+
const response = await client.jobs.retrieve({
|
|
42
|
+
jobId: 'jobId',
|
|
43
|
+
});
|
|
44
|
+
expect(response).toEqual(mockResponse);
|
|
45
|
+
});
|
|
46
|
+
});
|
|
47
|
+
|
|
48
|
+
describe('listJobs()', () => {
|
|
49
|
+
it('should return a list of jobs response object', async() => {
|
|
50
|
+
// Mock the fetch function
|
|
51
|
+
const mockResponse = mockJobsResponsePayload();
|
|
52
|
+
client._fetch = mockFetch(200, mockResponse);
|
|
53
|
+
|
|
54
|
+
const response = await client.jobs.list();
|
|
55
|
+
expect(response).toEqual(mockResponse);
|
|
56
|
+
});
|
|
57
|
+
});
|
|
58
|
+
|
|
59
|
+
describe('cancelJob()', () => {
|
|
60
|
+
it('should return a deleted job response object', async() => {
|
|
61
|
+
// Mock the fetch function
|
|
62
|
+
const mockResponse = mockDeletedJobResponsePayload();
|
|
63
|
+
client._fetch = mockFetch(200, mockResponse);
|
|
64
|
+
|
|
65
|
+
const response = await client.jobs.cancel({
|
|
66
|
+
jobId: 'jobId',
|
|
67
|
+
});
|
|
68
|
+
expect(response).toEqual(mockResponse);
|
|
69
|
+
});
|
|
70
|
+
});
|
|
71
|
+
});
|
package/tests/utils.js
CHANGED
|
@@ -183,40 +183,45 @@ export function mockChatResponsePayload() {
|
|
|
183
183
|
*/
|
|
184
184
|
export function mockChatResponseStreamingPayload() {
|
|
185
185
|
const encoder = new TextEncoder();
|
|
186
|
-
const firstMessage =
|
|
187
|
-
|
|
188
|
-
JSON.stringify({
|
|
189
|
-
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
|
|
190
|
-
model: 'mistral-small-latest',
|
|
191
|
-
choices: [
|
|
192
|
-
{
|
|
193
|
-
index: 0,
|
|
194
|
-
delta: {role: 'assistant'},
|
|
195
|
-
finish_reason: null,
|
|
196
|
-
},
|
|
197
|
-
],
|
|
198
|
-
}) +
|
|
199
|
-
'\n\n')];
|
|
200
|
-
const lastMessage = [encoder.encode('data: [DONE]\n\n')];
|
|
201
|
-
|
|
202
|
-
const dataMessages = [];
|
|
203
|
-
for (let i = 0; i < 10; i++) {
|
|
204
|
-
dataMessages.push(encoder.encode(
|
|
186
|
+
const firstMessage = [
|
|
187
|
+
encoder.encode(
|
|
205
188
|
'data: ' +
|
|
206
189
|
JSON.stringify({
|
|
207
190
|
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
|
|
208
|
-
object: 'chat.completion.chunk',
|
|
209
|
-
created: 1703168544,
|
|
210
191
|
model: 'mistral-small-latest',
|
|
211
192
|
choices: [
|
|
212
193
|
{
|
|
213
|
-
index:
|
|
214
|
-
delta: {
|
|
194
|
+
index: 0,
|
|
195
|
+
delta: {role: 'assistant'},
|
|
215
196
|
finish_reason: null,
|
|
216
197
|
},
|
|
217
198
|
],
|
|
218
199
|
}) +
|
|
219
|
-
'\n\n'
|
|
200
|
+
'\n\n',
|
|
201
|
+
),
|
|
202
|
+
];
|
|
203
|
+
const lastMessage = [encoder.encode('data: [DONE]\n\n')];
|
|
204
|
+
|
|
205
|
+
const dataMessages = [];
|
|
206
|
+
for (let i = 0; i < 10; i++) {
|
|
207
|
+
dataMessages.push(
|
|
208
|
+
encoder.encode(
|
|
209
|
+
'data: ' +
|
|
210
|
+
JSON.stringify({
|
|
211
|
+
id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e',
|
|
212
|
+
object: 'chat.completion.chunk',
|
|
213
|
+
created: 1703168544,
|
|
214
|
+
model: 'mistral-small-latest',
|
|
215
|
+
choices: [
|
|
216
|
+
{
|
|
217
|
+
index: i,
|
|
218
|
+
delta: {content: `stream response ${i}`},
|
|
219
|
+
finish_reason: null,
|
|
220
|
+
},
|
|
221
|
+
],
|
|
222
|
+
}) +
|
|
223
|
+
'\n\n',
|
|
224
|
+
),
|
|
220
225
|
);
|
|
221
226
|
}
|
|
222
227
|
|
|
@@ -255,3 +260,129 @@ export function mockEmbeddingRequest() {
|
|
|
255
260
|
input: 'embed',
|
|
256
261
|
};
|
|
257
262
|
}
|
|
263
|
+
|
|
264
|
+
/**
|
|
265
|
+
* Mock file response payload
|
|
266
|
+
* @return {Object}
|
|
267
|
+
*/
|
|
268
|
+
export function mockFileResponsePayload() {
|
|
269
|
+
return {
|
|
270
|
+
id: 'fileId',
|
|
271
|
+
object: 'file',
|
|
272
|
+
bytes: 0,
|
|
273
|
+
created_at: 1633046400000,
|
|
274
|
+
filename: 'file.jsonl',
|
|
275
|
+
purpose: 'fine-tune',
|
|
276
|
+
};
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
/**
|
|
280
|
+
* Mock files response payload
|
|
281
|
+
* @return {Object}
|
|
282
|
+
*/
|
|
283
|
+
export function mockFilesResponsePayload() {
|
|
284
|
+
return {
|
|
285
|
+
data: [
|
|
286
|
+
{
|
|
287
|
+
id: 'fileId',
|
|
288
|
+
object: 'file',
|
|
289
|
+
bytes: 0,
|
|
290
|
+
created_at: 1633046400000,
|
|
291
|
+
filename: 'file.jsonl',
|
|
292
|
+
purpose: 'fine-tune',
|
|
293
|
+
},
|
|
294
|
+
],
|
|
295
|
+
object: 'list',
|
|
296
|
+
};
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
/**
|
|
300
|
+
* Mock deleted file response payload
|
|
301
|
+
* @return {Object}
|
|
302
|
+
*/
|
|
303
|
+
export function mockDeletedFileResponsePayload() {
|
|
304
|
+
return {
|
|
305
|
+
id: 'fileId',
|
|
306
|
+
object: 'file',
|
|
307
|
+
deleted: true,
|
|
308
|
+
};
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
/**
|
|
312
|
+
* Mock job response payload
|
|
313
|
+
* @return {Object}
|
|
314
|
+
*/
|
|
315
|
+
export function mockJobResponsePayload() {
|
|
316
|
+
return {
|
|
317
|
+
id: 'jobId',
|
|
318
|
+
hyperparameters: {
|
|
319
|
+
training_steps: 1800,
|
|
320
|
+
learning_rate: 1.0e-4,
|
|
321
|
+
},
|
|
322
|
+
fine_tuned_model: 'fine_tuned_model_id',
|
|
323
|
+
model: 'mistral-medium',
|
|
324
|
+
status: 'QUEUED',
|
|
325
|
+
job_type: 'fine_tuning',
|
|
326
|
+
created_at: 1633046400000,
|
|
327
|
+
modified_at: 1633046400000,
|
|
328
|
+
training_files: ['file1.jsonl', 'file2.jsonl'],
|
|
329
|
+
validation_files: ['file3.jsonl', 'file4.jsonl'],
|
|
330
|
+
object: 'job',
|
|
331
|
+
};
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
/**
|
|
335
|
+
* Mock jobs response payload
|
|
336
|
+
* @return {Object}
|
|
337
|
+
*/
|
|
338
|
+
export function mockJobsResponsePayload() {
|
|
339
|
+
return {
|
|
340
|
+
data: [
|
|
341
|
+
{
|
|
342
|
+
id: 'jobId1',
|
|
343
|
+
hyperparameters: {
|
|
344
|
+
training_steps: 1800,
|
|
345
|
+
learning_rate: 1.0e-4,
|
|
346
|
+
},
|
|
347
|
+
fine_tuned_model: 'fine_tuned_model_id1',
|
|
348
|
+
model: 'mistral-medium',
|
|
349
|
+
status: 'QUEUED',
|
|
350
|
+
job_type: 'fine_tuning',
|
|
351
|
+
created_at: 1633046400000,
|
|
352
|
+
modified_at: 1633046400000,
|
|
353
|
+
training_files: ['file1.jsonl', 'file2.jsonl'],
|
|
354
|
+
validation_files: ['file3.jsonl', 'file4.jsonl'],
|
|
355
|
+
object: 'job',
|
|
356
|
+
},
|
|
357
|
+
{
|
|
358
|
+
id: 'jobId2',
|
|
359
|
+
hyperparameters: {
|
|
360
|
+
training_steps: 1800,
|
|
361
|
+
learning_rate: 1.0e-4,
|
|
362
|
+
},
|
|
363
|
+
fine_tuned_model: 'fine_tuned_model_id2',
|
|
364
|
+
model: 'mistral-medium',
|
|
365
|
+
status: 'RUNNING',
|
|
366
|
+
job_type: 'fine_tuning',
|
|
367
|
+
created_at: 1633046400000,
|
|
368
|
+
modified_at: 1633046400000,
|
|
369
|
+
training_files: ['file5.jsonl', 'file6.jsonl'],
|
|
370
|
+
validation_files: ['file7.jsonl', 'file8.jsonl'],
|
|
371
|
+
object: 'job',
|
|
372
|
+
},
|
|
373
|
+
],
|
|
374
|
+
object: 'list',
|
|
375
|
+
};
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
/**
|
|
379
|
+
* Mock deleted job response payload
|
|
380
|
+
* @return {Object}
|
|
381
|
+
*/
|
|
382
|
+
export function mockDeletedJobResponsePayload() {
|
|
383
|
+
return {
|
|
384
|
+
id: 'jobId',
|
|
385
|
+
object: 'job',
|
|
386
|
+
deleted: true,
|
|
387
|
+
};
|
|
388
|
+
}
|