@llumiverse/drivers 0.13.0 → 0.14.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/README.md +12 -10
  2. package/lib/cjs/bedrock/index.js +46 -14
  3. package/lib/cjs/bedrock/index.js.map +1 -1
  4. package/lib/cjs/bedrock/payloads.js +3 -0
  5. package/lib/cjs/bedrock/payloads.js.map +1 -0
  6. package/lib/cjs/bedrock/s3.js +5 -6
  7. package/lib/cjs/bedrock/s3.js.map +1 -1
  8. package/lib/cjs/groq/index.js +6 -6
  9. package/lib/cjs/groq/index.js.map +1 -1
  10. package/lib/cjs/index.js +2 -1
  11. package/lib/cjs/index.js.map +1 -1
  12. package/lib/cjs/mistral/index.js +5 -5
  13. package/lib/cjs/mistral/index.js.map +1 -1
  14. package/lib/cjs/openai/azure.js +31 -0
  15. package/lib/cjs/openai/azure.js.map +1 -0
  16. package/lib/cjs/{openai.js → openai/index.js} +17 -27
  17. package/lib/cjs/openai/index.js.map +1 -0
  18. package/lib/cjs/openai/openai.js +21 -0
  19. package/lib/cjs/openai/openai.js.map +1 -0
  20. package/lib/cjs/replicate.js +1 -1
  21. package/lib/cjs/replicate.js.map +1 -1
  22. package/lib/cjs/test/index.js +1 -1
  23. package/lib/cjs/test/index.js.map +1 -1
  24. package/lib/cjs/test/utils.js +3 -4
  25. package/lib/cjs/test/utils.js.map +1 -1
  26. package/lib/cjs/togetherai/index.js +2 -2
  27. package/lib/cjs/togetherai/index.js.map +1 -1
  28. package/lib/cjs/vertexai/debug.js +1 -2
  29. package/lib/cjs/vertexai/debug.js.map +1 -1
  30. package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -2
  31. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
  32. package/lib/cjs/vertexai/index.js +3 -2
  33. package/lib/cjs/vertexai/index.js.map +1 -1
  34. package/lib/cjs/vertexai/models/codey-chat.js +3 -3
  35. package/lib/cjs/vertexai/models/codey-chat.js.map +1 -1
  36. package/lib/cjs/vertexai/models/codey-text.js +2 -2
  37. package/lib/cjs/vertexai/models/codey-text.js.map +1 -1
  38. package/lib/cjs/vertexai/models/gemini.js +37 -21
  39. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  40. package/lib/cjs/vertexai/models/palm-model-base.js +1 -1
  41. package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -1
  42. package/lib/cjs/vertexai/models/palm2-chat.js +3 -3
  43. package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -1
  44. package/lib/cjs/vertexai/models/palm2-text.js +2 -2
  45. package/lib/cjs/vertexai/models/palm2-text.js.map +1 -1
  46. package/lib/cjs/vertexai/models.js +39 -17
  47. package/lib/cjs/vertexai/models.js.map +1 -1
  48. package/lib/cjs/vertexai/utils/tensor.js +2 -3
  49. package/lib/cjs/vertexai/utils/tensor.js.map +1 -1
  50. package/lib/cjs/watsonx/index.js +4 -4
  51. package/lib/cjs/watsonx/index.js.map +1 -1
  52. package/lib/esm/bedrock/index.js +46 -14
  53. package/lib/esm/bedrock/index.js.map +1 -1
  54. package/lib/esm/bedrock/payloads.js +2 -0
  55. package/lib/esm/bedrock/payloads.js.map +1 -0
  56. package/lib/esm/groq/index.js +7 -7
  57. package/lib/esm/groq/index.js.map +1 -1
  58. package/lib/esm/index.js +2 -1
  59. package/lib/esm/index.js.map +1 -1
  60. package/lib/esm/mistral/index.js +6 -6
  61. package/lib/esm/mistral/index.js.map +1 -1
  62. package/lib/esm/openai/azure.js +27 -0
  63. package/lib/esm/openai/azure.js.map +1 -0
  64. package/lib/esm/{openai.js → openai/index.js} +16 -23
  65. package/lib/esm/openai/index.js.map +1 -0
  66. package/lib/esm/openai/openai.js +14 -0
  67. package/lib/esm/openai/openai.js.map +1 -0
  68. package/lib/esm/replicate.js +1 -1
  69. package/lib/esm/replicate.js.map +1 -1
  70. package/lib/esm/test/index.js +1 -1
  71. package/lib/esm/test/index.js.map +1 -1
  72. package/lib/esm/togetherai/index.js +2 -2
  73. package/lib/esm/togetherai/index.js.map +1 -1
  74. package/lib/esm/vertexai/index.js +3 -2
  75. package/lib/esm/vertexai/index.js.map +1 -1
  76. package/lib/esm/vertexai/models/codey-chat.js +3 -3
  77. package/lib/esm/vertexai/models/codey-chat.js.map +1 -1
  78. package/lib/esm/vertexai/models/codey-text.js +2 -2
  79. package/lib/esm/vertexai/models/codey-text.js.map +1 -1
  80. package/lib/esm/vertexai/models/gemini.js +38 -22
  81. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  82. package/lib/esm/vertexai/models/palm-model-base.js +1 -1
  83. package/lib/esm/vertexai/models/palm-model-base.js.map +1 -1
  84. package/lib/esm/vertexai/models/palm2-chat.js +3 -3
  85. package/lib/esm/vertexai/models/palm2-chat.js.map +1 -1
  86. package/lib/esm/vertexai/models/palm2-text.js +2 -2
  87. package/lib/esm/vertexai/models/palm2-text.js.map +1 -1
  88. package/lib/esm/vertexai/models.js +35 -13
  89. package/lib/esm/vertexai/models.js.map +1 -1
  90. package/lib/esm/watsonx/index.js +4 -4
  91. package/lib/esm/watsonx/index.js.map +1 -1
  92. package/lib/types/bedrock/index.d.ts +4 -46
  93. package/lib/types/bedrock/index.d.ts.map +1 -1
  94. package/lib/types/bedrock/payloads.d.ts +68 -0
  95. package/lib/types/bedrock/payloads.d.ts.map +1 -0
  96. package/lib/types/groq/index.d.ts +1 -1
  97. package/lib/types/groq/index.d.ts.map +1 -1
  98. package/lib/types/index.d.ts +2 -1
  99. package/lib/types/index.d.ts.map +1 -1
  100. package/lib/types/mistral/index.d.ts +1 -1
  101. package/lib/types/mistral/index.d.ts.map +1 -1
  102. package/lib/types/openai/azure.d.ts +20 -0
  103. package/lib/types/openai/azure.d.ts.map +1 -0
  104. package/lib/types/{openai.d.ts → openai/index.d.ts} +10 -20
  105. package/lib/types/openai/index.d.ts.map +1 -0
  106. package/lib/types/openai/openai.d.ts +15 -0
  107. package/lib/types/openai/openai.d.ts.map +1 -0
  108. package/lib/types/test/index.d.ts +2 -2
  109. package/lib/types/test/index.d.ts.map +1 -1
  110. package/lib/types/test/utils.d.ts.map +1 -1
  111. package/lib/types/vertexai/index.d.ts +3 -1
  112. package/lib/types/vertexai/index.d.ts.map +1 -1
  113. package/lib/types/vertexai/models/gemini.d.ts +2 -1
  114. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  115. package/lib/types/vertexai/models/palm-model-base.d.ts +1 -1
  116. package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -1
  117. package/lib/types/vertexai/models.d.ts +1 -1
  118. package/lib/types/vertexai/models.d.ts.map +1 -1
  119. package/lib/types/watsonx/index.d.ts.map +1 -1
  120. package/package.json +24 -18
  121. package/src/bedrock/index.ts +59 -72
  122. package/src/bedrock/payloads.ts +67 -0
  123. package/src/groq/index.ts +7 -7
  124. package/src/index.ts +3 -1
  125. package/src/mistral/index.ts +6 -6
  126. package/src/openai/azure.ts +54 -0
  127. package/src/{openai.ts → openai/index.ts} +24 -28
  128. package/src/openai/openai.ts +33 -0
  129. package/src/replicate.ts +5 -5
  130. package/src/test/index.ts +2 -3
  131. package/src/togetherai/index.ts +2 -2
  132. package/src/vertexai/index.ts +6 -3
  133. package/src/vertexai/models/codey-chat.ts +3 -3
  134. package/src/vertexai/models/codey-text.ts +2 -2
  135. package/src/vertexai/models/gemini.ts +50 -26
  136. package/src/vertexai/models/palm-model-base.ts +1 -2
  137. package/src/vertexai/models/palm2-chat.ts +3 -3
  138. package/src/vertexai/models/palm2-text.ts +2 -2
  139. package/src/vertexai/models.ts +42 -15
  140. package/src/watsonx/index.ts +4 -6
  141. package/lib/cjs/openai.js.map +0 -1
  142. package/lib/esm/openai.js.map +0 -1
  143. package/lib/types/openai.d.ts.map +0 -1
