@llumiverse/drivers 0.13.0 → 0.14.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +12 -10
- package/lib/cjs/bedrock/index.js +46 -14
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/bedrock/payloads.js +3 -0
- package/lib/cjs/bedrock/payloads.js.map +1 -0
- package/lib/cjs/bedrock/s3.js +5 -6
- package/lib/cjs/bedrock/s3.js.map +1 -1
- package/lib/cjs/groq/index.js +6 -6
- package/lib/cjs/groq/index.js.map +1 -1
- package/lib/cjs/index.js +2 -1
- package/lib/cjs/index.js.map +1 -1
- package/lib/cjs/mistral/index.js +5 -5
- package/lib/cjs/mistral/index.js.map +1 -1
- package/lib/cjs/openai/azure.js +31 -0
- package/lib/cjs/openai/azure.js.map +1 -0
- package/lib/cjs/{openai.js → openai/index.js} +17 -27
- package/lib/cjs/openai/index.js.map +1 -0
- package/lib/cjs/openai/openai.js +21 -0
- package/lib/cjs/openai/openai.js.map +1 -0
- package/lib/cjs/replicate.js +1 -1
- package/lib/cjs/replicate.js.map +1 -1
- package/lib/cjs/test/index.js +1 -1
- package/lib/cjs/test/index.js.map +1 -1
- package/lib/cjs/test/utils.js +3 -4
- package/lib/cjs/test/utils.js.map +1 -1
- package/lib/cjs/togetherai/index.js +2 -2
- package/lib/cjs/togetherai/index.js.map +1 -1
- package/lib/cjs/vertexai/debug.js +1 -2
- package/lib/cjs/vertexai/debug.js.map +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -2
- package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
- package/lib/cjs/vertexai/index.js +3 -2
- package/lib/cjs/vertexai/index.js.map +1 -1
- package/lib/cjs/vertexai/models/codey-chat.js +3 -3
- package/lib/cjs/vertexai/models/codey-chat.js.map +1 -1
- package/lib/cjs/vertexai/models/codey-text.js +2 -2
- package/lib/cjs/vertexai/models/codey-text.js.map +1 -1
- package/lib/cjs/vertexai/models/gemini.js +37 -21
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/vertexai/models/palm-model-base.js +1 -1
- package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -1
- package/lib/cjs/vertexai/models/palm2-chat.js +3 -3
- package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -1
- package/lib/cjs/vertexai/models/palm2-text.js +2 -2
- package/lib/cjs/vertexai/models/palm2-text.js.map +1 -1
- package/lib/cjs/vertexai/models.js +39 -17
- package/lib/cjs/vertexai/models.js.map +1 -1
- package/lib/cjs/vertexai/utils/tensor.js +2 -3
- package/lib/cjs/vertexai/utils/tensor.js.map +1 -1
- package/lib/cjs/watsonx/index.js +4 -4
- package/lib/cjs/watsonx/index.js.map +1 -1
- package/lib/esm/bedrock/index.js +46 -14
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/bedrock/payloads.js +2 -0
- package/lib/esm/bedrock/payloads.js.map +1 -0
- package/lib/esm/groq/index.js +7 -7
- package/lib/esm/groq/index.js.map +1 -1
- package/lib/esm/index.js +2 -1
- package/lib/esm/index.js.map +1 -1
- package/lib/esm/mistral/index.js +6 -6
- package/lib/esm/mistral/index.js.map +1 -1
- package/lib/esm/openai/azure.js +27 -0
- package/lib/esm/openai/azure.js.map +1 -0
- package/lib/esm/{openai.js → openai/index.js} +16 -23
- package/lib/esm/openai/index.js.map +1 -0
- package/lib/esm/openai/openai.js +14 -0
- package/lib/esm/openai/openai.js.map +1 -0
- package/lib/esm/replicate.js +1 -1
- package/lib/esm/replicate.js.map +1 -1
- package/lib/esm/test/index.js +1 -1
- package/lib/esm/test/index.js.map +1 -1
- package/lib/esm/togetherai/index.js +2 -2
- package/lib/esm/togetherai/index.js.map +1 -1
- package/lib/esm/vertexai/index.js +3 -2
- package/lib/esm/vertexai/index.js.map +1 -1
- package/lib/esm/vertexai/models/codey-chat.js +3 -3
- package/lib/esm/vertexai/models/codey-chat.js.map +1 -1
- package/lib/esm/vertexai/models/codey-text.js +2 -2
- package/lib/esm/vertexai/models/codey-text.js.map +1 -1
- package/lib/esm/vertexai/models/gemini.js +38 -22
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/vertexai/models/palm-model-base.js +1 -1
- package/lib/esm/vertexai/models/palm-model-base.js.map +1 -1
- package/lib/esm/vertexai/models/palm2-chat.js +3 -3
- package/lib/esm/vertexai/models/palm2-chat.js.map +1 -1
- package/lib/esm/vertexai/models/palm2-text.js +2 -2
- package/lib/esm/vertexai/models/palm2-text.js.map +1 -1
- package/lib/esm/vertexai/models.js +35 -13
- package/lib/esm/vertexai/models.js.map +1 -1
- package/lib/esm/watsonx/index.js +4 -4
- package/lib/esm/watsonx/index.js.map +1 -1
- package/lib/types/bedrock/index.d.ts +4 -46
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/bedrock/payloads.d.ts +68 -0
- package/lib/types/bedrock/payloads.d.ts.map +1 -0
- package/lib/types/groq/index.d.ts +1 -1
- package/lib/types/groq/index.d.ts.map +1 -1
- package/lib/types/index.d.ts +2 -1
- package/lib/types/index.d.ts.map +1 -1
- package/lib/types/mistral/index.d.ts +1 -1
- package/lib/types/mistral/index.d.ts.map +1 -1
- package/lib/types/openai/azure.d.ts +20 -0
- package/lib/types/openai/azure.d.ts.map +1 -0
- package/lib/types/{openai.d.ts → openai/index.d.ts} +10 -20
- package/lib/types/openai/index.d.ts.map +1 -0
- package/lib/types/openai/openai.d.ts +15 -0
- package/lib/types/openai/openai.d.ts.map +1 -0
- package/lib/types/test/index.d.ts +2 -2
- package/lib/types/test/index.d.ts.map +1 -1
- package/lib/types/test/utils.d.ts.map +1 -1
- package/lib/types/vertexai/index.d.ts +3 -1
- package/lib/types/vertexai/index.d.ts.map +1 -1
- package/lib/types/vertexai/models/gemini.d.ts +2 -1
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
- package/lib/types/vertexai/models/palm-model-base.d.ts +1 -1
- package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -1
- package/lib/types/vertexai/models.d.ts +1 -1
- package/lib/types/vertexai/models.d.ts.map +1 -1
- package/lib/types/watsonx/index.d.ts.map +1 -1
- package/package.json +24 -18
- package/src/bedrock/index.ts +59 -72
- package/src/bedrock/payloads.ts +67 -0
- package/src/groq/index.ts +7 -7
- package/src/index.ts +3 -1
- package/src/mistral/index.ts +6 -6
- package/src/openai/azure.ts +54 -0
- package/src/{openai.ts → openai/index.ts} +24 -28
- package/src/openai/openai.ts +33 -0
- package/src/replicate.ts +5 -5
- package/src/test/index.ts +2 -3
- package/src/togetherai/index.ts +2 -2
- package/src/vertexai/index.ts +6 -3
- package/src/vertexai/models/codey-chat.ts +3 -3
- package/src/vertexai/models/codey-text.ts +2 -2
- package/src/vertexai/models/gemini.ts +50 -26
- package/src/vertexai/models/palm-model-base.ts +1 -2
- package/src/vertexai/models/palm2-chat.ts +3 -3
- package/src/vertexai/models/palm2-text.ts +2 -2
- package/src/vertexai/models.ts +42 -15
- package/src/watsonx/index.ts +4 -6
- package/lib/cjs/openai.js.map +0 -1
- package/lib/esm/openai.js.map +0 -1
- package/lib/types/openai.d.ts.map +0 -1
package/src/bedrock/index.ts
CHANGED
|
@@ -6,6 +6,7 @@ import { transformAsyncIterator } from "@llumiverse/core/async";
|
|
|
6
6
|
import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters";
|
|
7
7
|
import { AwsCredentialIdentity, Provider } from "@smithy/types";
|
|
8
8
|
import mnemonist from "mnemonist";
|
|
9
|
+
import { AI21RequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama2RequestPayload, MistralPayload } from "./payloads.js";
|
|
9
10
|
import { forceUploadFile } from "./s3.js";
|
|
10
11
|
|
|
11
12
|
const { LRUCache } = mnemonist;
|
|
@@ -23,7 +24,7 @@ export interface BedrockDriverOptions extends DriverOptions {
|
|
|
23
24
|
*/
|
|
24
25
|
region: string;
|
|
25
26
|
/**
|
|
26
|
-
* Tthe bucket name to be used for training.
|
|
27
|
+
* Tthe bucket name to be used for training.
|
|
27
28
|
* It will be created oif nto already exixts
|
|
28
29
|
*/
|
|
29
30
|
training_bucket?: string;
|
|
@@ -62,7 +63,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
62
63
|
this._executor = new BedrockRuntime({
|
|
63
64
|
region: this.options.region,
|
|
64
65
|
credentials: this.options.credentials,
|
|
65
|
-
|
|
66
|
+
|
|
66
67
|
});
|
|
67
68
|
}
|
|
68
69
|
return this._executor;
|
|
@@ -78,14 +79,13 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
78
79
|
return this._service;
|
|
79
80
|
}
|
|
80
81
|
|
|
81
|
-
protected formatPrompt(segments: PromptSegment[], opts: PromptOptions): BedrockPrompt {
|
|
82
|
+
protected async formatPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<BedrockPrompt> {
|
|
82
83
|
//TODO move the anthropic test in abstract driver?
|
|
83
84
|
if (opts.model.includes('anthropic')) {
|
|
84
85
|
//TODO: need to type better the types aren't checked properly by TS
|
|
85
|
-
|
|
86
|
-
return prompt;
|
|
86
|
+
return await formatClaudePrompt(segments, opts.result_schema);
|
|
87
87
|
} else {
|
|
88
|
-
return super.formatPrompt(segments, opts) as string;
|
|
88
|
+
return await super.formatPrompt(segments, opts) as string;
|
|
89
89
|
}
|
|
90
90
|
}
|
|
91
91
|
|
|
@@ -100,20 +100,23 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
100
100
|
// LLAMA2
|
|
101
101
|
return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
|
|
102
102
|
} else if (result.generations) {
|
|
103
|
-
//
|
|
103
|
+
// Cohere
|
|
104
104
|
return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
|
|
105
|
+
} else if (result.chat_history) {
|
|
106
|
+
//Cohere Command R
|
|
107
|
+
return [result.text, cohereFinishReason(result.finish_reason)];
|
|
105
108
|
} else if (result.completions) {
|
|
106
109
|
//A21
|
|
107
110
|
return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
|
|
108
111
|
} else if (result.content) {
|
|
109
|
-
//
|
|
112
|
+
// Claude
|
|
110
113
|
//if last prompt.messages is {, add { to the response
|
|
111
|
-
const p =
|
|
114
|
+
const p = prompt as ClaudeMessagesPrompt;
|
|
112
115
|
const lastMessage = (p as ClaudeMessagesPrompt).messages[p.messages.length - 1];
|
|
113
116
|
const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text;
|
|
114
117
|
|
|
115
|
-
return [
|
|
116
|
-
|
|
118
|
+
return [res, claudeFinishReason(result.stop_reason)];
|
|
119
|
+
|
|
117
120
|
} else if (result.outputs) {
|
|
118
121
|
// mistral
|
|
119
122
|
return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
|
|
@@ -146,7 +149,6 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
146
149
|
|
|
147
150
|
const payload = this.preparePayload(prompt, options);
|
|
148
151
|
const executor = this.getExecutor();
|
|
149
|
-
console.log("Requesting completion", JSON.stringify(payload));
|
|
150
152
|
const res = await executor.invokeModel({
|
|
151
153
|
modelId: options.model,
|
|
152
154
|
contentType: "application/json",
|
|
@@ -171,10 +173,9 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
171
173
|
return canStream;
|
|
172
174
|
}
|
|
173
175
|
|
|
174
|
-
async requestCompletionStream(prompt:
|
|
176
|
+
async requestCompletionStream(prompt: BedrockPrompt, options: ExecutionOptions): Promise<AsyncIterable<string>> {
|
|
175
177
|
const payload = this.preparePayload(prompt, options);
|
|
176
178
|
const executor = this.getExecutor();
|
|
177
|
-
console.log("Requesting completion stream", JSON.stringify(payload));
|
|
178
179
|
return executor.invokeModelWithResponseStream({
|
|
179
180
|
modelId: options.model,
|
|
180
181
|
contentType: "application/json",
|
|
@@ -186,12 +187,24 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
186
187
|
}
|
|
187
188
|
const decoder = new TextDecoder();
|
|
188
189
|
|
|
190
|
+
const addBracket = () => {
|
|
191
|
+
if (typeof prompt === 'object' && (prompt as ClaudeMessagesPrompt).messages) {
|
|
192
|
+
const p = prompt as ClaudeMessagesPrompt;
|
|
193
|
+
const lastMessage = p.messages[p.messages.length - 1];
|
|
194
|
+
return lastMessage.content[0].text === '{';
|
|
195
|
+
}
|
|
196
|
+
return false;
|
|
197
|
+
};
|
|
198
|
+
|
|
189
199
|
return transformAsyncIterator(res.body, (stream: ResponseStream) => {
|
|
190
200
|
const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
|
|
201
|
+
//console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
|
|
191
202
|
if (segment.delta) { // who is this?
|
|
192
203
|
return segment.delta.text || '';
|
|
193
204
|
} else if (segment.completion) { // who is this?
|
|
194
205
|
return segment.completion;
|
|
206
|
+
} else if (segment.text) { //cohere
|
|
207
|
+
return segment.text;
|
|
195
208
|
} else if (segment.completions) {
|
|
196
209
|
return segment.completions[0].data?.text;
|
|
197
210
|
} else if (segment.generation) {
|
|
@@ -211,7 +224,9 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
211
224
|
segment.toString();
|
|
212
225
|
}
|
|
213
226
|
|
|
214
|
-
}
|
|
227
|
+
},
|
|
228
|
+
() => addBracket() ? '{' : ''
|
|
229
|
+
);
|
|
215
230
|
|
|
216
231
|
}).catch((err) => {
|
|
217
232
|
this.logger.error("[Bedrock] Failed to stream", err);
|
|
@@ -235,11 +250,22 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
235
250
|
max_gen_len: options.max_tokens,
|
|
236
251
|
} as LLama2RequestPayload
|
|
237
252
|
} else if (contains(options.model, "claude")) {
|
|
253
|
+
|
|
254
|
+
const maxToken = () => {
|
|
255
|
+
if (options.max_tokens) {
|
|
256
|
+
return options.max_tokens;
|
|
257
|
+
|
|
258
|
+
} else if (contains(options.model, "claude-3-5")) {
|
|
259
|
+
return 8192;
|
|
260
|
+
} else {
|
|
261
|
+
return 4096
|
|
262
|
+
}
|
|
263
|
+
}
|
|
238
264
|
return {
|
|
239
265
|
anthropic_version: "bedrock-2023-05-31",
|
|
240
266
|
...(prompt as ClaudeMessagesPrompt),
|
|
241
267
|
temperature: options.temperature,
|
|
242
|
-
max_tokens:
|
|
268
|
+
max_tokens: maxToken(),
|
|
243
269
|
} as ClaudeRequestPayload;
|
|
244
270
|
} else if (contains(options.model, "ai21")) {
|
|
245
271
|
return {
|
|
@@ -247,12 +273,19 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
247
273
|
temperature: options.temperature,
|
|
248
274
|
maxTokens: options.max_tokens,
|
|
249
275
|
} as AI21RequestPayload;
|
|
250
|
-
} else if (contains(options.model, "
|
|
276
|
+
} else if (contains(options.model, "command-r-plus")) {
|
|
277
|
+
return {
|
|
278
|
+
message: prompt as string,
|
|
279
|
+
max_tokens: options.max_tokens,
|
|
280
|
+
temperature: options.temperature,
|
|
281
|
+
} as CohereCommandRPayload;
|
|
282
|
+
|
|
283
|
+
}
|
|
284
|
+
else if (contains(options.model, "cohere")) {
|
|
251
285
|
return {
|
|
252
286
|
prompt: prompt,
|
|
253
287
|
temperature: options.temperature,
|
|
254
288
|
max_tokens: options.max_tokens,
|
|
255
|
-
p: 0.9,
|
|
256
289
|
} as CohereRequestPayload;
|
|
257
290
|
} else if (contains(options.model, "amazon")) {
|
|
258
291
|
return {
|
|
@@ -289,7 +322,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
289
322
|
}
|
|
290
323
|
|
|
291
324
|
const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
|
|
292
|
-
const
|
|
325
|
+
const stream = await dataset.getStream();
|
|
326
|
+
const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name);
|
|
293
327
|
|
|
294
328
|
const service = this.getService();
|
|
295
329
|
const response = await service.send(new CreateModelCustomizationJobCommand({
|
|
@@ -363,10 +397,11 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
363
397
|
service.listFoundationModels({}).catch(() => {
|
|
364
398
|
this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
|
|
365
399
|
return undefined
|
|
366
|
-
|
|
400
|
+
}),
|
|
367
401
|
service.listCustomModels({}).catch(() => {
|
|
368
402
|
this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
|
|
369
|
-
return undefined
|
|
403
|
+
return undefined
|
|
404
|
+
}),
|
|
370
405
|
]);
|
|
371
406
|
|
|
372
407
|
if (!foundationals?.modelSummaries) {
|
|
@@ -390,7 +425,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
390
425
|
provider: this.provider,
|
|
391
426
|
//description: ``,
|
|
392
427
|
owner: m.providerName,
|
|
393
|
-
|
|
428
|
+
can_stream: m.responseStreamingSupported ?? false,
|
|
429
|
+
is_multimodal: m.inputModalities?.includes("IMAGE") ?? false,
|
|
394
430
|
tags: m.outputModalities ?? [],
|
|
395
431
|
};
|
|
396
432
|
|
|
@@ -410,7 +446,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
410
446
|
name: m.modelName ?? m.modelArn,
|
|
411
447
|
provider: this.provider,
|
|
412
448
|
description: `Custom model from ${m.baseModelName}`,
|
|
413
|
-
|
|
449
|
+
is_custom: true,
|
|
414
450
|
};
|
|
415
451
|
|
|
416
452
|
aimodels.push(model);
|
|
@@ -459,55 +495,6 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
459
495
|
|
|
460
496
|
|
|
461
497
|
|
|
462
|
-
interface LLama2RequestPayload {
|
|
463
|
-
prompt: string;
|
|
464
|
-
temperature: number;
|
|
465
|
-
top_p?: number;
|
|
466
|
-
max_gen_len: number;
|
|
467
|
-
}
|
|
468
|
-
|
|
469
|
-
interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
|
|
470
|
-
anthropic_version: "bedrock-2023-05-31",
|
|
471
|
-
max_tokens: number,
|
|
472
|
-
prompt: string;
|
|
473
|
-
temperature?: number;
|
|
474
|
-
top_p?: number,
|
|
475
|
-
top_k?: number,
|
|
476
|
-
stop_sequences?: [string];
|
|
477
|
-
}
|
|
478
|
-
|
|
479
|
-
interface AI21RequestPayload {
|
|
480
|
-
prompt: string;
|
|
481
|
-
temperature: number;
|
|
482
|
-
maxTokens: number;
|
|
483
|
-
}
|
|
484
|
-
|
|
485
|
-
interface CohereRequestPayload {
|
|
486
|
-
prompt: string;
|
|
487
|
-
temperature: number;
|
|
488
|
-
max_tokens?: number;
|
|
489
|
-
p?: number;
|
|
490
|
-
}
|
|
491
|
-
|
|
492
|
-
interface AmazonRequestPayload {
|
|
493
|
-
inputText: string,
|
|
494
|
-
textGenerationConfig: {
|
|
495
|
-
temperature: number,
|
|
496
|
-
topP: number,
|
|
497
|
-
maxTokenCount: number,
|
|
498
|
-
stopSequences: [string];
|
|
499
|
-
};
|
|
500
|
-
}
|
|
501
|
-
|
|
502
|
-
interface MistralPayload {
|
|
503
|
-
prompt: string,
|
|
504
|
-
temperature: number,
|
|
505
|
-
max_tokens: number,
|
|
506
|
-
top_p?: number,
|
|
507
|
-
top_k?: number,
|
|
508
|
-
}
|
|
509
|
-
|
|
510
|
-
|
|
511
498
|
function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): TrainingJob {
|
|
512
499
|
const jobStatus = job.status;
|
|
513
500
|
let status = TrainingJobStatus.running;
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import { ClaudeMessagesPrompt } from "@llumiverse/core/formatters";
|
|
2
|
+
|
|
3
|
+
export interface LLama2RequestPayload {
|
|
4
|
+
prompt: string;
|
|
5
|
+
temperature: number;
|
|
6
|
+
top_p?: number;
|
|
7
|
+
max_gen_len: number;
|
|
8
|
+
}
|
|
9
|
+
export interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
|
|
10
|
+
anthropic_version: "bedrock-2023-05-31";
|
|
11
|
+
max_tokens: number;
|
|
12
|
+
prompt: string;
|
|
13
|
+
temperature?: number;
|
|
14
|
+
top_p?: number;
|
|
15
|
+
top_k?: number;
|
|
16
|
+
stop_sequences?: [string];
|
|
17
|
+
}
|
|
18
|
+
export interface AI21RequestPayload {
|
|
19
|
+
prompt: string;
|
|
20
|
+
temperature: number;
|
|
21
|
+
maxTokens: number;
|
|
22
|
+
}
|
|
23
|
+
export interface CohereRequestPayload {
|
|
24
|
+
prompt: string;
|
|
25
|
+
temperature: number;
|
|
26
|
+
max_tokens?: number;
|
|
27
|
+
p?: number;
|
|
28
|
+
}
|
|
29
|
+
export interface AmazonRequestPayload {
|
|
30
|
+
inputText: string;
|
|
31
|
+
textGenerationConfig: {
|
|
32
|
+
temperature: number;
|
|
33
|
+
topP: number;
|
|
34
|
+
maxTokenCount: number;
|
|
35
|
+
stopSequences: [string];
|
|
36
|
+
};
|
|
37
|
+
}
|
|
38
|
+
export interface MistralPayload {
|
|
39
|
+
prompt: string;
|
|
40
|
+
temperature: number;
|
|
41
|
+
max_tokens: number;
|
|
42
|
+
top_p?: number;
|
|
43
|
+
top_k?: number;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
export interface CohereCommandRPayload {
|
|
47
|
+
|
|
48
|
+
message: string,
|
|
49
|
+
chat_history?: {
|
|
50
|
+
role: 'USER' | 'CHATBOT',
|
|
51
|
+
message: string }[],
|
|
52
|
+
documents?: { title: string, snippet: string }[],
|
|
53
|
+
search_queries_only?: boolean,
|
|
54
|
+
preamble?: string,
|
|
55
|
+
max_tokens: number,
|
|
56
|
+
temperature?: number,
|
|
57
|
+
p?: number,
|
|
58
|
+
k?: number,
|
|
59
|
+
prompt_truncation?: string,
|
|
60
|
+
frequency_penalty?: number,
|
|
61
|
+
presence_penalty?: number,
|
|
62
|
+
seed?: number,
|
|
63
|
+
return_prompt?: boolean,
|
|
64
|
+
stop_sequences?: string[],
|
|
65
|
+
raw_prompting?: boolean
|
|
66
|
+
|
|
67
|
+
}
|
package/src/groq/index.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
|
|
2
2
|
import { transformAsyncIterator } from "@llumiverse/core/async";
|
|
3
|
-
import { OpenAITextMessage,
|
|
3
|
+
import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
|
|
4
4
|
import Groq from "groq-sdk";
|
|
5
5
|
|
|
6
6
|
|
|
@@ -27,7 +27,7 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
|
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
// protected canStream(options: ExecutionOptions): Promise<boolean> {
|
|
30
|
-
// if (options.
|
|
30
|
+
// if (options.result_schema) {
|
|
31
31
|
// // not yet streamign json responses
|
|
32
32
|
// return Promise.resolve(false);
|
|
33
33
|
// } else {
|
|
@@ -42,17 +42,17 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
|
|
|
42
42
|
// type: "json_object",
|
|
43
43
|
// }
|
|
44
44
|
|
|
45
|
-
// return _options.
|
|
45
|
+
// return _options.result_schema ? responseFormatJson : undefined;
|
|
46
46
|
return undefined;
|
|
47
47
|
}
|
|
48
48
|
|
|
49
|
-
protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
|
|
50
|
-
const messages =
|
|
49
|
+
protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<OpenAITextMessage[]> {
|
|
50
|
+
const messages = formatOpenAILikeTextPrompt(segments);
|
|
51
51
|
//Add JSON instruction is schema is provided
|
|
52
|
-
if (opts.
|
|
52
|
+
if (opts.result_schema) {
|
|
53
53
|
messages.push({
|
|
54
54
|
role: "user",
|
|
55
|
-
content: "IMPORTANT: " + getJSONSafetyNotice(opts.
|
|
55
|
+
content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
|
|
56
56
|
});
|
|
57
57
|
}
|
|
58
58
|
return messages;
|
package/src/index.ts
CHANGED
|
@@ -2,9 +2,11 @@ export * from "./bedrock/index.js";
|
|
|
2
2
|
export * from "./groq/index.js";
|
|
3
3
|
export * from "./huggingface_ie.js";
|
|
4
4
|
export * from "./mistral/index.js";
|
|
5
|
-
export * from "./openai.js";
|
|
5
|
+
export * from "./openai/azure.js";
|
|
6
|
+
export * from "./openai/openai.js";
|
|
6
7
|
export * from "./replicate.js";
|
|
7
8
|
export * from "./test/index.js";
|
|
8
9
|
export * from "./togetherai/index.js";
|
|
9
10
|
export * from "./vertexai/index.js";
|
|
10
11
|
export * from "./watsonx/index.js";
|
|
12
|
+
|
package/src/mistral/index.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
|
|
2
2
|
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
|
-
import { OpenAITextMessage,
|
|
3
|
+
import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
|
|
4
4
|
import { FetchClient } from "api-fetch-client";
|
|
5
5
|
import { ChatCompletionResponse, CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
|
|
6
6
|
|
|
@@ -42,20 +42,20 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
|
|
|
42
42
|
// } as ResponseFormat;
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
// return _options.
|
|
45
|
+
// return _options.result_schema ? responseFormatJson : responseFormatText;
|
|
46
46
|
|
|
47
47
|
//TODO remove this when Mistral properly supports the parameters - it makes an error for now
|
|
48
48
|
// some models like mixtral mistrall tiny or medium are throwing an error when using the response_format parameter
|
|
49
49
|
return undefined
|
|
50
50
|
}
|
|
51
51
|
|
|
52
|
-
protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
|
|
53
|
-
const messages =
|
|
52
|
+
protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<OpenAITextMessage[]> {
|
|
53
|
+
const messages = formatOpenAILikeTextPrompt(segments);
|
|
54
54
|
//Add JSON instruction is schema is provided
|
|
55
|
-
if (opts.
|
|
55
|
+
if (opts.result_schema) {
|
|
56
56
|
messages.push({
|
|
57
57
|
role: "user",
|
|
58
|
-
content: "IMPORTANT: " + getJSONSafetyNotice(opts.
|
|
58
|
+
content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
|
|
59
59
|
});
|
|
60
60
|
}
|
|
61
61
|
return messages;
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import { DefaultAzureCredential, getBearerTokenProvider } from "@azure/identity";
|
|
2
|
+
import { DriverOptions } from "@llumiverse/core";
|
|
3
|
+
import { AzureOpenAI } from "openai";
|
|
4
|
+
import { BaseOpenAIDriver } from "./index.js";
|
|
5
|
+
|
|
6
|
+
export interface AzureOpenAIDriverOptions extends DriverOptions {
|
|
7
|
+
|
|
8
|
+
/**
|
|
9
|
+
* The credentials to use to access Azure OpenAI
|
|
10
|
+
*/
|
|
11
|
+
azureADTokenProvider?: any; //type with azure credntials
|
|
12
|
+
|
|
13
|
+
apiKey?: string;
|
|
14
|
+
|
|
15
|
+
endpoint?: string;
|
|
16
|
+
|
|
17
|
+
apiVersion?: string
|
|
18
|
+
|
|
19
|
+
deployment?: string;
|
|
20
|
+
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export class AzureOpenAIDriver extends BaseOpenAIDriver {
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
service: AzureOpenAI;
|
|
27
|
+
provider: "azure_openai";
|
|
28
|
+
|
|
29
|
+
constructor(opts: AzureOpenAIDriverOptions) {
|
|
30
|
+
super(opts);
|
|
31
|
+
|
|
32
|
+
if (!opts.azureADTokenProvider && !opts.apiKey) {
|
|
33
|
+
opts.azureADTokenProvider = this.getDefaultAuth();
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
this.service = new AzureOpenAI({
|
|
37
|
+
apiKey: opts.apiKey,
|
|
38
|
+
azureADTokenProvider: opts.azureADTokenProvider,
|
|
39
|
+
endpoint: opts.endpoint,
|
|
40
|
+
apiVersion: opts.apiVersion ?? "2024-05-01-preview",
|
|
41
|
+
deployment: opts.deployment
|
|
42
|
+
});
|
|
43
|
+
this.provider = "azure_openai";
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
getDefaultAuth() {
|
|
48
|
+
const scope = "https://cognitiveservices.azure.com/.default";
|
|
49
|
+
const azureADTokenProvider = getBearerTokenProvider(new DefaultAzureCredential(), scope);
|
|
50
|
+
return azureADTokenProvider;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
}
|
|
@@ -15,8 +15,8 @@ import {
|
|
|
15
15
|
TrainingPromptOptions
|
|
16
16
|
} from "@llumiverse/core";
|
|
17
17
|
import { asyncMap } from "@llumiverse/core/async";
|
|
18
|
-
import {
|
|
19
|
-
import OpenAI from "openai";
|
|
18
|
+
import { formatOpenAILikeMultimodalPrompt } from "@llumiverse/core/formatters";
|
|
19
|
+
import OpenAI, { AzureOpenAI } from "openai";
|
|
20
20
|
import { Stream } from "openai/streaming";
|
|
21
21
|
|
|
22
22
|
const supportFineTunning = new Set([
|
|
@@ -27,26 +27,20 @@ const supportFineTunning = new Set([
|
|
|
27
27
|
"gpt-4-0613"
|
|
28
28
|
]);
|
|
29
29
|
|
|
30
|
-
export interface
|
|
31
|
-
apiKey: string;
|
|
30
|
+
export interface BaseOpenAIDriverOptions extends DriverOptions {
|
|
32
31
|
}
|
|
33
32
|
|
|
34
|
-
export class
|
|
35
|
-
|
|
33
|
+
export abstract class BaseOpenAIDriver extends AbstractDriver<
|
|
34
|
+
BaseOpenAIDriverOptions,
|
|
36
35
|
OpenAI.Chat.Completions.ChatCompletionMessageParam[]
|
|
37
36
|
> {
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
generatedContentTypes: string[] = ["text/plain"];
|
|
41
|
-
service: OpenAI;
|
|
42
|
-
provider = OpenAIDriver.PROVIDER;
|
|
37
|
+
abstract provider: "azure_openai" | "openai";
|
|
38
|
+
abstract service: OpenAI | AzureOpenAI ;
|
|
43
39
|
|
|
44
|
-
constructor(opts:
|
|
40
|
+
constructor(opts: BaseOpenAIDriverOptions) {
|
|
45
41
|
super(opts);
|
|
46
|
-
this.
|
|
47
|
-
|
|
48
|
-
});
|
|
49
|
-
this.formatPrompt = formatOpenAILikePrompt;
|
|
42
|
+
this.formatPrompt = formatOpenAILikeMultimodalPrompt as any //TODO: better type, we send back OpenAI.Chat.Completions.ChatCompletionMessageParam[] but just not compatbile with Function call that we don't use here
|
|
43
|
+
|
|
50
44
|
}
|
|
51
45
|
|
|
52
46
|
extractDataFromResponse(
|
|
@@ -63,7 +57,7 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
63
57
|
const finish_reason = choice.finish_reason;
|
|
64
58
|
|
|
65
59
|
//if no schema, return content
|
|
66
|
-
if (!options.
|
|
60
|
+
if (!options.result_schema) {
|
|
67
61
|
return {
|
|
68
62
|
result: choice.message.content as string,
|
|
69
63
|
token_usage: tokenInfo,
|
|
@@ -86,7 +80,7 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
86
80
|
}
|
|
87
81
|
|
|
88
82
|
async requestCompletionStream(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
|
|
89
|
-
const mapFn = options.
|
|
83
|
+
const mapFn = options.result_schema
|
|
90
84
|
? (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
|
|
91
85
|
return (
|
|
92
86
|
chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? ""
|
|
@@ -103,18 +97,18 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
103
97
|
temperature: options.temperature,
|
|
104
98
|
n: 1,
|
|
105
99
|
max_tokens: options.max_tokens,
|
|
106
|
-
tools: options.
|
|
100
|
+
tools: options.result_schema
|
|
107
101
|
? [
|
|
108
102
|
{
|
|
109
103
|
function: {
|
|
110
104
|
name: "format_output",
|
|
111
|
-
parameters: options.
|
|
105
|
+
parameters: options.result_schema as any,
|
|
112
106
|
},
|
|
113
107
|
type: "function"
|
|
114
108
|
} as OpenAI.Chat.ChatCompletionTool,
|
|
115
109
|
]
|
|
116
110
|
: undefined,
|
|
117
|
-
tool_choice: options.
|
|
111
|
+
tool_choice: options.result_schema
|
|
118
112
|
? {
|
|
119
113
|
type: 'function',
|
|
120
114
|
function: { name: "format_output" }
|
|
@@ -125,12 +119,12 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
125
119
|
}
|
|
126
120
|
|
|
127
121
|
async requestCompletion(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
|
|
128
|
-
const functions = options.
|
|
122
|
+
const functions = options.result_schema
|
|
129
123
|
? [
|
|
130
124
|
{
|
|
131
125
|
function: {
|
|
132
126
|
name: "format_output",
|
|
133
|
-
parameters: options.
|
|
127
|
+
parameters: options.result_schema as any,
|
|
134
128
|
},
|
|
135
129
|
type: 'function'
|
|
136
130
|
} as OpenAI.Chat.ChatCompletionTool,
|
|
@@ -145,13 +139,13 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
145
139
|
n: 1,
|
|
146
140
|
max_tokens: options.max_tokens,
|
|
147
141
|
tools: functions,
|
|
148
|
-
tool_choice: options.
|
|
142
|
+
tool_choice: options.result_schema
|
|
149
143
|
? {
|
|
150
144
|
type: 'function',
|
|
151
145
|
function: { name: "format_output" }
|
|
152
146
|
} : undefined,
|
|
153
147
|
// functions: functions,
|
|
154
|
-
// function_call: options.
|
|
148
|
+
// function_call: options.result_schema
|
|
155
149
|
// ? { name: "format_output" }
|
|
156
150
|
// : undefined,
|
|
157
151
|
});
|
|
@@ -163,11 +157,11 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
163
157
|
return completion;
|
|
164
158
|
}
|
|
165
159
|
|
|
166
|
-
createTrainingPrompt(options: TrainingPromptOptions): string {
|
|
160
|
+
createTrainingPrompt(options: TrainingPromptOptions): Promise<string> {
|
|
167
161
|
if (options.model.includes("gpt")) {
|
|
168
162
|
return super.createTrainingPrompt(options);
|
|
169
163
|
} else {
|
|
170
|
-
// babbage, davinci not yet implemented
|
|
164
|
+
// babbage, davinci not yet implemented
|
|
171
165
|
throw new Error("Unsupported model for training: " + options.model);
|
|
172
166
|
}
|
|
173
167
|
}
|
|
@@ -217,7 +211,7 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
217
211
|
return this._listModels();
|
|
218
212
|
}
|
|
219
213
|
|
|
220
|
-
async _listModels(filter?: (m: OpenAI.Models.Model) => boolean) {
|
|
214
|
+
async _listModels(filter?: (m: OpenAI.Models.Model) => boolean): Promise<AIModel[]> {
|
|
221
215
|
let result = await this.service.models.list();
|
|
222
216
|
const models = filter ? result.data.filter(filter) : result.data;
|
|
223
217
|
return models.map((m) => ({
|
|
@@ -226,6 +220,8 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
226
220
|
provider: this.provider,
|
|
227
221
|
owner: m.owned_by,
|
|
228
222
|
type: m.object === "model" ? ModelType.Text : ModelType.Unknown,
|
|
223
|
+
can_stream: true,
|
|
224
|
+
is_multimodal: m.id.includes("gpt-4")
|
|
229
225
|
}));
|
|
230
226
|
}
|
|
231
227
|
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import { DriverOptions } from "@llumiverse/core";
|
|
4
|
+
import OpenAI from "openai";
|
|
5
|
+
import { BaseOpenAIDriver } from "./index.js";
|
|
6
|
+
|
|
7
|
+
export interface OpenAIDriverOptions extends DriverOptions {
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* The OpenAI api key
|
|
11
|
+
*/
|
|
12
|
+
apiKey?: string; //type with azure credntials
|
|
13
|
+
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
export class OpenAIDriver extends BaseOpenAIDriver {
|
|
20
|
+
|
|
21
|
+
service: OpenAI;
|
|
22
|
+
provider: "openai";
|
|
23
|
+
|
|
24
|
+
constructor(opts: OpenAIDriverOptions) {
|
|
25
|
+
super(opts);
|
|
26
|
+
this.service = new OpenAI({
|
|
27
|
+
apiKey: opts.apiKey
|
|
28
|
+
});
|
|
29
|
+
this.provider = "openai";
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
}
|