@llumiverse/drivers 0.18.0 → 0.19.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/cjs/bedrock/index.js +19 -22
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/huggingface_ie.js +1 -1
- package/lib/cjs/huggingface_ie.js.map +1 -1
- package/lib/cjs/mistral/index.js +1 -1
- package/lib/cjs/mistral/index.js.map +1 -1
- package/lib/cjs/openai/index.js +10 -14
- package/lib/cjs/openai/index.js.map +1 -1
- package/lib/cjs/togetherai/index.js +1 -1
- package/lib/cjs/togetherai/index.js.map +1 -1
- package/lib/cjs/vertexai/index.js +81 -18
- package/lib/cjs/vertexai/index.js.map +1 -1
- package/lib/cjs/vertexai/models/claude.js +46 -66
- package/lib/cjs/vertexai/models/claude.js.map +1 -1
- package/lib/cjs/vertexai/models/gemini.js +413 -80
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/vertexai/models/llama.js +182 -0
- package/lib/cjs/vertexai/models/llama.js.map +1 -0
- package/lib/cjs/vertexai/models.js +4 -0
- package/lib/cjs/vertexai/models.js.map +1 -1
- package/lib/cjs/watsonx/index.js +1 -1
- package/lib/cjs/watsonx/index.js.map +1 -1
- package/lib/cjs/xai/index.js +1 -1
- package/lib/cjs/xai/index.js.map +1 -1
- package/lib/esm/bedrock/index.js +19 -22
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/huggingface_ie.js +1 -1
- package/lib/esm/huggingface_ie.js.map +1 -1
- package/lib/esm/mistral/index.js +1 -1
- package/lib/esm/mistral/index.js.map +1 -1
- package/lib/esm/openai/index.js +12 -16
- package/lib/esm/openai/index.js.map +1 -1
- package/lib/esm/togetherai/index.js +1 -1
- package/lib/esm/togetherai/index.js.map +1 -1
- package/lib/esm/vertexai/index.js +81 -18
- package/lib/esm/vertexai/index.js.map +1 -1
- package/lib/esm/vertexai/models/claude.js +46 -66
- package/lib/esm/vertexai/models/claude.js.map +1 -1
- package/lib/esm/vertexai/models/gemini.js +409 -76
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/vertexai/models/llama.js +178 -0
- package/lib/esm/vertexai/models/llama.js.map +1 -0
- package/lib/esm/vertexai/models.js +4 -0
- package/lib/esm/vertexai/models.js.map +1 -1
- package/lib/esm/watsonx/index.js +1 -1
- package/lib/esm/watsonx/index.js.map +1 -1
- package/lib/esm/xai/index.js +1 -1
- package/lib/esm/xai/index.js.map +1 -1
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/groq/index.d.ts +1 -1
- package/lib/types/groq/index.d.ts.map +1 -1
- package/lib/types/huggingface_ie.d.ts +1 -1
- package/lib/types/huggingface_ie.d.ts.map +1 -1
- package/lib/types/mistral/index.d.ts +1 -1
- package/lib/types/mistral/index.d.ts.map +1 -1
- package/lib/types/openai/index.d.ts.map +1 -1
- package/lib/types/togetherai/index.d.ts +1 -1
- package/lib/types/togetherai/index.d.ts.map +1 -1
- package/lib/types/vertexai/index.d.ts +17 -7
- package/lib/types/vertexai/index.d.ts.map +1 -1
- package/lib/types/vertexai/models/claude.d.ts.map +1 -1
- package/lib/types/vertexai/models/gemini.d.ts +9 -6
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
- package/lib/types/vertexai/models/llama.d.ts +20 -0
- package/lib/types/vertexai/models/llama.d.ts.map +1 -0
- package/lib/types/vertexai/models.d.ts +6 -2
- package/lib/types/vertexai/models.d.ts.map +1 -1
- package/lib/types/watsonx/index.d.ts +1 -1
- package/lib/types/watsonx/index.d.ts.map +1 -1
- package/lib/types/xai/index.d.ts +1 -1
- package/lib/types/xai/index.d.ts.map +1 -1
- package/package.json +16 -16
- package/src/bedrock/index.ts +19 -22
- package/src/groq/index.ts +1 -1
- package/src/huggingface_ie.ts +1 -1
- package/src/mistral/index.ts +1 -1
- package/src/openai/index.ts +12 -16
- package/src/togetherai/index.ts +1 -1
- package/src/vertexai/index.ts +95 -22
- package/src/vertexai/models/claude.ts +54 -69
- package/src/vertexai/models/gemini.ts +473 -93
- package/src/vertexai/models/llama.ts +261 -0
- package/src/vertexai/models.ts +6 -2
- package/src/watsonx/index.ts +1 -1
- package/src/xai/index.ts +1 -1
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@llumiverse/drivers",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.19.0",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"description": "LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.",
|
|
6
6
|
"files": [
|
|
@@ -48,29 +48,29 @@
|
|
|
48
48
|
"vitest": "^3.0.9"
|
|
49
49
|
},
|
|
50
50
|
"dependencies": {
|
|
51
|
-
"@anthropic-ai/sdk": "^0.
|
|
52
|
-
"@anthropic-ai/vertex-sdk": "^0.
|
|
53
|
-
"@aws-sdk/client-bedrock": "^3.
|
|
54
|
-
"@aws-sdk/client-bedrock-runtime": "^3.
|
|
55
|
-
"@aws-sdk/client-s3": "^3.
|
|
56
|
-
"@aws-sdk/credential-providers": "^3.
|
|
57
|
-
"@aws-sdk/lib-storage": "^3.
|
|
58
|
-
"@aws-sdk/types": "^3.
|
|
59
|
-
"@azure/identity": "^4.
|
|
51
|
+
"@anthropic-ai/sdk": "^0.52.0",
|
|
52
|
+
"@anthropic-ai/vertex-sdk": "^0.11.4",
|
|
53
|
+
"@aws-sdk/client-bedrock": "^3.816.0",
|
|
54
|
+
"@aws-sdk/client-bedrock-runtime": "^3.816.0",
|
|
55
|
+
"@aws-sdk/client-s3": "^3.816.0",
|
|
56
|
+
"@aws-sdk/credential-providers": "^3.816.0",
|
|
57
|
+
"@aws-sdk/lib-storage": "^3.816.0",
|
|
58
|
+
"@aws-sdk/types": "^3.804.0",
|
|
59
|
+
"@azure/identity": "^4.10.0",
|
|
60
60
|
"@azure/openai": "2.0.0",
|
|
61
61
|
"@google-cloud/aiplatform": "^3.35.0",
|
|
62
|
-
"@google
|
|
62
|
+
"@google/genai": "^1.0.0",
|
|
63
63
|
"@huggingface/inference": "2.6.7",
|
|
64
|
-
"api-fetch-client": "^0.
|
|
64
|
+
"@vertesia/api-fetch-client": "^0.60.0",
|
|
65
65
|
"eventsource": "^4.0.0",
|
|
66
66
|
"google-auth-library": "^9.14.0",
|
|
67
|
-
"groq-sdk": "^0.
|
|
67
|
+
"groq-sdk": "^0.22.0",
|
|
68
68
|
"mnemonist": "^0.40.0",
|
|
69
69
|
"node-web-stream-adapters": "^0.2.0",
|
|
70
|
-
"openai": "^4.
|
|
70
|
+
"openai": "^4.103.0",
|
|
71
71
|
"replicate": "^1.0.1",
|
|
72
|
-
"@llumiverse/common": "0.
|
|
73
|
-
"@llumiverse/core": "0.
|
|
72
|
+
"@llumiverse/common": "0.19.0",
|
|
73
|
+
"@llumiverse/core": "0.19.0"
|
|
74
74
|
},
|
|
75
75
|
"ts_dual_module": {
|
|
76
76
|
"outDir": "lib"
|
package/src/bedrock/index.ts
CHANGED
|
@@ -295,7 +295,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
295
295
|
}
|
|
296
296
|
|
|
297
297
|
preparePayload(prompt: ConverseRequest, options: ExecutionOptions) {
|
|
298
|
-
const model_options = options.model_options as TextFallbackOptions;
|
|
298
|
+
const model_options: TextFallbackOptions = options.model_options as TextFallbackOptions ?? { _option_id: "text-fallback" };
|
|
299
299
|
|
|
300
300
|
let additionalField = {};
|
|
301
301
|
|
|
@@ -305,7 +305,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
305
305
|
}
|
|
306
306
|
//Titan models also exists but does not support any additional options
|
|
307
307
|
if (options.model.includes("nova")) {
|
|
308
|
-
additionalField = { inferenceConfig: { topK: model_options
|
|
308
|
+
additionalField = { inferenceConfig: { topK: model_options.top_k } };
|
|
309
309
|
}
|
|
310
310
|
} else if (options.model.includes("claude")) {
|
|
311
311
|
if (options.result_schema) {
|
|
@@ -313,18 +313,15 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
313
313
|
}
|
|
314
314
|
if (options.model.includes("claude-3-7")) {
|
|
315
315
|
const thinking_options = options.model_options as BedrockClaudeOptions;
|
|
316
|
-
const thinking = thinking_options
|
|
317
|
-
if (!model_options?.max_tokens) {
|
|
318
|
-
model_options.max_tokens = thinking ? 128000 : 8192;
|
|
319
|
-
}
|
|
316
|
+
const thinking = thinking_options.thinking_mode ?? false;
|
|
320
317
|
additionalField = {
|
|
321
318
|
...additionalField,
|
|
322
319
|
reasoning_config: {
|
|
323
320
|
type: thinking ? "enabled" : "disabled",
|
|
324
|
-
budget_tokens: thinking_options
|
|
321
|
+
budget_tokens: thinking_options.thinking_budget_tokens,
|
|
325
322
|
}
|
|
326
323
|
};
|
|
327
|
-
if (thinking && (thinking_options
|
|
324
|
+
if (thinking && (thinking_options.thinking_budget_tokens ?? 0) > 64000) {
|
|
328
325
|
additionalField = {
|
|
329
326
|
...additionalField,
|
|
330
327
|
anthorpic_beta: ["output-128k-2025-02-19"]
|
|
@@ -332,16 +329,16 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
332
329
|
}
|
|
333
330
|
}
|
|
334
331
|
//Needs max_tokens to be set
|
|
335
|
-
if (!model_options
|
|
332
|
+
if (!model_options.max_tokens) {
|
|
336
333
|
model_options.max_tokens = getMaxTokensLimit(options.model, model_options);
|
|
337
334
|
}
|
|
338
|
-
additionalField = { ...additionalField, top_k: model_options
|
|
335
|
+
additionalField = { ...additionalField, top_k: model_options.top_k };
|
|
339
336
|
} else if (options.model.includes("meta")) {
|
|
340
337
|
//LLaMA models support no additional options
|
|
341
338
|
} else if (options.model.includes("mistral")) {
|
|
342
339
|
//7B instruct and 8x7B instruct
|
|
343
340
|
if (options.model.includes("7b")) {
|
|
344
|
-
additionalField = { top_k: model_options
|
|
341
|
+
additionalField = { top_k: model_options.top_k };
|
|
345
342
|
//Does not support system messages
|
|
346
343
|
if (prompt.system && prompt.system?.length != 0) {
|
|
347
344
|
prompt.messages?.push(converseSystemToMessages(prompt.system));
|
|
@@ -360,8 +357,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
360
357
|
//Jurassic 2 models do.
|
|
361
358
|
if (options.model.includes("j2")) {
|
|
362
359
|
additionalField = {
|
|
363
|
-
presencePenalty: { scale: model_options
|
|
364
|
-
frequencyPenalty: { scale: model_options
|
|
360
|
+
presencePenalty: { scale: model_options.presence_penalty },
|
|
361
|
+
frequencyPenalty: { scale: model_options.frequency_penalty },
|
|
365
362
|
};
|
|
366
363
|
//Does not support system messages
|
|
367
364
|
if (prompt.system && prompt.system?.length != 0) {
|
|
@@ -375,13 +372,13 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
375
372
|
//Command R and R plus
|
|
376
373
|
if (options.model.includes("cohere.command-r")) {
|
|
377
374
|
additionalField = {
|
|
378
|
-
k: model_options
|
|
379
|
-
frequency_penalty: model_options
|
|
380
|
-
presence_penalty: model_options
|
|
375
|
+
k: model_options.top_k,
|
|
376
|
+
frequency_penalty: model_options.frequency_penalty,
|
|
377
|
+
presence_penalty: model_options.presence_penalty,
|
|
381
378
|
};
|
|
382
379
|
} else {
|
|
383
380
|
// Command non-R
|
|
384
|
-
additionalField = { k: model_options
|
|
381
|
+
additionalField = { k: model_options.top_k };
|
|
385
382
|
//Does not support system messages
|
|
386
383
|
if (prompt.system && prompt.system?.length != 0) {
|
|
387
384
|
prompt.messages?.push(converseSystemToMessages(prompt.system));
|
|
@@ -404,7 +401,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
404
401
|
//If last message is "```json", add corresponding ``` as a stop sequence.
|
|
405
402
|
if (prompt.messages && prompt.messages.length > 0) {
|
|
406
403
|
if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
|
|
407
|
-
let stopSeq = model_options
|
|
404
|
+
let stopSeq = model_options.stop_sequence;
|
|
408
405
|
if (!stopSeq) {
|
|
409
406
|
model_options.stop_sequence = ["```"];
|
|
410
407
|
} else if (!stopSeq.includes("```")) {
|
|
@@ -421,10 +418,10 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
421
418
|
system: prompt.system,
|
|
422
419
|
modelId: options.model,
|
|
423
420
|
inferenceConfig: {
|
|
424
|
-
maxTokens: model_options
|
|
425
|
-
temperature: model_options
|
|
426
|
-
topP: model_options
|
|
427
|
-
stopSequences: model_options
|
|
421
|
+
maxTokens: model_options.max_tokens,
|
|
422
|
+
temperature: model_options.temperature,
|
|
423
|
+
topP: model_options.top_p,
|
|
424
|
+
stopSequences: model_options.stop_sequence,
|
|
428
425
|
} satisfies InferenceConfiguration,
|
|
429
426
|
additionalModelRequestFields: {
|
|
430
427
|
...additionalField,
|
package/src/groq/index.ts
CHANGED
|
@@ -35,7 +35,7 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
|
|
|
35
35
|
// }
|
|
36
36
|
// }
|
|
37
37
|
|
|
38
|
-
getResponseFormat(_options: ExecutionOptions):
|
|
38
|
+
getResponseFormat(_options: ExecutionOptions): undefined {
|
|
39
39
|
//TODO: when forcing json_object type the streaming is not supported.
|
|
40
40
|
// either implement canStream as above or comment the code below:
|
|
41
41
|
// const responseFormatJson: Groq.Chat.Completions.CompletionCreateParams.ResponseFormat = {
|
package/src/huggingface_ie.ts
CHANGED
|
@@ -14,7 +14,7 @@ import {
|
|
|
14
14
|
TextFallbackOptions,
|
|
15
15
|
} from "@llumiverse/core";
|
|
16
16
|
import { transformAsyncIterator } from "@llumiverse/core/async";
|
|
17
|
-
import { FetchClient } from "api-fetch-client";
|
|
17
|
+
import { FetchClient } from "@vertesia/api-fetch-client";
|
|
18
18
|
|
|
19
19
|
export interface HuggingFaceIEDriverOptions extends DriverOptions {
|
|
20
20
|
apiKey: string;
|
package/src/mistral/index.ts
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { AIModel, AbstractDriver, Completion, CompletionChunk, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment, TextFallbackOptions } from "@llumiverse/core";
|
|
2
2
|
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
3
|
import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
|
|
4
|
-
import { FetchClient } from "api-fetch-client";
|
|
4
|
+
import { FetchClient } from "@vertesia/api-fetch-client";
|
|
5
5
|
import { ChatCompletionResponse, CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
|
|
6
6
|
|
|
7
7
|
//TODO retry on 429
|
package/src/openai/index.ts
CHANGED
|
@@ -19,9 +19,10 @@ import {
|
|
|
19
19
|
TrainingPromptOptions,
|
|
20
20
|
getModelCapabilities,
|
|
21
21
|
modelModalitiesToArray,
|
|
22
|
+
supportsToolUse,
|
|
22
23
|
} from "@llumiverse/core";
|
|
23
24
|
import { asyncMap } from "@llumiverse/core/async";
|
|
24
|
-
import { formatOpenAILikeMultimodalPrompt
|
|
25
|
+
import { formatOpenAILikeMultimodalPrompt } from "@llumiverse/core/formatters";
|
|
25
26
|
import OpenAI, { AzureOpenAI } from "openai";
|
|
26
27
|
import { Stream } from "openai/streaming";
|
|
27
28
|
|
|
@@ -87,7 +88,7 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
|
|
|
87
88
|
}
|
|
88
89
|
|
|
89
90
|
const toolDefs = getToolDefinitions(options.tools);
|
|
90
|
-
const useTools: boolean = toolDefs ?
|
|
91
|
+
const useTools: boolean = toolDefs ? supportsToolUse(options.model, "openai", true) : false;
|
|
91
92
|
|
|
92
93
|
const mapFn = (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
|
|
93
94
|
let result = undefined
|
|
@@ -167,7 +168,7 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
|
|
|
167
168
|
insert_image_detail(prompt, model_options?.image_detail ?? "auto");
|
|
168
169
|
|
|
169
170
|
const toolDefs = getToolDefinitions(options.tools);
|
|
170
|
-
const useTools: boolean = toolDefs ?
|
|
171
|
+
const useTools: boolean = toolDefs ? supportsToolUse(options.model, "openai") : false;
|
|
171
172
|
|
|
172
173
|
let conversation = updateConversation(options.conversation as OpenAIMessageBlock[], prompt);
|
|
173
174
|
|
|
@@ -289,7 +290,8 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
|
|
|
289
290
|
|
|
290
291
|
//Some of these use the completions API instead of the chat completions API.
|
|
291
292
|
//Others are for non-text input modalities. Therefore common to both.
|
|
292
|
-
const wordBlacklist = ["embed", "whisper", "transcribe", "audio", "moderation", "tts",
|
|
293
|
+
const wordBlacklist = ["embed", "whisper", "transcribe", "audio", "moderation", "tts",
|
|
294
|
+
"realtime", "dall-e", "babbage", "davinci", "codex", "o1-pro"];
|
|
293
295
|
|
|
294
296
|
if (this.provider === "azure_openai") {
|
|
295
297
|
//Azure OpenAI has additional information about the models
|
|
@@ -415,20 +417,14 @@ function convertRoles(messages: OpenAIMessageBlock[], model: string): OpenAIMess
|
|
|
415
417
|
return messages
|
|
416
418
|
}
|
|
417
419
|
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
if (!list_check && model.includes("gpt-4o") && !model.includes("gpt-4o-2024-05-13")) {
|
|
421
|
-
return true;
|
|
422
|
-
}
|
|
423
|
-
return list_check
|
|
424
|
-
}
|
|
425
|
-
|
|
420
|
+
//Structured output support is typically aligned with tool use support
|
|
421
|
+
//Not true for realtime models, which do not support structured output, but do support tool use.
|
|
426
422
|
function supportsSchema(model: string): boolean {
|
|
427
|
-
const
|
|
428
|
-
if (
|
|
429
|
-
return
|
|
423
|
+
const realtimeModel = model.includes("realtime");
|
|
424
|
+
if (realtimeModel) {
|
|
425
|
+
return false;
|
|
430
426
|
}
|
|
431
|
-
return
|
|
427
|
+
return supportsToolUse(model, "openai");
|
|
432
428
|
}
|
|
433
429
|
|
|
434
430
|
function getToolDefinitions(tools: ToolDefinition[] | undefined | null): OpenAI.ChatCompletionTool[] | undefined {
|
package/src/togetherai/index.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { AIModel, AbstractDriver, Completion, CompletionChunk, DriverOptions, EmbeddingsResult, ExecutionOptions, TextFallbackOptions } from "@llumiverse/core";
|
|
2
2
|
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
|
-
import { FetchClient } from "api-fetch-client";
|
|
3
|
+
import { FetchClient } from "@vertesia/api-fetch-client";
|
|
4
4
|
import { TextCompletion, TogetherModelInfo } from "./interfaces.js";
|
|
5
5
|
|
|
6
6
|
interface TogetherAIDriverOptions extends DriverOptions {
|
package/src/vertexai/index.ts
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
|
|
2
1
|
import {
|
|
3
2
|
AIModel,
|
|
4
3
|
AbstractDriver,
|
|
5
4
|
Completion,
|
|
6
|
-
|
|
5
|
+
CompletionChunk,
|
|
7
6
|
DriverOptions,
|
|
8
7
|
EmbeddingsResult,
|
|
9
8
|
ExecutionOptions,
|
|
@@ -14,7 +13,7 @@ import {
|
|
|
14
13
|
getModelCapabilities,
|
|
15
14
|
modelModalitiesToArray,
|
|
16
15
|
} from "@llumiverse/core";
|
|
17
|
-
import { FetchClient } from "api-fetch-client";
|
|
16
|
+
import { FetchClient } from "@vertesia/api-fetch-client";
|
|
18
17
|
import { GoogleAuth, GoogleAuthOptions } from "google-auth-library";
|
|
19
18
|
import { JSONClient } from "google-auth-library/build/src/auth/googleauth.js";
|
|
20
19
|
import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
|
|
@@ -24,6 +23,7 @@ import { getEmbeddingsForImages } from "./embeddings/embeddings-image.js";
|
|
|
24
23
|
import { v1beta1 } from "@google-cloud/aiplatform";
|
|
25
24
|
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk";
|
|
26
25
|
import { ImagenModelDefinition, ImagenPrompt } from "./models/imagen.js";
|
|
26
|
+
import { GoogleGenAI, Content, Tool } from "@google/genai";
|
|
27
27
|
|
|
28
28
|
export interface VertexAIDriverOptions extends DriverOptions {
|
|
29
29
|
project: string;
|
|
@@ -31,8 +31,14 @@ export interface VertexAIDriverOptions extends DriverOptions {
|
|
|
31
31
|
googleAuthOptions?: GoogleAuthOptions;
|
|
32
32
|
}
|
|
33
33
|
|
|
34
|
+
export interface GenerateContentPrompt {
|
|
35
|
+
contents: Content[];
|
|
36
|
+
system?: string;
|
|
37
|
+
tools?: Tool[];
|
|
38
|
+
}
|
|
39
|
+
|
|
34
40
|
//General Prompt type for VertexAI
|
|
35
|
-
export type VertexAIPrompt =
|
|
41
|
+
export type VertexAIPrompt = ImagenPrompt | GenerateContentPrompt;
|
|
36
42
|
|
|
37
43
|
export function trimModelName(model: string) {
|
|
38
44
|
const i = model.lastIndexOf("@");
|
|
@@ -46,8 +52,9 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
46
52
|
aiplatform: v1beta1.ModelServiceClient | undefined;
|
|
47
53
|
anthropicClient: AnthropicVertex | undefined;
|
|
48
54
|
fetchClient: FetchClient | undefined;
|
|
55
|
+
googleGenAI: GoogleGenAI | undefined;
|
|
56
|
+
llamaClient: FetchClient & { region?: string } | undefined;
|
|
49
57
|
modelGarden: v1beta1.ModelGardenServiceClient | undefined;
|
|
50
|
-
vertexai: VertexAI | undefined;
|
|
51
58
|
|
|
52
59
|
authClient: JSONClient | GoogleAuth<JSONClient>;
|
|
53
60
|
|
|
@@ -57,12 +64,28 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
57
64
|
this.aiplatform = undefined;
|
|
58
65
|
this.anthropicClient = undefined;
|
|
59
66
|
this.fetchClient = undefined
|
|
67
|
+
this.googleGenAI = undefined;
|
|
60
68
|
this.modelGarden = undefined;
|
|
61
|
-
this.
|
|
69
|
+
this.llamaClient = undefined;
|
|
62
70
|
|
|
63
71
|
this.authClient = options.googleAuthOptions?.authClient ?? new GoogleAuth(options.googleAuthOptions);
|
|
64
72
|
}
|
|
65
73
|
|
|
74
|
+
public getGoogleGenAIClient(): GoogleGenAI {
|
|
75
|
+
//Lazy initialisation
|
|
76
|
+
if (!this.googleGenAI) {
|
|
77
|
+
this.googleGenAI = new GoogleGenAI({
|
|
78
|
+
project: this.options.project,
|
|
79
|
+
location: this.options.region,
|
|
80
|
+
vertexai: true,
|
|
81
|
+
googleAuthOptions: {
|
|
82
|
+
authClient: this.authClient as JSONClient,
|
|
83
|
+
}
|
|
84
|
+
});
|
|
85
|
+
}
|
|
86
|
+
return this.googleGenAI;
|
|
87
|
+
}
|
|
88
|
+
|
|
66
89
|
public getFetchClient(): FetchClient {
|
|
67
90
|
//Lazy initialisation
|
|
68
91
|
if (!this.fetchClient) {
|
|
@@ -78,6 +101,24 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
78
101
|
return this.fetchClient;
|
|
79
102
|
}
|
|
80
103
|
|
|
104
|
+
public getLLamaClient(region: string = "us-central1"): FetchClient {
|
|
105
|
+
//Lazy initialisation
|
|
106
|
+
if (!this.llamaClient || this.llamaClient["region"] !== region) {
|
|
107
|
+
this.llamaClient = createFetchClient({
|
|
108
|
+
region: region,
|
|
109
|
+
project: this.options.project,
|
|
110
|
+
apiVersion: "v1beta1",
|
|
111
|
+
}).withAuthCallback(async () => {
|
|
112
|
+
const accessTokenResponse = await this.authClient.getAccessToken();
|
|
113
|
+
const token = typeof accessTokenResponse === 'string' ? accessTokenResponse : accessTokenResponse?.token;
|
|
114
|
+
return `Bearer ${token}`;
|
|
115
|
+
});
|
|
116
|
+
// Store the region for potential client reuse
|
|
117
|
+
this.llamaClient["region"] = region;
|
|
118
|
+
}
|
|
119
|
+
return this.llamaClient;
|
|
120
|
+
}
|
|
121
|
+
|
|
81
122
|
public getAnthropicClient(): AnthropicVertex {
|
|
82
123
|
//Lazy initialisation
|
|
83
124
|
if (!this.anthropicClient) {
|
|
@@ -89,18 +130,6 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
89
130
|
return this.anthropicClient;
|
|
90
131
|
}
|
|
91
132
|
|
|
92
|
-
public getVertexAIClient(): VertexAI {
|
|
93
|
-
//Lazy initialisation
|
|
94
|
-
if (!this.vertexai) {
|
|
95
|
-
this.vertexai = new VertexAI({
|
|
96
|
-
project: this.options.project,
|
|
97
|
-
location: this.options.region,
|
|
98
|
-
googleAuthOptions: this.options.googleAuthOptions,
|
|
99
|
-
});
|
|
100
|
-
}
|
|
101
|
-
return this.vertexai;
|
|
102
|
-
}
|
|
103
|
-
|
|
104
133
|
public getAIPlatformClient(): v1beta1.ModelServiceClient {
|
|
105
134
|
//Lazy initialisation
|
|
106
135
|
if (!this.aiplatform) {
|
|
@@ -125,6 +154,18 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
125
154
|
return this.modelGarden;
|
|
126
155
|
}
|
|
127
156
|
|
|
157
|
+
validateResult(result: Completion, options: ExecutionOptions) {
|
|
158
|
+
// Optionally preprocess the result before validation
|
|
159
|
+
const modelDef = getModelDefinition(options.model);
|
|
160
|
+
if (typeof modelDef.preValidationProcessing === "function") {
|
|
161
|
+
const processed = modelDef.preValidationProcessing(result, options);
|
|
162
|
+
result = processed.result;
|
|
163
|
+
options = processed.options;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
super.validateResult(result, options);
|
|
167
|
+
}
|
|
168
|
+
|
|
128
169
|
protected canStream(options: ExecutionOptions): Promise<boolean> {
|
|
129
170
|
if (options.output_modality == Modalities.image) {
|
|
130
171
|
return Promise.resolve(false);
|
|
@@ -145,7 +186,7 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
145
186
|
async requestTextCompletionStream(
|
|
146
187
|
prompt: VertexAIPrompt,
|
|
147
188
|
options: ExecutionOptions,
|
|
148
|
-
): Promise<AsyncIterable<
|
|
189
|
+
): Promise<AsyncIterable<CompletionChunk>> {
|
|
149
190
|
return getModelDefinition(options.model).requestTextCompletionStream(this, prompt, options);
|
|
150
191
|
}
|
|
151
192
|
|
|
@@ -178,14 +219,31 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
178
219
|
);
|
|
179
220
|
|
|
180
221
|
//Model Garden Publisher models - Pretrained models
|
|
181
|
-
const publishers = ["google", "anthropic"];
|
|
182
|
-
|
|
222
|
+
const publishers = ["google", "anthropic", "meta"];
|
|
223
|
+
// Meta "maas" models are LLama Models-As-A-Service. Non-maas models are not pre-deployed.
|
|
224
|
+
const supportedModels = { google: ["gemini", "imagen"], anthropic: ["claude"], meta: ["maas"] };
|
|
225
|
+
// Additional models not in the listings, but we want to include
|
|
226
|
+
// TODO: Remove once the models are available in the listing API, or no longer needed
|
|
227
|
+
const additionalModels = {
|
|
228
|
+
google: ["imagen-3.0-fast-generate-001"],
|
|
229
|
+
anthropic: [],
|
|
230
|
+
meta: [
|
|
231
|
+
"llama-4-maverick-17b-128e-instruct-maas",
|
|
232
|
+
"llama-4-scout-17b-16e-instruct-maas",
|
|
233
|
+
"llama-3.3-70b-instruct-maas",
|
|
234
|
+
"llama-3.2-90b-vision-instruct-maas",
|
|
235
|
+
"llama-3.1-405b-instruct-maas",
|
|
236
|
+
"llama-3.1-70b-instruct-maas",
|
|
237
|
+
"llama-3.1-8b-instruct-maas",
|
|
238
|
+
],
|
|
239
|
+
}
|
|
183
240
|
|
|
184
241
|
//Used to exclude retired models that are still in the listing API but not available for use.
|
|
185
242
|
//Or models we do not support yet
|
|
186
243
|
const unsupportedModelsByPublisher = {
|
|
187
244
|
google: ["gemini-pro", "gemini-ultra"],
|
|
188
245
|
anthropic: [],
|
|
246
|
+
meta: [],
|
|
189
247
|
};
|
|
190
248
|
|
|
191
249
|
for (const publisher of publishers) {
|
|
@@ -228,13 +286,28 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
228
286
|
tool_support: modelCapability.tool_support,
|
|
229
287
|
} satisfies AIModel<string>;
|
|
230
288
|
}));
|
|
289
|
+
|
|
290
|
+
// Add additional models that are not in the listing
|
|
291
|
+
for (const additionalModel of additionalModels[publisher as keyof typeof additionalModels]) {
|
|
292
|
+
const publisherModelName = `publishers/${publisher}/models/${additionalModel}`;
|
|
293
|
+
const modelCapability = getModelCapabilities(additionalModel, "vertexai");
|
|
294
|
+
models.push({
|
|
295
|
+
id: publisherModelName,
|
|
296
|
+
name: additionalModel,
|
|
297
|
+
provider: 'vertexai',
|
|
298
|
+
owner: publisher,
|
|
299
|
+
input_modalities: modelModalitiesToArray(modelCapability.input),
|
|
300
|
+
output_modalities: modelModalitiesToArray(modelCapability.output),
|
|
301
|
+
tool_support: modelCapability.tool_support,
|
|
302
|
+
} satisfies AIModel<string>);
|
|
303
|
+
}
|
|
231
304
|
}
|
|
232
305
|
|
|
233
306
|
//Remove duplicates
|
|
234
307
|
const uniqueModels = Array.from(new Set(models.map(a => a.id)))
|
|
235
308
|
.map(id => {
|
|
236
309
|
return models.find(a => a.id === id) ?? {} as AIModel<string>;
|
|
237
|
-
});
|
|
310
|
+
}).sort((a, b) => a.id.localeCompare(b.id));
|
|
238
311
|
|
|
239
312
|
return uniqueModels;
|
|
240
313
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import * as AnthropicAPI from '@anthropic-ai/sdk';
|
|
2
|
-
import { ContentBlock, ContentBlockParam, Message, TextBlockParam } from "@anthropic-ai/sdk/resources/index.js";
|
|
2
|
+
import { ContentBlock, ContentBlockParam, ImageBlockParam, Message, TextBlockParam } from "@anthropic-ai/sdk/resources/index.js";
|
|
3
3
|
import {
|
|
4
4
|
AIModel, Completion, CompletionChunkObject, ExecutionOptions, JSONObject, ModelType,
|
|
5
5
|
PromptOptions, PromptRole, PromptSegment, readStreamAsBase64, ToolUse, VertexAIClaudeOptions
|
|
@@ -15,23 +15,6 @@ interface ClaudePrompt {
|
|
|
15
15
|
system: TextBlockParam[];
|
|
16
16
|
}
|
|
17
17
|
|
|
18
|
-
function getFullModelName(model: string): string {
|
|
19
|
-
if (model.includes("claude-3-5-sonnet-v2")) {
|
|
20
|
-
return "claude-3-5-sonnet-v2@20241022"
|
|
21
|
-
} else if (model.includes("claude-3-5-sonnet")) {
|
|
22
|
-
return "claude-3-5-sonnet@20240620"
|
|
23
|
-
} else if (model.includes("claude-3-5-haiku")) {
|
|
24
|
-
return "claude-3-5-haiku@20241022"
|
|
25
|
-
} else if (model.includes("claude-3-opus")) {
|
|
26
|
-
return "claude-3-opus@20240229"
|
|
27
|
-
} else if (model.includes("claude-3-sonnet")) {
|
|
28
|
-
return "claude-3-sonnet@20240229"
|
|
29
|
-
} else if (model.includes("claude-3-haiku")) {
|
|
30
|
-
return "claude-3-haiku@20240307"
|
|
31
|
-
}
|
|
32
|
-
return model;
|
|
33
|
-
}
|
|
34
|
-
|
|
35
18
|
function claudeFinishReason(reason: string | undefined) {
|
|
36
19
|
if (!reason) return undefined;
|
|
37
20
|
switch (reason) {
|
|
@@ -63,6 +46,36 @@ function maxToken(max_tokens: number | undefined, model: string): number {
|
|
|
63
46
|
}
|
|
64
47
|
}
|
|
65
48
|
|
|
49
|
+
async function collectImageBlocks(segment: PromptSegment, contentBlocks: ContentBlockParam[]) {
|
|
50
|
+
for (const file of segment.files || []) {
|
|
51
|
+
if (file.mime_type?.startsWith("image/")) {
|
|
52
|
+
const allowedTypes = ["image/png", "image/jpeg", "image/gif", "image/webp"];
|
|
53
|
+
if (!allowedTypes.includes(file.mime_type)) {
|
|
54
|
+
throw new Error(`Unsupported image type: ${file.mime_type}`);
|
|
55
|
+
}
|
|
56
|
+
const mimeType = String(file.mime_type) as "image/png" | "image/jpeg" | "image/gif" | "image/webp";
|
|
57
|
+
|
|
58
|
+
contentBlocks.push({
|
|
59
|
+
type: 'image',
|
|
60
|
+
source: {
|
|
61
|
+
type: 'base64',
|
|
62
|
+
data: await readStreamAsBase64(await file.getStream()),
|
|
63
|
+
media_type: mimeType
|
|
64
|
+
}
|
|
65
|
+
});
|
|
66
|
+
} else if (file.mime_type?.startsWith("text/")) {
|
|
67
|
+
contentBlocks.push({
|
|
68
|
+
source: {
|
|
69
|
+
type: 'text',
|
|
70
|
+
data: await readStreamAsBase64(await file.getStream()),
|
|
71
|
+
media_type: 'text/plain'
|
|
72
|
+
},
|
|
73
|
+
type: 'document'
|
|
74
|
+
});
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
66
79
|
export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
|
|
67
80
|
|
|
68
81
|
model: AIModel
|
|
@@ -111,60 +124,38 @@ export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
|
|
|
111
124
|
if (!segment.tool_use_id) {
|
|
112
125
|
throw new Error("Tool prompt segment must have a tool_use_id");
|
|
113
126
|
}
|
|
127
|
+
|
|
128
|
+
const imageBlocks: ImageBlockParam[] = [];
|
|
129
|
+
await collectImageBlocks(segment, imageBlocks);
|
|
130
|
+
|
|
114
131
|
messages.push({
|
|
115
132
|
role: 'user',
|
|
116
|
-
content: [
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
133
|
+
content: [{
|
|
134
|
+
type: 'tool_result',
|
|
135
|
+
tool_use_id: segment.tool_use_id,
|
|
136
|
+
content: [{
|
|
137
|
+
type: 'text',
|
|
138
|
+
text: segment.content || ''
|
|
139
|
+
}, ...imageBlocks]
|
|
140
|
+
}]
|
|
123
141
|
});
|
|
142
|
+
|
|
124
143
|
} else {
|
|
125
144
|
const contentBlocks: ContentBlockParam[] = [];
|
|
126
|
-
|
|
127
|
-
if (file.mime_type?.startsWith("image/")) {
|
|
128
|
-
const allowedTypes = ["image/png", "image/jpeg", "image/gif", "image/webp"];
|
|
129
|
-
if (!allowedTypes.includes(file.mime_type)) {
|
|
130
|
-
throw new Error(`Unsupported image type: ${file.mime_type}`);
|
|
131
|
-
}
|
|
132
|
-
|
|
133
|
-
contentBlocks.push({
|
|
134
|
-
type: 'image',
|
|
135
|
-
source: {
|
|
136
|
-
type: 'base64',
|
|
137
|
-
data: await readStreamAsBase64(await file.getStream()),
|
|
138
|
-
media_type: file.mime_type as any
|
|
139
|
-
}
|
|
140
|
-
});
|
|
141
|
-
} else if (file.mime_type?.startsWith("text/")) {
|
|
142
|
-
contentBlocks.push({
|
|
143
|
-
source: {
|
|
144
|
-
type: 'text',
|
|
145
|
-
data: await readStreamAsBase64(await file.getStream()),
|
|
146
|
-
media_type: 'text/plain'
|
|
147
|
-
},
|
|
148
|
-
type : 'document'
|
|
149
|
-
});
|
|
150
|
-
}
|
|
151
|
-
}
|
|
152
|
-
|
|
145
|
+
collectImageBlocks(segment, contentBlocks);
|
|
153
146
|
if (segment.content) {
|
|
154
147
|
contentBlocks.push({
|
|
155
148
|
type: 'text',
|
|
156
149
|
text: segment.content
|
|
157
150
|
});
|
|
158
151
|
}
|
|
159
|
-
|
|
160
|
-
|
|
161
152
|
messages.push({
|
|
162
153
|
role: segment.role === PromptRole.assistant ? 'assistant' : 'user',
|
|
163
154
|
content: contentBlocks
|
|
164
155
|
});
|
|
165
156
|
}
|
|
166
157
|
}
|
|
167
|
-
|
|
158
|
+
|
|
168
159
|
const system = systemSegments.concat(safetySegments);
|
|
169
160
|
|
|
170
161
|
return {
|
|
@@ -253,23 +244,17 @@ export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
|
|
|
253
244
|
}
|
|
254
245
|
});
|
|
255
246
|
|
|
256
|
-
//Streaming does not give information on the input tokens,
|
|
257
|
-
//So we use a separate call to get the input tokens.
|
|
258
|
-
//Non-critical and model name sensitive so we put it in a try catch block
|
|
259
|
-
let count_tokens = { input_tokens: 0 };
|
|
260
|
-
try {
|
|
261
|
-
count_tokens = await client.messages.countTokens({
|
|
262
|
-
...prompt, // messages, system
|
|
263
|
-
model: getFullModelName(modelName),
|
|
264
|
-
});
|
|
265
|
-
} catch (e) {
|
|
266
|
-
driver.logger.warn("Failed to get token count for model " + modelName);
|
|
267
|
-
}
|
|
268
|
-
|
|
269
247
|
const stream = asyncMap(response_stream, async (item: any) => {
|
|
248
|
+
if (item.type == "message_start") {
|
|
249
|
+
return {
|
|
250
|
+
result: '',
|
|
251
|
+
token_usage: { prompt: item?.message?.usage?.input_tokens, result: item?.message?.usage?.output_tokens },
|
|
252
|
+
finish_reason: undefined,
|
|
253
|
+
}
|
|
254
|
+
}
|
|
270
255
|
return {
|
|
271
256
|
result: item?.delta?.text ?? '',
|
|
272
|
-
token_usage: {
|
|
257
|
+
token_usage: { result: item?.usage?.output_tokens },
|
|
273
258
|
finish_reason: claudeFinishReason(item?.delta?.stop_reason ?? ''),
|
|
274
259
|
}
|
|
275
260
|
});
|