@llumiverse/drivers 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. package/lib/cjs/bedrock/index.js +338 -0
  2. package/lib/cjs/bedrock/index.js.map +1 -0
  3. package/lib/cjs/bedrock/s3.js +61 -0
  4. package/lib/cjs/bedrock/s3.js.map +1 -0
  5. package/lib/cjs/huggingface_ie.js +181 -0
  6. package/lib/cjs/huggingface_ie.js.map +1 -0
  7. package/lib/cjs/index.js +24 -0
  8. package/lib/cjs/index.js.map +1 -0
  9. package/lib/cjs/openai.js +205 -0
  10. package/lib/cjs/openai.js.map +1 -0
  11. package/lib/cjs/package.json +3 -0
  12. package/lib/cjs/replicate.js +290 -0
  13. package/lib/cjs/replicate.js.map +1 -0
  14. package/lib/cjs/test/TestErrorCompletionStream.js +20 -0
  15. package/lib/cjs/test/TestErrorCompletionStream.js.map +1 -0
  16. package/lib/cjs/test/TestValidationErrorCompletionStream.js +24 -0
  17. package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -0
  18. package/lib/cjs/test/index.js +109 -0
  19. package/lib/cjs/test/index.js.map +1 -0
  20. package/lib/cjs/test/utils.js +31 -0
  21. package/lib/cjs/test/utils.js.map +1 -0
  22. package/lib/cjs/togetherai/index.js +92 -0
  23. package/lib/cjs/togetherai/index.js.map +1 -0
  24. package/lib/cjs/togetherai/interfaces.js +3 -0
  25. package/lib/cjs/togetherai/interfaces.js.map +1 -0
  26. package/lib/cjs/vertexai/debug.js +13 -0
  27. package/lib/cjs/vertexai/debug.js.map +1 -0
  28. package/lib/cjs/vertexai/index.js +80 -0
  29. package/lib/cjs/vertexai/index.js.map +1 -0
  30. package/lib/cjs/vertexai/models/codey-chat.js +65 -0
  31. package/lib/cjs/vertexai/models/codey-chat.js.map +1 -0
  32. package/lib/cjs/vertexai/models/codey-text.js +35 -0
  33. package/lib/cjs/vertexai/models/codey-text.js.map +1 -0
  34. package/lib/cjs/vertexai/models/gemini.js +140 -0
  35. package/lib/cjs/vertexai/models/gemini.js.map +1 -0
  36. package/lib/cjs/vertexai/models/palm-model-base.js +65 -0
  37. package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -0
  38. package/lib/cjs/vertexai/models/palm2-chat.js +65 -0
  39. package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -0
  40. package/lib/cjs/vertexai/models/palm2-text.js +35 -0
  41. package/lib/cjs/vertexai/models/palm2-text.js.map +1 -0
  42. package/lib/cjs/vertexai/models.js +93 -0
  43. package/lib/cjs/vertexai/models.js.map +1 -0
  44. package/lib/cjs/vertexai/utils/prompts.js +52 -0
  45. package/lib/cjs/vertexai/utils/prompts.js.map +1 -0
  46. package/lib/cjs/vertexai/utils/tensor.js +87 -0
  47. package/lib/cjs/vertexai/utils/tensor.js.map +1 -0
  48. package/lib/esm/bedrock/index.js +331 -0
  49. package/lib/esm/bedrock/index.js.map +1 -0
  50. package/lib/esm/bedrock/s3.js +53 -0
  51. package/lib/esm/bedrock/s3.js.map +1 -0
  52. package/lib/esm/huggingface_ie.js +177 -0
  53. package/lib/esm/huggingface_ie.js.map +1 -0
  54. package/lib/esm/index.js +8 -0
  55. package/lib/esm/index.js.map +1 -0
  56. package/lib/esm/openai.js +198 -0
  57. package/lib/esm/openai.js.map +1 -0
  58. package/lib/esm/replicate.js +283 -0
  59. package/lib/esm/replicate.js.map +1 -0
  60. package/lib/esm/test/TestErrorCompletionStream.js +16 -0
  61. package/lib/esm/test/TestErrorCompletionStream.js.map +1 -0
  62. package/lib/esm/test/TestValidationErrorCompletionStream.js +20 -0
  63. package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -0
  64. package/lib/esm/test/index.js +91 -0
  65. package/lib/esm/test/index.js.map +1 -0
  66. package/lib/esm/test/utils.js +25 -0
  67. package/lib/esm/test/utils.js.map +1 -0
  68. package/lib/esm/togetherai/index.js +88 -0
  69. package/lib/esm/togetherai/index.js.map +1 -0
  70. package/lib/esm/togetherai/interfaces.js +2 -0
  71. package/lib/esm/togetherai/interfaces.js.map +1 -0
  72. package/lib/esm/vertexai/debug.js +6 -0
  73. package/lib/esm/vertexai/debug.js.map +1 -0
  74. package/lib/esm/vertexai/index.js +76 -0
  75. package/lib/esm/vertexai/index.js.map +1 -0
  76. package/lib/esm/vertexai/models/codey-chat.js +61 -0
  77. package/lib/esm/vertexai/models/codey-chat.js.map +1 -0
  78. package/lib/esm/vertexai/models/codey-text.js +31 -0
  79. package/lib/esm/vertexai/models/codey-text.js.map +1 -0
  80. package/lib/esm/vertexai/models/gemini.js +136 -0
  81. package/lib/esm/vertexai/models/gemini.js.map +1 -0
  82. package/lib/esm/vertexai/models/palm-model-base.js +61 -0
  83. package/lib/esm/vertexai/models/palm-model-base.js.map +1 -0
  84. package/lib/esm/vertexai/models/palm2-chat.js +61 -0
  85. package/lib/esm/vertexai/models/palm2-chat.js.map +1 -0
  86. package/lib/esm/vertexai/models/palm2-text.js +31 -0
  87. package/lib/esm/vertexai/models/palm2-text.js.map +1 -0
  88. package/lib/esm/vertexai/models.js +87 -0
  89. package/lib/esm/vertexai/models.js.map +1 -0
  90. package/lib/esm/vertexai/utils/prompts.js +47 -0
  91. package/lib/esm/vertexai/utils/prompts.js.map +1 -0
  92. package/lib/esm/vertexai/utils/tensor.js +82 -0
  93. package/lib/esm/vertexai/utils/tensor.js.map +1 -0
  94. package/lib/types/bedrock/index.d.ts +88 -0
  95. package/lib/types/bedrock/index.d.ts.map +1 -0
  96. package/lib/types/bedrock/s3.d.ts +20 -0
  97. package/lib/types/bedrock/s3.d.ts.map +1 -0
  98. package/lib/types/huggingface_ie.d.ts +36 -0
  99. package/lib/types/huggingface_ie.d.ts.map +1 -0
  100. package/lib/types/index.d.ts +8 -0
  101. package/lib/types/index.d.ts.map +1 -0
  102. package/lib/types/openai.d.ts +36 -0
  103. package/lib/types/openai.d.ts.map +1 -0
  104. package/lib/types/replicate.d.ts +52 -0
  105. package/lib/types/replicate.d.ts.map +1 -0
  106. package/lib/types/test/TestErrorCompletionStream.d.ts +9 -0
  107. package/lib/types/test/TestErrorCompletionStream.d.ts.map +1 -0
  108. package/lib/types/test/TestValidationErrorCompletionStream.d.ts +9 -0
  109. package/lib/types/test/TestValidationErrorCompletionStream.d.ts.map +1 -0
  110. package/lib/types/test/index.d.ts +27 -0
  111. package/lib/types/test/index.d.ts.map +1 -0
  112. package/lib/types/test/utils.d.ts +5 -0
  113. package/lib/types/test/utils.d.ts.map +1 -0
  114. package/lib/types/togetherai/index.d.ts +23 -0
  115. package/lib/types/togetherai/index.d.ts.map +1 -0
  116. package/lib/types/togetherai/interfaces.d.ts +81 -0
  117. package/lib/types/togetherai/interfaces.d.ts.map +1 -0
  118. package/lib/types/vertexai/debug.d.ts +2 -0
  119. package/lib/types/vertexai/debug.d.ts.map +1 -0
  120. package/lib/types/vertexai/index.d.ts +26 -0
  121. package/lib/types/vertexai/index.d.ts.map +1 -0
  122. package/lib/types/vertexai/models/codey-chat.d.ts +51 -0
  123. package/lib/types/vertexai/models/codey-chat.d.ts.map +1 -0
  124. package/lib/types/vertexai/models/codey-text.d.ts +39 -0
  125. package/lib/types/vertexai/models/codey-text.d.ts.map +1 -0
  126. package/lib/types/vertexai/models/gemini.d.ts +11 -0
  127. package/lib/types/vertexai/models/gemini.d.ts.map +1 -0
  128. package/lib/types/vertexai/models/palm-model-base.d.ts +47 -0
  129. package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -0
  130. package/lib/types/vertexai/models/palm2-chat.d.ts +61 -0
  131. package/lib/types/vertexai/models/palm2-chat.d.ts.map +1 -0
  132. package/lib/types/vertexai/models/palm2-text.d.ts +39 -0
  133. package/lib/types/vertexai/models/palm2-text.d.ts.map +1 -0
  134. package/lib/types/vertexai/models.d.ts +14 -0
  135. package/lib/types/vertexai/models.d.ts.map +1 -0
  136. package/lib/types/vertexai/utils/prompts.d.ts +20 -0
  137. package/lib/types/vertexai/utils/prompts.d.ts.map +1 -0
  138. package/lib/types/vertexai/utils/tensor.d.ts +6 -0
  139. package/lib/types/vertexai/utils/tensor.d.ts.map +1 -0
  140. package/package.json +72 -0
  141. package/src/bedrock/index.ts +456 -0
  142. package/src/bedrock/s3.ts +62 -0
  143. package/src/huggingface_ie.ts +269 -0
  144. package/src/index.ts +7 -0
  145. package/src/openai.ts +254 -0
  146. package/src/replicate.ts +333 -0
  147. package/src/test/TestErrorCompletionStream.ts +17 -0
  148. package/src/test/TestValidationErrorCompletionStream.ts +21 -0
  149. package/src/test/index.ts +102 -0
  150. package/src/test/utils.ts +28 -0
  151. package/src/togetherai/index.ts +105 -0
  152. package/src/togetherai/interfaces.ts +88 -0
  153. package/src/vertexai/README.md +257 -0
  154. package/src/vertexai/debug.ts +6 -0
  155. package/src/vertexai/index.ts +99 -0
  156. package/src/vertexai/models/codey-chat.ts +115 -0
  157. package/src/vertexai/models/codey-text.ts +69 -0
  158. package/src/vertexai/models/gemini.ts +152 -0
  159. package/src/vertexai/models/palm-model-base.ts +122 -0
  160. package/src/vertexai/models/palm2-chat.ts +119 -0
  161. package/src/vertexai/models/palm2-text.ts +69 -0
  162. package/src/vertexai/models.ts +104 -0
  163. package/src/vertexai/utils/prompts.ts +66 -0
  164. package/src/vertexai/utils/tensor.ts +82 -0
