@llumiverse/drivers 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. package/lib/cjs/bedrock/index.js +338 -0
  2. package/lib/cjs/bedrock/index.js.map +1 -0
  3. package/lib/cjs/bedrock/s3.js +61 -0
  4. package/lib/cjs/bedrock/s3.js.map +1 -0
  5. package/lib/cjs/huggingface_ie.js +181 -0
  6. package/lib/cjs/huggingface_ie.js.map +1 -0
  7. package/lib/cjs/index.js +24 -0
  8. package/lib/cjs/index.js.map +1 -0
  9. package/lib/cjs/openai.js +205 -0
  10. package/lib/cjs/openai.js.map +1 -0
  11. package/lib/cjs/package.json +3 -0
  12. package/lib/cjs/replicate.js +290 -0
  13. package/lib/cjs/replicate.js.map +1 -0
  14. package/lib/cjs/test/TestErrorCompletionStream.js +20 -0
  15. package/lib/cjs/test/TestErrorCompletionStream.js.map +1 -0
  16. package/lib/cjs/test/TestValidationErrorCompletionStream.js +24 -0
  17. package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -0
  18. package/lib/cjs/test/index.js +109 -0
  19. package/lib/cjs/test/index.js.map +1 -0
  20. package/lib/cjs/test/utils.js +31 -0
  21. package/lib/cjs/test/utils.js.map +1 -0
  22. package/lib/cjs/togetherai/index.js +92 -0
  23. package/lib/cjs/togetherai/index.js.map +1 -0
  24. package/lib/cjs/togetherai/interfaces.js +3 -0
  25. package/lib/cjs/togetherai/interfaces.js.map +1 -0
  26. package/lib/cjs/vertexai/debug.js +13 -0
  27. package/lib/cjs/vertexai/debug.js.map +1 -0
  28. package/lib/cjs/vertexai/index.js +80 -0
  29. package/lib/cjs/vertexai/index.js.map +1 -0
  30. package/lib/cjs/vertexai/models/codey-chat.js +65 -0
  31. package/lib/cjs/vertexai/models/codey-chat.js.map +1 -0
  32. package/lib/cjs/vertexai/models/codey-text.js +35 -0
  33. package/lib/cjs/vertexai/models/codey-text.js.map +1 -0
  34. package/lib/cjs/vertexai/models/gemini.js +140 -0
  35. package/lib/cjs/vertexai/models/gemini.js.map +1 -0
  36. package/lib/cjs/vertexai/models/palm-model-base.js +65 -0
  37. package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -0
  38. package/lib/cjs/vertexai/models/palm2-chat.js +65 -0
  39. package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -0
  40. package/lib/cjs/vertexai/models/palm2-text.js +35 -0
  41. package/lib/cjs/vertexai/models/palm2-text.js.map +1 -0
  42. package/lib/cjs/vertexai/models.js +93 -0
  43. package/lib/cjs/vertexai/models.js.map +1 -0
  44. package/lib/cjs/vertexai/utils/prompts.js +52 -0
  45. package/lib/cjs/vertexai/utils/prompts.js.map +1 -0
  46. package/lib/cjs/vertexai/utils/tensor.js +87 -0
  47. package/lib/cjs/vertexai/utils/tensor.js.map +1 -0
  48. package/lib/esm/bedrock/index.js +331 -0
  49. package/lib/esm/bedrock/index.js.map +1 -0
  50. package/lib/esm/bedrock/s3.js +53 -0
  51. package/lib/esm/bedrock/s3.js.map +1 -0
  52. package/lib/esm/huggingface_ie.js +177 -0
  53. package/lib/esm/huggingface_ie.js.map +1 -0
  54. package/lib/esm/index.js +8 -0
  55. package/lib/esm/index.js.map +1 -0
  56. package/lib/esm/openai.js +198 -0
  57. package/lib/esm/openai.js.map +1 -0
  58. package/lib/esm/replicate.js +283 -0
  59. package/lib/esm/replicate.js.map +1 -0
  60. package/lib/esm/test/TestErrorCompletionStream.js +16 -0
  61. package/lib/esm/test/TestErrorCompletionStream.js.map +1 -0
  62. package/lib/esm/test/TestValidationErrorCompletionStream.js +20 -0
  63. package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -0
  64. package/lib/esm/test/index.js +91 -0
  65. package/lib/esm/test/index.js.map +1 -0
  66. package/lib/esm/test/utils.js +25 -0
  67. package/lib/esm/test/utils.js.map +1 -0
  68. package/lib/esm/togetherai/index.js +88 -0
  69. package/lib/esm/togetherai/index.js.map +1 -0
  70. package/lib/esm/togetherai/interfaces.js +2 -0
  71. package/lib/esm/togetherai/interfaces.js.map +1 -0
  72. package/lib/esm/vertexai/debug.js +6 -0
  73. package/lib/esm/vertexai/debug.js.map +1 -0
  74. package/lib/esm/vertexai/index.js +76 -0
  75. package/lib/esm/vertexai/index.js.map +1 -0
  76. package/lib/esm/vertexai/models/codey-chat.js +61 -0
  77. package/lib/esm/vertexai/models/codey-chat.js.map +1 -0
  78. package/lib/esm/vertexai/models/codey-text.js +31 -0
  79. package/lib/esm/vertexai/models/codey-text.js.map +1 -0
  80. package/lib/esm/vertexai/models/gemini.js +136 -0
  81. package/lib/esm/vertexai/models/gemini.js.map +1 -0
  82. package/lib/esm/vertexai/models/palm-model-base.js +61 -0
  83. package/lib/esm/vertexai/models/palm-model-base.js.map +1 -0
  84. package/lib/esm/vertexai/models/palm2-chat.js +61 -0
  85. package/lib/esm/vertexai/models/palm2-chat.js.map +1 -0
  86. package/lib/esm/vertexai/models/palm2-text.js +31 -0
  87. package/lib/esm/vertexai/models/palm2-text.js.map +1 -0
  88. package/lib/esm/vertexai/models.js +87 -0
  89. package/lib/esm/vertexai/models.js.map +1 -0
  90. package/lib/esm/vertexai/utils/prompts.js +47 -0
  91. package/lib/esm/vertexai/utils/prompts.js.map +1 -0
  92. package/lib/esm/vertexai/utils/tensor.js +82 -0
  93. package/lib/esm/vertexai/utils/tensor.js.map +1 -0
  94. package/lib/types/bedrock/index.d.ts +88 -0
  95. package/lib/types/bedrock/index.d.ts.map +1 -0
  96. package/lib/types/bedrock/s3.d.ts +20 -0
  97. package/lib/types/bedrock/s3.d.ts.map +1 -0
  98. package/lib/types/huggingface_ie.d.ts +36 -0
  99. package/lib/types/huggingface_ie.d.ts.map +1 -0
  100. package/lib/types/index.d.ts +8 -0
  101. package/lib/types/index.d.ts.map +1 -0
  102. package/lib/types/openai.d.ts +36 -0
  103. package/lib/types/openai.d.ts.map +1 -0
  104. package/lib/types/replicate.d.ts +52 -0
  105. package/lib/types/replicate.d.ts.map +1 -0
  106. package/lib/types/test/TestErrorCompletionStream.d.ts +9 -0
  107. package/lib/types/test/TestErrorCompletionStream.d.ts.map +1 -0
  108. package/lib/types/test/TestValidationErrorCompletionStream.d.ts +9 -0
  109. package/lib/types/test/TestValidationErrorCompletionStream.d.ts.map +1 -0
  110. package/lib/types/test/index.d.ts +27 -0
  111. package/lib/types/test/index.d.ts.map +1 -0
  112. package/lib/types/test/utils.d.ts +5 -0
  113. package/lib/types/test/utils.d.ts.map +1 -0
  114. package/lib/types/togetherai/index.d.ts +23 -0
  115. package/lib/types/togetherai/index.d.ts.map +1 -0
  116. package/lib/types/togetherai/interfaces.d.ts +81 -0
  117. package/lib/types/togetherai/interfaces.d.ts.map +1 -0
  118. package/lib/types/vertexai/debug.d.ts +2 -0
  119. package/lib/types/vertexai/debug.d.ts.map +1 -0
  120. package/lib/types/vertexai/index.d.ts +26 -0
  121. package/lib/types/vertexai/index.d.ts.map +1 -0
  122. package/lib/types/vertexai/models/codey-chat.d.ts +51 -0
  123. package/lib/types/vertexai/models/codey-chat.d.ts.map +1 -0
  124. package/lib/types/vertexai/models/codey-text.d.ts +39 -0
  125. package/lib/types/vertexai/models/codey-text.d.ts.map +1 -0
  126. package/lib/types/vertexai/models/gemini.d.ts +11 -0
  127. package/lib/types/vertexai/models/gemini.d.ts.map +1 -0
  128. package/lib/types/vertexai/models/palm-model-base.d.ts +47 -0
  129. package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -0
  130. package/lib/types/vertexai/models/palm2-chat.d.ts +61 -0
  131. package/lib/types/vertexai/models/palm2-chat.d.ts.map +1 -0
  132. package/lib/types/vertexai/models/palm2-text.d.ts +39 -0
  133. package/lib/types/vertexai/models/palm2-text.d.ts.map +1 -0
  134. package/lib/types/vertexai/models.d.ts +14 -0
  135. package/lib/types/vertexai/models.d.ts.map +1 -0
  136. package/lib/types/vertexai/utils/prompts.d.ts +20 -0
  137. package/lib/types/vertexai/utils/prompts.d.ts.map +1 -0
  138. package/lib/types/vertexai/utils/tensor.d.ts +6 -0
  139. package/lib/types/vertexai/utils/tensor.d.ts.map +1 -0
  140. package/package.json +72 -0
  141. package/src/bedrock/index.ts +456 -0
  142. package/src/bedrock/s3.ts +62 -0
  143. package/src/huggingface_ie.ts +269 -0
  144. package/src/index.ts +7 -0
  145. package/src/openai.ts +254 -0
  146. package/src/replicate.ts +333 -0
  147. package/src/test/TestErrorCompletionStream.ts +17 -0
  148. package/src/test/TestValidationErrorCompletionStream.ts +21 -0
  149. package/src/test/index.ts +102 -0
  150. package/src/test/utils.ts +28 -0
  151. package/src/togetherai/index.ts +105 -0
  152. package/src/togetherai/interfaces.ts +88 -0
  153. package/src/vertexai/README.md +257 -0
  154. package/src/vertexai/debug.ts +6 -0
  155. package/src/vertexai/index.ts +99 -0
  156. package/src/vertexai/models/codey-chat.ts +115 -0
  157. package/src/vertexai/models/codey-text.ts +69 -0
  158. package/src/vertexai/models/gemini.ts +152 -0
  159. package/src/vertexai/models/palm-model-base.ts +122 -0
  160. package/src/vertexai/models/palm2-chat.ts +119 -0
  161. package/src/vertexai/models/palm2-text.ts +69 -0
  162. package/src/vertexai/models.ts +104 -0
  163. package/src/vertexai/utils/prompts.ts +66 -0
  164. package/src/vertexai/utils/tensor.ts +82 -0
