@llumiverse/drivers 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/cjs/bedrock/index.js +338 -0
- package/lib/cjs/bedrock/index.js.map +1 -0
- package/lib/cjs/bedrock/s3.js +61 -0
- package/lib/cjs/bedrock/s3.js.map +1 -0
- package/lib/cjs/huggingface_ie.js +181 -0
- package/lib/cjs/huggingface_ie.js.map +1 -0
- package/lib/cjs/index.js +24 -0
- package/lib/cjs/index.js.map +1 -0
- package/lib/cjs/openai.js +205 -0
- package/lib/cjs/openai.js.map +1 -0
- package/lib/cjs/package.json +3 -0
- package/lib/cjs/replicate.js +290 -0
- package/lib/cjs/replicate.js.map +1 -0
- package/lib/cjs/test/TestErrorCompletionStream.js +20 -0
- package/lib/cjs/test/TestErrorCompletionStream.js.map +1 -0
- package/lib/cjs/test/TestValidationErrorCompletionStream.js +24 -0
- package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -0
- package/lib/cjs/test/index.js +109 -0
- package/lib/cjs/test/index.js.map +1 -0
- package/lib/cjs/test/utils.js +31 -0
- package/lib/cjs/test/utils.js.map +1 -0
- package/lib/cjs/togetherai/index.js +92 -0
- package/lib/cjs/togetherai/index.js.map +1 -0
- package/lib/cjs/togetherai/interfaces.js +3 -0
- package/lib/cjs/togetherai/interfaces.js.map +1 -0
- package/lib/cjs/vertexai/debug.js +13 -0
- package/lib/cjs/vertexai/debug.js.map +1 -0
- package/lib/cjs/vertexai/index.js +80 -0
- package/lib/cjs/vertexai/index.js.map +1 -0
- package/lib/cjs/vertexai/models/codey-chat.js +65 -0
- package/lib/cjs/vertexai/models/codey-chat.js.map +1 -0
- package/lib/cjs/vertexai/models/codey-text.js +35 -0
- package/lib/cjs/vertexai/models/codey-text.js.map +1 -0
- package/lib/cjs/vertexai/models/gemini.js +140 -0
- package/lib/cjs/vertexai/models/gemini.js.map +1 -0
- package/lib/cjs/vertexai/models/palm-model-base.js +65 -0
- package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -0
- package/lib/cjs/vertexai/models/palm2-chat.js +65 -0
- package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -0
- package/lib/cjs/vertexai/models/palm2-text.js +35 -0
- package/lib/cjs/vertexai/models/palm2-text.js.map +1 -0
- package/lib/cjs/vertexai/models.js +93 -0
- package/lib/cjs/vertexai/models.js.map +1 -0
- package/lib/cjs/vertexai/utils/prompts.js +52 -0
- package/lib/cjs/vertexai/utils/prompts.js.map +1 -0
- package/lib/cjs/vertexai/utils/tensor.js +87 -0
- package/lib/cjs/vertexai/utils/tensor.js.map +1 -0
- package/lib/esm/bedrock/index.js +331 -0
- package/lib/esm/bedrock/index.js.map +1 -0
- package/lib/esm/bedrock/s3.js +53 -0
- package/lib/esm/bedrock/s3.js.map +1 -0
- package/lib/esm/huggingface_ie.js +177 -0
- package/lib/esm/huggingface_ie.js.map +1 -0
- package/lib/esm/index.js +8 -0
- package/lib/esm/index.js.map +1 -0
- package/lib/esm/openai.js +198 -0
- package/lib/esm/openai.js.map +1 -0
- package/lib/esm/replicate.js +283 -0
- package/lib/esm/replicate.js.map +1 -0
- package/lib/esm/test/TestErrorCompletionStream.js +16 -0
- package/lib/esm/test/TestErrorCompletionStream.js.map +1 -0
- package/lib/esm/test/TestValidationErrorCompletionStream.js +20 -0
- package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -0
- package/lib/esm/test/index.js +91 -0
- package/lib/esm/test/index.js.map +1 -0
- package/lib/esm/test/utils.js +25 -0
- package/lib/esm/test/utils.js.map +1 -0
- package/lib/esm/togetherai/index.js +88 -0
- package/lib/esm/togetherai/index.js.map +1 -0
- package/lib/esm/togetherai/interfaces.js +2 -0
- package/lib/esm/togetherai/interfaces.js.map +1 -0
- package/lib/esm/vertexai/debug.js +6 -0
- package/lib/esm/vertexai/debug.js.map +1 -0
- package/lib/esm/vertexai/index.js +76 -0
- package/lib/esm/vertexai/index.js.map +1 -0
- package/lib/esm/vertexai/models/codey-chat.js +61 -0
- package/lib/esm/vertexai/models/codey-chat.js.map +1 -0
- package/lib/esm/vertexai/models/codey-text.js +31 -0
- package/lib/esm/vertexai/models/codey-text.js.map +1 -0
- package/lib/esm/vertexai/models/gemini.js +136 -0
- package/lib/esm/vertexai/models/gemini.js.map +1 -0
- package/lib/esm/vertexai/models/palm-model-base.js +61 -0
- package/lib/esm/vertexai/models/palm-model-base.js.map +1 -0
- package/lib/esm/vertexai/models/palm2-chat.js +61 -0
- package/lib/esm/vertexai/models/palm2-chat.js.map +1 -0
- package/lib/esm/vertexai/models/palm2-text.js +31 -0
- package/lib/esm/vertexai/models/palm2-text.js.map +1 -0
- package/lib/esm/vertexai/models.js +87 -0
- package/lib/esm/vertexai/models.js.map +1 -0
- package/lib/esm/vertexai/utils/prompts.js +47 -0
- package/lib/esm/vertexai/utils/prompts.js.map +1 -0
- package/lib/esm/vertexai/utils/tensor.js +82 -0
- package/lib/esm/vertexai/utils/tensor.js.map +1 -0
- package/lib/types/bedrock/index.d.ts +88 -0
- package/lib/types/bedrock/index.d.ts.map +1 -0
- package/lib/types/bedrock/s3.d.ts +20 -0
- package/lib/types/bedrock/s3.d.ts.map +1 -0
- package/lib/types/huggingface_ie.d.ts +36 -0
- package/lib/types/huggingface_ie.d.ts.map +1 -0
- package/lib/types/index.d.ts +8 -0
- package/lib/types/index.d.ts.map +1 -0
- package/lib/types/openai.d.ts +36 -0
- package/lib/types/openai.d.ts.map +1 -0
- package/lib/types/replicate.d.ts +52 -0
- package/lib/types/replicate.d.ts.map +1 -0
- package/lib/types/test/TestErrorCompletionStream.d.ts +9 -0
- package/lib/types/test/TestErrorCompletionStream.d.ts.map +1 -0
- package/lib/types/test/TestValidationErrorCompletionStream.d.ts +9 -0
- package/lib/types/test/TestValidationErrorCompletionStream.d.ts.map +1 -0
- package/lib/types/test/index.d.ts +27 -0
- package/lib/types/test/index.d.ts.map +1 -0
- package/lib/types/test/utils.d.ts +5 -0
- package/lib/types/test/utils.d.ts.map +1 -0
- package/lib/types/togetherai/index.d.ts +23 -0
- package/lib/types/togetherai/index.d.ts.map +1 -0
- package/lib/types/togetherai/interfaces.d.ts +81 -0
- package/lib/types/togetherai/interfaces.d.ts.map +1 -0
- package/lib/types/vertexai/debug.d.ts +2 -0
- package/lib/types/vertexai/debug.d.ts.map +1 -0
- package/lib/types/vertexai/index.d.ts +26 -0
- package/lib/types/vertexai/index.d.ts.map +1 -0
- package/lib/types/vertexai/models/codey-chat.d.ts +51 -0
- package/lib/types/vertexai/models/codey-chat.d.ts.map +1 -0
- package/lib/types/vertexai/models/codey-text.d.ts +39 -0
- package/lib/types/vertexai/models/codey-text.d.ts.map +1 -0
- package/lib/types/vertexai/models/gemini.d.ts +11 -0
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -0
- package/lib/types/vertexai/models/palm-model-base.d.ts +47 -0
- package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -0
- package/lib/types/vertexai/models/palm2-chat.d.ts +61 -0
- package/lib/types/vertexai/models/palm2-chat.d.ts.map +1 -0
- package/lib/types/vertexai/models/palm2-text.d.ts +39 -0
- package/lib/types/vertexai/models/palm2-text.d.ts.map +1 -0
- package/lib/types/vertexai/models.d.ts +14 -0
- package/lib/types/vertexai/models.d.ts.map +1 -0
- package/lib/types/vertexai/utils/prompts.d.ts +20 -0
- package/lib/types/vertexai/utils/prompts.d.ts.map +1 -0
- package/lib/types/vertexai/utils/tensor.d.ts +6 -0
- package/lib/types/vertexai/utils/tensor.d.ts.map +1 -0
- package/package.json +72 -0
- package/src/bedrock/index.ts +456 -0
- package/src/bedrock/s3.ts +62 -0
- package/src/huggingface_ie.ts +269 -0
- package/src/index.ts +7 -0
- package/src/openai.ts +254 -0
- package/src/replicate.ts +333 -0
- package/src/test/TestErrorCompletionStream.ts +17 -0
- package/src/test/TestValidationErrorCompletionStream.ts +21 -0
- package/src/test/index.ts +102 -0
- package/src/test/utils.ts +28 -0
- package/src/togetherai/index.ts +105 -0
- package/src/togetherai/interfaces.ts +88 -0
- package/src/vertexai/README.md +257 -0
- package/src/vertexai/debug.ts +6 -0
- package/src/vertexai/index.ts +99 -0
- package/src/vertexai/models/codey-chat.ts +115 -0
- package/src/vertexai/models/codey-text.ts +69 -0
- package/src/vertexai/models/gemini.ts +152 -0
- package/src/vertexai/models/palm-model-base.ts +122 -0
- package/src/vertexai/models/palm2-chat.ts +119 -0
- package/src/vertexai/models/palm2-text.ts +69 -0
- package/src/vertexai/models.ts +104 -0
- package/src/vertexai/utils/prompts.ts +66 -0
- package/src/vertexai/utils/tensor.ts +82 -0
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import {
|
|
2
|
+
AIModel,
|
|
3
|
+
AIModelStatus,
|
|
4
|
+
AbstractDriver,
|
|
5
|
+
BuiltinProviders,
|
|
6
|
+
DriverOptions,
|
|
7
|
+
ExecutionOptions,
|
|
8
|
+
PromptFormats
|
|
9
|
+
} from "@llumiverse/core";
|
|
10
|
+
import { transformAsyncIterator } from "@llumiverse/core/async";
|
|
11
|
+
import {
|
|
12
|
+
HfInference,
|
|
13
|
+
HfInferenceEndpoint,
|
|
14
|
+
TextGenerationStreamOutput
|
|
15
|
+
} from "@huggingface/inference";
|
|
16
|
+
import { FetchClient } from "api-fetch-client";
|
|
17
|
+
|
|
18
|
+
export interface HuggingFaceIEDriverOptions extends DriverOptions {
|
|
19
|
+
apiKey: string;
|
|
20
|
+
endpoint_url: string;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptions, string> {
|
|
24
|
+
service: FetchClient;
|
|
25
|
+
provider = BuiltinProviders.huggingface_ie;
|
|
26
|
+
_executor?: HfInferenceEndpoint;
|
|
27
|
+
defaultFormat = PromptFormats.genericTextLLM;
|
|
28
|
+
|
|
29
|
+
constructor(
|
|
30
|
+
options: HuggingFaceIEDriverOptions
|
|
31
|
+
) {
|
|
32
|
+
super(options);
|
|
33
|
+
if (!options.endpoint_url) {
|
|
34
|
+
throw new Error(`Endpoint URL is required for ${this.provider}`);
|
|
35
|
+
}
|
|
36
|
+
this.service = new FetchClient(this.options.endpoint_url);
|
|
37
|
+
this.service.headers["Authorization"] = `Bearer ${this.options.apiKey}`;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
async getModelURLEndpoint(
|
|
41
|
+
modelId: string
|
|
42
|
+
): Promise<{ url: string; status: string; }> {
|
|
43
|
+
const res = (await this.service.get(`/${modelId}`)) as HuggingFaceIEModel;
|
|
44
|
+
return {
|
|
45
|
+
url: res.status.url,
|
|
46
|
+
status: getStatus(res),
|
|
47
|
+
};
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
async getExecutor(model: string) {
|
|
51
|
+
if (!this._executor) {
|
|
52
|
+
const endpoint = await this.getModelURLEndpoint(model);
|
|
53
|
+
if (!endpoint.url)
|
|
54
|
+
throw new Error(
|
|
55
|
+
`Endpoint URL not found for model ${model}`
|
|
56
|
+
);
|
|
57
|
+
if (endpoint.status !== AIModelStatus.Available)
|
|
58
|
+
throw new Error(
|
|
59
|
+
`Endpoint ${model} is not running - current status: ${endpoint.status}`
|
|
60
|
+
);
|
|
61
|
+
|
|
62
|
+
this._executor = new HfInference(this.options.apiKey).endpoint(
|
|
63
|
+
endpoint.url
|
|
64
|
+
);
|
|
65
|
+
}
|
|
66
|
+
return this._executor;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
async requestCompletionStream(prompt: string, options: ExecutionOptions) {
|
|
70
|
+
const executor = await this.getExecutor(options.model);
|
|
71
|
+
const req = executor.textGenerationStream({
|
|
72
|
+
inputs: prompt,
|
|
73
|
+
parameters: {
|
|
74
|
+
temperature: options.temperature,
|
|
75
|
+
max_new_tokens: options.max_tokens,
|
|
76
|
+
},
|
|
77
|
+
});
|
|
78
|
+
|
|
79
|
+
return transformAsyncIterator(req, (val: TextGenerationStreamOutput) => {
|
|
80
|
+
//special like <s> are not part of the result
|
|
81
|
+
if (val.token.special) return "";
|
|
82
|
+
return val.token.text;
|
|
83
|
+
});
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
async requestCompletion(prompt: string, options: ExecutionOptions) {
|
|
87
|
+
const executor = await this.getExecutor(options.model);
|
|
88
|
+
const res = await executor.textGeneration({
|
|
89
|
+
inputs: prompt,
|
|
90
|
+
parameters: {
|
|
91
|
+
temperature: options.temperature,
|
|
92
|
+
max_new_tokens: options.max_tokens,
|
|
93
|
+
},
|
|
94
|
+
});
|
|
95
|
+
|
|
96
|
+
return {
|
|
97
|
+
result: res.generated_text,
|
|
98
|
+
token_usage: {
|
|
99
|
+
result: res.generated_text.length,
|
|
100
|
+
prompt: prompt.length,
|
|
101
|
+
total: res.generated_text.length + prompt.length,
|
|
102
|
+
},
|
|
103
|
+
};
|
|
104
|
+
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// ============== management API ==============
|
|
108
|
+
|
|
109
|
+
// Not implemented
|
|
110
|
+
async listTrainableModels(): Promise<AIModel<string>[]> {
|
|
111
|
+
return [];
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
async listModels(): Promise<AIModel[]> {
|
|
115
|
+
const res = await this.service.get("/");
|
|
116
|
+
const hfModels = res.items as HuggingFaceIEModel[];
|
|
117
|
+
|
|
118
|
+
const models: AIModel[] = hfModels.map((model: HuggingFaceIEModel) => ({
|
|
119
|
+
id: model.name,
|
|
120
|
+
name: `${model.name} [${model.model.repository}:${model.model.task}]`,
|
|
121
|
+
provider: this.provider,
|
|
122
|
+
tags: [model.model.task],
|
|
123
|
+
status: getStatus(model),
|
|
124
|
+
}));
|
|
125
|
+
|
|
126
|
+
return models;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
async validateConnection(): Promise<boolean> {
|
|
130
|
+
try {
|
|
131
|
+
await this.service.get("/models");
|
|
132
|
+
return true;
|
|
133
|
+
} catch (error) {
|
|
134
|
+
return false;
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
async generateEmbeddings(content: string, model?: string): Promise<{ embeddings: number[], model: string; }> {
|
|
139
|
+
this.logger?.debug(`[Huggingface] Generating embeddings for ${content} on ${model}`);
|
|
140
|
+
throw new Error("Method not implemented.");
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
//get status from HF status
|
|
146
|
+
function getStatus(hfModel: HuggingFaceIEModel): AIModelStatus {
|
|
147
|
+
//[ pending, initializing, updating, updateFailed, running, paused, failed, scaledToZero ]
|
|
148
|
+
switch (hfModel.status.state) {
|
|
149
|
+
case "running":
|
|
150
|
+
return AIModelStatus.Available;
|
|
151
|
+
case "initializing":
|
|
152
|
+
return AIModelStatus.Pending;
|
|
153
|
+
case "updating":
|
|
154
|
+
return AIModelStatus.Pending;
|
|
155
|
+
case "updateFailed":
|
|
156
|
+
return AIModelStatus.Unavailable;
|
|
157
|
+
case "paused":
|
|
158
|
+
return AIModelStatus.Stopped;
|
|
159
|
+
case "failed":
|
|
160
|
+
return AIModelStatus.Unavailable;
|
|
161
|
+
case "scaledToZero":
|
|
162
|
+
return AIModelStatus.Available;
|
|
163
|
+
default:
|
|
164
|
+
return AIModelStatus.Unknown;
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
interface HuggingFaceIEModel {
|
|
169
|
+
accountId: string;
|
|
170
|
+
compute: {
|
|
171
|
+
accelerator: string;
|
|
172
|
+
instanceSize: string;
|
|
173
|
+
instanceType: string;
|
|
174
|
+
scaling: {
|
|
175
|
+
maxReplica: number;
|
|
176
|
+
minReplica: number;
|
|
177
|
+
};
|
|
178
|
+
};
|
|
179
|
+
model: {
|
|
180
|
+
framework: string;
|
|
181
|
+
image: {
|
|
182
|
+
huggingface: {};
|
|
183
|
+
};
|
|
184
|
+
repository: string;
|
|
185
|
+
revision: string;
|
|
186
|
+
task: string;
|
|
187
|
+
};
|
|
188
|
+
name: string;
|
|
189
|
+
provider: {
|
|
190
|
+
region: string;
|
|
191
|
+
vendor: string;
|
|
192
|
+
};
|
|
193
|
+
status: {
|
|
194
|
+
createdAt: string;
|
|
195
|
+
createdBy: {
|
|
196
|
+
id: string;
|
|
197
|
+
name: string;
|
|
198
|
+
};
|
|
199
|
+
message: string;
|
|
200
|
+
private: {
|
|
201
|
+
serviceName: string;
|
|
202
|
+
};
|
|
203
|
+
readyReplica: number;
|
|
204
|
+
state: string;
|
|
205
|
+
targetReplica: number;
|
|
206
|
+
updatedAt: string;
|
|
207
|
+
updatedBy: {
|
|
208
|
+
id: string;
|
|
209
|
+
name: string;
|
|
210
|
+
};
|
|
211
|
+
url: string;
|
|
212
|
+
};
|
|
213
|
+
type: string;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
/*
|
|
217
|
+
Example of model returned by the API
|
|
218
|
+
{
|
|
219
|
+
"items": [
|
|
220
|
+
{
|
|
221
|
+
"accountId": "string",
|
|
222
|
+
"compute": {
|
|
223
|
+
"accelerator": "cpu",
|
|
224
|
+
"instanceSize": "large",
|
|
225
|
+
"instanceType": "c6i",
|
|
226
|
+
"scaling": {
|
|
227
|
+
"maxReplica": 8,
|
|
228
|
+
"minReplica": 2
|
|
229
|
+
}
|
|
230
|
+
},
|
|
231
|
+
"model": {
|
|
232
|
+
"framework": "custom",
|
|
233
|
+
"image": {
|
|
234
|
+
"huggingface": {}
|
|
235
|
+
},
|
|
236
|
+
"repository": "gpt2",
|
|
237
|
+
"revision": "6c0e6080953db56375760c0471a8c5f2929baf11",
|
|
238
|
+
"task": "text-classification"
|
|
239
|
+
},
|
|
240
|
+
"name": "my-endpoint",
|
|
241
|
+
"provider": {
|
|
242
|
+
"region": "us-east-1",
|
|
243
|
+
"vendor": "aws"
|
|
244
|
+
},
|
|
245
|
+
"status": {
|
|
246
|
+
"createdAt": "2023-10-19T05:04:17.305Z",
|
|
247
|
+
"createdBy": {
|
|
248
|
+
"id": "string",
|
|
249
|
+
"name": "string"
|
|
250
|
+
},
|
|
251
|
+
"message": "Endpoint is ready",
|
|
252
|
+
"private": {
|
|
253
|
+
"serviceName": "string"
|
|
254
|
+
},
|
|
255
|
+
"readyReplica": 2,
|
|
256
|
+
"state": "pending",
|
|
257
|
+
"targetReplica": 4,
|
|
258
|
+
"updatedAt": "2023-10-19T05:04:17.305Z",
|
|
259
|
+
"updatedBy": {
|
|
260
|
+
"id": "string",
|
|
261
|
+
"name": "string"
|
|
262
|
+
},
|
|
263
|
+
"url": "https://endpoint-id.region.vendor.endpoints.huggingface.cloud"
|
|
264
|
+
},
|
|
265
|
+
"type": "public"
|
|
266
|
+
}
|
|
267
|
+
]
|
|
268
|
+
}
|
|
269
|
+
*/
|
package/src/index.ts
ADDED
package/src/openai.ts
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
import {
|
|
2
|
+
AIModel,
|
|
3
|
+
AbstractDriver,
|
|
4
|
+
BuiltinProviders,
|
|
5
|
+
Completion,
|
|
6
|
+
DataSource,
|
|
7
|
+
DriverOptions,
|
|
8
|
+
ExecutionOptions,
|
|
9
|
+
ExecutionTokenUsage,
|
|
10
|
+
ModelType,
|
|
11
|
+
PromptFormats,
|
|
12
|
+
PromptSegment,
|
|
13
|
+
TrainingJob,
|
|
14
|
+
TrainingJobStatus,
|
|
15
|
+
TrainingOptions,
|
|
16
|
+
TrainingPromptOptions
|
|
17
|
+
} from "@llumiverse/core";
|
|
18
|
+
import { asyncMap } from "@llumiverse/core/async";
|
|
19
|
+
import OpenAI from "openai";
|
|
20
|
+
import { Stream } from "openai/streaming";
|
|
21
|
+
|
|
22
|
+
const supportFineTunning = new Set([
|
|
23
|
+
"gpt-3.5-turbo-1106",
|
|
24
|
+
"gpt-3.5-turbo-0613",
|
|
25
|
+
"babbage-002",
|
|
26
|
+
"davinci-002",
|
|
27
|
+
"gpt-4-0613"
|
|
28
|
+
]);
|
|
29
|
+
|
|
30
|
+
export interface OpenAIDriverOptions extends DriverOptions {
|
|
31
|
+
apiKey: string;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export class OpenAIDriver extends AbstractDriver<
|
|
35
|
+
OpenAIDriverOptions,
|
|
36
|
+
OpenAI.Chat.Completions.ChatCompletionMessageParam[]
|
|
37
|
+
> {
|
|
38
|
+
inputContentTypes: string[] = ["text/plain"];
|
|
39
|
+
generatedContentTypes: string[] = ["text/plain"];
|
|
40
|
+
service: OpenAI;
|
|
41
|
+
provider = BuiltinProviders.openai;
|
|
42
|
+
defaultFormat = PromptFormats.openai;
|
|
43
|
+
|
|
44
|
+
constructor(opts: OpenAIDriverOptions) {
|
|
45
|
+
super(opts);
|
|
46
|
+
this.service = new OpenAI({
|
|
47
|
+
apiKey: opts.apiKey,
|
|
48
|
+
});
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
createPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
|
|
52
|
+
// openai only supports opanai format - force the format
|
|
53
|
+
return super.createPrompt(segments, { ...opts, format: PromptFormats.openai })
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
extractDataFromResponse(
|
|
57
|
+
options: ExecutionOptions,
|
|
58
|
+
result: OpenAI.Chat.Completions.ChatCompletion
|
|
59
|
+
): Completion {
|
|
60
|
+
const tokenInfo: ExecutionTokenUsage = {
|
|
61
|
+
prompt: result.usage?.prompt_tokens,
|
|
62
|
+
result: result.usage?.completion_tokens,
|
|
63
|
+
total: result.usage?.total_tokens,
|
|
64
|
+
};
|
|
65
|
+
|
|
66
|
+
//if no schema, return content
|
|
67
|
+
if (!options.resultSchema) {
|
|
68
|
+
return {
|
|
69
|
+
result: result.choices[0]?.message.content as string,
|
|
70
|
+
token_usage: tokenInfo,
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
//we have a schema: get the content and return after validation
|
|
75
|
+
const data = result.choices[0]?.message.function_call?.arguments as any;
|
|
76
|
+
if (!data) {
|
|
77
|
+
this.logger?.error("[OpenAI] Response is not valid", result);
|
|
78
|
+
throw new Error("Response is not valid: no data");
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
return {
|
|
82
|
+
result: data,
|
|
83
|
+
token_usage: tokenInfo
|
|
84
|
+
};
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
async requestCompletionStream(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
|
|
88
|
+
const mapFn = options.resultSchema
|
|
89
|
+
? (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
|
|
90
|
+
return (
|
|
91
|
+
chunk.choices[0]?.delta?.function_call?.arguments ?? ""
|
|
92
|
+
);
|
|
93
|
+
}
|
|
94
|
+
: (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
|
|
95
|
+
return chunk.choices[0]?.delta?.content ?? "";
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
const stream = (await this.service.chat.completions.create({
|
|
99
|
+
stream: true,
|
|
100
|
+
model: options.model,
|
|
101
|
+
messages: prompt,
|
|
102
|
+
temperature: options.temperature,
|
|
103
|
+
n: 1,
|
|
104
|
+
max_tokens: options.max_tokens,
|
|
105
|
+
functions: options.resultSchema
|
|
106
|
+
? [
|
|
107
|
+
{
|
|
108
|
+
name: "format_output",
|
|
109
|
+
parameters: options.resultSchema as any,
|
|
110
|
+
},
|
|
111
|
+
]
|
|
112
|
+
: undefined,
|
|
113
|
+
function_call: options.resultSchema
|
|
114
|
+
? { name: "format_output" }
|
|
115
|
+
: undefined,
|
|
116
|
+
})) as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>;
|
|
117
|
+
|
|
118
|
+
return asyncMap(stream, mapFn);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
async requestCompletion(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
|
|
122
|
+
const functions = options.resultSchema
|
|
123
|
+
? [
|
|
124
|
+
{
|
|
125
|
+
name: "format_output",
|
|
126
|
+
parameters: options.resultSchema as any,
|
|
127
|
+
},
|
|
128
|
+
]
|
|
129
|
+
: undefined;
|
|
130
|
+
|
|
131
|
+
const res = await this.service.chat.completions.create({
|
|
132
|
+
stream: false,
|
|
133
|
+
model: options.model,
|
|
134
|
+
messages: prompt,
|
|
135
|
+
temperature: options.temperature,
|
|
136
|
+
n: 1,
|
|
137
|
+
max_tokens: options.max_tokens,
|
|
138
|
+
functions: functions,
|
|
139
|
+
function_call: options.resultSchema
|
|
140
|
+
? { name: "format_output" }
|
|
141
|
+
: undefined,
|
|
142
|
+
});
|
|
143
|
+
|
|
144
|
+
return this.extractDataFromResponse(options, res);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
createTrainingPrompt(options: TrainingPromptOptions): string {
|
|
148
|
+
if (options.model.includes("gpt")) {
|
|
149
|
+
return super.createTrainingPrompt(options);
|
|
150
|
+
} else {
|
|
151
|
+
// babbage, davinci not yet implemented
|
|
152
|
+
throw new Error("Unsupported model for training: " + options.model);
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {
|
|
157
|
+
const url = await dataset.getURL();
|
|
158
|
+
const file = await this.service.files.create({
|
|
159
|
+
file: await fetch(url),
|
|
160
|
+
purpose: "fine-tune",
|
|
161
|
+
});
|
|
162
|
+
|
|
163
|
+
const job = await this.service.fineTuning.jobs.create({
|
|
164
|
+
training_file: file.id,
|
|
165
|
+
model: options.model,
|
|
166
|
+
hyperparameters: options.params
|
|
167
|
+
})
|
|
168
|
+
|
|
169
|
+
return jobInfo(job);
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
async cancelTraining(jobId: string): Promise<TrainingJob> {
|
|
173
|
+
const job = await this.service.fineTuning.jobs.cancel(jobId);
|
|
174
|
+
return jobInfo(job);
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
async getTrainingJob(jobId: string): Promise<TrainingJob> {
|
|
178
|
+
const job = await this.service.fineTuning.jobs.retrieve(jobId);
|
|
179
|
+
return jobInfo(job);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// ========= management API =============
|
|
183
|
+
|
|
184
|
+
async validateConnection(): Promise<boolean> {
|
|
185
|
+
try {
|
|
186
|
+
await this.service.models.list();
|
|
187
|
+
return true;
|
|
188
|
+
} catch (error) {
|
|
189
|
+
return false;
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
listTrainableModels(): Promise<AIModel<string>[]> {
|
|
194
|
+
return this._listModels((m) => supportFineTunning.has(m.id));
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
async listModels(): Promise<AIModel[]> {
|
|
198
|
+
return this._listModels();
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
async _listModels(filter?: (m: OpenAI.Models.Model) => boolean) {
|
|
202
|
+
let result = await this.service.models.list();
|
|
203
|
+
const models = filter ? result.data.filter(filter) : result.data;
|
|
204
|
+
return models.map((m) => ({
|
|
205
|
+
id: m.id,
|
|
206
|
+
name: m.id,
|
|
207
|
+
provider: this.provider,
|
|
208
|
+
owner: m.owned_by,
|
|
209
|
+
type: m.object === "model" ? ModelType.Text : ModelType.Unknown,
|
|
210
|
+
}));
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
async generateEmbeddings(content: string, model: string = "text-embedding-ada-002"): Promise<{ embeddings: number[], model: string; }> {
|
|
215
|
+
const res = await this.service.embeddings.create({
|
|
216
|
+
input: content,
|
|
217
|
+
model: model,
|
|
218
|
+
});
|
|
219
|
+
|
|
220
|
+
const embeddings = res.data[0].embedding;
|
|
221
|
+
|
|
222
|
+
if (!embeddings || embeddings.length === 0) {
|
|
223
|
+
throw new Error("No embedding found");
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
return { embeddings, model };
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
function jobInfo(job: OpenAI.FineTuning.Jobs.FineTuningJob): TrainingJob {
|
|
233
|
+
//validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.
|
|
234
|
+
const jobStatus = job.status;
|
|
235
|
+
let status = TrainingJobStatus.running;
|
|
236
|
+
let details: string | undefined;
|
|
237
|
+
if (jobStatus === 'succeeded') {
|
|
238
|
+
status = TrainingJobStatus.succeeded;
|
|
239
|
+
} else if (jobStatus === 'failed') {
|
|
240
|
+
status = TrainingJobStatus.failed;
|
|
241
|
+
details = job.error ? `${job.error.code} - ${job.error.message} ${job.error.param ? " [" + job.error.param + "]" : ""}` : "error";
|
|
242
|
+
} else if (jobStatus === 'cancelled') {
|
|
243
|
+
status = TrainingJobStatus.cancelled;
|
|
244
|
+
} else {
|
|
245
|
+
status = TrainingJobStatus.running;
|
|
246
|
+
details = jobStatus;
|
|
247
|
+
}
|
|
248
|
+
return {
|
|
249
|
+
id: job.id,
|
|
250
|
+
model: job.fine_tuned_model || undefined,
|
|
251
|
+
status,
|
|
252
|
+
details
|
|
253
|
+
}
|
|
254
|
+
}
|