@llumiverse/drivers 0.12.3 → 0.14.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +13 -11
- package/lib/cjs/bedrock/index.js +63 -18
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/bedrock/payloads.js +3 -0
- package/lib/cjs/bedrock/payloads.js.map +1 -0
- package/lib/cjs/bedrock/s3.js +5 -6
- package/lib/cjs/bedrock/s3.js.map +1 -1
- package/lib/cjs/groq/index.js +6 -6
- package/lib/cjs/groq/index.js.map +1 -1
- package/lib/cjs/huggingface_ie.js.map +1 -1
- package/lib/cjs/index.js +4 -2
- package/lib/cjs/index.js.map +1 -1
- package/lib/cjs/mistral/index.js +5 -5
- package/lib/cjs/mistral/index.js.map +1 -1
- package/lib/cjs/openai/azure.js +31 -0
- package/lib/cjs/openai/azure.js.map +1 -0
- package/lib/cjs/{openai.js → openai/index.js} +17 -27
- package/lib/cjs/openai/index.js.map +1 -0
- package/lib/cjs/openai/openai.js +21 -0
- package/lib/cjs/openai/openai.js.map +1 -0
- package/lib/cjs/replicate.js +1 -1
- package/lib/cjs/replicate.js.map +1 -1
- package/lib/cjs/test/index.js +1 -1
- package/lib/cjs/test/index.js.map +1 -1
- package/lib/cjs/test/utils.js +3 -4
- package/lib/cjs/test/utils.js.map +1 -1
- package/lib/cjs/togetherai/index.js +2 -2
- package/lib/cjs/togetherai/index.js.map +1 -1
- package/lib/cjs/vertexai/debug.js +1 -2
- package/lib/cjs/vertexai/debug.js.map +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -2
- package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
- package/lib/cjs/vertexai/index.js +3 -2
- package/lib/cjs/vertexai/index.js.map +1 -1
- package/lib/cjs/vertexai/models/codey-chat.js +3 -3
- package/lib/cjs/vertexai/models/codey-chat.js.map +1 -1
- package/lib/cjs/vertexai/models/codey-text.js +2 -2
- package/lib/cjs/vertexai/models/codey-text.js.map +1 -1
- package/lib/cjs/vertexai/models/gemini.js +43 -27
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/vertexai/models/palm-model-base.js +1 -1
- package/lib/cjs/vertexai/models/palm-model-base.js.map +1 -1
- package/lib/cjs/vertexai/models/palm2-chat.js +3 -3
- package/lib/cjs/vertexai/models/palm2-chat.js.map +1 -1
- package/lib/cjs/vertexai/models/palm2-text.js +2 -2
- package/lib/cjs/vertexai/models/palm2-text.js.map +1 -1
- package/lib/cjs/vertexai/models.js +39 -17
- package/lib/cjs/vertexai/models.js.map +1 -1
- package/lib/cjs/vertexai/utils/tensor.js +2 -3
- package/lib/cjs/vertexai/utils/tensor.js.map +1 -1
- package/lib/cjs/watsonx/index.js +124 -0
- package/lib/cjs/watsonx/index.js.map +1 -0
- package/lib/cjs/watsonx/interfaces.js +3 -0
- package/lib/cjs/watsonx/interfaces.js.map +1 -0
- package/lib/esm/bedrock/index.js +63 -18
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/bedrock/payloads.js +2 -0
- package/lib/esm/bedrock/payloads.js.map +1 -0
- package/lib/esm/groq/index.js +7 -7
- package/lib/esm/groq/index.js.map +1 -1
- package/lib/esm/huggingface_ie.js.map +1 -1
- package/lib/esm/index.js +4 -2
- package/lib/esm/index.js.map +1 -1
- package/lib/esm/mistral/index.js +6 -6
- package/lib/esm/mistral/index.js.map +1 -1
- package/lib/esm/openai/azure.js +27 -0
- package/lib/esm/openai/azure.js.map +1 -0
- package/lib/esm/{openai.js → openai/index.js} +16 -23
- package/lib/esm/openai/index.js.map +1 -0
- package/lib/esm/openai/openai.js +14 -0
- package/lib/esm/openai/openai.js.map +1 -0
- package/lib/esm/replicate.js +1 -1
- package/lib/esm/replicate.js.map +1 -1
- package/lib/esm/test/index.js +1 -1
- package/lib/esm/test/index.js.map +1 -1
- package/lib/esm/togetherai/index.js +2 -2
- package/lib/esm/togetherai/index.js.map +1 -1
- package/lib/esm/vertexai/index.js +3 -2
- package/lib/esm/vertexai/index.js.map +1 -1
- package/lib/esm/vertexai/models/codey-chat.js +3 -3
- package/lib/esm/vertexai/models/codey-chat.js.map +1 -1
- package/lib/esm/vertexai/models/codey-text.js +2 -2
- package/lib/esm/vertexai/models/codey-text.js.map +1 -1
- package/lib/esm/vertexai/models/gemini.js +44 -28
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/vertexai/models/palm-model-base.js +1 -1
- package/lib/esm/vertexai/models/palm-model-base.js.map +1 -1
- package/lib/esm/vertexai/models/palm2-chat.js +3 -3
- package/lib/esm/vertexai/models/palm2-chat.js.map +1 -1
- package/lib/esm/vertexai/models/palm2-text.js +2 -2
- package/lib/esm/vertexai/models/palm2-text.js.map +1 -1
- package/lib/esm/vertexai/models.js +35 -13
- package/lib/esm/vertexai/models.js.map +1 -1
- package/lib/esm/watsonx/index.js +120 -0
- package/lib/esm/watsonx/index.js.map +1 -0
- package/lib/esm/watsonx/interfaces.js +2 -0
- package/lib/esm/watsonx/interfaces.js.map +1 -0
- package/lib/types/bedrock/index.d.ts +4 -46
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/bedrock/payloads.d.ts +68 -0
- package/lib/types/bedrock/payloads.d.ts.map +1 -0
- package/lib/types/groq/index.d.ts +1 -1
- package/lib/types/groq/index.d.ts.map +1 -1
- package/lib/types/huggingface_ie.d.ts +5 -5
- package/lib/types/index.d.ts +4 -2
- package/lib/types/index.d.ts.map +1 -1
- package/lib/types/mistral/index.d.ts +1 -1
- package/lib/types/mistral/index.d.ts.map +1 -1
- package/lib/types/openai/azure.d.ts +20 -0
- package/lib/types/openai/azure.d.ts.map +1 -0
- package/lib/types/{openai.d.ts → openai/index.d.ts} +10 -20
- package/lib/types/openai/index.d.ts.map +1 -0
- package/lib/types/openai/openai.d.ts +15 -0
- package/lib/types/openai/openai.d.ts.map +1 -0
- package/lib/types/test/index.d.ts +2 -2
- package/lib/types/test/index.d.ts.map +1 -1
- package/lib/types/test/utils.d.ts.map +1 -1
- package/lib/types/vertexai/index.d.ts +3 -1
- package/lib/types/vertexai/index.d.ts.map +1 -1
- package/lib/types/vertexai/models/gemini.d.ts +2 -1
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
- package/lib/types/vertexai/models/palm-model-base.d.ts +1 -1
- package/lib/types/vertexai/models/palm-model-base.d.ts.map +1 -1
- package/lib/types/vertexai/models.d.ts +1 -1
- package/lib/types/vertexai/models.d.ts.map +1 -1
- package/lib/types/watsonx/index.d.ts +27 -0
- package/lib/types/watsonx/index.d.ts.map +1 -0
- package/lib/types/watsonx/interfaces.d.ts +61 -0
- package/lib/types/watsonx/interfaces.d.ts.map +1 -0
- package/package.json +24 -18
- package/src/bedrock/index.ts +72 -70
- package/src/bedrock/payloads.ts +67 -0
- package/src/groq/index.ts +7 -7
- package/src/huggingface_ie.ts +1 -1
- package/src/index.ts +5 -2
- package/src/mistral/index.ts +6 -6
- package/src/openai/azure.ts +54 -0
- package/src/{openai.ts → openai/index.ts} +24 -28
- package/src/openai/openai.ts +33 -0
- package/src/replicate.ts +5 -5
- package/src/test/index.ts +2 -3
- package/src/togetherai/index.ts +2 -2
- package/src/vertexai/index.ts +6 -3
- package/src/vertexai/models/codey-chat.ts +3 -3
- package/src/vertexai/models/codey-text.ts +2 -2
- package/src/vertexai/models/gemini.ts +56 -32
- package/src/vertexai/models/palm-model-base.ts +1 -2
- package/src/vertexai/models/palm2-chat.ts +3 -3
- package/src/vertexai/models/palm2-text.ts +2 -2
- package/src/vertexai/models.ts +42 -15
- package/src/watsonx/index.ts +161 -0
- package/src/watsonx/interfaces.ts +71 -0
- package/lib/cjs/openai.js.map +0 -1
- package/lib/esm/openai.js.map +0 -1
- package/lib/types/openai.d.ts.map +0 -1
package/src/bedrock/index.ts
CHANGED
|
@@ -6,6 +6,7 @@ import { transformAsyncIterator } from "@llumiverse/core/async";
|
|
|
6
6
|
import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters";
|
|
7
7
|
import { AwsCredentialIdentity, Provider } from "@smithy/types";
|
|
8
8
|
import mnemonist from "mnemonist";
|
|
9
|
+
import { AI21RequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama2RequestPayload, MistralPayload } from "./payloads.js";
|
|
9
10
|
import { forceUploadFile } from "./s3.js";
|
|
10
11
|
|
|
11
12
|
const { LRUCache } = mnemonist;
|
|
@@ -23,7 +24,7 @@ export interface BedrockDriverOptions extends DriverOptions {
|
|
|
23
24
|
*/
|
|
24
25
|
region: string;
|
|
25
26
|
/**
|
|
26
|
-
* Tthe bucket name to be used for training.
|
|
27
|
+
* Tthe bucket name to be used for training.
|
|
27
28
|
* It will be created oif nto already exixts
|
|
28
29
|
*/
|
|
29
30
|
training_bucket?: string;
|
|
@@ -62,6 +63,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
62
63
|
this._executor = new BedrockRuntime({
|
|
63
64
|
region: this.options.region,
|
|
64
65
|
credentials: this.options.credentials,
|
|
66
|
+
|
|
65
67
|
});
|
|
66
68
|
}
|
|
67
69
|
return this._executor;
|
|
@@ -77,13 +79,13 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
77
79
|
return this._service;
|
|
78
80
|
}
|
|
79
81
|
|
|
80
|
-
protected formatPrompt(segments: PromptSegment[], opts: PromptOptions): BedrockPrompt {
|
|
82
|
+
protected async formatPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<BedrockPrompt> {
|
|
81
83
|
//TODO move the anthropic test in abstract driver?
|
|
82
84
|
if (opts.model.includes('anthropic')) {
|
|
83
85
|
//TODO: need to type better the types aren't checked properly by TS
|
|
84
|
-
return formatClaudePrompt(segments, opts.
|
|
86
|
+
return await formatClaudePrompt(segments, opts.result_schema);
|
|
85
87
|
} else {
|
|
86
|
-
return super.formatPrompt(segments, opts) as string;
|
|
88
|
+
return await super.formatPrompt(segments, opts) as string;
|
|
87
89
|
}
|
|
88
90
|
}
|
|
89
91
|
|
|
@@ -98,14 +100,23 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
98
100
|
// LLAMA2
|
|
99
101
|
return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
|
|
100
102
|
} else if (result.generations) {
|
|
101
|
-
//
|
|
103
|
+
// Cohere
|
|
102
104
|
return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
|
|
105
|
+
} else if (result.chat_history) {
|
|
106
|
+
//Cohere Command R
|
|
107
|
+
return [result.text, cohereFinishReason(result.finish_reason)];
|
|
103
108
|
} else if (result.completions) {
|
|
104
109
|
//A21
|
|
105
110
|
return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
|
|
106
111
|
} else if (result.content) {
|
|
107
|
-
//
|
|
108
|
-
|
|
112
|
+
// Claude
|
|
113
|
+
//if last prompt.messages is {, add { to the response
|
|
114
|
+
const p = prompt as ClaudeMessagesPrompt;
|
|
115
|
+
const lastMessage = (p as ClaudeMessagesPrompt).messages[p.messages.length - 1];
|
|
116
|
+
const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text;
|
|
117
|
+
|
|
118
|
+
return [res, claudeFinishReason(result.stop_reason)];
|
|
119
|
+
|
|
109
120
|
} else if (result.outputs) {
|
|
110
121
|
// mistral
|
|
111
122
|
return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
|
|
@@ -162,7 +173,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
162
173
|
return canStream;
|
|
163
174
|
}
|
|
164
175
|
|
|
165
|
-
async requestCompletionStream(prompt:
|
|
176
|
+
async requestCompletionStream(prompt: BedrockPrompt, options: ExecutionOptions): Promise<AsyncIterable<string>> {
|
|
166
177
|
const payload = this.preparePayload(prompt, options);
|
|
167
178
|
const executor = this.getExecutor();
|
|
168
179
|
return executor.invokeModelWithResponseStream({
|
|
@@ -176,12 +187,24 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
176
187
|
}
|
|
177
188
|
const decoder = new TextDecoder();
|
|
178
189
|
|
|
190
|
+
const addBracket = () => {
|
|
191
|
+
if (typeof prompt === 'object' && (prompt as ClaudeMessagesPrompt).messages) {
|
|
192
|
+
const p = prompt as ClaudeMessagesPrompt;
|
|
193
|
+
const lastMessage = p.messages[p.messages.length - 1];
|
|
194
|
+
return lastMessage.content[0].text === '{';
|
|
195
|
+
}
|
|
196
|
+
return false;
|
|
197
|
+
};
|
|
198
|
+
|
|
179
199
|
return transformAsyncIterator(res.body, (stream: ResponseStream) => {
|
|
180
200
|
const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
|
|
201
|
+
//console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
|
|
181
202
|
if (segment.delta) { // who is this?
|
|
182
203
|
return segment.delta.text || '';
|
|
183
204
|
} else if (segment.completion) { // who is this?
|
|
184
205
|
return segment.completion;
|
|
206
|
+
} else if (segment.text) { //cohere
|
|
207
|
+
return segment.text;
|
|
185
208
|
} else if (segment.completions) {
|
|
186
209
|
return segment.completions[0].data?.text;
|
|
187
210
|
} else if (segment.generation) {
|
|
@@ -201,7 +224,9 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
201
224
|
segment.toString();
|
|
202
225
|
}
|
|
203
226
|
|
|
204
|
-
}
|
|
227
|
+
},
|
|
228
|
+
() => addBracket() ? '{' : ''
|
|
229
|
+
);
|
|
205
230
|
|
|
206
231
|
}).catch((err) => {
|
|
207
232
|
this.logger.error("[Bedrock] Failed to stream", err);
|
|
@@ -224,12 +249,23 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
224
249
|
temperature: options.temperature,
|
|
225
250
|
max_gen_len: options.max_tokens,
|
|
226
251
|
} as LLama2RequestPayload
|
|
227
|
-
} else if (contains(options.model, "
|
|
252
|
+
} else if (contains(options.model, "claude")) {
|
|
253
|
+
|
|
254
|
+
const maxToken = () => {
|
|
255
|
+
if (options.max_tokens) {
|
|
256
|
+
return options.max_tokens;
|
|
257
|
+
|
|
258
|
+
} else if (contains(options.model, "claude-3-5")) {
|
|
259
|
+
return 8192;
|
|
260
|
+
} else {
|
|
261
|
+
return 4096
|
|
262
|
+
}
|
|
263
|
+
}
|
|
228
264
|
return {
|
|
229
265
|
anthropic_version: "bedrock-2023-05-31",
|
|
230
266
|
...(prompt as ClaudeMessagesPrompt),
|
|
231
267
|
temperature: options.temperature,
|
|
232
|
-
max_tokens:
|
|
268
|
+
max_tokens: maxToken(),
|
|
233
269
|
} as ClaudeRequestPayload;
|
|
234
270
|
} else if (contains(options.model, "ai21")) {
|
|
235
271
|
return {
|
|
@@ -237,12 +273,19 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
237
273
|
temperature: options.temperature,
|
|
238
274
|
maxTokens: options.max_tokens,
|
|
239
275
|
} as AI21RequestPayload;
|
|
240
|
-
} else if (contains(options.model, "
|
|
276
|
+
} else if (contains(options.model, "command-r-plus")) {
|
|
277
|
+
return {
|
|
278
|
+
message: prompt as string,
|
|
279
|
+
max_tokens: options.max_tokens,
|
|
280
|
+
temperature: options.temperature,
|
|
281
|
+
} as CohereCommandRPayload;
|
|
282
|
+
|
|
283
|
+
}
|
|
284
|
+
else if (contains(options.model, "cohere")) {
|
|
241
285
|
return {
|
|
242
286
|
prompt: prompt,
|
|
243
287
|
temperature: options.temperature,
|
|
244
288
|
max_tokens: options.max_tokens,
|
|
245
|
-
p: 0.9,
|
|
246
289
|
} as CohereRequestPayload;
|
|
247
290
|
} else if (contains(options.model, "amazon")) {
|
|
248
291
|
return {
|
|
@@ -279,7 +322,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
279
322
|
}
|
|
280
323
|
|
|
281
324
|
const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
|
|
282
|
-
const
|
|
325
|
+
const stream = await dataset.getStream();
|
|
326
|
+
const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name);
|
|
283
327
|
|
|
284
328
|
const service = this.getService();
|
|
285
329
|
const response = await service.send(new CreateModelCustomizationJobCommand({
|
|
@@ -350,11 +394,17 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
350
394
|
async _listModels(foundationFilter?: (m: FoundationModelSummary) => boolean): Promise<AIModel[]> {
|
|
351
395
|
const service = this.getService();
|
|
352
396
|
const [foundationals, customs] = await Promise.all([
|
|
353
|
-
service.listFoundationModels({})
|
|
354
|
-
|
|
397
|
+
service.listFoundationModels({}).catch(() => {
|
|
398
|
+
this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
|
|
399
|
+
return undefined
|
|
400
|
+
}),
|
|
401
|
+
service.listCustomModels({}).catch(() => {
|
|
402
|
+
this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
|
|
403
|
+
return undefined
|
|
404
|
+
}),
|
|
355
405
|
]);
|
|
356
406
|
|
|
357
|
-
if (!foundationals
|
|
407
|
+
if (!foundationals?.modelSummaries) {
|
|
358
408
|
throw new Error("Foundation models not found");
|
|
359
409
|
}
|
|
360
410
|
|
|
@@ -373,9 +423,10 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
373
423
|
id: m.modelArn ?? m.modelId,
|
|
374
424
|
name: `${m.providerName} ${m.modelName}`,
|
|
375
425
|
provider: this.provider,
|
|
376
|
-
description:
|
|
426
|
+
//description: ``,
|
|
377
427
|
owner: m.providerName,
|
|
378
|
-
|
|
428
|
+
can_stream: m.responseStreamingSupported ?? false,
|
|
429
|
+
is_multimodal: m.inputModalities?.includes("IMAGE") ?? false,
|
|
379
430
|
tags: m.outputModalities ?? [],
|
|
380
431
|
};
|
|
381
432
|
|
|
@@ -383,7 +434,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
383
434
|
});
|
|
384
435
|
|
|
385
436
|
//add custom models
|
|
386
|
-
if (customs
|
|
437
|
+
if (customs?.modelSummaries) {
|
|
387
438
|
customs.modelSummaries.forEach((m) => {
|
|
388
439
|
|
|
389
440
|
if (!m.modelArn) {
|
|
@@ -395,7 +446,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
395
446
|
name: m.modelName ?? m.modelArn,
|
|
396
447
|
provider: this.provider,
|
|
397
448
|
description: `Custom model from ${m.baseModelName}`,
|
|
398
|
-
|
|
449
|
+
is_custom: true,
|
|
399
450
|
};
|
|
400
451
|
|
|
401
452
|
aimodels.push(model);
|
|
@@ -444,55 +495,6 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
444
495
|
|
|
445
496
|
|
|
446
497
|
|
|
447
|
-
interface LLama2RequestPayload {
|
|
448
|
-
prompt: string;
|
|
449
|
-
temperature: number;
|
|
450
|
-
top_p?: number;
|
|
451
|
-
max_gen_len: number;
|
|
452
|
-
}
|
|
453
|
-
|
|
454
|
-
interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
|
|
455
|
-
anthropic_version: "bedrock-2023-05-31",
|
|
456
|
-
max_tokens: number,
|
|
457
|
-
prompt: string;
|
|
458
|
-
temperature?: number;
|
|
459
|
-
top_p?: number,
|
|
460
|
-
top_k?: number,
|
|
461
|
-
stop_sequences?: [string];
|
|
462
|
-
}
|
|
463
|
-
|
|
464
|
-
interface AI21RequestPayload {
|
|
465
|
-
prompt: string;
|
|
466
|
-
temperature: number;
|
|
467
|
-
maxTokens: number;
|
|
468
|
-
}
|
|
469
|
-
|
|
470
|
-
interface CohereRequestPayload {
|
|
471
|
-
prompt: string;
|
|
472
|
-
temperature: number;
|
|
473
|
-
max_tokens?: number;
|
|
474
|
-
p?: number;
|
|
475
|
-
}
|
|
476
|
-
|
|
477
|
-
interface AmazonRequestPayload {
|
|
478
|
-
inputText: string,
|
|
479
|
-
textGenerationConfig: {
|
|
480
|
-
temperature: number,
|
|
481
|
-
topP: number,
|
|
482
|
-
maxTokenCount: number,
|
|
483
|
-
stopSequences: [string];
|
|
484
|
-
};
|
|
485
|
-
}
|
|
486
|
-
|
|
487
|
-
interface MistralPayload {
|
|
488
|
-
prompt: string,
|
|
489
|
-
temperature: number,
|
|
490
|
-
max_tokens: number,
|
|
491
|
-
top_p?: number,
|
|
492
|
-
top_k?: number,
|
|
493
|
-
}
|
|
494
|
-
|
|
495
|
-
|
|
496
498
|
function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): TrainingJob {
|
|
497
499
|
const jobStatus = job.status;
|
|
498
500
|
let status = TrainingJobStatus.running;
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import { ClaudeMessagesPrompt } from "@llumiverse/core/formatters";
|
|
2
|
+
|
|
3
|
+
export interface LLama2RequestPayload {
|
|
4
|
+
prompt: string;
|
|
5
|
+
temperature: number;
|
|
6
|
+
top_p?: number;
|
|
7
|
+
max_gen_len: number;
|
|
8
|
+
}
|
|
9
|
+
export interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
|
|
10
|
+
anthropic_version: "bedrock-2023-05-31";
|
|
11
|
+
max_tokens: number;
|
|
12
|
+
prompt: string;
|
|
13
|
+
temperature?: number;
|
|
14
|
+
top_p?: number;
|
|
15
|
+
top_k?: number;
|
|
16
|
+
stop_sequences?: [string];
|
|
17
|
+
}
|
|
18
|
+
export interface AI21RequestPayload {
|
|
19
|
+
prompt: string;
|
|
20
|
+
temperature: number;
|
|
21
|
+
maxTokens: number;
|
|
22
|
+
}
|
|
23
|
+
export interface CohereRequestPayload {
|
|
24
|
+
prompt: string;
|
|
25
|
+
temperature: number;
|
|
26
|
+
max_tokens?: number;
|
|
27
|
+
p?: number;
|
|
28
|
+
}
|
|
29
|
+
export interface AmazonRequestPayload {
|
|
30
|
+
inputText: string;
|
|
31
|
+
textGenerationConfig: {
|
|
32
|
+
temperature: number;
|
|
33
|
+
topP: number;
|
|
34
|
+
maxTokenCount: number;
|
|
35
|
+
stopSequences: [string];
|
|
36
|
+
};
|
|
37
|
+
}
|
|
38
|
+
export interface MistralPayload {
|
|
39
|
+
prompt: string;
|
|
40
|
+
temperature: number;
|
|
41
|
+
max_tokens: number;
|
|
42
|
+
top_p?: number;
|
|
43
|
+
top_k?: number;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
export interface CohereCommandRPayload {
|
|
47
|
+
|
|
48
|
+
message: string,
|
|
49
|
+
chat_history?: {
|
|
50
|
+
role: 'USER' | 'CHATBOT',
|
|
51
|
+
message: string }[],
|
|
52
|
+
documents?: { title: string, snippet: string }[],
|
|
53
|
+
search_queries_only?: boolean,
|
|
54
|
+
preamble?: string,
|
|
55
|
+
max_tokens: number,
|
|
56
|
+
temperature?: number,
|
|
57
|
+
p?: number,
|
|
58
|
+
k?: number,
|
|
59
|
+
prompt_truncation?: string,
|
|
60
|
+
frequency_penalty?: number,
|
|
61
|
+
presence_penalty?: number,
|
|
62
|
+
seed?: number,
|
|
63
|
+
return_prompt?: boolean,
|
|
64
|
+
stop_sequences?: string[],
|
|
65
|
+
raw_prompting?: boolean
|
|
66
|
+
|
|
67
|
+
}
|
package/src/groq/index.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
|
|
2
2
|
import { transformAsyncIterator } from "@llumiverse/core/async";
|
|
3
|
-
import { OpenAITextMessage,
|
|
3
|
+
import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
|
|
4
4
|
import Groq from "groq-sdk";
|
|
5
5
|
|
|
6
6
|
|
|
@@ -27,7 +27,7 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
|
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
// protected canStream(options: ExecutionOptions): Promise<boolean> {
|
|
30
|
-
// if (options.
|
|
30
|
+
// if (options.result_schema) {
|
|
31
31
|
// // not yet streamign json responses
|
|
32
32
|
// return Promise.resolve(false);
|
|
33
33
|
// } else {
|
|
@@ -42,17 +42,17 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
|
|
|
42
42
|
// type: "json_object",
|
|
43
43
|
// }
|
|
44
44
|
|
|
45
|
-
// return _options.
|
|
45
|
+
// return _options.result_schema ? responseFormatJson : undefined;
|
|
46
46
|
return undefined;
|
|
47
47
|
}
|
|
48
48
|
|
|
49
|
-
protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
|
|
50
|
-
const messages =
|
|
49
|
+
protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<OpenAITextMessage[]> {
|
|
50
|
+
const messages = formatOpenAILikeTextPrompt(segments);
|
|
51
51
|
//Add JSON instruction is schema is provided
|
|
52
|
-
if (opts.
|
|
52
|
+
if (opts.result_schema) {
|
|
53
53
|
messages.push({
|
|
54
54
|
role: "user",
|
|
55
|
-
content: "IMPORTANT: " + getJSONSafetyNotice(opts.
|
|
55
|
+
content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
|
|
56
56
|
});
|
|
57
57
|
}
|
|
58
58
|
return messages;
|
package/src/huggingface_ie.ts
CHANGED
|
@@ -92,7 +92,7 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
|
|
|
92
92
|
},
|
|
93
93
|
});
|
|
94
94
|
|
|
95
|
-
let finish_reason = res.details?.finish_reason;
|
|
95
|
+
let finish_reason = res.details?.finish_reason as string;
|
|
96
96
|
if (finish_reason === "eos_token") {
|
|
97
97
|
finish_reason = "stop";
|
|
98
98
|
}
|
package/src/index.ts
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
1
|
export * from "./bedrock/index.js";
|
|
2
|
+
export * from "./groq/index.js";
|
|
2
3
|
export * from "./huggingface_ie.js";
|
|
3
4
|
export * from "./mistral/index.js";
|
|
4
|
-
export * from "./openai.js";
|
|
5
|
+
export * from "./openai/azure.js";
|
|
6
|
+
export * from "./openai/openai.js";
|
|
5
7
|
export * from "./replicate.js";
|
|
6
8
|
export * from "./test/index.js";
|
|
7
9
|
export * from "./togetherai/index.js";
|
|
8
10
|
export * from "./vertexai/index.js";
|
|
9
|
-
export * from "./
|
|
11
|
+
export * from "./watsonx/index.js";
|
|
12
|
+
|
package/src/mistral/index.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment } from "@llumiverse/core";
|
|
2
2
|
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
|
-
import { OpenAITextMessage,
|
|
3
|
+
import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
|
|
4
4
|
import { FetchClient } from "api-fetch-client";
|
|
5
5
|
import { ChatCompletionResponse, CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
|
|
6
6
|
|
|
@@ -42,20 +42,20 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
|
|
|
42
42
|
// } as ResponseFormat;
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
// return _options.
|
|
45
|
+
// return _options.result_schema ? responseFormatJson : responseFormatText;
|
|
46
46
|
|
|
47
47
|
//TODO remove this when Mistral properly supports the parameters - it makes an error for now
|
|
48
48
|
// some models like mixtral mistrall tiny or medium are throwing an error when using the response_format parameter
|
|
49
49
|
return undefined
|
|
50
50
|
}
|
|
51
51
|
|
|
52
|
-
protected formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): OpenAITextMessage[] {
|
|
53
|
-
const messages =
|
|
52
|
+
protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<OpenAITextMessage[]> {
|
|
53
|
+
const messages = formatOpenAILikeTextPrompt(segments);
|
|
54
54
|
//Add JSON instruction is schema is provided
|
|
55
|
-
if (opts.
|
|
55
|
+
if (opts.result_schema) {
|
|
56
56
|
messages.push({
|
|
57
57
|
role: "user",
|
|
58
|
-
content: "IMPORTANT: " + getJSONSafetyNotice(opts.
|
|
58
|
+
content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
|
|
59
59
|
});
|
|
60
60
|
}
|
|
61
61
|
return messages;
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import { DefaultAzureCredential, getBearerTokenProvider } from "@azure/identity";
|
|
2
|
+
import { DriverOptions } from "@llumiverse/core";
|
|
3
|
+
import { AzureOpenAI } from "openai";
|
|
4
|
+
import { BaseOpenAIDriver } from "./index.js";
|
|
5
|
+
|
|
6
|
+
export interface AzureOpenAIDriverOptions extends DriverOptions {
|
|
7
|
+
|
|
8
|
+
/**
|
|
9
|
+
* The credentials to use to access Azure OpenAI
|
|
10
|
+
*/
|
|
11
|
+
azureADTokenProvider?: any; //type with azure credntials
|
|
12
|
+
|
|
13
|
+
apiKey?: string;
|
|
14
|
+
|
|
15
|
+
endpoint?: string;
|
|
16
|
+
|
|
17
|
+
apiVersion?: string
|
|
18
|
+
|
|
19
|
+
deployment?: string;
|
|
20
|
+
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export class AzureOpenAIDriver extends BaseOpenAIDriver {
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
service: AzureOpenAI;
|
|
27
|
+
provider: "azure_openai";
|
|
28
|
+
|
|
29
|
+
constructor(opts: AzureOpenAIDriverOptions) {
|
|
30
|
+
super(opts);
|
|
31
|
+
|
|
32
|
+
if (!opts.azureADTokenProvider && !opts.apiKey) {
|
|
33
|
+
opts.azureADTokenProvider = this.getDefaultAuth();
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
this.service = new AzureOpenAI({
|
|
37
|
+
apiKey: opts.apiKey,
|
|
38
|
+
azureADTokenProvider: opts.azureADTokenProvider,
|
|
39
|
+
endpoint: opts.endpoint,
|
|
40
|
+
apiVersion: opts.apiVersion ?? "2024-05-01-preview",
|
|
41
|
+
deployment: opts.deployment
|
|
42
|
+
});
|
|
43
|
+
this.provider = "azure_openai";
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
getDefaultAuth() {
|
|
48
|
+
const scope = "https://cognitiveservices.azure.com/.default";
|
|
49
|
+
const azureADTokenProvider = getBearerTokenProvider(new DefaultAzureCredential(), scope);
|
|
50
|
+
return azureADTokenProvider;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
}
|
|
@@ -15,8 +15,8 @@ import {
|
|
|
15
15
|
TrainingPromptOptions
|
|
16
16
|
} from "@llumiverse/core";
|
|
17
17
|
import { asyncMap } from "@llumiverse/core/async";
|
|
18
|
-
import {
|
|
19
|
-
import OpenAI from "openai";
|
|
18
|
+
import { formatOpenAILikeMultimodalPrompt } from "@llumiverse/core/formatters";
|
|
19
|
+
import OpenAI, { AzureOpenAI } from "openai";
|
|
20
20
|
import { Stream } from "openai/streaming";
|
|
21
21
|
|
|
22
22
|
const supportFineTunning = new Set([
|
|
@@ -27,26 +27,20 @@ const supportFineTunning = new Set([
|
|
|
27
27
|
"gpt-4-0613"
|
|
28
28
|
]);
|
|
29
29
|
|
|
30
|
-
export interface
|
|
31
|
-
apiKey: string;
|
|
30
|
+
export interface BaseOpenAIDriverOptions extends DriverOptions {
|
|
32
31
|
}
|
|
33
32
|
|
|
34
|
-
export class
|
|
35
|
-
|
|
33
|
+
export abstract class BaseOpenAIDriver extends AbstractDriver<
|
|
34
|
+
BaseOpenAIDriverOptions,
|
|
36
35
|
OpenAI.Chat.Completions.ChatCompletionMessageParam[]
|
|
37
36
|
> {
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
generatedContentTypes: string[] = ["text/plain"];
|
|
41
|
-
service: OpenAI;
|
|
42
|
-
provider = OpenAIDriver.PROVIDER;
|
|
37
|
+
abstract provider: "azure_openai" | "openai";
|
|
38
|
+
abstract service: OpenAI | AzureOpenAI ;
|
|
43
39
|
|
|
44
|
-
constructor(opts:
|
|
40
|
+
constructor(opts: BaseOpenAIDriverOptions) {
|
|
45
41
|
super(opts);
|
|
46
|
-
this.
|
|
47
|
-
|
|
48
|
-
});
|
|
49
|
-
this.formatPrompt = formatOpenAILikePrompt;
|
|
42
|
+
this.formatPrompt = formatOpenAILikeMultimodalPrompt as any //TODO: better type, we send back OpenAI.Chat.Completions.ChatCompletionMessageParam[] but just not compatbile with Function call that we don't use here
|
|
43
|
+
|
|
50
44
|
}
|
|
51
45
|
|
|
52
46
|
extractDataFromResponse(
|
|
@@ -63,7 +57,7 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
63
57
|
const finish_reason = choice.finish_reason;
|
|
64
58
|
|
|
65
59
|
//if no schema, return content
|
|
66
|
-
if (!options.
|
|
60
|
+
if (!options.result_schema) {
|
|
67
61
|
return {
|
|
68
62
|
result: choice.message.content as string,
|
|
69
63
|
token_usage: tokenInfo,
|
|
@@ -86,7 +80,7 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
86
80
|
}
|
|
87
81
|
|
|
88
82
|
async requestCompletionStream(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
|
|
89
|
-
const mapFn = options.
|
|
83
|
+
const mapFn = options.result_schema
|
|
90
84
|
? (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
|
|
91
85
|
return (
|
|
92
86
|
chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? ""
|
|
@@ -103,18 +97,18 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
103
97
|
temperature: options.temperature,
|
|
104
98
|
n: 1,
|
|
105
99
|
max_tokens: options.max_tokens,
|
|
106
|
-
tools: options.
|
|
100
|
+
tools: options.result_schema
|
|
107
101
|
? [
|
|
108
102
|
{
|
|
109
103
|
function: {
|
|
110
104
|
name: "format_output",
|
|
111
|
-
parameters: options.
|
|
105
|
+
parameters: options.result_schema as any,
|
|
112
106
|
},
|
|
113
107
|
type: "function"
|
|
114
108
|
} as OpenAI.Chat.ChatCompletionTool,
|
|
115
109
|
]
|
|
116
110
|
: undefined,
|
|
117
|
-
tool_choice: options.
|
|
111
|
+
tool_choice: options.result_schema
|
|
118
112
|
? {
|
|
119
113
|
type: 'function',
|
|
120
114
|
function: { name: "format_output" }
|
|
@@ -125,12 +119,12 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
125
119
|
}
|
|
126
120
|
|
|
127
121
|
async requestCompletion(prompt: OpenAI.Chat.Completions.ChatCompletionMessageParam[], options: ExecutionOptions): Promise<any> {
|
|
128
|
-
const functions = options.
|
|
122
|
+
const functions = options.result_schema
|
|
129
123
|
? [
|
|
130
124
|
{
|
|
131
125
|
function: {
|
|
132
126
|
name: "format_output",
|
|
133
|
-
parameters: options.
|
|
127
|
+
parameters: options.result_schema as any,
|
|
134
128
|
},
|
|
135
129
|
type: 'function'
|
|
136
130
|
} as OpenAI.Chat.ChatCompletionTool,
|
|
@@ -145,13 +139,13 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
145
139
|
n: 1,
|
|
146
140
|
max_tokens: options.max_tokens,
|
|
147
141
|
tools: functions,
|
|
148
|
-
tool_choice: options.
|
|
142
|
+
tool_choice: options.result_schema
|
|
149
143
|
? {
|
|
150
144
|
type: 'function',
|
|
151
145
|
function: { name: "format_output" }
|
|
152
146
|
} : undefined,
|
|
153
147
|
// functions: functions,
|
|
154
|
-
// function_call: options.
|
|
148
|
+
// function_call: options.result_schema
|
|
155
149
|
// ? { name: "format_output" }
|
|
156
150
|
// : undefined,
|
|
157
151
|
});
|
|
@@ -163,11 +157,11 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
163
157
|
return completion;
|
|
164
158
|
}
|
|
165
159
|
|
|
166
|
-
createTrainingPrompt(options: TrainingPromptOptions): string {
|
|
160
|
+
createTrainingPrompt(options: TrainingPromptOptions): Promise<string> {
|
|
167
161
|
if (options.model.includes("gpt")) {
|
|
168
162
|
return super.createTrainingPrompt(options);
|
|
169
163
|
} else {
|
|
170
|
-
// babbage, davinci not yet implemented
|
|
164
|
+
// babbage, davinci not yet implemented
|
|
171
165
|
throw new Error("Unsupported model for training: " + options.model);
|
|
172
166
|
}
|
|
173
167
|
}
|
|
@@ -217,7 +211,7 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
217
211
|
return this._listModels();
|
|
218
212
|
}
|
|
219
213
|
|
|
220
|
-
async _listModels(filter?: (m: OpenAI.Models.Model) => boolean) {
|
|
214
|
+
async _listModels(filter?: (m: OpenAI.Models.Model) => boolean): Promise<AIModel[]> {
|
|
221
215
|
let result = await this.service.models.list();
|
|
222
216
|
const models = filter ? result.data.filter(filter) : result.data;
|
|
223
217
|
return models.map((m) => ({
|
|
@@ -226,6 +220,8 @@ export class OpenAIDriver extends AbstractDriver<
|
|
|
226
220
|
provider: this.provider,
|
|
227
221
|
owner: m.owned_by,
|
|
228
222
|
type: m.object === "model" ? ModelType.Text : ModelType.Unknown,
|
|
223
|
+
can_stream: true,
|
|
224
|
+
is_multimodal: m.id.includes("gpt-4")
|
|
229
225
|
}));
|
|
230
226
|
}
|
|
231
227
|
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import { DriverOptions } from "@llumiverse/core";
|
|
4
|
+
import OpenAI from "openai";
|
|
5
|
+
import { BaseOpenAIDriver } from "./index.js";
|
|
6
|
+
|
|
7
|
+
export interface OpenAIDriverOptions extends DriverOptions {
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* The OpenAI api key
|
|
11
|
+
*/
|
|
12
|
+
apiKey?: string; //type with azure credntials
|
|
13
|
+
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
export class OpenAIDriver extends BaseOpenAIDriver {
|
|
20
|
+
|
|
21
|
+
service: OpenAI;
|
|
22
|
+
provider: "openai";
|
|
23
|
+
|
|
24
|
+
constructor(opts: OpenAIDriverOptions) {
|
|
25
|
+
super(opts);
|
|
26
|
+
this.service = new OpenAI({
|
|
27
|
+
apiKey: opts.apiKey
|
|
28
|
+
});
|
|
29
|
+
this.provider = "openai";
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
}
|