@@ -0,0 +1,456 @@
1
+ import { Bedrock, CreateModelCustomizationJobCommand, FoundationModelSummary, GetModelCustomizationJobCommand, GetModelCustomizationJobCommandOutput, ModelCustomizationJobStatus, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock";
2
+ import { BedrockRuntime, InvokeModelCommandOutput, ResponseStream } from "@aws-sdk/client-bedrock-runtime";
3
+ import { S3Client } from "@aws-sdk/client-s3";
4
+ import { AIModel, AbstractDriver, BuiltinProviders, Completion, DataSource, DriverOptions, ExecutionOptions, ModelSearchPayload, PromptFormats, TrainingJob, TrainingJobStatus, TrainingOptions } from "@llumiverse/core";
5
+ import { transformAsyncIterator } from "@llumiverse/core/async";
6
+ import { AwsCredentialIdentity, Provider } from "@smithy/types";
7
+ import mnemonist from "mnemonist";
8
+ import { forceUploadFile } from "./s3.js";
9
+
10
+ const { LRUCache } = mnemonist;
11
+
12
+ const supportStreamingCache = new LRUCache<string, boolean>(4096);
13
+
14
+ export interface BedrockModelCapabilities {
15
+ name: string;
16
+ canStream: boolean;
17
+ }
18
+
19
+ export interface BedrockDriverOptions extends DriverOptions {
20
+ /**
21
+ * The AWS region
22
+ */
23
+ region: string;
24
+ /**
25
+ * Tthe bucket name to be used for training.
26
+ * It will be created oif nto already exixts
27
+ */
28
+ training_bucket?: string;
29
+
30
+ /**
31
+ * The role ARN to be used for training
32
+ */
33
+ training_role_arn?: string;
34
+
35
+ /**
36
+ * The credentials to use to access AWS
37
+ */
38
+ credentials?: AwsCredentialIdentity | Provider<AwsCredentialIdentity>;
39
+ }
40
+
41
+ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, string> {
42
+
43
+ provider = BuiltinProviders.bedrock;
44
+
45
+ private _executor?: BedrockRuntime;
46
+ private _service?: Bedrock;
47
+
48
+ defaultFormat = PromptFormats.genericTextLLM;
49
+
50
+ constructor(options: BedrockDriverOptions) {
51
+ super(options);
52
+ if (!options.region) {
53
+ throw new Error("No region found. Set the region in the environment's endpoint URL.");
54
+ }
55
+ }
56
+
57
+ getExecutor() {
58
+ if (!this._executor) {
59
+ this._executor = new BedrockRuntime({
60
+ region: this.options.region,
61
+ credentials: this.options.credentials,
62
+ });
63
+ }
64
+ return this._executor;
65
+ }
66
+
67
+ getService() {
68
+ if (!this._service) {
69
+ this._service = new Bedrock({
70
+ region: this.options.region,
71
+ credentials: this.options.credentials,
72
+ });
73
+ }
74
+ return this._service;
75
+ }
76
+
77
+ extractDataFromResponse(prompt: string, response: InvokeModelCommandOutput): Completion {
78
+
79
+ const decoder = new TextDecoder();
80
+ const body = decoder.decode(response.body);
81
+ const result = JSON.parse(body);
82
+
83
+ const getText = () => {
84
+ if (result.completion) {
85
+ return result.completion;
86
+ } else if (result.generation) {
87
+ return result.generation;
88
+ } else if (result.generations) {
89
+ return result.generations[0].text;
90
+ } else if (result.completions) {
91
+ //A21
92
+ return result.completions[0].data?.text;
93
+ } else {
94
+ result.toString();
95
+ }
96
+ };
97
+
98
+ const text = getText();
99
+
100
+ return {
101
+ result: text,
102
+ token_usage: {
103
+ result: text?.length,
104
+ prompt: prompt.length,
105
+ total: text?.length + prompt.length,
106
+ }
107
+ }
108
+ }
109
+
110
+ async requestCompletion(prompt: string, options: ExecutionOptions): Promise<Completion> {
111
+
112
+ const payload = this.preparePayload(prompt, options);
113
+ const executor = this.getExecutor();
114
+ const res = await executor.invokeModel({
115
+ modelId: options.model,
116
+ contentType: "application/json",
117
+ body: JSON.stringify(payload),
118
+ });
119
+
120
+ return this.extractDataFromResponse(prompt, res);
121
+ }
122
+
123
+ protected async canStream(options: ExecutionOptions): Promise<boolean> {
124
+ let canStream = supportStreamingCache.get(options.model);
125
+ if (canStream == null) {
126
+ const response = await this.getService().getFoundationModel({
127
+ modelIdentifier: options.model
128
+ });
129
+ canStream = response.modelDetails?.responseStreamingSupported ?? false;
130
+ supportStreamingCache.set(options.model, canStream);
131
+ }
132
+ return canStream;
133
+ }
134
+
135
+ async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<string>> {
136
+ const payload = this.preparePayload(prompt, options);
137
+ const executor = this.getExecutor();
138
+ return executor.invokeModelWithResponseStream({
139
+ modelId: options.model,
140
+ contentType: "application/json",
141
+ body: JSON.stringify(payload),
142
+ }).then((res) => {
143
+
144
+ if (!res.body) {
145
+ throw new Error("Body not found");
146
+ }
147
+ const decoder = new TextDecoder();
148
+
149
+ return transformAsyncIterator(res.body, (stream: ResponseStream) => {
150
+ const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
151
+ if (segment.completion) {
152
+ return segment.completion;
153
+ } else if (segment.completions) {
154
+ return segment.completions[0].data?.text;
155
+ } else if (segment.generation) {
156
+ return segment.generation;
157
+ } else if (segment.generations) {
158
+ return segment.generations[0].text;
159
+ } else {
160
+ segment.toString();
161
+ }
162
+
163
+ });
164
+
165
+ }).catch((err) => {
166
+ this.logger.error("[Bedrock] Failed to stream", err);
167
+ throw err;
168
+ });
169
+ }
170
+
171
+
172
+
173
+ preparePayload(prompt: string, options: ExecutionOptions) {
174
+
175
+ //split arn on / should give provider
176
+ //TODO: check if works with custom models
177
+ //const provider = options.model.split("/")[0];
178
+ const contains = (str: string, substr: string) => str.indexOf(substr) !== -1;
179
+
180
+ if (contains(options.model, "meta")) {
181
+ return {
182
+ prompt,
183
+ temperature: options.temperature,
184
+ max_gen_len: options.max_tokens,
185
+ } as LLama2RequestPayload
186
+ } else if (contains(options.model, "anthropic")) {
187
+ return {
188
+ prompt: prompt,
189
+ temperature: options.temperature,
190
+ max_tokens_to_sample: options.max_tokens ?? 256,
191
+ } as ClaudeRequestPayload;
192
+ } else if (contains(options.model, "ai21")) {
193
+ return {
194
+ prompt: prompt,
195
+ temperature: options.temperature,
196
+ maxTokens: options.max_tokens,
197
+ } as AI21RequestPayload;
198
+ } else if (contains(options.model, "cohere")) {
199
+ return {
200
+ prompt: prompt,
201
+ temperature: options.temperature,
202
+ max_tokens: options.max_tokens,
203
+ p: 0.9,
204
+ } as CohereRequestPayload;
205
+ } else if (contains(options.model, "amazon")) {
206
+ return {
207
+ inputText: prompt,
208
+ textGenerationConfig: {
209
+ temperature: options.temperature,
210
+ topP: 0.9,
211
+ maxTokenCount: options.max_tokens,
212
+ stopSequences: ["\n"],
213
+ },
214
+ } as AmazonRequestPayload;
215
+ } else {
216
+ throw new Error("Cannot prepare payload for unknown provider: " + options.model);
217
+ }
218
+
219
+ }
220
+
221
+ async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {
222
+
223
+ //convert options.params to Record<string, string>
224
+ const params: Record<string, string> = {};
225
+ for (const [key, value] of Object.entries(options.params || {})) {
226
+ params[key] = String(value);
227
+ }
228
+
229
+ if (!this.options.training_bucket) {
230
+ throw new Error("Training cannot nbe used since the 'training_bucket' property was not specified in driver options")
231
+ }
232
+
233
+ const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
234
+ const upload = await forceUploadFile(s3, dataset.getStream(), this.options.training_bucket, dataset.name);
235
+
236
+ const service = this.getService();
237
+ const response = await service.send(new CreateModelCustomizationJobCommand({
238
+ jobName: options.name + "-job",
239
+ customModelName: options.name,
240
+ roleArn: this.options.training_role_arn || undefined,
241
+ baseModelIdentifier: options.model,
242
+ clientRequestToken: "llumiverse-" + Date.now(),
243
+ trainingDataConfig: {
244
+ s3Uri: `s3://${upload.Bucket}/${upload.Key}`,
245
+ },
246
+ outputDataConfig: undefined,
247
+ hyperParameters: params,
248
+ //TODO not supported?
249
+ //customizationType: "FINE_TUNING",
250
+ }));
251
+
252
+ const job = await service.send(new GetModelCustomizationJobCommand({
253
+ jobIdentifier: response.jobArn
254
+ }));
255
+
256
+ return jobInfo(job, response.jobArn!);
257
+ }
258
+
259
+ async cancelTraining(jobId: string): Promise<TrainingJob> {
260
+ const service = this.getService();
261
+ await service.send(new StopModelCustomizationJobCommand({
262
+ jobIdentifier: jobId
263
+ }));
264
+ const job = await service.send(new GetModelCustomizationJobCommand({
265
+ jobIdentifier: jobId
266
+ }));
267
+
268
+ return jobInfo(job, jobId);
269
+ }
270
+
271
+ async getTrainingJob(jobId: string): Promise<TrainingJob> {
272
+ const service = this.getService();
273
+ const job = await service.send(new GetModelCustomizationJobCommand({
274
+ jobIdentifier: jobId
275
+ }));
276
+
277
+ return jobInfo(job, jobId);
278
+ }
279
+
280
+ // ===================== management API ==================
281
+
282
+ async validateConnection(): Promise<boolean> {
283
+ const service = this.getService();
284
+ this.logger.debug("[Bedrock] validating connection", service.config.credentials.name);
285
+ //return true as if the client has been initialized, it means the connection is valid
286
+ return true;
287
+ }
288
+
289
+
290
+ async listTrainableModels(): Promise<AIModel<string>[]> {
291
+ this.logger.debug("[Bedrock] listing trainable models");
292
+ return this._listModels(m => m.customizationsSupported ? m.customizationsSupported.includes("FINE_TUNING") : false);
293
+ }
294
+
295
+ async listModels(_params: ModelSearchPayload): Promise<AIModel[]> {
296
+ this.logger.debug("[Bedrock] listing models");
297
+ // exclude trainable models since they are not executable
298
+ const filter = (m: FoundationModelSummary) => m.inferenceTypesSupported?.includes("ON_DEMAND") ?? false;
299
+ return this._listModels(filter);
300
+ }
301
+
302
+ async _listModels(foundationFilter?: (m: FoundationModelSummary) => boolean): Promise<AIModel[]> {
303
+ const service = this.getService();
304
+ const [foundationals, customs] = await Promise.all([
305
+ service.listFoundationModels({}),
306
+ service.listCustomModels({}),
307
+ ]);
308
+
309
+ if (!foundationals.modelSummaries) {
310
+ throw new Error("Foundation models not found");
311
+ }
312
+
313
+ let fmodels = foundationals.modelSummaries || [];
314
+ if (foundationFilter) {
315
+ fmodels = fmodels.filter(foundationFilter);
316
+ }
317
+
318
+ const aimodels: AIModel[] = fmodels.map((m) => {
319
+
320
+ if (!m.modelId) {
321
+ throw new Error("modelId not found");
322
+ }
323
+
324
+ const model: AIModel = {
325
+ id: m.modelArn ?? m.modelId,
326
+ name: `${m.providerName} ${m.modelName}`,
327
+ provider: this.provider,
328
+ description: `id: ${m.modelId}`,
329
+ owner: m.providerName,
330
+ canStream: m.responseStreamingSupported ?? false,
331
+ tags: m.outputModalities ?? [],
332
+ };
333
+
334
+ return model;
335
+ });
336
+
337
+ //add custom models
338
+ if (customs.modelSummaries) {
339
+ customs.modelSummaries.forEach((m) => {
340
+
341
+ if (!m.modelArn) {
342
+ throw new Error("Model ID not found");
343
+ }
344
+
345
+ const model: AIModel = {
346
+ id: m.modelArn,
347
+ name: m.modelName ?? m.modelArn,
348
+ provider: this.provider,
349
+ description: `Custom model from ${m.baseModelName}`,
350
+ isCustom: true,
351
+ };
352
+
353
+ aimodels.push(model);
354
+ this.validateConnection;
355
+ });
356
+ }
357
+
358
+ return aimodels;
359
+ }
360
+
361
+ async generateEmbeddings(content: string, model: string = "amazon.titan-embed-text-v1"): Promise<{ embeddings: number[], model: string; }> {
362
+
363
+ this.logger.info("[Bedrock] Generating embeddings with model " + model);
364
+
365
+ const executor = this.getExecutor();
366
+ const res = await executor.invokeModel(
367
+ {
368
+ modelId: model,
369
+ contentType: "text/plain",
370
+ body: content,
371
+ }
372
+ );
373
+
374
+ const decoder = new TextDecoder();
375
+ const body = decoder.decode(res.body);
376
+
377
+ const result = JSON.parse(body);
378
+
379
+ if (!result.embedding) {
380
+ throw new Error("Embeddings not found");
381
+ }
382
+
383
+ return result.embedding;
384
+
385
+ }
386
+
387
+ }
388
+
389
+
390
+
391
+ interface LLama2RequestPayload {
392
+ prompt: string;
393
+ temperature: number;
394
+ top_p?: number;
395
+ max_gen_len: number;
396
+ }
397
+
398
+ interface ClaudeRequestPayload {
399
+ prompt: string;
400
+ temperature?: number;
401
+ max_tokens_to_sample?: number;
402
+ top_p?: number,
403
+ top_k?: number,
404
+ stop_sequences?: [string];
405
+ }
406
+
407
+ interface AI21RequestPayload {
408
+ prompt: string;
409
+ temperature: number;
410
+ maxTokens: number;
411
+ }
412
+
413
+ interface CohereRequestPayload {
414
+ prompt: string;
415
+ temperature: number;
416
+ max_tokens?: number;
417
+ p?: number;
418
+ }
419
+
420
+ interface AmazonRequestPayload {
421
+ inputText: string,
422
+ textGenerationConfig: {
423
+ temperature: number,
424
+ topP: number,
425
+ maxTokenCount: number,
426
+ stopSequences: [string];
427
+ };
428
+ }
429
+
430
+
431
+ function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): TrainingJob {
432
+ const jobStatus = job.status;
433
+ let status = TrainingJobStatus.running;
434
+ let details: string | undefined;
435
+ if (jobStatus === ModelCustomizationJobStatus.COMPLETED) {
436
+ status = TrainingJobStatus.succeeded;
437
+ } else if (jobStatus === ModelCustomizationJobStatus.FAILED) {
438
+ status = TrainingJobStatus.failed;
439
+ details = job.failureMessage || "error";
440
+ } else if (jobStatus === ModelCustomizationJobStatus.STOPPED) {
441
+ status = TrainingJobStatus.cancelled;
442
+ } else {
443
+ status = TrainingJobStatus.running;
444
+ details = jobStatus;
445
+ }
446
+ job.baseModelArn
447
+ return {
448
+ id: jobId,
449
+ model: job.outputModelArn,
450
+ status,
451
+ details
452
+ }
453
+ }
454
+
455
+
456
+
@@ -0,0 +1,62 @@
1
+ import { CreateBucketCommand, HeadBucketCommand, S3Client } from "@aws-sdk/client-s3";
2
+ import { Progress, Upload } from "@aws-sdk/lib-storage";
3
+ import { Readable } from "stream";
4
+
5
+ export async function doesBucketExist(s3: S3Client, bucketName: string): Promise<boolean> {
6
+ try {
7
+ await s3.send(new HeadBucketCommand({ Bucket: bucketName }));
8
+ return true;
9
+ } catch (err: any) {
10
+ if (err.name === 'NotFound') {
11
+ return false;
12
+ }
13
+ throw err;
14
+ }
15
+ }
16
+
17
+ export function createBucket(s3: S3Client, bucketName: string) {
18
+ return s3.send(new CreateBucketCommand({
19
+ Bucket: bucketName
20
+ }));
21
+ }
22
+
23
+
24
+ export async function tryCreateBucket(s3: S3Client, bucketName: string) {
25
+ const exists = await doesBucketExist(s3, bucketName);
26
+ if (!exists) {
27
+ return createBucket(s3, bucketName);
28
+ }
29
+ }
30
+
31
+
32
+ export async function uploadFile(s3: S3Client, source: Readable | string | Buffer, bucketName: string, file: string, onProgress?: (progress: Progress) => void) {
33
+
34
+ const upload = new Upload({
35
+ client: s3,
36
+ params: {
37
+ Bucket: bucketName,
38
+ Key: file,
39
+ Body: source,
40
+ }
41
+ });
42
+
43
+ onProgress && upload.on("httpUploadProgress", onProgress);
44
+
45
+ const result = await upload.done();
46
+ return result;
47
+ }
48
+
49
+ /**
50
+ * Create the bucket if not already exists and then upload the file.
51
+ * @param s3
52
+ * @param source
53
+ * @param bucketName
54
+ * @param file
55
+ * @param onProgress
56
+ * @returns
57
+ */
58
+ export async function forceUploadFile(s3: S3Client, source: Readable | string | Buffer, bucketName: string, file: string, onProgress?: (progress: Progress) => void) {
59
+ // make sure the bucket exists
60
+ await tryCreateBucket(s3, bucketName);
61
+ return uploadFile(s3, source, bucketName, file, onProgress);
62
+ }