@llumiverse/drivers 0.12.3 → 0.14.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. package/README.md +13 -11
  2. package/lib/cjs/bedrock/index.js +63 -18
  3. package/lib/cjs/bedrock/index.js.map +1 -1
  4. package/lib/cjs/bedrock/payloads.js +3 -0
  5. package/lib/cjs/bedrock/payloads.js.map +1 -0
  6. package/lib/cjs/bedrock/s3.js +5 -6
  7. package/lib/cjs/bedrock/s3.js.map +1 -1
  8. package/lib/cjs/groq/index.js +6 -6
  9. package/lib/cjs/groq/index.js.map +1 -1
  10. package/lib/cjs/huggingface_ie.js.map +1 -1
  11. package/lib/cjs/index.js +4 -2
  12. package/lib/cjs/index.js.map +1 -1
  13. package/lib/cjs/mistral/index.js +5 -5
  14. package/lib/cjs/mistral/index.js.map +1 -1
  15. package/lib/cjs/openai/azure.js +31 -0
  16. package/lib/cjs/openai/azure.js.map +1 -0
  17. package/lib/cjs/{openai.js → openai/index.js} +17 -27
  18. package/lib/cjs/openai/index.js.map +1 -0
  19. package/lib/cjs/openai/openai.js +21 -0
  20. package/lib/cjs/openai/openai.js.map +1 -0
  21. package/lib/cjs/replicate.js +1 -1
  22. package/lib/cjs/replicate.js.map +1 -1
  23. package/lib/cjs/test/index.js +1 -1
  24. package/lib/cjs/test/index.js.map +1 -1
  25. package/lib/cjs/test/utils.js +3 -4
  26. package/lib/cjs/test/utils.js.map +1 -1
  27. package/lib/cjs/togetherai/index.js +2 -2
  28. package/lib/cjs/togetherai/index.js.map +1 -1
  29. package/lib/cjs/vertexai/debug.js +1 -2
  30. package/lib/cjs/vertexai/debug.js.map +1 -1
  31. package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -2
  32. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
  33. package/lib/cjs/vertexai/index.js +3 -2
  34. package/lib/cjs/vertexai/index.js.map +1 -1
  35. package/lib/cjs/vertexai/models/codey-chat.js +3 -3
  36. package/lib/cjs/vertexai/models/codey-chat.js.map +1 -1
  37. package/lib/cjs/vertexai/models/codey-text.js +2 -2
  38. package/lib/cjs/vertexai/models/codey-text.js.map +1 -1
  39. package/lib/cjs/vertexai/models/gemini.js +43 -27
  40. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  41. package/lib/cjs/vertexai/models/palm-model-base.js +1 -1
  42. package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -1
  43. package/lib/cjs/vertexai/models/palm2-chat.js +3 -3
  44. package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -1
  45. package/lib/cjs/vertexai/models/palm2-text.js +2 -2
  46. package/lib/cjs/vertexai/models/palm2-text.js.map +1 -1
  47. package/lib/cjs/vertexai/models.js +39 -17
  48. package/lib/cjs/vertexai/models.js.map +1 -1
  49. package/lib/cjs/vertexai/utils/tensor.js +2 -3
  50. package/lib/cjs/vertexai/utils/tensor.js.map +1 -1
  51. package/lib/cjs/watsonx/index.js +124 -0
  52. package/lib/cjs/watsonx/index.js.map +1 -0
  53. package/lib/cjs/watsonx/interfaces.js +3 -0
  54. package/lib/cjs/watsonx/interfaces.js.map +1 -0
  55. package/lib/esm/bedrock/index.js +63 -18
  56. package/lib/esm/bedrock/index.js.map +1 -1
  57. package/lib/esm/bedrock/payloads.js +2 -0
  58. package/lib/esm/bedrock/payloads.js.map +1 -0
  59. package/lib/esm/groq/index.js +7 -7
  60. package/lib/esm/groq/index.js.map +1 -1
  61. package/lib/esm/huggingface_ie.js.map +1 -1
  62. package/lib/esm/index.js +4 -2
  63. package/lib/esm/index.js.map +1 -1
  64. package/lib/esm/mistral/index.js +6 -6
  65. package/lib/esm/mistral/index.js.map +1 -1
  66. package/lib/esm/openai/azure.js +27 -0
  67. package/lib/esm/openai/azure.js.map +1 -0
  68. package/lib/esm/{openai.js → openai/index.js} +16 -23
  69. package/lib/esm/openai/index.js.map +1 -0
  70. package/lib/esm/openai/openai.js +14 -0
  71. package/lib/esm/openai/openai.js.map +1 -0
  72. package/lib/esm/replicate.js +1 -1
  73. package/lib/esm/replicate.js.map +1 -1
  74. package/lib/esm/test/index.js +1 -1
  75. package/lib/esm/test/index.js.map +1 -1
  76. package/lib/esm/togetherai/index.js +2 -2
  77. package/lib/esm/togetherai/index.js.map +1 -1
  78. package/lib/esm/vertexai/index.js +3 -2
  79. package/lib/esm/vertexai/index.js.map +1 -1
  80. package/lib/esm/vertexai/models/codey-chat.js +3 -3
  81. package/lib/esm/vertexai/models/codey-chat.js.map +1 -1
  82. package/lib/esm/vertexai/models/codey-text.js +2 -2
  83. package/lib/esm/vertexai/models/codey-text.js.map +1 -1
  84. package/lib/esm/vertexai/models/gemini.js +44 -28
  85. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  86. package/lib/esm/vertexai/models/palm-model-base.js +1 -1
  87. package/lib/esm/vertexai/models/palm-model-base.js.map +1 -1
  88. package/lib/esm/vertexai/models/palm2-chat.js +3 -3
  89. package/lib/esm/vertexai/models/palm2-chat.js.map +1 -1
  90. package/lib/esm/vertexai/models/palm2-text.js +2 -2
  91. package/lib/esm/vertexai/models/palm2-text.js.map +1 -1
  92. package/lib/esm/vertexai/models.js +35 -13
  93. package/lib/esm/vertexai/models.js.map +1 -1
  94. package/lib/esm/watsonx/index.js +120 -0
  95. package/lib/esm/watsonx/index.js.map +1 -0
  96. package/lib/esm/watsonx/interfaces.js +2 -0
  97. package/lib/esm/watsonx/interfaces.js.map +1 -0
  98. package/lib/types/bedrock/index.d.ts +4 -46
  99. package/lib/types/bedrock/index.d.ts.map +1 -1
  100. package/lib/types/bedrock/payloads.d.ts +68 -0
  101. package/lib/types/bedrock/payloads.d.ts.map +1 -0
  102. package/lib/types/groq/index.d.ts +1 -1
  103. package/lib/types/groq/index.d.ts.map +1 -1
  104. package/lib/types/huggingface_ie.d.ts +5 -5
  105. package/lib/types/index.d.ts +4 -2
  106. package/lib/types/index.d.ts.map +1 -1
  107. package/lib/types/mistral/index.d.ts +1 -1
  108. package/lib/types/mistral/index.d.ts.map +1 -1
  109. package/lib/types/openai/azure.d.ts +20 -0
  110. package/lib/types/openai/azure.d.ts.map +1 -0
  111. package/lib/types/{openai.d.ts → openai/index.d.ts} +10 -20
  112. package/lib/types/openai/index.d.ts.map +1 -0
  113. package/lib/types/openai/openai.d.ts +15 -0
  114. package/lib/types/openai/openai.d.ts.map +1 -0
  115. package/lib/types/test/index.d.ts +2 -2
  116. package/lib/types/test/index.d.ts.map +1 -1
  117. package/lib/types/test/utils.d.ts.map +1 -1
  118. package/lib/types/vertexai/index.d.ts +3 -1
  119. package/lib/types/vertexai/index.d.ts.map +1 -1
  120. package/lib/types/vertexai/models/gemini.d.ts +2 -1
  121. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  122. package/lib/types/vertexai/models/palm-model-base.d.ts +1 -1
  123. package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -1
  124. package/lib/types/vertexai/models.d.ts +1 -1
  125. package/lib/types/vertexai/models.d.ts.map +1 -1
  126. package/lib/types/watsonx/index.d.ts +27 -0
  127. package/lib/types/watsonx/index.d.ts.map +1 -0
  128. package/lib/types/watsonx/interfaces.d.ts +61 -0
  129. package/lib/types/watsonx/interfaces.d.ts.map +1 -0
  130. package/package.json +24 -18
  131. package/src/bedrock/index.ts +72 -70
  132. package/src/bedrock/payloads.ts +67 -0
  133. package/src/groq/index.ts +7 -7
  134. package/src/huggingface_ie.ts +1 -1
  135. package/src/index.ts +5 -2
  136. package/src/mistral/index.ts +6 -6
  137. package/src/openai/azure.ts +54 -0
  138. package/src/{openai.ts → openai/index.ts} +24 -28
  139. package/src/openai/openai.ts +33 -0
  140. package/src/replicate.ts +5 -5
  141. package/src/test/index.ts +2 -3
  142. package/src/togetherai/index.ts +2 -2
  143. package/src/vertexai/index.ts +6 -3
  144. package/src/vertexai/models/codey-chat.ts +3 -3
  145. package/src/vertexai/models/codey-text.ts +2 -2
  146. package/src/vertexai/models/gemini.ts +56 -32
  147. package/src/vertexai/models/palm-model-base.ts +1 -2
  148. package/src/vertexai/models/palm2-chat.ts +3 -3
  149. package/src/vertexai/models/palm2-text.ts +2 -2
  150. package/src/vertexai/models.ts +42 -15
  151. package/src/watsonx/index.ts +161 -0
  152. package/src/watsonx/interfaces.ts +71 -0
  153. package/lib/cjs/openai.js.map +0 -1
  154. package/lib/esm/openai.js.map +0 -1
  155. package/lib/types/openai.d.ts.map +0 -1