@@ -6,6 +6,7 @@ import { transformAsyncIterator } from "@llumiverse/core/async";
6
6
  import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters";
7
7
  import { AwsCredentialIdentity, Provider } from "@smithy/types";
8
8
  import mnemonist from "mnemonist";
9
+ import { AI21RequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama2RequestPayload, MistralPayload } from "./payloads.js";
9
10
  import { forceUploadFile } from "./s3.js";
10
11
 
11
12
  const { LRUCache } = mnemonist;
@@ -23,7 +24,7 @@ export interface BedrockDriverOptions extends DriverOptions {
23
24
  */
24
25
  region: string;
25
26
  /**
26
- * Tthe bucket name to be used for training.
27
+ * Tthe bucket name to be used for training.
27
28
  * It will be created oif nto already exixts
28
29
  */
29
30
  training_bucket?: string;
@@ -62,7 +63,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
62
63
  this._executor = new BedrockRuntime({
63
64
  region: this.options.region,
64
65
  credentials: this.options.credentials,
65
-
66
+
66
67
  });
67
68
  }
68
69
  return this._executor;
@@ -78,14 +79,13 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
78
79
  return this._service;
79
80
  }
80
81
 
81
- protected formatPrompt(segments: PromptSegment[], opts: PromptOptions): BedrockPrompt {
82
+ protected async formatPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<BedrockPrompt> {
82
83
  //TODO move the anthropic test in abstract driver?
83
84
  if (opts.model.includes('anthropic')) {
84
85
  //TODO: need to type better the types aren't checked properly by TS
85
- const prompt = formatClaudePrompt(segments, opts.resultSchema);
86
- return prompt;
86
+ return await formatClaudePrompt(segments, opts.result_schema);
87
87
  } else {
88
- return super.formatPrompt(segments, opts) as string;
88
+ return await super.formatPrompt(segments, opts) as string;
89
89
  }
90
90
  }
91
91
 
@@ -100,20 +100,23 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
100
100
  // LLAMA2
101
101
  return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
102
102
  } else if (result.generations) {
103
- // COHERE
103
+ // Cohere
104
104
  return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
105
+ } else if (result.chat_history) {
106
+ //Cohere Command R
107
+ return [result.text, cohereFinishReason(result.finish_reason)];
105
108
  } else if (result.completions) {
106
109
  //A21
107
110
  return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
108
111
  } else if (result.content) {
109
- // anthropic claude
112
+ // Claude
110
113
  //if last prompt.messages is {, add { to the response
111
- const p = prompt as ClaudeMessagesPrompt;
114
+ const p = prompt as ClaudeMessagesPrompt;
112
115
  const lastMessage = (p as ClaudeMessagesPrompt).messages[p.messages.length - 1];
113
116
  const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text;
114
117
 
115
- return [ res, claudeFinishReason(result.stop_reason)];
116
-
118
+ return [res, claudeFinishReason(result.stop_reason)];
119
+
117
120
  } else if (result.outputs) {
118
121
  // mistral
119
122
  return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
@@ -146,7 +149,6 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
146
149
 
147
150
  const payload = this.preparePayload(prompt, options);
148
151
  const executor = this.getExecutor();
149
- console.log("Requesting completion", JSON.stringify(payload));
150
152
  const res = await executor.invokeModel({
151
153
  modelId: options.model,
152
154
  contentType: "application/json",
@@ -171,10 +173,9 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
171
173
  return canStream;
172
174
  }
173
175
 
174
- async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<string>> {
176
+ async requestCompletionStream(prompt: BedrockPrompt, options: ExecutionOptions): Promise<AsyncIterable<string>> {
175
177
  const payload = this.preparePayload(prompt, options);
176
178
  const executor = this.getExecutor();
177
- console.log("Requesting completion stream", JSON.stringify(payload));
178
179
  return executor.invokeModelWithResponseStream({
179
180
  modelId: options.model,
180
181
  contentType: "application/json",
@@ -186,12 +187,24 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
186
187
  }
187
188
  const decoder = new TextDecoder();
188
189
 
190
+ const addBracket = () => {
191
+ if (typeof prompt === 'object' && (prompt as ClaudeMessagesPrompt).messages) {
192
+ const p = prompt as ClaudeMessagesPrompt;
193
+ const lastMessage = p.messages[p.messages.length - 1];
194
+ return lastMessage.content[0].text === '{';
195
+ }
196
+ return false;
197
+ };
198
+
189
199
  return transformAsyncIterator(res.body, (stream: ResponseStream) => {
190
200
  const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
201
+ //console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
191
202
  if (segment.delta) { // who is this?
192
203
  return segment.delta.text || '';
193
204
  } else if (segment.completion) { // who is this?
194
205
  return segment.completion;
206
+ } else if (segment.text) { //cohere
207
+ return segment.text;
195
208
  } else if (segment.completions) {
196
209
  return segment.completions[0].data?.text;
197
210
  } else if (segment.generation) {
@@ -211,7 +224,9 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
211
224
  segment.toString();
212
225
  }
213
226
 
214
- });
227
+ },
228
+ () => addBracket() ? '{' : ''
229
+ );
215
230
 
216
231
  }).catch((err) => {
217
232
  this.logger.error("[Bedrock] Failed to stream", err);
@@ -235,11 +250,22 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
235
250
  max_gen_len: options.max_tokens,
236
251
  } as LLama2RequestPayload
237
252
  } else if (contains(options.model, "claude")) {
253
+
254
+ const maxToken = () => {
255
+ if (options.max_tokens) {
256
+ return options.max_tokens;
257
+
258
+ } else if (contains(options.model, "claude-3-5")) {
259
+ return 8192;
260
+ } else {
261
+ return 4096
262
+ }
263
+ }
238
264
  return {
239
265
  anthropic_version: "bedrock-2023-05-31",
240
266
  ...(prompt as ClaudeMessagesPrompt),
241
267
  temperature: options.temperature,
242
- max_tokens: options.max_tokens,
268
+ max_tokens: maxToken(),
243
269
  } as ClaudeRequestPayload;
244
270
  } else if (contains(options.model, "ai21")) {
245
271
  return {
@@ -247,12 +273,19 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
247
273
  temperature: options.temperature,
248
274
  maxTokens: options.max_tokens,
249
275
  } as AI21RequestPayload;
250
- } else if (contains(options.model, "cohere")) {
276
+ } else if (contains(options.model, "command-r-plus")) {
277
+ return {
278
+ message: prompt as string,
279
+ max_tokens: options.max_tokens,
280
+ temperature: options.temperature,
281
+ } as CohereCommandRPayload;
282
+
283
+ }
284
+ else if (contains(options.model, "cohere")) {
251
285
  return {
252
286
  prompt: prompt,
253
287
  temperature: options.temperature,
254
288
  max_tokens: options.max_tokens,
255
- p: 0.9,
256
289
  } as CohereRequestPayload;
257
290
  } else if (contains(options.model, "amazon")) {
258
291
  return {
@@ -289,7 +322,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
289
322
  }
290
323
 
291
324
  const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
292
- const upload = await forceUploadFile(s3, dataset.getStream(), this.options.training_bucket, dataset.name);
325
+ const stream = await dataset.getStream();
326
+ const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name);
293
327
 
294
328
  const service = this.getService();
295
329
  const response = await service.send(new CreateModelCustomizationJobCommand({
@@ -363,10 +397,11 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
363
397
  service.listFoundationModels({}).catch(() => {
364
398
  this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
365
399
  return undefined
366
- }),
400
+ }),
367
401
  service.listCustomModels({}).catch(() => {
368
402
  this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
369
- return undefined}),
403
+ return undefined
404
+ }),
370
405
  ]);
