@llumiverse/drivers 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. package/lib/cjs/bedrock/index.js +338 -0
  2. package/lib/cjs/bedrock/index.js.map +1 -0
  3. package/lib/cjs/bedrock/s3.js +61 -0
  4. package/lib/cjs/bedrock/s3.js.map +1 -0
  5. package/lib/cjs/huggingface_ie.js +181 -0
  6. package/lib/cjs/huggingface_ie.js.map +1 -0
  7. package/lib/cjs/index.js +24 -0
  8. package/lib/cjs/index.js.map +1 -0
  9. package/lib/cjs/openai.js +205 -0
  10. package/lib/cjs/openai.js.map +1 -0
  11. package/lib/cjs/package.json +3 -0
  12. package/lib/cjs/replicate.js +290 -0
  13. package/lib/cjs/replicate.js.map +1 -0
  14. package/lib/cjs/test/TestErrorCompletionStream.js +20 -0
  15. package/lib/cjs/test/TestErrorCompletionStream.js.map +1 -0
  16. package/lib/cjs/test/TestValidationErrorCompletionStream.js +24 -0
  17. package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -0
  18. package/lib/cjs/test/index.js +109 -0
  19. package/lib/cjs/test/index.js.map +1 -0
  20. package/lib/cjs/test/utils.js +31 -0
  21. package/lib/cjs/test/utils.js.map +1 -0
  22. package/lib/cjs/togetherai/index.js +92 -0
  23. package/lib/cjs/togetherai/index.js.map +1 -0
  24. package/lib/cjs/togetherai/interfaces.js +3 -0
  25. package/lib/cjs/togetherai/interfaces.js.map +1 -0
  26. package/lib/cjs/vertexai/debug.js +13 -0
  27. package/lib/cjs/vertexai/debug.js.map +1 -0
  28. package/lib/cjs/vertexai/index.js +80 -0
  29. package/lib/cjs/vertexai/index.js.map +1 -0
  30. package/lib/cjs/vertexai/models/codey-chat.js +65 -0
  31. package/lib/cjs/vertexai/models/codey-chat.js.map +1 -0
  32. package/lib/cjs/vertexai/models/codey-text.js +35 -0
  33. package/lib/cjs/vertexai/models/codey-text.js.map +1 -0
  34. package/lib/cjs/vertexai/models/gemini.js +140 -0
  35. package/lib/cjs/vertexai/models/gemini.js.map +1 -0
  36. package/lib/cjs/vertexai/models/palm-model-base.js +65 -0
  37. package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -0
  38. package/lib/cjs/vertexai/models/palm2-chat.js +65 -0
  39. package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -0
  40. package/lib/cjs/vertexai/models/palm2-text.js +35 -0
  41. package/lib/cjs/vertexai/models/palm2-text.js.map +1 -0
  42. package/lib/cjs/vertexai/models.js +93 -0
  43. package/lib/cjs/vertexai/models.js.map +1 -0
  44. package/lib/cjs/vertexai/utils/prompts.js +52 -0
  45. package/lib/cjs/vertexai/utils/prompts.js.map +1 -0
  46. package/lib/cjs/vertexai/utils/tensor.js +87 -0
  47. package/lib/cjs/vertexai/utils/tensor.js.map +1 -0
  48. package/lib/esm/bedrock/index.js +331 -0
  49. package/lib/esm/bedrock/index.js.map +1 -0
  50. package/lib/esm/bedrock/s3.js +53 -0
  51. package/lib/esm/bedrock/s3.js.map +1 -0
  52. package/lib/esm/huggingface_ie.js +177 -0
  53. package/lib/esm/huggingface_ie.js.map +1 -0
  54. package/lib/esm/index.js +8 -0
  55. package/lib/esm/index.js.map +1 -0
  56. package/lib/esm/openai.js +198 -0
  57. package/lib/esm/openai.js.map +1 -0
  58. package/lib/esm/replicate.js +283 -0
  59. package/lib/esm/replicate.js.map +1 -0
  60. package/lib/esm/test/TestErrorCompletionStream.js +16 -0
  61. package/lib/esm/test/TestErrorCompletionStream.js.map +1 -0
  62. package/lib/esm/test/TestValidationErrorCompletionStream.js +20 -0
  63. package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -0
  64. package/lib/esm/test/index.js +91 -0
  65. package/lib/esm/test/index.js.map +1 -0
  66. package/lib/esm/test/utils.js +25 -0
  67. package/lib/esm/test/utils.js.map +1 -0
  68. package/lib/esm/togetherai/index.js +88 -0
  69. package/lib/esm/togetherai/index.js.map +1 -0
  70. package/lib/esm/togetherai/interfaces.js +2 -0
  71. package/lib/esm/togetherai/interfaces.js.map +1 -0
  72. package/lib/esm/vertexai/debug.js +6 -0
  73. package/lib/esm/vertexai/debug.js.map +1 -0
  74. package/lib/esm/vertexai/index.js +76 -0
  75. package/lib/esm/vertexai/index.js.map +1 -0
  76. package/lib/esm/vertexai/models/codey-chat.js +61 -0
  77. package/lib/esm/vertexai/models/codey-chat.js.map +1 -0
  78. package/lib/esm/vertexai/models/codey-text.js +31 -0
  79. package/lib/esm/vertexai/models/codey-text.js.map +1 -0
  80. package/lib/esm/vertexai/models/gemini.js +136 -0
  81. package/lib/esm/vertexai/models/gemini.js.map +1 -0
  82. package/lib/esm/vertexai/models/palm-model-base.js +61 -0
  83. package/lib/esm/vertexai/models/palm-model-base.js.map +1 -0
  84. package/lib/esm/vertexai/models/palm2-chat.js +61 -0
  85. package/lib/esm/vertexai/models/palm2-chat.js.map +1 -0
  86. package/lib/esm/vertexai/models/palm2-text.js +31 -0
  87. package/lib/esm/vertexai/models/palm2-text.js.map +1 -0
  88. package/lib/esm/vertexai/models.js +87 -0
  89. package/lib/esm/vertexai/models.js.map +1 -0
  90. package/lib/esm/vertexai/utils/prompts.js +47 -0
  91. package/lib/esm/vertexai/utils/prompts.js.map +1 -0
  92. package/lib/esm/vertexai/utils/tensor.js +82 -0
  93. package/lib/esm/vertexai/utils/tensor.js.map +1 -0
  94. package/lib/types/bedrock/index.d.ts +88 -0
  95. package/lib/types/bedrock/index.d.ts.map +1 -0
  96. package/lib/types/bedrock/s3.d.ts +20 -0
  97. package/lib/types/bedrock/s3.d.ts.map +1 -0
  98. package/lib/types/huggingface_ie.d.ts +36 -0
  99. package/lib/types/huggingface_ie.d.ts.map +1 -0
  100. package/lib/types/index.d.ts +8 -0
  101. package/lib/types/index.d.ts.map +1 -0
  102. package/lib/types/openai.d.ts +36 -0
  103. package/lib/types/openai.d.ts.map +1 -0
  104. package/lib/types/replicate.d.ts +52 -0
  105. package/lib/types/replicate.d.ts.map +1 -0
  106. package/lib/types/test/TestErrorCompletionStream.d.ts +9 -0
  107. package/lib/types/test/TestErrorCompletionStream.d.ts.map +1 -0
  108. package/lib/types/test/TestValidationErrorCompletionStream.d.ts +9 -0
  109. package/lib/types/test/TestValidationErrorCompletionStream.d.ts.map +1 -0
  110. package/lib/types/test/index.d.ts +27 -0
  111. package/lib/types/test/index.d.ts.map +1 -0
  112. package/lib/types/test/utils.d.ts +5 -0
  113. package/lib/types/test/utils.d.ts.map +1 -0
  114. package/lib/types/togetherai/index.d.ts +23 -0
  115. package/lib/types/togetherai/index.d.ts.map +1 -0
  116. package/lib/types/togetherai/interfaces.d.ts +81 -0
  117. package/lib/types/togetherai/interfaces.d.ts.map +1 -0
  118. package/lib/types/vertexai/debug.d.ts +2 -0
  119. package/lib/types/vertexai/debug.d.ts.map +1 -0
  120. package/lib/types/vertexai/index.d.ts +26 -0
  121. package/lib/types/vertexai/index.d.ts.map +1 -0
  122. package/lib/types/vertexai/models/codey-chat.d.ts +51 -0
  123. package/lib/types/vertexai/models/codey-chat.d.ts.map +1 -0
  124. package/lib/types/vertexai/models/codey-text.d.ts +39 -0
  125. package/lib/types/vertexai/models/codey-text.d.ts.map +1 -0
  126. package/lib/types/vertexai/models/gemini.d.ts +11 -0
  127. package/lib/types/vertexai/models/gemini.d.ts.map +1 -0
  128. package/lib/types/vertexai/models/palm-model-base.d.ts +47 -0
  129. package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -0
  130. package/lib/types/vertexai/models/palm2-chat.d.ts +61 -0
  131. package/lib/types/vertexai/models/palm2-chat.d.ts.map +1 -0
  132. package/lib/types/vertexai/models/palm2-text.d.ts +39 -0
  133. package/lib/types/vertexai/models/palm2-text.d.ts.map +1 -0
  134. package/lib/types/vertexai/models.d.ts +14 -0
  135. package/lib/types/vertexai/models.d.ts.map +1 -0
  136. package/lib/types/vertexai/utils/prompts.d.ts +20 -0
  137. package/lib/types/vertexai/utils/prompts.d.ts.map +1 -0
  138. package/lib/types/vertexai/utils/tensor.d.ts +6 -0
  139. package/lib/types/vertexai/utils/tensor.d.ts.map +1 -0
  140. package/package.json +72 -0
  141. package/src/bedrock/index.ts +456 -0
  142. package/src/bedrock/s3.ts +62 -0
  143. package/src/huggingface_ie.ts +269 -0
  144. package/src/index.ts +7 -0
  145. package/src/openai.ts +254 -0
  146. package/src/replicate.ts +333 -0
  147. package/src/test/TestErrorCompletionStream.ts +17 -0
  148. package/src/test/TestValidationErrorCompletionStream.ts +21 -0
  149. package/src/test/index.ts +102 -0
  150. package/src/test/utils.ts +28 -0
  151. package/src/togetherai/index.ts +105 -0
  152. package/src/togetherai/interfaces.ts +88 -0
  153. package/src/vertexai/README.md +257 -0
  154. package/src/vertexai/debug.ts +6 -0
  155. package/src/vertexai/index.ts +99 -0
  156. package/src/vertexai/models/codey-chat.ts +115 -0
  157. package/src/vertexai/models/codey-text.ts +69 -0
  158. package/src/vertexai/models/gemini.ts +152 -0
  159. package/src/vertexai/models/palm-model-base.ts +122 -0
  160. package/src/vertexai/models/palm2-chat.ts +119 -0
  161. package/src/vertexai/models/palm2-text.ts +69 -0
  162. package/src/vertexai/models.ts +104 -0
  163. package/src/vertexai/utils/prompts.ts +66 -0
  164. package/src/vertexai/utils/tensor.ts +82 -0