@@ -6,6 +6,7 @@ import { transformAsyncIterator } from "@llumiverse/core/async";
6
6
  import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters";
7
7
  import { AwsCredentialIdentity, Provider } from "@smithy/types";
8
8
  import mnemonist from "mnemonist";
9
+ import { AI21RequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama2RequestPayload, MistralPayload } from "./payloads.js";
9
10
  import { forceUploadFile } from "./s3.js";
10
11
 
11
12
  const { LRUCache } = mnemonist;
@@ -23,7 +24,7 @@ export interface BedrockDriverOptions extends DriverOptions {
23
24
  */
24
25
  region: string;
25
26
  /**
26
- * Tthe bucket name to be used for training.
27
+ * Tthe bucket name to be used for training.
27
28
  * It will be created oif nto already exixts
28
29
  */
29
30
  training_bucket?: string;
@@ -62,6 +63,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
62
63
  this._executor = new BedrockRuntime({
63
64
  region: this.options.region,
64
65
  credentials: this.options.credentials,
66
+
65
67
  });
66
68
  }
67
69
  return this._executor;
@@ -77,13 +79,13 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
77
79
  return this._service;
78
80
  }
79
81
 
80
- protected formatPrompt(segments: PromptSegment[], opts: PromptOptions): BedrockPrompt {
82
+ protected async formatPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<BedrockPrompt> {
81
83
  //TODO move the anthropic test in abstract driver?
82
84
  if (opts.model.includes('anthropic')) {
83
85
  //TODO: need to type better the types aren't checked properly by TS
84
- return formatClaudePrompt(segments, opts.resultSchema);
86
+ return await formatClaudePrompt(segments, opts.result_schema);
85
87
  } else {
86
- return super.formatPrompt(segments, opts) as string;
88
+ return await super.formatPrompt(segments, opts) as string;
87
89
  }
88
90
  }
89
91
 
@@ -98,14 +100,23 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
98
100
  // LLAMA2
99
101
  return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
100
102
  } else if (result.generations) {
101
- // COHERE
103
+ // Cohere
102
104
  return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
105
+ } else if (result.chat_history) {
106
+ //Cohere Command R
107
+ return [result.text, cohereFinishReason(result.finish_reason)];
103
108
  } else if (result.completions) {
104
109
  //A21
105
110
  return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
106
111
  } else if (result.content) {
107
- // anthropic claude
108
- return [result.content[0]?.text || '', claudeFinishReason(result.stop_reason)];
112
+ // Claude
113
+ //if last prompt.messages is {, add { to the response
114
+ const p = prompt as ClaudeMessagesPrompt;
115
+ const lastMessage = (p as ClaudeMessagesPrompt).messages[p.messages.length - 1];
116
+ const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text;
117
+
118
+ return [res, claudeFinishReason(result.stop_reason)];
119
+
109
120
  } else if (result.outputs) {
110
121
  // mistral
111
122
  return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
@@ -162,7 +173,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
162
173
  return canStream;
163
174
  }