371
406
 
372
407
  if (!foundationals?.modelSummaries) {
@@ -390,7 +425,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
390
425
  provider: this.provider,
391
426
  //description: ``,
392
427
  owner: m.providerName,
393
- canStream: m.responseStreamingSupported ?? false,
428
+ can_stream: m.responseStreamingSupported ?? false,
429
+ is_multimodal: m.inputModalities?.includes("IMAGE") ?? false,
394
430
  tags: m.outputModalities ?? [],
395
431
  };
396
432
 
@@ -410,7 +446,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
410
446
  name: m.modelName ?? m.modelArn,
411
447
  provider: this.provider,
412
448
  description: `Custom model from ${m.baseModelName}`,
413
- isCustom: true,
449
+ is_custom: true,
414
450
  };
415
451
 
416
452
  aimodels.push(model);
@@ -459,55 +495,6 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
459
495
 
460
496
 
461
497
 
462
- interface LLama2RequestPayload {
463
- prompt: string;
464
- temperature: number;
465
- top_p?: number;
466
- max_gen_len: number;
467
- }
468
-
469
- interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
470
- anthropic_version: "bedrock-2023-05-31",
471
- max_tokens: number,
472
- prompt: string;
473
- temperature?: number;
474
- top_p?: number,
475
- top_k?: number,
476
- stop_sequences?: [string];
477
- }
478
-
479
- interface AI21RequestPayload {
480
- prompt: string;
481
- temperature: number;
482
- maxTokens: number;
483
- }
484
-
485
- interface CohereRequestPayload {
486
- prompt: string;
487
- temperature: number;
488
- max_tokens?: number;
489
- p?: number;
490
- }
491
-
492
- interface AmazonRequestPayload {
493
- inputText: string,
494
- textGenerationConfig: {
495
- temperature: number,
496
- topP: number,
497
- maxTokenCount: number,
498
- stopSequences: [string];
499
- };
500
- }
501
-
502
- interface MistralPayload {
503
- prompt: string,
504
- temperature: number,
505
- max_tokens: number,
506
- top_p?: number,
507
- top_k?: number,
508
- }
509
-
510
-
511
498
  function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): TrainingJob {
512
499
  const jobStatus = job.status;
513
500
  let status = TrainingJobStatus.running;
@@ -0,0 +1,67 @@
1
+ import { ClaudeMessagesPrompt } from "@llumiverse/core/formatters";
2
+
3
+ export interface LLama2RequestPayload {
4
+ prompt: string;
5
+ temperature: number;
6
+ top_p?: number;
7
+ max_gen_len: number;
8
+ }
9
+ export interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
10
+ anthropic_version: "bedrock-2023-05-31";
11
+ max_tokens: number;
12
+ prompt: string;
13
+ temperature?: number;
14
+ top_p?: number;
15
+ top_k?: number;
16
+ stop_sequences?: [string];
17
+ }
18
+ export interface AI21RequestPayload {
19
+ prompt: string;
20
+ temperature: number;
21
+ maxTokens: number;
22
+ }
23
+ export interface CohereRequestPayload {
24
+ prompt: string;
25
+ temperature: number;
26
+ max_tokens?: number;
27
+ p?: number;
28
+ }
29
+ export interface AmazonRequestPayload {
30
+ inputText: string;
31
+ textGenerationConfig: {
32
+ temperature: number;
33
+ topP: number;
34
+ maxTokenCount: number;
35
+ stopSequences: [string];
36
+ };
37
+ }
38
+ export interface MistralPayload {
39
+ prompt: string;
40
+ temperature: number;
41
+ max_tokens: number;
42
+ top_p?: number;
43
+ top_k?: number;
44
+ }
45
+
46
+ export interface CohereCommandRPayload {
47
+
48
+ message: string,
49
+ chat_history?: {
50
+ role: 'USER' | 'CHATBOT',
51
+ message: string }[],
52
+ documents?: { title: string, snippet: string }[],
53
+ search_queries_only?: boolean,
54
+ preamble?: string,
55
+ max_tokens: number,
56
+ temperature?: number,
57
+ p?: number,
58
+ k?: number,
59
+ prompt_truncation?: string,
60
+ frequency_penalty?: number,
61
+ presence_penalty?: number,
62
+ seed?: number,
63
+ return_prompt?: boolean,
64
+ stop_sequences?: string[],
65
+ raw_prompting?: boolean
66
+
67
+ }
package/src/groq/index.ts CHANGED
@@ -1,6 +1,6 @@
1
1
  import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