@@ -0,0 +1,269 @@
1
+ import {
2
+ AIModel,
3
+ AIModelStatus,
4
+ AbstractDriver,
5
+ BuiltinProviders,
6
+ DriverOptions,
7
+ ExecutionOptions,
8
+ PromptFormats
9
+ } from "@llumiverse/core";
10
+ import { transformAsyncIterator } from "@llumiverse/core/async";
11
+ import {
12
+ HfInference,
13
+ HfInferenceEndpoint,
14
+ TextGenerationStreamOutput
15
+ } from "@huggingface/inference";
16
+ import { FetchClient } from "api-fetch-client";
17
+
18
+ export interface HuggingFaceIEDriverOptions extends DriverOptions {
19
+ apiKey: string;
20
+ endpoint_url: string;
21
+ }
22
+
23
+ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptions, string> {
24
+ service: FetchClient;
25
+ provider = BuiltinProviders.huggingface_ie;
26
+ _executor?: HfInferenceEndpoint;
27
+ defaultFormat = PromptFormats.genericTextLLM;
28
+
29
+ constructor(
30
+ options: HuggingFaceIEDriverOptions
31
+ ) {
32
+ super(options);
33
+ if (!options.endpoint_url) {
34
+ throw new Error(`Endpoint URL is required for ${this.provider}`);
35
+ }
36
+ this.service = new FetchClient(this.options.endpoint_url);
37
+ this.service.headers["Authorization"] = `Bearer ${this.options.apiKey}`;
38
+ }
39
+
40
+ async getModelURLEndpoint(
41
+ modelId: string
42
+ ): Promise<{ url: string; status: string; }> {
43
+ const res = (await this.service.get(`/${modelId}`)) as HuggingFaceIEModel;
44
+ return {
45
+ url: res.status.url,
46
+ status: getStatus(res),
47
+ };
48
+ }
49
+
50
+ async getExecutor(model: string) {
51
+ if (!this._executor) {
52
+ const endpoint = await this.getModelURLEndpoint(model);
53
+ if (!endpoint.url)
54
+ throw new Error(
55
+ `Endpoint URL not found for model ${model}`
56
+ );
57
+ if (endpoint.status !== AIModelStatus.Available)
58
+ throw new Error(
59
+ `Endpoint ${model} is not running - current status: ${endpoint.status}`
60
+ );
61
+
62
+ this._executor = new HfInference(this.options.apiKey).endpoint(
63
+ endpoint.url
64
+ );
65
+ }
66
+ return this._executor;
67
+ }
68
+
69
+ async requestCompletionStream(prompt: string, options: ExecutionOptions) {
70
+ const executor = await this.getExecutor(options.model);
71
+ const req = executor.textGenerationStream({
72
+ inputs: prompt,
73
+ parameters: {
74
+ temperature: options.temperature,
75
+ max_new_tokens: options.max_tokens,
76
+ },
77
+ });
78
+
79
+ return transformAsyncIterator(req, (val: TextGenerationStreamOutput) => {
80
+ //special like <s> are not part of the result
81
+ if (val.token.special) return "";
82
+ return val.token.text;
83
+ });
84
+ }
85
+
86
+ async requestCompletion(prompt: string, options: ExecutionOptions) {
87
+ const executor = await this.getExecutor(options.model);
88
+ const res = await executor.textGeneration({
89
+ inputs: prompt,
90
+ parameters: {
91
+ temperature: options.temperature,
92
+ max_new_tokens: options.max_tokens,
93
+ },
94
+ });
95
+
96
+ return {
97
+ result: res.generated_text,
98
+ token_usage: {
99
+ result: res.generated_text.length,
100
+ prompt: prompt.length,
101
+ total: res.generated_text.length + prompt.length,
102
+ },
103
+ };
104
+
105
+ }
106
+
107
+ // ============== management API ==============
108
+
109
+ // Not implemented
110
+ async listTrainableModels(): Promise<AIModel<string>[]> {
111
+ return [];
112
+ }
113
+
114
+ async listModels(): Promise<AIModel[]> {
115
+ const res = await this.service.get("/");
116
+ const hfModels = res.items as HuggingFaceIEModel[];
117
+
118
+ const models: AIModel[] = hfModels.map((model: HuggingFaceIEModel) => ({
119
+ id: model.name,
120
+ name: `${model.name} [${model.model.repository}:${model.model.task}]`,
121
+ provider: this.provider,
122
+ tags: [model.model.task],
123
+ status: getStatus(model),
124
+ }));
125
+
126
+ return models;
127
+ }
128
+
129
+ async validateConnection(): Promise<boolean> {
130
+ try {
131
+ await this.service.get("/models");
132
+ return true;
133
+ } catch (error) {
134
+ return false;
135
+ }
136
+ }
137
+
138
+ async generateEmbeddings(content: string, model?: string): Promise<{ embeddings: number[], model: string; }> {
139
+ this.logger?.debug(`[Huggingface] Generating embeddings for ${content} on ${model}`);
140
+ throw new Error("Method not implemented.");
141
+ }
142
+
143
+ }
144
+
145
+ //get status from HF status
146
+ function getStatus(hfModel: HuggingFaceIEModel): AIModelStatus {
147
+ //[ pending, initializing, updating, updateFailed, running, paused, failed, scaledToZero ]
148
+ switch (hfModel.status.state) {
149
+ case "running":
150
+ return AIModelStatus.Available;
151
+ case "initializing":
152
+ return AIModelStatus.Pending;
153
+ case "updating":
154
+ return AIModelStatus.Pending;
155
+ case "updateFailed":
156
+ return AIModelStatus.Unavailable;
157
+ case "paused":
158
+ return AIModelStatus.Stopped;
159
+ case "failed":
160
+ return AIModelStatus.Unavailable;
161
+ case "scaledToZero":
162
+ return AIModelStatus.Available;
163
+ default:
164
+ return AIModelStatus.Unknown;
165
+ }
166
+ }
167
+
168
+ interface HuggingFaceIEModel {
169
+ accountId: string;
170
+ compute: {
171
+ accelerator: string;
172
+ instanceSize: string;
173
+ instanceType: string;
174
+ scaling: {
175
+ maxReplica: number;
176
+ minReplica: number;
177
+ };
178
+ };
179
+ model: {
180
+ framework: string;
181
+ image: {
182
+ huggingface: {};
183
+ };
184
+ repository: string;
185
+ revision: string;
186
+ task: string;
187
+ };
188
+ name: string;
189
+ provider: {
190
+ region: string;
191
+ vendor: string;
192
+ };
193
+ status: {
194
+ createdAt: string;
195
+ createdBy: {
196
+ id: string;
197
+ name: string;
198
+ };
199
+ message: string;
200
+ private: {
201
+ serviceName: string;
202
+ };
203
+ readyReplica: number;
204
+ state: string;
205
+ targetReplica: number;
206
+ updatedAt: string;
207
+ updatedBy: {
208
+ id: string;
209
+ name: string;
210
+ };
211
+ url: string;
212
+ };
213
+ type: string;
214
+ }
215
+
216
+ /*
217
+ Example of model returned by the API
218
+ {
219
+ "items": [
220
+ {
221
+ "accountId": "string",
222
+ "compute": {
223
+ "accelerator": "cpu",
224
+ "instanceSize": "large",
225
+ "instanceType": "c6i",
226
+ "scaling": {
227
+ "maxReplica": 8,
228
+ "minReplica": 2
229
+ }
230
+ },
231
+ "model": {
232
+ "framework": "custom",
233
+ "image": {
234
+ "huggingface": {}
235
+ },
236
+ "repository": "gpt2",
237
+ "revision": "6c0e6080953db56375760c0471a8c5f2929baf11",
238
+ "task": "text-classification"
239
+ },
240
+ "name": "my-endpoint",
241
+ "provider": {
242
+ "region": "us-east-1",
243
+ "vendor": "aws"
244
+ },
245
+ "status": {
246
+ "createdAt": "2023-10-19T05:04:17.305Z",
247
+ "createdBy": {
248
+ "id": "string",
249
+ "name": "string"
250
+ },
251
+ "message": "Endpoint is ready",
252
+ "private": {
253
+ "serviceName": "string"
254
+ },
255
+ "readyReplica": 2,
256
+ "state": "pending",
257
+ "targetReplica": 4,
258
+ "updatedAt": "2023-10-19T05:04:17.305Z",
259
+ "updatedBy": {
260
+ "id": "string",
261
+ "name": "string"
262
+ },
263
+ "url": "https://endpoint-id.region.vendor.endpoints.huggingface.cloud"
264
+ },
265
+ "type": "public"
266
+ }
267
+ ]
268
+ }
269
+ */
package/src/index.ts ADDED
@@ -0,0 +1,7 @@
1
+ export * from "./bedrock/index.js";
2
+ export * from "./huggingface_ie.js";
3
+ export * from "./openai.js";
4
+ export * from "./replicate.js";
5
+ export * from "./vertexai/index.js";
6
+ export * from "./togetherai/index.js";
7
+ export * from "./test/index.js";
package/src/openai.ts ADDED
@@ -0,0 +1,254 @@
1
+ import {
2
+ AIModel,
3
+ AbstractDriver,
4
+ BuiltinProviders,
5
+ Completion,
6
+ DataSource,
7
+ DriverOptions,
8
+ ExecutionOptions,
9
+ ExecutionTokenUsage,
10
+ ModelType,
11
+ PromptFormats,
12
+ PromptSegment,
13
+ TrainingJob,
14
+ TrainingJobStatus,
15
+ TrainingOptions,
16
+ TrainingPromptOptions
17
+ } from "@llumiverse/core";
18
+ import { asyncMap } from "@llumiverse/core/async";
19
+ import OpenAI from "openai";
20
+ import { Stream } from "openai/streaming";
21
+
22
+ const supportFineTunning = new Set([
23
+ "gpt-3.5-turbo-1106",
24
+ "gpt-3.5-turbo-0613",
25
+ "babbage-002",
26
+ "davinci-002",
27
+ "gpt-4-0613"
28
+ ]);
29
+
30
+ export interface OpenAIDriverOptions extends DriverOptions {
31
+ apiKey: string;
32
+ }
33
+
34
+ export class OpenAIDriver extends AbstractDriver<
35
+ OpenAIDriverOptions,
36
+ OpenAI.Chat.Completions.ChatCompletionMessageParam[]
37
+ > {
38
+ inputContentTypes: string[] = ["text/plain"];
39
+ generatedContentTypes: string[] = ["text/plain"];
40
+ service: OpenAI;
41
+ provider = BuiltinProviders.openai;
42
+ defaultFormat = PromptFormats.openai;
43
+
44
+ constructor(opts: OpenAIDriverOptions) {
45
+ super(opts);
46
+ this.service = new OpenAI({
47
+ apiKey: opts.apiKey,
48
+ });
49
+ }
50
+
51
+ createPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
52
+ // openai only supports opanai format - force the format
53
+ return super.createPrompt(segments, { ...opts, format: PromptFormats.openai })
54
+ }
55
+
56
+ extractDataFromResponse(
57
+ options: ExecutionOptions,
58
+ result: OpenAI.Chat.Completions.ChatCompletion
59
+ ): Completion {
60
+ const tokenInfo: ExecutionTokenUsage = {
61
+ prompt: result.usage?.prompt_tokens,
62
+ result: result.usage?.completion_tokens,
63
+ total: result.usage?.total_tokens,
64
+ };
65
+
66
+ //if no schema, return content
67
+ if (!options.resultSchema) {
68
+ return {
69
+ result: result.choices[0]?.message.content as string,
70
+ token_usage: tokenInfo,
71
+ }
72
+ }
73
+
74
+ //we have a schema: get the content and return after validation
75
+ const data = result.choices[0]?.message.function_call?.arguments as any;
76
+ if (!data) {
77
+ this.logger?.error("[OpenAI] Response is not valid", result);
78
+ throw new Error("Response is not valid: no data");
79
+ }
80
+
81
+ return {
82
+ result: data,
83
+ token_usage: tokenInfo
84
+ };
85
+ }
86
+
87
+ async requestCompletionStream(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
88
+ const mapFn = options.resultSchema
89
+ ? (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
90
+ return (
91
+ chunk.choices[0]?.delta?.function_call?.arguments ?? ""
92
+ );
93
+ }
94
+ : (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
95
+ return chunk.choices[0]?.delta?.content ?? "";
96
+ };
97
+
98
+ const stream = (await this.service.chat.completions.create({
99
+ stream: true,
100
+ model: options.model,
101
+ messages: prompt,
102
+ temperature: options.temperature,
103
+ n: 1,
104
+ max_tokens: options.max_tokens,
105
+ functions: options.resultSchema
106
+ ? [
107
+ {
108
+ name: "format_output",
109
+ parameters: options.resultSchema as any,
110
+ },
111
+ ]
112
+ : undefined,
113
+ function_call: options.resultSchema
114
+ ? { name: "format_output" }
115
+ : undefined,
116
+ })) as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>;
117
+
118
+ return asyncMap(stream, mapFn);
119
+ }
120
+
121
+ async requestCompletion(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
122
+ const functions = options.resultSchema
123
+ ? [
124
+ {
125
+ name: "format_output",
126
+ parameters: options.resultSchema as any,
127
+ },
128
+ ]
129
+ : undefined;
130
+
131
+ const res = await this.service.chat.completions.create({
132
+ stream: false,
133
+ model: options.model,
134
+ messages: prompt,
135
+ temperature: options.temperature,
136
+ n: 1,
137
+ max_tokens: options.max_tokens,
138
+ functions: functions,
139
+ function_call: options.resultSchema
140
+ ? { name: "format_output" }
141
+ : undefined,
142
+ });
143
+
144
+ return this.extractDataFromResponse(options, res);
145
+ }
146
+
147
+ createTrainingPrompt(options: TrainingPromptOptions): string {
148
+ if (options.model.includes("gpt")) {
149
+ return super.createTrainingPrompt(options);
150
+ } else {
151
+ // babbage, davinci not yet implemented
152
+ throw new Error("Unsupported model for training: " + options.model);
153
+ }
154
+ }
155
+
156
+ async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {
157
+ const url = await dataset.getURL();
158
+ const file = await this.service.files.create({
159
+ file: await fetch(url),
160
+ purpose: "fine-tune",
161
+ });
162
+
163
+ const job = await this.service.fineTuning.jobs.create({
164
+ training_file: file.id,
165
+ model: options.model,
166
+ hyperparameters: options.params
167
+ })
168
+
169
+ return jobInfo(job);
170
+ }
171
+
172
+ async cancelTraining(jobId: string): Promise<TrainingJob> {
173
+ const job = await this.service.fineTuning.jobs.cancel(jobId);
174
+ return jobInfo(job);
175
+ }
176
+
177
+ async getTrainingJob(jobId: string): Promise<TrainingJob> {
178
+ const job = await this.service.fineTuning.jobs.retrieve(jobId);
179
+ return jobInfo(job);
180
+ }
181
+
182
+ // ========= management API =============
183
+
184
+ async validateConnection(): Promise<boolean> {
185
+ try {
186
+ await this.service.models.list();
187
+ return true;
188
+ } catch (error) {
189
+ return false;
190
+ }
191
+ }
192
+
193
+ listTrainableModels(): Promise<AIModel<string>[]> {
194
+ return this._listModels((m) => supportFineTunning.has(m.id));
195
+ }
196
+
197
+ async listModels(): Promise<AIModel[]> {
198
+ return this._listModels();
199
+ }
200
+
201
+ async _listModels(filter?: (m: OpenAI.Models.Model) => boolean) {
202
+ let result = await this.service.models.list();
203
+ const models = filter ? result.data.filter(filter) : result.data;
204
+ return models.map((m) => ({
205
+ id: m.id,
206
+ name: m.id,
207
+ provider: this.provider,
208
+ owner: m.owned_by,
209
+ type: m.object === "model" ? ModelType.Text : ModelType.Unknown,
210
+ }));
211
+ }
212
+
213
+
214
+ async generateEmbeddings(content: string, model: string = "text-embedding-ada-002"): Promise<{ embeddings: number[], model: string; }> {
215
+ const res = await this.service.embeddings.create({
216
+ input: content,
217
+ model: model,
218
+ });
219
+
220
+ const embeddings = res.data[0].embedding;
221
+
222
+ if (!embeddings || embeddings.length === 0) {
223
+ throw new Error("No embedding found");
224
+ }
225
+
226
+ return { embeddings, model };
227
+ }
228
+
229
+ }
230
+
231
+
232
+ function jobInfo(job: OpenAI.FineTuning.Jobs.FineTuningJob): TrainingJob {
233
+ //validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.
234
+ const jobStatus = job.status;
235
+ let status = TrainingJobStatus.running;
236
+ let details: string | undefined;
237
+ if (jobStatus === 'succeeded') {
238
+ status = TrainingJobStatus.succeeded;
239
+ } else if (jobStatus === 'failed') {
240
+ status = TrainingJobStatus.failed;
241
+ details = job.error ? `${job.error.code} - ${job.error.message} ${job.error.param ? " [" + job.error.param + "]" : ""}` : "error";
242
+ } else if (jobStatus === 'cancelled') {
243
+ status = TrainingJobStatus.cancelled;
244
+ } else {
245
+ status = TrainingJobStatus.running;
246
+ details = jobStatus;
247
+ }
248
+ return {
249
+ id: job.id,
250
+ model: job.fine_tuned_model || undefined,
251
+ status,
252
+ details
253
+ }
254
+ }