164
175
 
165
- async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<string>> {
176
+ async requestCompletionStream(prompt: BedrockPrompt, options: ExecutionOptions): Promise<AsyncIterable<string>> {
166
177
  const payload = this.preparePayload(prompt, options);
167
178
  const executor = this.getExecutor();
168
179
  return executor.invokeModelWithResponseStream({
@@ -176,12 +187,24 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
176
187
  }
177
188
  const decoder = new TextDecoder();
178
189
 
190
+ const addBracket = () => {
191
+ if (typeof prompt === 'object' && (prompt as ClaudeMessagesPrompt).messages) {
192
+ const p = prompt as ClaudeMessagesPrompt;
193
+ const lastMessage = p.messages[p.messages.length - 1];
194
+ return lastMessage.content[0].text === '{';
195
+ }
196
+ return false;
197
+ };
198
+
179
199
  return transformAsyncIterator(res.body, (stream: ResponseStream) => {
180
200
  const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
201
+ //console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
181
202
  if (segment.delta) { // who is this?
182
203
  return segment.delta.text || '';
183
204
  } else if (segment.completion) { // who is this?
184
205
  return segment.completion;
206
+ } else if (segment.text) { //cohere
207
+ return segment.text;
185
208
  } else if (segment.completions) {
186
209
  return segment.completions[0].data?.text;
187
210
  } else if (segment.generation) {
@@ -201,7 +224,9 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
201
224
  segment.toString();
202
225
  }
203
226
 
204
- });
227
+ },
228
+ () => addBracket() ? '{' : ''
229
+ );
205
230
 
206
231
  }).catch((err) => {
207
232
  this.logger.error("[Bedrock] Failed to stream", err);
@@ -224,12 +249,23 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
224
249
  temperature: options.temperature,
225
250
  max_gen_len: options.max_tokens,
226
251
  } as LLama2RequestPayload
227
- } else if (contains(options.model, "anthropic")) {
252
+ } else if (contains(options.model, "claude")) {
253
+
254
+ const maxToken = () => {
255
+ if (options.max_tokens) {
256
+ return options.max_tokens;
257
+
258
+ } else if (contains(options.model, "claude-3-5")) {
259
+ return 8192;
260
+ } else {
261
+ return 4096
262
+ }
263
+ }
228
264
  return {
229
265
  anthropic_version: "bedrock-2023-05-31",
230
266
  ...(prompt as ClaudeMessagesPrompt),
231
267
  temperature: options.temperature,
232
- max_tokens: options.max_tokens,
268
+ max_tokens: maxToken(),
233
269
  } as ClaudeRequestPayload;
234
270
  } else if (contains(options.model, "ai21")) {
235
271
  return {
@@ -237,12 +273,19 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
237
273
  temperature: options.temperature,
238
274
  maxTokens: options.max_tokens,
239
275
  } as AI21RequestPayload;
240
- } else if (contains(options.model, "cohere")) {
276
+ } else if (contains(options.model, "command-r-plus")) {
277
+ return {
278
+ message: prompt as string,
279
+ max_tokens: options.max_tokens,
280
+ temperature: options.temperature,
281
+ } as CohereCommandRPayload;
282
+
283
+ }
284
+ else if (contains(options.model, "cohere")) {
241
285
  return {
242
286
  prompt: prompt,
243
287
  temperature: options.temperature,
244
288
  max_tokens: options.max_tokens,
245
- p: 0.9,
246
289
  } as CohereRequestPayload;
247
290
  } else if (contains(options.model, "amazon")) {
248
291
  return {
@@ -279,7 +322,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
279
322
  }
280
323
 
281
324
  const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
282
- const upload = await forceUploadFile(s3, dataset.getStream(), this.options.training_bucket, dataset.name);
325
+ const stream = await dataset.getStream();
326
+ const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name);
283
327
 