2
2
  import { transformAsyncIterator } from "@llumiverse/core/async";
3
- import { OpenAITextMessage, formatOpenAILikePrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
3
+ import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
4
4
  import Groq from "groq-sdk";
5
5
 
6
6
 
@@ -27,7 +27,7 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
27
27
  }
28
28
 
29
29
  // protected canStream(options: ExecutionOptions): Promise<boolean> {
30
- // if (options.resultSchema) {
30
+ // if (options.result_schema) {
31
31
  // // not yet streamign json responses
32
32
  // return Promise.resolve(false);
33
33
  // } else {
@@ -42,17 +42,17 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
42
42
  // type: "json_object",
43
43
  // }
44
44
 
45
- // return _options.resultSchema ? responseFormatJson : undefined;
45
+ // return _options.result_schema ? responseFormatJson : undefined;
46
46
  return undefined;
47
47
  }
48
48
 
49
- protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
50
- const messages = formatOpenAILikePrompt(segments);
49
+ protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<OpenAITextMessage[]> {
50
+ const messages = formatOpenAILikeTextPrompt(segments);
51
51
  //Add JSON instruction is schema is provided
52
- if (opts.resultSchema) {
52
+ if (opts.result_schema) {
53
53
  messages.push({
54
54
  role: "user",
55
- content: "IMPORTANT: " + getJSONSafetyNotice(opts.resultSchema)
55
+ content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
56
56
  });
57
57
  }
58
58
  return messages;
package/src/index.ts CHANGED
@@ -2,9 +2,11 @@ export * from "./bedrock/index.js";
2
2
  export * from "./groq/index.js";
3
3
  export * from "./huggingface_ie.js";
4
4
  export * from "./mistral/index.js";
5
- export * from "./openai.js";
5
+ export * from "./openai/azure.js";
6
+ export * from "./openai/openai.js";
6
7
  export * from "./replicate.js";
7
8
  export * from "./test/index.js";
8
9
  export * from "./togetherai/index.js";
9
10
  export * from "./vertexai/index.js";
10
11
  export * from "./watsonx/index.js";
12
+
@@ -1,6 +1,6 @@
1
1
  import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
- import { OpenAITextMessage, formatOpenAILikePrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
3
+ import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
4
4
  import { FetchClient } from "api-fetch-client";
5
5
  import { ChatCompletionResponse, CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
6
6
 
@@ -42,20 +42,20 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
42
42
  // } as ResponseFormat;
43
43
 
44
44
 
45
- // return _options.resultSchema ? responseFormatJson : responseFormatText;
45
+ // return _options.result_schema ? responseFormatJson : responseFormatText;
46
46
 
47
47
  //TODO remove this when Mistral properly supports the parameters - it makes an error for now
48
48
  // some models like mixtral mistrall tiny or medium are throwing an error when using the response_format parameter
49
49
  return undefined
50
50
  }
51
51
 
52
- protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
53
- const messages = formatOpenAILikePrompt(segments);
52
+ protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<OpenAITextMessage[]> {
53
+ const messages = formatOpenAILikeTextPrompt(segments);
54
54
  //Add JSON instruction is schema is provided
55
- if (opts.resultSchema) {
55
+ if (opts.result_schema) {
56
56
  messages.push({
57
57
  role: "user",
58
- content: "IMPORTANT: " + getJSONSafetyNotice(opts.resultSchema)
58
+ content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
59
59
  });
60
60
  }
61
61
  return messages;
@@ -0,0 +1,54 @@
1
+ import { DefaultAzureCredential, getBearerTokenProvider } from "@azure/identity";
2
+ import { DriverOptions } from "@llumiverse/core";
3
+ import { AzureOpenAI } from "openai";
4
+ import { BaseOpenAIDriver } from "./index.js";
5
+
6
+ export interface AzureOpenAIDriverOptions extends DriverOptions {
7
+
8
+ /**
9
+ * The credentials to use to access Azure OpenAI
10
+ */
11
+ azureADTokenProvider?: any; //type with azure credntials
12
+
13
+ apiKey?: string;
14
+
15
+ endpoint?: string;
16
+
17
+ apiVersion?: string
18
+
19
+ deployment?: string;
20
+
21
+ }
22
+
23
+ export class AzureOpenAIDriver extends BaseOpenAIDriver {
24
+
25
+
26
+ service: AzureOpenAI;
27
+ provider: "azure_openai";
28
+
29
+ constructor(opts: AzureOpenAIDriverOptions) {
30
+ super(opts);
31
+
32
+ if (!opts.azureADTokenProvider && !opts.apiKey) {
33
+ opts.azureADTokenProvider = this.getDefaultAuth();
34
+ }
35
+
36
+ this.service = new AzureOpenAI({
37
+ apiKey: opts.apiKey,
38
+ azureADTokenProvider: opts.azureADTokenProvider,
39
+ endpoint: opts.endpoint,
40
+ apiVersion: opts.apiVersion ?? "2024-05-01-preview",
41
+ deployment: opts.deployment
42
+ });
43
+ this.provider = "azure_openai";
44
+ }
45
+
46
+
47
+ getDefaultAuth() {
48
+ const scope = "https://cognitiveservices.azure.com/.default";
49
+ const azureADTokenProvider = getBearerTokenProvider(new DefaultAzureCredential(), scope);
50
+ return azureADTokenProvider;
51
+ }
52
+
53
+
54
+ }
@@ -15,8 +15,8 @@ import {
15
15
  TrainingPromptOptions
16
16
  } from "@llumiverse/core";
17
17
  import { asyncMap } from "@llumiverse/core/async";
18
- import { formatOpenAILikePrompt } from "@llumiverse/core/formatters";
19
- import OpenAI from "openai";
18
+ import { formatOpenAILikeMultimodalPrompt } from "@llumiverse/core/formatters";
19
+ import OpenAI, { AzureOpenAI } from "openai";
20
20
  import { Stream } from "openai/streaming";
21
21
 
22
22
  const supportFineTunning = new Set([
@@ -27,26 +27,20 @@ const supportFineTunning = new Set([
27
27
  "gpt-4-0613"
28
28
  ]);
29
29
 
30
- export interface OpenAIDriverOptions extends DriverOptions {
31
- apiKey: string;
30
+ export interface BaseOpenAIDriverOptions extends DriverOptions {
32
31
  }
33
32
 
34
- export class OpenAIDriver extends AbstractDriver<
35
- OpenAIDriverOptions,
33
+ export abstract class BaseOpenAIDriver extends AbstractDriver<
34
+ BaseOpenAIDriverOptions,
36
35
  OpenAI.Chat.Completions.ChatCompletionMessageParam[]
37
36
  > {
38
- static PROVIDER = "openai";
39
- inputContentTypes: string[] = ["text/plain"];
40
- generatedContentTypes: string[] = ["text/plain"];
41
- service: OpenAI;
42
- provider = OpenAIDriver.PROVIDER;
37
+ abstract provider: "azure_openai" | "openai";
38
+ abstract service: OpenAI | AzureOpenAI ;
43
39
 
44
- constructor(opts: OpenAIDriverOptions) {
40
+ constructor(opts: BaseOpenAIDriverOptions) {
45
41
  super(opts);
46
- this.service = new OpenAI({
47
- apiKey: opts.apiKey,
48
- });
49
- this.formatPrompt = formatOpenAILikePrompt;
42
+ this.formatPrompt = formatOpenAILikeMultimodalPrompt as any //TODO: better type, we send back OpenAI.Chat.Completions.ChatCompletionMessageParam[] but just not compatbile with Function call that we don't use here
43
+
50
44
  }
51
45
 
52
46
  extractDataFromResponse(
@@ -63,7 +57,7 @@ export class OpenAIDriver extends AbstractDriver<
63
57
  const finish_reason = choice.finish_reason;
64
58
 
65
59
  //if no schema, return content
66
- if (!options.resultSchema) {
60
+ if (!options.result_schema) {
67
61
  return {
68
62
  result: choice.message.content as string,
69
63
  token_usage: tokenInfo,
@@ -86,7 +80,7 @@ export class OpenAIDriver extends AbstractDriver<
86
80
  }
87
81
 
88
82
  async requestCompletionStream(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
89
- const mapFn = options.resultSchema
83
+ const mapFn = options.result_schema
90
84
  ? (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
91
85
  return (
92
86
  chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? ""
@@ -103,18 +97,18 @@ export class OpenAIDriver extends AbstractDriver<
103
97
  temperature: options.temperature,
104
98
  n: 1,
105
99
  max_tokens: options.max_tokens,
106
- tools: options.resultSchema
100
+ tools: options.result_schema
107
101
  ? [
108
102
  {
109
103
  function: {
110
104
  name: "format_output",
111
- parameters: options.resultSchema as any,
105
+ parameters: options.result_schema as any,
112
106
  },
113
107
  type: "function"
114
108
  } as OpenAI.Chat.ChatCompletionTool,
115
109
  ]
116
110
  : undefined,
117
- tool_choice: options.resultSchema
111
+ tool_choice: options.result_schema
118
112
  ? {
119
113
  type: 'function',
120
114
  function: { name: "format_output" }
@@ -125,12 +119,12 @@ export class OpenAIDriver extends AbstractDriver<
125
119
  }
126
120
 
127
121
  async requestCompletion(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
128
- const functions = options.resultSchema
122
+ const functions = options.result_schema
129
123
  ? [
130
124
  {
131
125
  function: {
132
126
  name: "format_output",
133
- parameters: options.resultSchema as any,
127
+ parameters: options.result_schema as any,
134
128
  },
135
129
  type: 'function'
136
130
  } as OpenAI.Chat.ChatCompletionTool,
@@ -145,13 +139,13 @@ export class OpenAIDriver extends AbstractDriver<
145
139
  n: 1,
146
140
  max_tokens: options.max_tokens,
147
141
  tools: functions,
148
- tool_choice: options.resultSchema
142
+ tool_choice: options.result_schema
149
143
  ? {
150
144
  type: 'function',
151
145
  function: { name: "format_output" }
152
146
  } : undefined,
153
147
  // functions: functions,
154
- // function_call: options.resultSchema
148
+ // function_call: options.result_schema
155
149
  // ? { name: "format_output" }
156
150
  // : undefined,
157
151
  });
@@ -163,11 +157,11 @@ export class OpenAIDriver extends AbstractDriver<
163
157
  return completion;
164
158
  }
165
159
 
166
- createTrainingPrompt(options: TrainingPromptOptions): string {
160
+ createTrainingPrompt(options: TrainingPromptOptions): Promise<string> {
167
161
  if (options.model.includes("gpt")) {
168
162
  return super.createTrainingPrompt(options);
169
163
  } else {
170
- // babbage, davinci not yet implemented
164
+ // babbage, davinci not yet implemented
171
165
  throw new Error("Unsupported model for training: " + options.model);
172
166
  }
173
167
  }
@@ -217,7 +211,7 @@ export class OpenAIDriver extends AbstractDriver<
217
211
  return this._listModels();
218
212
  }
219
213
 
220
- async _listModels(filter?: (m: OpenAI.Models.Model) => boolean) {
214
+ async _listModels(filter?: (m: OpenAI.Models.Model) => boolean): Promise<AIModel[]> {
221
215
  let result = await this.service.models.list();
222
216
  const models = filter ? result.data.filter(filter) : result.data;
223
217
  return models.map((m) => ({
@@ -226,6 +220,8 @@ export class OpenAIDriver extends AbstractDriver<
226
220
  provider: this.provider,
227
221
  owner: m.owned_by,
228
222
  type: m.object === "model" ? ModelType.Text : ModelType.Unknown,
223
+ can_stream: true,
224
+ is_multimodal: m.id.includes("gpt-4")
229
225
  }));
230
226
  }
231
227
 
@@ -0,0 +1,33 @@
1
+
2
+
3
+ import { DriverOptions } from "@llumiverse/core";
4
+ import OpenAI from "openai";
5
+ import { BaseOpenAIDriver } from "./index.js";
6
+
7
+ export interface OpenAIDriverOptions extends DriverOptions {
8
+
9
+ /**
10
+ * The OpenAI api key
11
+ */
12
+ apiKey?: string; //type with azure credntials
13
+
14
+ }
15
+
16
+
17
+
18
+
19
+ export class OpenAIDriver extends BaseOpenAIDriver {
20
+
21
+ service: OpenAI;
22
+ provider: "openai";
23
+
24
+ constructor(opts: OpenAIDriverOptions) {
25
+ super(opts);
26
+ this.service = new OpenAI({
27
+ apiKey: opts.apiKey
28
+ });
29
+ this.provider = "openai";
30
+ }
31
+
32
+
33
+ }