@@ -0,0 +1,333 @@
1
+ import {
2
+ AIModel,
3
+ AbstractDriver,
4
+ BuiltinProviders,
5
+ Completion,
6
+ DataSource,
7
+ DriverOptions,
8
+ ExecutionOptions,
9
+ ModelSearchPayload,
10
+ PromptFormats,
11
+ TrainingJob,
12
+ TrainingJobStatus,
13
+ TrainingOptions
14
+ } from "@llumiverse/core";
15
+ import { EventStream } from "@llumiverse/core/async";
16
+ import EventSource from "eventsource";
17
+ import Replicate, { Prediction } from "replicate";
18
+
19
+ let cachedTrainableModels: AIModel[] | undefined;
20
+ let cachedTrainableModelsTimestamp: number = 0;
21
+
22
+ const supportFineTunning = new Set([
23
+ "meta/llama-2-70b-chat",
24
+ "meta/llama-2-13b-chat",
25
+ "meta/llama-2-7b-chat",
26
+ "meta/llama-2-7b",
27
+ "meta/llama-2-70b",
28
+ "meta/llama-2-13b",
29
+ "mistralai/mistral-7b-v0.1"
30
+ ]);
31
+
32
+ export interface ReplicateDriverOptions extends DriverOptions {
33
+ apiKey: string;
34
+ }
35
+
36
+ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
37
+ provider = BuiltinProviders.replicate;
38
+ service: Replicate;
39
+ defaultFormat = PromptFormats.genericTextLLM;
40
+
41
+ static parseModelId(modelId: string) {
42
+ const [owner, modelPart] = modelId.split("/");
43
+ const i = modelPart.indexOf(':');
44
+ if (i === -1) {
45
+ throw new Error("Invalid model id. Expected format: owner/model:version");
46
+ }
47
+ return {
48
+ owner, model: modelPart.slice(0, i), version: modelPart.slice(i + 1)
49
+ }
50
+ }
51
+
52
+ constructor(options: ReplicateDriverOptions) {
53
+ super(options);
54
+ this.service = new Replicate({
55
+ auth: options.apiKey,
56
+ });
57
+ }
58
+
59
+ extractDataFromResponse(prompt: string, response: Prediction): Completion {
60
+ const text = response.output.join("");
61
+ return {
62
+ result: text,
63
+ token_usage: {
64
+ result: response.output.length,
65
+ prompt: prompt.length,
66
+ total: response.output.length + prompt.length,
67
+ },
68
+ };
69
+ }
70
+
71
+ async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<string>> {
72
+ const model = ReplicateDriver.parseModelId(options.model);
73
+ const predictionData = {
74
+ input: {
75
+ prompt: prompt,
76
+ max_new_tokens: options.max_tokens || 1024,
77
+ temperature: options.temperature,
78
+ },
79
+ version: model.version,
80
+ stream: true, //streaming described here https://replicate.com/blog/streaming
81
+ };
82
+
83
+ const prediction =
84
+ await this.service.predictions.create(predictionData);
85
+
86
+ const stream = new EventStream<string>();
87
+
88
+ const source = new EventSource(prediction.urls.stream!);
89
+ source.addEventListener("output", (e: any) => {
90
+ stream.push(e.data);
91
+ });
92
+ source.addEventListener("error", (e: any) => {
93
+ let error: any;
94
+ try {
95
+ error = JSON.parse(e.data);
96
+ } catch (error) {
97
+ error = JSON.stringify(e);
98
+ }
99
+ this.logger?.error(e, error, "Error in SSE stream");
100
+ });
101
+ source.addEventListener("done", () => {
102
+ try {
103
+ stream.close(""); // not using e.data which is {}
104
+ } finally {
105
+ source.close();
106
+ }
107
+ });
108
+ return stream;
109
+ }
110
+
111
+ async requestCompletion(prompt: string, options: ExecutionOptions) {
112
+ const model = ReplicateDriver.parseModelId(options.model);
113
+ const predictionData = {
114
+ input: {
115
+ prompt: prompt,
116
+ max_new_tokens: options.max_tokens || 1024,
117
+ temperature: options.temperature,
118
+ },
119
+ version: model.version,
120
+ //TODO stream
121
+ //stream: stream, //streaming described here https://replicate.com/blog/streaming
122
+ };
123
+
124
+ const prediction =
125
+ await this.service.predictions.create(predictionData);
126
+
127
+ //TODO stream
128
+ //if we're streaming, return right away for the stream handler to handle
129
+ // if (stream) return prediction;
130
+
131
+ //not streaming, wait for the result
132
+ const res = await this.service.wait(prediction, {});
133
+
134
+ const text = res.output.join("");
135
+ return {
136
+ result: text,
137
+ token_usage: {
138
+ result: res.output.length,
139
+ prompt: prompt.length,
140
+ total: res.output.length + prompt.length,
141
+ },
142
+ };
143
+ }
144
+
145
+ async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {
146
+ if (options.name.indexOf("/") === -1) {
147
+ throw new Error("Invalid target model name. Expected format: owner/model");
148
+ }
149
+ const { owner, model, version } = ReplicateDriver.parseModelId(options.model);
150
+ const job = await this.service.trainings.create(owner, model, version, {
151
+ destination: options.name as any,
152
+ input: {
153
+ train_data: dataset.getURL(),
154
+ },
155
+ })
156
+ return jobInfo(job, options.name);
157
+ }
158
+
159
+ /**
160
+ * This method is not returning a consistent TrainingJob like the one returned by startTraining
161
+ * Instead of returning the full model name `owner/model:version` it returns only the version `version
162
+ * @param jobId
163
+ * @returns
164
+ */
165
+ async cancelTraining(jobId: string): Promise<TrainingJob> {
166
+ const job = await this.service.trainings.cancel(jobId);
167
+ return jobInfo(job);
168
+ }
169
+
170
+ /**
171
+ * This method is not returning a consistent TrainingJob like the one returned by startTraining
172
+ * Instead of returning the full model name `owner/model:version` it returns only the version `version
173
+ * @param jobId
174
+ * @returns
175
+ */
176
+ async getTrainingJob(jobId: string): Promise<TrainingJob> {
177
+ const job = await this.service.trainings.get(jobId);
178
+ return jobInfo(job);
179
+ }
180
+
181
+ // ========= management API =============
182
+
183
+ async validateConnection(): Promise<boolean> {
184
+ try {
185
+ await this.service.predictions.list();
186
+ return true;
187
+ } catch (error) {
188
+ return false;
189
+ }
190
+ }
191
+
192
+ async _listTrainableModels(): Promise<AIModel[]> {
193
+ const promises = Array.from(supportFineTunning).map(id => {
194
+ const [owner, model] = id.split('/');
195
+ return this.service.models.get(owner, model)
196
+ });
197
+ const results = await Promise.all(promises);
198
+ return results.filter(m => !!m.latest_version).map(m => {
199
+ const fullName = m.owner + '/' + m.name;
200
+ const v = m.latest_version!;
201
+ return {
202
+ id: fullName + ':' + v.id,
203
+ name:
204
+ fullName + "@" + v.cog_version + ":" + v.id.slice(0, 6),
205
+ provider: this.provider,
206
+ owner: m.owner,
207
+ description: m.description,
208
+ } as AIModel;
209
+ });
210
+ }
211
+
212
+ async listTrainableModels(): Promise<AIModel[]> {
213
+ if (!cachedTrainableModels || Date.now() > cachedTrainableModelsTimestamp + 12 * 3600 * 1000) { // 12 hours
214
+ cachedTrainableModels = await this._listTrainableModels();
215
+ cachedTrainableModelsTimestamp = Date.now();
216
+ }
217
+ return cachedTrainableModels;
218
+ }
219
+
220
+ async listModels(params: ModelSearchPayload): Promise<AIModel[]> {
221
+ if (!params.text) {
222
+ return this.listTrainableModels();
223
+ }
224
+ const [owner, model] = params.text.split("/");
225
+ if (!owner || !model) {
226
+ throw new Error("Invalid model name. Expected format: owner/model");
227
+ }
228
+
229
+ return this.listModelVersions(owner, model);
230
+ }
231
+
232
+ async listModelVersions(owner: string, model: string): Promise<AIModel[]> {
233
+ const [rModel, versions] = await Promise.all([
234
+ this.service.models.get(owner, model),
235
+ this.service.models.versions.list(owner, model),
236
+ ]);
237
+
238
+ if (!rModel || !versions || versions.length === 0) {
239
+ throw new Error("Model not found or no versions avaialble");
240
+ }
241
+
242
+ const models: AIModel[] = (versions as any).results.map((v: any) => {
243
+ const fullName = rModel.owner + '/' + rModel.name;
244
+ return {
245
+ id: fullName + ':' + v.id,
246
+ name:
247
+ fullName + "@" + v.cog_version + ":" + v.id.slice(0, 6),
248
+ provider: this.provider,
249
+ owner: rModel.owner,
250
+ description: rModel.description,
251
+ canTrain: supportFineTunning.has(fullName),
252
+ } as AIModel;
253
+ });
254
+
255
+ //set latest version
256
+ //const idx = models.findIndex(m => m.id === rModel.latest_version?.id);
257
+ //models[idx].name = rModel.name + "@latest"
258
+
259
+ return models;
260
+ }
261
+
262
+ async searchModels(params: ModelSearchPayload): Promise<AIModel[]> {
263
+ const res = await this.service.request("models/search", {
264
+ params: {
265
+ query: params.text,
266
+ },
267
+ });
268
+
269
+ const rModels = ((await res.json()) as any).models;
270
+
271
+ const models: AIModel[] = rModels.map((v: any) => {
272
+ return {
273
+ id: v.name,
274
+ name: v.name,
275
+ provider: this.provider,
276
+ owner: v.username,
277
+ description: v.description,
278
+ has_versions: true,
279
+ };
280
+ });
281
+
282
+ return models;
283
+ }
284
+
285
+ generateEmbeddings(content: string, model?: string): Promise<{ embeddings: number[], model: string; }> {
286
+ this.logger?.debug(`[Replicate] Generating embeddings for ${content} on ${model}`);
287
+ throw new Error("Method not implemented.");
288
+ }
289
+
290
+ }
291
+
292
+ function jobInfo(job: Prediction, modelName?: string): TrainingJob {
293
+ // 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled'
294
+ const jobStatus = job.status;
295
+ let details: string | undefined;
296
+ let status = TrainingJobStatus.running;
297
+ if (jobStatus === 'succeeded') {
298
+ status = TrainingJobStatus.succeeded;
299
+ } else if (jobStatus === 'failed') {
300
+ status = TrainingJobStatus.failed;
301
+ const error = job.error;
302
+ if (typeof error === 'string') {
303
+ details = error;
304
+ } else {
305
+ const parts = [];
306
+ if (error.code) {
307
+ parts.push(error.code + ' - ');
308
+ }
309
+ if (error.message) {
310
+ parts.push(error.message);
311
+ }
312
+ if (parts.length) {
313
+ details = parts.join(' ');
314
+ } else {
315
+ details = JSON.stringify(error);
316
+ }
317
+ }
318
+ details = job.error ? `${job.error.code} - ${job.error.message} ${job.error.param ? " [" + job.error.param + "]" : ""}` : "error";
319
+ } else if (jobStatus === 'canceled') {
320
+ status = TrainingJobStatus.cancelled;
321
+ } else {
322
+ status = TrainingJobStatus.running;
323
+ details = job.status;
324
+ }
325
+
326
+ return {
327
+ id: job.id,
328
+ status,
329
+ details,
330
+ model: modelName ? modelName + ':' + job.version : job.version
331
+ } as TrainingJob;
332
+
333
+ }
@@ -0,0 +1,17 @@
1
+ import { CompletionStream, ExecutionOptions, ExecutionResponse, PromptSegment } from "@llumiverse/core";
2
+ import { sleep, throwError } from "./utils.js";
3
+
4
+ export class TestErrorCompletionStream implements CompletionStream<PromptSegment[]> {
5
+
6
+ completion: ExecutionResponse<PromptSegment[]> | undefined;
7
+
8
+ constructor(public segments: PromptSegment[],
9
+ public options: ExecutionOptions) {
10
+ }
11
+ async *[Symbol.asyncIterator]() {
12
+ yield "Started TestError. Next we will thrown an error.\n";
13
+ sleep(1000);
14
+ throwError("Testing stream completion error.", this.segments);
15
+ }
16
+ }
17
+
@@ -0,0 +1,21 @@
1
+ import { CompletionStream, ExecutionOptions, ExecutionResponse, PromptSegment } from "@llumiverse/core";
2
+ import { createValidationErrorCompletion, sleep } from "./utils.js";
3
+
4
+
5
+ export class TestValidationErrorCompletionStream implements CompletionStream<PromptSegment[]> {
6
+
7
+ completion: ExecutionResponse<PromptSegment[]> | undefined;
8
+
9
+ constructor(public segments: PromptSegment[],
10
+ public options: ExecutionOptions) {
11
+ }
12
+ async *[Symbol.asyncIterator]() {
13
+ yield "Started TestValidationError.\n";
14
+ await sleep(1000);
15
+ yield "chunk1\n"
16
+ await sleep(1000);
17
+ yield "chunk2\n"
18
+ await sleep(1000);
19
+ this.completion = createValidationErrorCompletion(this.segments);
20
+ }
21
+ }
@@ -0,0 +1,102 @@
1
+ import { AIModel, AIModelStatus, CompletionStream, Driver, ExecutionOptions, ExecutionResponse, ModelType, PromptOptions, PromptSegment, TrainingJob } from "@llumiverse/core";
2
+ import { TestErrorCompletionStream } from "./TestErrorCompletionStream.js";
3
+ import { TestValidationErrorCompletionStream } from "./TestValidationErrorCompletionStream.js";
4
+ import { createValidationErrorCompletion, sleep, throwError } from "./utils.js";
5
+
6
+ export * from "./TestErrorCompletionStream.js";
7
+ export * from "./TestValidationErrorCompletionStream.js";
8
+
9
+ export enum TestDriverModels {
10
+ executionError = "execution-error",
11
+ validationError = "validation-error",
12
+ }
13
+
14
+ export class TestDriver implements Driver<PromptSegment[]> {
15
+ provider = "test";
16
+
17
+ createTrainingPrompt(): string {
18
+ throw new Error("Method not implemented.");
19
+ }
20
+
21
+ startTraining(): Promise<TrainingJob> {
22
+ throw new Error("Method not implemented.");
23
+ }
24
+
25
+ cancelTraining(): Promise<TrainingJob> {
26
+ throw new Error("Method not implemented.");
27
+ }
28
+
29
+ getTrainingJob(_jobId: string): Promise<TrainingJob> {
30
+ throw new Error("Method not implemented.");
31
+ }
32
+
33
+ createPrompt(segments: PromptSegment[], _opts: PromptOptions): PromptSegment[] {
34
+ return segments;
35
+ }
36
+ execute(segments: PromptSegment[], options: ExecutionOptions): Promise<ExecutionResponse<PromptSegment[]>> {
37
+ switch (options.model) {
38
+ case TestDriverModels.executionError:
39
+ return this.executeError(segments, options);
40
+ case TestDriverModels.validationError:
41
+ return this.executeValidationError(segments, options);
42
+ default:
43
+ throwError("[test driver] Unknown model: " + options.model, segments)
44
+ }
45
+ }
46
+ async stream(segments: PromptSegment[], options: ExecutionOptions): Promise<CompletionStream<PromptSegment[]>> {
47
+ switch (options.model) {
48
+ case TestDriverModels.executionError:
49
+ return new TestErrorCompletionStream(segments, options);
50
+ case TestDriverModels.validationError:
51
+ return new TestValidationErrorCompletionStream(segments, options);
52
+ default:
53
+ throwError("[test driver] Unknown model: " + options.model, segments)
54
+ }
55
+ }
56
+
57
+ async listTrainableModels(): Promise<AIModel<string>[]> {
58
+ return [];
59
+ }
60
+
61
+ async listModels(): Promise<AIModel<string>[]> {
62
+ return [
63
+ {
64
+ id: TestDriverModels.executionError,
65
+ name: "Execution Error",
66
+ type: ModelType.Test,
67
+ provider: this.provider,
68
+ status: AIModelStatus.Available,
69
+ description: "Test execution errors",
70
+ tags: [],
71
+ },
72
+ {
73
+ id: TestDriverModels.validationError,
74
+ name: "Validation Error",
75
+ type: ModelType.Test,
76
+ provider: this.provider,
77
+ status: AIModelStatus.Available,
78
+ description: "Test validation errors",
79
+ tags: [],
80
+ },
81
+ ]
82
+ }
83
+ validateConnection(): Promise<boolean> {
84
+ throw new Error("Method not implemented.");
85
+ }
86
+ generateEmbeddings(): Promise<{ embeddings: number[]; model: string; }> {
87
+ throw new Error("Method not implemented.");
88
+ }
89
+
90
+ // ============== execution error ==================
91
+ async executeError(segments: PromptSegment[], _options: ExecutionOptions): Promise<ExecutionResponse<PromptSegment[]>> {
92
+ await sleep(1000);
93
+ throwError("Testing stream completion error.", segments);
94
+ }
95
+ // ============== validation error ==================
96
+ async executeValidationError(segments: PromptSegment[], _options: ExecutionOptions): Promise<ExecutionResponse<PromptSegment[]>> {
97
+ await sleep(3000);
98
+ return createValidationErrorCompletion(segments);
99
+ }
100
+
101
+ }
102
+
@@ -0,0 +1,28 @@
1
+ import { ExecutionResponse, PromptSegment } from "@llumiverse/core";
2
+
3
+ export function throwError(message: string, prompt: PromptSegment[]): never {
4
+ const err = new Error(message);
5
+ (err as any).prompt = prompt;
6
+ throw err;
7
+ }
8
+
9
+ export function createValidationErrorCompletion(segments: PromptSegment[]) {
10
+ return {
11
+ result: "An invalid result",
12
+ prompt: segments,
13
+ execution_time: 3000,
14
+ error: {
15
+ code: "validation_error",
16
+ message: "Result cannot be validated!",
17
+ },
18
+ token_usage: {
19
+ result: 10,
20
+ prompt: 10,
21
+ total: 20,
22
+ },
23
+ } as ExecutionResponse<PromptSegment[]>;
24
+ }
25
+
26
+ export function sleep(ms: number) {
27
+ return new Promise(resolve => setTimeout(resolve, ms));
28
+ }
@@ -0,0 +1,105 @@
1
+ import { AIModel, AbstractDriver, Completion, DriverOptions, ExecutionOptions, PromptFormats } from "@llumiverse/core";
2
+ import { FetchClient, ServerSentEvent } from "api-fetch-client";
3
+ import { TogetherModelInfo } from "./interfaces.js";
4
+
5
+ interface TogetherAIDriverOptions extends DriverOptions {
6
+ apiKey: string;
7
+ }
8
+
9
+ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, string> {
10
+ provider: string;
11
+ apiKey: string;
12
+ defaultFormat: PromptFormats;
13
+ fetchClient: FetchClient;
14
+
15
+ constructor(options: TogetherAIDriverOptions) {
16
+ super(options);
17
+ this.provider = "togetherai";
18
+ this.defaultFormat = PromptFormats.genericTextLLM;
19
+ this.apiKey = options.apiKey;
20
+ this.fetchClient = new FetchClient('https://api.together.xyz').withHeaders({
21
+ authorization: `Bearer ${this.apiKey}`
22
+ });
23
+ }
24
+
25
+ async requestCompletion(prompt: string, options: ExecutionOptions): Promise<Completion<any>> {
26
+ const res = await this.fetchClient.post('/v1/completions', {
27
+ payload: {
28
+ model: options.model,
29
+ prompt: prompt,
30
+ max_tokens: options.max_tokens ?? 1024,
31
+ temperature: options.temperature ?? 0.7,
32
+ }
33
+ })
34
+
35
+ const text = res.choices[0]?.text ?? '';
36
+ const usage = res.usage || {};
37
+ return {
38
+ result: text,
39
+ token_usage: {
40
+ prompt: usage.prompt_tokens,
41
+ result: usage.completion_tokens,
42
+ total: usage.total_tokens,
43
+ },
44
+ }
45
+ }
46
+
47
+ async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<string>> {
48
+
49
+ const stream = await this.fetchClient.post('/v1/completions', {
50
+ payload: {
51
+ model: options.model,
52
+ prompt: prompt,
53
+ max_tokens: options.max_tokens ?? 1024,
54
+ temperature: options.temperature ?? 0.7,
55
+ stream: true,
56
+ },
57
+ reader: 'sse'
58
+ })
59
+
60
+ return stream.pipeThrough(new TransformStream<ServerSentEvent, string>({
61
+ transform(event: ServerSentEvent, controller) {
62
+ if (event.type === 'event' && event.data && event.data !== '[DONE]') {
63
+ try {
64
+ const data = JSON.parse(event.data);
65
+ controller.enqueue(data.choices[0]?.text ?? '');
66
+ } catch (err) {
67
+ // double check for the last event whicb is not a JSON - at this time togetherai returrns the string [DONE]
68
+ // do nothing - happens if data is not a JSON - the last event data is the [DONE] string
69
+ }
70
+ }
71
+ }
72
+ }));
73
+
74
+ }
75
+
76
+ async listModels(): Promise<AIModel<string>[]> {
77
+ const models: TogetherModelInfo[] = await this.fetchClient.get("/models/info");
78
+ // logObject('#### LIST MODELS RESULT IS', models[0]);
79
+
80
+ const aimodels = models.map(m => {
81
+ return {
82
+ id: m.name,
83
+ name: m.display_name,
84
+ description: m.description,
85
+ provider: this.provider,
86
+ formats: [PromptFormats.genericTextLLM],
87
+ }
88
+ });
89
+
90
+ return aimodels;
91
+
92
+ }
93
+
94
+ listTrainableModels(): Promise<AIModel<string>[]> {
95
+ throw new Error("Method not implemented.");
96
+ }
97
+ validateConnection(): Promise<boolean> {
98
+ throw new Error("Method not implemented.");
99
+ }
100
+ //@ts-ignore
101
+ generateEmbeddings(content: string, model?: string | undefined): Promise<{ embeddings: number[]; model: string; }> {
102
+ throw new Error("Method not implemented.");
103
+ }
104
+
105
+ }
@@ -0,0 +1,88 @@
1
+ interface ModelInstanceConfig {
2
+ appearsIn: any[];
3
+ order: number;
4
+ }
5
+
6
+ interface Config {
7
+ stop: string[];
8
+ prompt_format: string;
9
+ chat_template: string;
10
+ }
11
+
12
+ interface Pricing {
13
+ input: number;
14
+ output: number;
15
+ hourly: number;
16
+ }
17
+
18
+ interface Instance {
19
+ avzone: string;
20
+ cluster: string;
21
+ }
22
+
23
+ interface Ask {
24
+ [key: string]: number;
25
+ }
26
+
27
+ interface Gpu {
28
+ [key: string]: number;
29
+ }
30
+
31
+ interface Price {
32
+ base: number;
33
+ finetune: number;
34
+ hourly: number;
35
+ input: number;
36
+ output: number;
37
+ }
38
+
39
+ interface Stat {
40
+ avzone: string;
41
+ cluster: string;
42
+ capacity: number;
43
+ qps: number;
44
+ throughput_in: number;
45
+ throughput_out: number;
46
+ error_rate: number;
47
+ retry_rate: number;
48
+ }
49
+
50
+ interface Depth {
51
+ num_asks: number;
52
+ num_bids: number;
53
+ num_running: number;
54
+ asks: Ask;
55
+ asks_updated: string;
56
+ gpus: Gpu;
57
+ qps: number;
58
+ permit_required: boolean;
59
+ price: Price;
60
+ throughput_in: number;
61
+ throughput_out: number;
62
+ stats: Stat[];
63
+ }
64
+
65
+ export interface TogetherModelInfo {
66
+ modelInstanceConfig: ModelInstanceConfig;
67
+ _id: string;
68
+ name: string;
69
+ display_name: string;
70
+ display_type: string;
71
+ description: string;
72
+ license: string;
73
+ creator_organization: string;
74
+ hardware_label: string;
75
+ num_parameters: number;
76
+ show_in_playground: boolean;
77
+ isFeaturedModel: boolean;
78
+ context_length: number;
79
+ config: Config;
80
+ pricing: Pricing;
81
+ created_at: string;
82
+ update_at: string;
83
+ instances: Instance[];
84
+ access: string;
85
+ link: string;
86
+ descriptionLink: string;
87
+ depth: Depth;
88
+ }