284
328
  const service = this.getService();
285
329
  const response = await service.send(new CreateModelCustomizationJobCommand({
@@ -350,11 +394,17 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
350
394
  async _listModels(foundationFilter?: (m: FoundationModelSummary) => boolean): Promise<AIModel[]> {
351
395
  const service = this.getService();
352
396
  const [foundationals, customs] = await Promise.all([
353
- service.listFoundationModels({}),
354
- service.listCustomModels({}),
397
+ service.listFoundationModels({}).catch(() => {
398
+ this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
399
+ return undefined
400
+ }),
401
+ service.listCustomModels({}).catch(() => {
402
+ this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
403
+ return undefined
404
+ }),
355
405
  ]);
356
406
 
357
- if (!foundationals.modelSummaries) {
407
+ if (!foundationals?.modelSummaries) {
358
408
  throw new Error("Foundation models not found");
359
409
  }
360
410
 
@@ -373,9 +423,10 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
373
423
  id: m.modelArn ?? m.modelId,
374
424
  name: `${m.providerName} ${m.modelName}`,
375
425
  provider: this.provider,
376
- description: `id: ${m.modelId}`,
426
+ //description: ``,
377
427
  owner: m.providerName,
378
- canStream: m.responseStreamingSupported ?? false,
428
+ can_stream: m.responseStreamingSupported ?? false,
429
+ is_multimodal: m.inputModalities?.includes("IMAGE") ?? false,
379
430
  tags: m.outputModalities ?? [],
380
431
  };
381
432
 
@@ -383,7 +434,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
383
434
  });
384
435
 
385
436
  //add custom models
386
- if (customs.modelSummaries) {
437
+ if (customs?.modelSummaries) {
387
438
  customs.modelSummaries.forEach((m) => {
388
439
 
389
440
  if (!m.modelArn) {
@@ -395,7 +446,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
395
446
  name: m.modelName ?? m.modelArn,
396
447
  provider: this.provider,
397
448
  description: `Custom model from ${m.baseModelName}`,
398
- isCustom: true,
449
+ is_custom: true,
399
450
  };
400
451
 
401
452
  aimodels.push(model);
@@ -444,55 +495,6 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
444
495
 
445
496
 
446
497
 
447
- interface LLama2RequestPayload {
448
- prompt: string;
449
- temperature: number;
450
- top_p?: number;
451
- max_gen_len: number;
452
- }
453
-
454
- interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
455
- anthropic_version: "bedrock-2023-05-31",
456
- max_tokens: number,
457
- prompt: string;
458
- temperature?: number;
459
- top_p?: number,
460
- top_k?: number,
461
- stop_sequences?: [string];
462
- }
463
-
464
- interface AI21RequestPayload {
465
- prompt: string;
466
- temperature: number;
467
- maxTokens: number;
468
- }
469
-
470
- interface CohereRequestPayload {
471
- prompt: string;
472
- temperature: number;
473
- max_tokens?: number;
474
- p?: number;
475
- }
476
-
477
- interface AmazonRequestPayload {
478
- inputText: string,
479
- textGenerationConfig: {
480
- temperature: number,
481
- topP: number,
482
- maxTokenCount: number,
483
- stopSequences: [string];
484
- };
485
- }
486
-
487
- interface MistralPayload {
488
- prompt: string,
489
- temperature: number,
490
- max_tokens: number,
491
- top_p?: number,
492
- top_k?: number,
493
- }
494
-
495
-
496
498
  function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): TrainingJob {
497
499
  const jobStatus = job.status;
498
500
  let status = TrainingJobStatus.running;
@@ -0,0 +1,67 @@
1
+ import { ClaudeMessagesPrompt } from "@llumiverse/core/formatters";
2
+
3
+ export interface LLama2RequestPayload {
4
+ prompt: string;
5
+ temperature: number;
6
+ top_p?: number;
7
+ max_gen_len: number;
8
+ }
9
+ export interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
10
+ anthropic_version: "bedrock-2023-05-31";
11
+ max_tokens: number;
12
+ prompt: string;
13
+ temperature?: number;
14
+ top_p?: number;
15
+ top_k?: number;
16
+ stop_sequences?: [string];
17
+ }
18
+ export interface AI21RequestPayload {
19
+ prompt: string;
20
+ temperature: number;
21
+ maxTokens: number;
22
+ }
23
+ export interface CohereRequestPayload {
24
+ prompt: string;
25
+ temperature: number;
26
+ max_tokens?: number;
27
+ p?: number;
28
+ }
29
+ export interface AmazonRequestPayload {
30
+ inputText: string;
31
+ textGenerationConfig: {
32
+ temperature: number;
33
+ topP: number;
34
+ maxTokenCount: number;
35
+ stopSequences: [string];
36
+ };
37
+ }
38
+ export interface MistralPayload {
39
+ prompt: string;
40
+ temperature: number;
41
+ max_tokens: number;
42
+ top_p?: number;
43
+ top_k?: number;
44
+ }
45
+
46
+ export interface CohereCommandRPayload {
47
+
48
+ message: string,
49
+ chat_history?: {
50
+ role: 'USER' | 'CHATBOT',
51
+ message: string }[],
52
+ documents?: { title: string, snippet: string }[],
53
+ search_queries_only?: boolean,
54
+ preamble?: string,
55
+ max_tokens: number,
56
+ temperature?: number,
57
+ p?: number,
58
+ k?: number,
59
+ prompt_truncation?: string,
60
+ frequency_penalty?: number,
61
+ presence_penalty?: number,
62
+ seed?: number,
63
+ return_prompt?: boolean,
64
+ stop_sequences?: string[],
65
+ raw_prompting?: boolean
66
+
67
+ }
package/src/groq/index.ts CHANGED
@@ -1,6 +1,6 @@
1
1
  import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
2
2
  import { transformAsyncIterator } from "@llumiverse/core/async";
3
- import { OpenAITextMessage, formatOpenAILikePrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
3
+ import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
4
4
  import Groq from "groq-sdk";
5
5
 
6
6
 
@@ -27,7 +27,7 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
27
27
  }
28
28
 
29
29
  // protected canStream(options: ExecutionOptions): Promise<boolean> {
30
- // if (options.resultSchema) {
30
+ // if (options.result_schema) {
31
31
  // // not yet streamign json responses
32
32
  // return Promise.resolve(false);
33
33
  // } else {
@@ -42,17 +42,17 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
42
42
  // type: "json_object",
43
43
  // }
44
44
 
45
- // return _options.resultSchema ? responseFormatJson : undefined;
45
+ // return _options.result_schema ? responseFormatJson : undefined;
46
46
  return undefined;
47
47
  }
48
48
 
49
- protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
50
- const messages = formatOpenAILikePrompt(segments);
49
+ protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<OpenAITextMessage[]> {
50
+ const messages = formatOpenAILikeTextPrompt(segments);
51
51
  //Add JSON instruction is schema is provided
52
- if (opts.resultSchema) {
52
+ if (opts.result_schema) {
53
53
  messages.push({
54
54
  role: "user",
55
- content: "IMPORTANT: " + getJSONSafetyNotice(opts.resultSchema)
55
+ content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
56
56
  });
57
57
  }
58
58
  return messages;
@@ -92,7 +92,7 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
92
92
  },
93
93
  });
94
94
 
95
- let finish_reason = res.details?.finish_reason;
95
+ let finish_reason = res.details?.finish_reason as string;
96
96
  if (finish_reason === "eos_token") {
97
97
  finish_reason = "stop";
98
98
  }
package/src/index.ts CHANGED
@@ -1,9 +1,12 @@
1
1
  export * from "./bedrock/index.js";
2
+ export * from "./groq/index.js";
2
3
  export * from "./huggingface_ie.js";
3
4
  export * from "./mistral/index.js";
4
- export * from "./openai.js";
5
+ export * from "./openai/azure.js";
6
+ export * from "./openai/openai.js";
5
7
  export * from "./replicate.js";
6
8
  export * from "./test/index.js";
7
9
  export * from "./togetherai/index.js";
8
10
  export * from "./vertexai/index.js";
9
- export * from "./groq/index.js";
11
+ export * from "./watsonx/index.js";
12
+
@@ -1,6 +1,6 @@
1
1
  import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
- import { OpenAITextMessage, formatOpenAILikePrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
3
+ import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
4
4
  import { FetchClient } from "api-fetch-client";
5
5
  import { ChatCompletionResponse, CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
6
6
 
@@ -42,20 +42,20 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
42
42
  // } as ResponseFormat;
43
43
 
44
44
 
45
- // return _options.resultSchema ? responseFormatJson : responseFormatText;
45
+ // return _options.result_schema ? responseFormatJson : responseFormatText;
46
46
 
47
47
  //TODO remove this when Mistral properly supports the parameters - it makes an error for now
48
48
  // some models like mixtral mistrall tiny or medium are throwing an error when using the response_format parameter
49
49
  return undefined
50
50
  }
51
51
 
52
- protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
53
- const messages = formatOpenAILikePrompt(segments);
52
+ protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<OpenAITextMessage[]> {
53
+ const messages = formatOpenAILikeTextPrompt(segments);
54
54
  //Add JSON instruction is schema is provided
55
- if (opts.resultSchema) {
55
+ if (opts.result_schema) {
56
56
  messages.push({
57
57
  role: "user",
58
- content: "IMPORTANT: " + getJSONSafetyNotice(opts.resultSchema)
58
+ content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
59
59
  });
60
60
  }
61
61
  return messages;
@@ -0,0 +1,54 @@
1
+ import { DefaultAzureCredential, getBearerTokenProvider } from "@azure/identity";
2
+ import { DriverOptions } from "@llumiverse/core";
3
+ import { AzureOpenAI } from "openai";
4
+ import { BaseOpenAIDriver } from "./index.js";
5
+
6
+ export interface AzureOpenAIDriverOptions extends DriverOptions {
7
+
8
+ /**
9
+ * The credentials to use to access Azure OpenAI
10
+ */
11
+ azureADTokenProvider?: any; //type with azure credntials
12
+
13
+ apiKey?: string;
14
+
15
+ endpoint?: string;
16
+
17
+ apiVersion?: string
18
+
19
+ deployment?: string;
20
+
21
+ }
22
+
23
+ export class AzureOpenAIDriver extends BaseOpenAIDriver {
24
+
25
+
26
+ service: AzureOpenAI;
27
+ provider: "azure_openai";
28
+
29
+ constructor(opts: AzureOpenAIDriverOptions) {
30
+ super(opts);
31
+
32
+ if (!opts.azureADTokenProvider && !opts.apiKey) {
33
+ opts.azureADTokenProvider = this.getDefaultAuth();
34
+ }
35
+
36
+ this.service = new AzureOpenAI({
37
+ apiKey: opts.apiKey,
38
+ azureADTokenProvider: opts.azureADTokenProvider,
39
+ endpoint: opts.endpoint,
40
+ apiVersion: opts.apiVersion ?? "2024-05-01-preview",
41
+ deployment: opts.deployment
42
+ });
43
+ this.provider = "azure_openai";
44
+ }
45
+
46
+
47
+ getDefaultAuth() {
48
+ const scope = "https://cognitiveservices.azure.com/.default";
49
+ const azureADTokenProvider = getBearerTokenProvider(new DefaultAzureCredential(), scope);
50
+ return azureADTokenProvider;
51
+ }
52
+
53
+
54
+ }
@@ -15,8 +15,8 @@ import {
15
15
  TrainingPromptOptions
16
16
  } from "@llumiverse/core";
17
17
  import { asyncMap } from "@llumiverse/core/async";
18
- import { formatOpenAILikePrompt } from "@llumiverse/core/formatters";
19
- import OpenAI from "openai";
18
+ import { formatOpenAILikeMultimodalPrompt } from "@llumiverse/core/formatters";
19
+ import OpenAI, { AzureOpenAI } from "openai";
20
20
  import { Stream } from "openai/streaming";
21
21
 
22
22
  const supportFineTunning = new Set([
@@ -27,26 +27,20 @@ const supportFineTunning = new Set([
27
27
  "gpt-4-0613"
28
28
  ]);
29
29
 
30
- export interface OpenAIDriverOptions extends DriverOptions {
31
- apiKey: string;
30
+ export interface BaseOpenAIDriverOptions extends DriverOptions {
32
31
  }
33
32
 
34
- export class OpenAIDriver extends AbstractDriver<
35
- OpenAIDriverOptions,
33
+ export abstract class BaseOpenAIDriver extends AbstractDriver<
34
+ BaseOpenAIDriverOptions,
36
35
  OpenAI.Chat.Completions.ChatCompletionMessageParam[]
37
36
  > {
38
- static PROVIDER = "openai";
39
- inputContentTypes: string[] = ["text/plain"];
40
- generatedContentTypes: string[] = ["text/plain"];
41
- service: OpenAI;
42
- provider = OpenAIDriver.PROVIDER;
37
+ abstract provider: "azure_openai" | "openai";
38
+ abstract service: OpenAI | AzureOpenAI ;
43
39
 
44
- constructor(opts: OpenAIDriverOptions) {
40
+ constructor(opts: BaseOpenAIDriverOptions) {
45
41
  super(opts);
46
- this.service = new OpenAI({
47
- apiKey: opts.apiKey,
48
- });
49
- this.formatPrompt = formatOpenAILikePrompt;
42
+ this.formatPrompt = formatOpenAILikeMultimodalPrompt as any //TODO: better type, we send back OpenAI.Chat.Completions.ChatCompletionMessageParam[] but just not compatbile with Function call that we don't use here
43
+
50
44
  }
51
45
 
52
46
  extractDataFromResponse(
@@ -63,7 +57,7 @@ export class OpenAIDriver extends AbstractDriver<
63
57
  const finish_reason = choice.finish_reason;
64
58
 
65
59
  //if no schema, return content
66
- if (!options.resultSchema) {
60
+ if (!options.result_schema) {
67
61
  return {
68
62
  result: choice.message.content as string,
69
63
  token_usage: tokenInfo,
@@ -86,7 +80,7 @@ export class OpenAIDriver extends AbstractDriver<
86
80
  }
87
81
 
88
82
  async requestCompletionStream(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
89
- const mapFn = options.resultSchema
83
+ const mapFn = options.result_schema
90
84
  ? (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
91
85
  return (
92
86
  chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? ""
@@ -103,18 +97,18 @@ export class OpenAIDriver extends AbstractDriver<
103
97
  temperature: options.temperature,
104
98
  n: 1,
105
99
  max_tokens: options.max_tokens,
106
- tools: options.resultSchema
100
+ tools: options.result_schema
107
101
  ? [
108
102
  {
109
103
  function: {
110
104
  name: "format_output",
111
- parameters: options.resultSchema as any,
105
+ parameters: options.result_schema as any,
112
106
  },
113
107
  type: "function"
114
108
  } as OpenAI.Chat.ChatCompletionTool,
115
109
  ]
116
110
  : undefined,
117
- tool_choice: options.resultSchema
111
+ tool_choice: options.result_schema
118
112
  ? {
119
113
  type: 'function',
120
114
  function: { name: "format_output" }
@@ -125,12 +119,12 @@ export class OpenAIDriver extends AbstractDriver<
125
119
  }
126
120
 
127
121
  async requestCompletion(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
128
- const functions = options.resultSchema
122
+ const functions = options.result_schema
129
123
  ? [
130
124
  {
131
125
  function: {
132
126
  name: "format_output",
133
- parameters: options.resultSchema as any,
127
+ parameters: options.result_schema as any,
134
128
  },
135
129
  type: 'function'
136
130
  } as OpenAI.Chat.ChatCompletionTool,
@@ -145,13 +139,13 @@ export class OpenAIDriver extends AbstractDriver<
145
139
  n: 1,
146
140
  max_tokens: options.max_tokens,
147
141
  tools: functions,
148
- tool_choice: options.resultSchema
142
+ tool_choice: options.result_schema
149
143
  ? {
150
144
  type: 'function',
151
145
  function: { name: "format_output" }
152
146
  } : undefined,
153
147
  // functions: functions,
154
- // function_call: options.resultSchema
148
+ // function_call: options.result_schema
155
149
  // ? { name: "format_output" }
156
150
  // : undefined,
157
151
  });
@@ -163,11 +157,11 @@ export class OpenAIDriver extends AbstractDriver<
163
157
  return completion;
164
158
  }
165
159
 
166
- createTrainingPrompt(options: TrainingPromptOptions): string {
160
+ createTrainingPrompt(options: TrainingPromptOptions): Promise<string> {
167
161
  if (options.model.includes("gpt")) {
168
162
  return super.createTrainingPrompt(options);
169
163
  } else {
170
- // babbage, davinci not yet implemented
164
+ // babbage, davinci not yet implemented
171
165
  throw new Error("Unsupported model for training: " + options.model);
172
166
  }
173
167
  }
@@ -217,7 +211,7 @@ export class OpenAIDriver extends AbstractDriver<
217
211
  return this._listModels();
218
212
  }
219
213
 
220
- async _listModels(filter?: (m: OpenAI.Models.Model) => boolean) {
214
+ async _listModels(filter?: (m: OpenAI.Models.Model) => boolean): Promise<AIModel[]> {
221
215
  let result = await this.service.models.list();
222
216
  const models = filter ? result.data.filter(filter) : result.data;
223
217
  return models.map((m) => ({
@@ -226,6 +220,8 @@ export class OpenAIDriver extends AbstractDriver<
226
220
  provider: this.provider,
227
221
  owner: m.owned_by,
228
222
  type: m.object === "model" ? ModelType.Text : ModelType.Unknown,
223
+ can_stream: true,
224
+ is_multimodal: m.id.includes("gpt-4")
229
225
  }));
230
226
  }
231
227
 
@@ -0,0 +1,33 @@
1
+
2
+
3
+ import { DriverOptions } from "@llumiverse/core";
4
+ import OpenAI from "openai";
5
+ import { BaseOpenAIDriver } from "./index.js";
6
+
7
+ export interface OpenAIDriverOptions extends DriverOptions {
8
+
9
+ /**
10
+ * The OpenAI api key
11
+ */
12
+ apiKey?: string; //type with azure credntials
13
+
14
+ }
15
+
16
+
17
+
18
+
19
+ export class OpenAIDriver extends BaseOpenAIDriver {
20
+
21
+ service: OpenAI;
22
+ provider: "openai";
23
+
24
+ constructor(opts: OpenAIDriverOptions) {
25
+ super(opts);
26
+ this.service = new OpenAI({
27
+ apiKey: opts.apiKey
28
+ });
29
+ this.provider = "openai";
30
+ }
31
+
32
+
33
+ }