@llumiverse/drivers 0.18.0 → 0.19.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/lib/cjs/bedrock/index.js +19 -22
  2. package/lib/cjs/bedrock/index.js.map +1 -1
  3. package/lib/cjs/huggingface_ie.js +1 -1
  4. package/lib/cjs/huggingface_ie.js.map +1 -1
  5. package/lib/cjs/mistral/index.js +1 -1
  6. package/lib/cjs/mistral/index.js.map +1 -1
  7. package/lib/cjs/openai/index.js +10 -14
  8. package/lib/cjs/openai/index.js.map +1 -1
  9. package/lib/cjs/togetherai/index.js +1 -1
  10. package/lib/cjs/togetherai/index.js.map +1 -1
  11. package/lib/cjs/vertexai/index.js +81 -18
  12. package/lib/cjs/vertexai/index.js.map +1 -1
  13. package/lib/cjs/vertexai/models/claude.js +46 -66
  14. package/lib/cjs/vertexai/models/claude.js.map +1 -1
  15. package/lib/cjs/vertexai/models/gemini.js +413 -80
  16. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  17. package/lib/cjs/vertexai/models/llama.js +182 -0
  18. package/lib/cjs/vertexai/models/llama.js.map +1 -0
  19. package/lib/cjs/vertexai/models.js +4 -0
  20. package/lib/cjs/vertexai/models.js.map +1 -1
  21. package/lib/cjs/watsonx/index.js +1 -1
  22. package/lib/cjs/watsonx/index.js.map +1 -1
  23. package/lib/cjs/xai/index.js +1 -1
  24. package/lib/cjs/xai/index.js.map +1 -1
  25. package/lib/esm/bedrock/index.js +19 -22
  26. package/lib/esm/bedrock/index.js.map +1 -1
  27. package/lib/esm/huggingface_ie.js +1 -1
  28. package/lib/esm/huggingface_ie.js.map +1 -1
  29. package/lib/esm/mistral/index.js +1 -1
  30. package/lib/esm/mistral/index.js.map +1 -1
  31. package/lib/esm/openai/index.js +12 -16
  32. package/lib/esm/openai/index.js.map +1 -1
  33. package/lib/esm/togetherai/index.js +1 -1
  34. package/lib/esm/togetherai/index.js.map +1 -1
  35. package/lib/esm/vertexai/index.js +81 -18
  36. package/lib/esm/vertexai/index.js.map +1 -1
  37. package/lib/esm/vertexai/models/claude.js +46 -66
  38. package/lib/esm/vertexai/models/claude.js.map +1 -1
  39. package/lib/esm/vertexai/models/gemini.js +409 -76
  40. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  41. package/lib/esm/vertexai/models/llama.js +178 -0
  42. package/lib/esm/vertexai/models/llama.js.map +1 -0
  43. package/lib/esm/vertexai/models.js +4 -0
  44. package/lib/esm/vertexai/models.js.map +1 -1
  45. package/lib/esm/watsonx/index.js +1 -1
  46. package/lib/esm/watsonx/index.js.map +1 -1
  47. package/lib/esm/xai/index.js +1 -1
  48. package/lib/esm/xai/index.js.map +1 -1
  49. package/lib/types/bedrock/index.d.ts.map +1 -1
  50. package/lib/types/groq/index.d.ts +1 -1
  51. package/lib/types/groq/index.d.ts.map +1 -1
  52. package/lib/types/huggingface_ie.d.ts +1 -1
  53. package/lib/types/huggingface_ie.d.ts.map +1 -1
  54. package/lib/types/mistral/index.d.ts +1 -1
  55. package/lib/types/mistral/index.d.ts.map +1 -1
  56. package/lib/types/openai/index.d.ts.map +1 -1
  57. package/lib/types/togetherai/index.d.ts +1 -1
  58. package/lib/types/togetherai/index.d.ts.map +1 -1
  59. package/lib/types/vertexai/index.d.ts +17 -7
  60. package/lib/types/vertexai/index.d.ts.map +1 -1
  61. package/lib/types/vertexai/models/claude.d.ts.map +1 -1
  62. package/lib/types/vertexai/models/gemini.d.ts +9 -6
  63. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  64. package/lib/types/vertexai/models/llama.d.ts +20 -0
  65. package/lib/types/vertexai/models/llama.d.ts.map +1 -0
  66. package/lib/types/vertexai/models.d.ts +6 -2
  67. package/lib/types/vertexai/models.d.ts.map +1 -1
  68. package/lib/types/watsonx/index.d.ts +1 -1
  69. package/lib/types/watsonx/index.d.ts.map +1 -1
  70. package/lib/types/xai/index.d.ts +1 -1
  71. package/lib/types/xai/index.d.ts.map +1 -1
  72. package/package.json +16 -16
  73. package/src/bedrock/index.ts +19 -22
  74. package/src/groq/index.ts +1 -1
  75. package/src/huggingface_ie.ts +1 -1
  76. package/src/mistral/index.ts +1 -1
  77. package/src/openai/index.ts +12 -16
  78. package/src/togetherai/index.ts +1 -1
  79. package/src/vertexai/index.ts +95 -22
  80. package/src/vertexai/models/claude.ts +54 -69
  81. package/src/vertexai/models/gemini.ts +473 -93
  82. package/src/vertexai/models/llama.ts +261 -0
  83. package/src/vertexai/models.ts +6 -2
  84. package/src/watsonx/index.ts +1 -1
  85. package/src/xai/index.ts +1 -1
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@llumiverse/drivers",
3
- "version": "0.18.0",
3
+ "version": "0.19.0",
4
4
  "type": "module",
5
5
  "description": "LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.",
6
6
  "files": [
@@ -48,29 +48,29 @@
48
48
  "vitest": "^3.0.9"
49
49
  },
50
50
  "dependencies": {
51
- "@anthropic-ai/sdk": "^0.39.0",
52
- "@anthropic-ai/vertex-sdk": "^0.7.0",
53
- "@aws-sdk/client-bedrock": "^3.787.0",
54
- "@aws-sdk/client-bedrock-runtime": "^3.787.0",
55
- "@aws-sdk/client-s3": "^3.787.0",
56
- "@aws-sdk/credential-providers": "^3.787.0",
57
- "@aws-sdk/lib-storage": "^3.787.0",
58
- "@aws-sdk/types": "^3.775.0",
59
- "@azure/identity": "^4.9.1",
51
+ "@anthropic-ai/sdk": "^0.52.0",
52
+ "@anthropic-ai/vertex-sdk": "^0.11.4",
53
+ "@aws-sdk/client-bedrock": "^3.816.0",
54
+ "@aws-sdk/client-bedrock-runtime": "^3.816.0",
55
+ "@aws-sdk/client-s3": "^3.816.0",
56
+ "@aws-sdk/credential-providers": "^3.816.0",
57
+ "@aws-sdk/lib-storage": "^3.816.0",
58
+ "@aws-sdk/types": "^3.804.0",
59
+ "@azure/identity": "^4.10.0",
60
60
  "@azure/openai": "2.0.0",
61
61
  "@google-cloud/aiplatform": "^3.35.0",
62
- "@google-cloud/vertexai": "^1.10.0",
62
+ "@google/genai": "^1.0.0",
63
63
  "@huggingface/inference": "2.6.7",
64
- "api-fetch-client": "^0.13.0",
64
+ "@vertesia/api-fetch-client": "^0.60.0",
65
65
  "eventsource": "^4.0.0",
66
66
  "google-auth-library": "^9.14.0",
67
- "groq-sdk": "^0.19.0",
67
+ "groq-sdk": "^0.22.0",
68
68
  "mnemonist": "^0.40.0",
69
69
  "node-web-stream-adapters": "^0.2.0",
70
- "openai": "^4.98.0",
70
+ "openai": "^4.103.0",
71
71
  "replicate": "^1.0.1",
72
- "@llumiverse/common": "0.18.0",
73
- "@llumiverse/core": "0.18.0"
72
+ "@llumiverse/common": "0.19.0",
73
+ "@llumiverse/core": "0.19.0"
74
74
  },
75
75
  "ts_dual_module": {
76
76
  "outDir": "lib"
@@ -295,7 +295,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
295
295
  }
296
296
 
297
297
  preparePayload(prompt: ConverseRequest, options: ExecutionOptions) {
298
- const model_options = options.model_options as TextFallbackOptions;
298
+ const model_options: TextFallbackOptions = options.model_options as TextFallbackOptions ?? { _option_id: "text-fallback" };
299
299
 
300
300
  let additionalField = {};
301
301
 
@@ -305,7 +305,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
305
305
  }
306
306
  //Titan models also exists but does not support any additional options
307
307
  if (options.model.includes("nova")) {
308
- additionalField = { inferenceConfig: { topK: model_options?.top_k } };
308
+ additionalField = { inferenceConfig: { topK: model_options.top_k } };
309
309
  }
310
310
  } else if (options.model.includes("claude")) {
311
311
  if (options.result_schema) {
@@ -313,18 +313,15 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
313
313
  }
314
314
  if (options.model.includes("claude-3-7")) {
315
315
  const thinking_options = options.model_options as BedrockClaudeOptions;
316
- const thinking = thinking_options?.thinking_mode ?? false;
317
- if (!model_options?.max_tokens) {
318
- model_options.max_tokens = thinking ? 128000 : 8192;
319
- }
316
+ const thinking = thinking_options.thinking_mode ?? false;
320
317
  additionalField = {
321
318
  ...additionalField,
322
319
  reasoning_config: {
323
320
  type: thinking ? "enabled" : "disabled",
324
- budget_tokens: thinking_options?.thinking_budget_tokens,
321
+ budget_tokens: thinking_options.thinking_budget_tokens,
325
322
  }
326
323
  };
327
- if (thinking && (thinking_options?.thinking_budget_tokens ?? 0) > 64000) {
324
+ if (thinking && (thinking_options.thinking_budget_tokens ?? 0) > 64000) {
328
325
  additionalField = {
329
326
  ...additionalField,
330
327
  anthorpic_beta: ["output-128k-2025-02-19"]
@@ -332,16 +329,16 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
332
329
  }
333
330
  }
334
331
  //Needs max_tokens to be set
335
- if (!model_options?.max_tokens) {
332
+ if (!model_options.max_tokens) {
336
333
  model_options.max_tokens = getMaxTokensLimit(options.model, model_options);
337
334
  }
338
- additionalField = { ...additionalField, top_k: model_options?.top_k };
335
+ additionalField = { ...additionalField, top_k: model_options.top_k };
339
336
  } else if (options.model.includes("meta")) {
340
337
  //LLaMA models support no additional options
341
338
  } else if (options.model.includes("mistral")) {
342
339
  //7B instruct and 8x7B instruct
343
340
  if (options.model.includes("7b")) {
344
- additionalField = { top_k: model_options?.top_k };
341
+ additionalField = { top_k: model_options.top_k };
345
342
  //Does not support system messages
346
343
  if (prompt.system && prompt.system?.length != 0) {
347
344
  prompt.messages?.push(converseSystemToMessages(prompt.system));
@@ -360,8 +357,8 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
360
357
  //Jurassic 2 models do.
361
358
  if (options.model.includes("j2")) {
362
359
  additionalField = {
363
- presencePenalty: { scale: model_options?.presence_penalty },
364
- frequencyPenalty: { scale: model_options?.frequency_penalty },
360
+ presencePenalty: { scale: model_options.presence_penalty },
361
+ frequencyPenalty: { scale: model_options.frequency_penalty },
365
362
  };
366
363
  //Does not support system messages
367
364
  if (prompt.system && prompt.system?.length != 0) {
@@ -375,13 +372,13 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
375
372
  //Command R and R plus
376
373
  if (options.model.includes("cohere.command-r")) {
377
374
  additionalField = {
378
- k: model_options?.top_k,
379
- frequency_penalty: model_options?.frequency_penalty,
380
- presence_penalty: model_options?.presence_penalty,
375
+ k: model_options.top_k,
376
+ frequency_penalty: model_options.frequency_penalty,
377
+ presence_penalty: model_options.presence_penalty,
381
378
  };
382
379
  } else {
383
380
  // Command non-R
384
- additionalField = { k: model_options?.top_k };
381
+ additionalField = { k: model_options.top_k };
385
382
  //Does not support system messages
386
383
  if (prompt.system && prompt.system?.length != 0) {
387
384
  prompt.messages?.push(converseSystemToMessages(prompt.system));
@@ -404,7 +401,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
404
401
  //If last message is "```json", add corresponding ``` as a stop sequence.
405
402
  if (prompt.messages && prompt.messages.length > 0) {
406
403
  if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
407
- let stopSeq = model_options?.stop_sequence;
404
+ let stopSeq = model_options.stop_sequence;
408
405
  if (!stopSeq) {
409
406
  model_options.stop_sequence = ["```"];
410
407
  } else if (!stopSeq.includes("```")) {
@@ -421,10 +418,10 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
421
418
  system: prompt.system,
422
419
  modelId: options.model,
423
420
  inferenceConfig: {
424
- maxTokens: model_options?.max_tokens,
425
- temperature: model_options?.temperature,
426
- topP: model_options?.top_p,
427
- stopSequences: model_options?.stop_sequence,
421
+ maxTokens: model_options.max_tokens,
422
+ temperature: model_options.temperature,
423
+ topP: model_options.top_p,
424
+ stopSequences: model_options.stop_sequence,
428
425
  } satisfies InferenceConfiguration,
429
426
  additionalModelRequestFields: {
430
427
  ...additionalField,
package/src/groq/index.ts CHANGED
@@ -35,7 +35,7 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, OpenAITextMess
35
35
  // }
36
36
  // }
37
37
 
38
- getResponseFormat(_options: ExecutionOptions): Groq.Chat.Completions.CompletionCreateParams.ResponseFormat | undefined {
38
+ getResponseFormat(_options: ExecutionOptions): undefined {
39
39
  //TODO: when forcing json_object type the streaming is not supported.
40
40
  // either implement canStream as above or comment the code below:
41
41
  // const responseFormatJson: Groq.Chat.Completions.CompletionCreateParams.ResponseFormat = {
@@ -14,7 +14,7 @@ import {
14
14
  TextFallbackOptions,
15
15
  } from "@llumiverse/core";
16
16
  import { transformAsyncIterator } from "@llumiverse/core/async";
17
- import { FetchClient } from "api-fetch-client";
17
+ import { FetchClient } from "@vertesia/api-fetch-client";
18
18
 
19
19
  export interface HuggingFaceIEDriverOptions extends DriverOptions {
20
20
  apiKey: string;
@@ -1,7 +1,7 @@
1
1
  import { AIModel, AbstractDriver, Completion, CompletionChunk, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment, TextFallbackOptions } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
3
  import { OpenAITextMessage, formatOpenAILikeTextPrompt, getJSONSafetyNotice } from "@llumiverse/core/formatters";
4
- import { FetchClient } from "api-fetch-client";
4
+ import { FetchClient } from "@vertesia/api-fetch-client";
5
5
  import { ChatCompletionResponse, CompletionRequestParams, ListModelsResponse, ResponseFormat } from "./types.js";
6
6
 
7
7
  //TODO retry on 429
@@ -19,9 +19,10 @@ import {
19
19
  TrainingPromptOptions,
20
20
  getModelCapabilities,
21
21
  modelModalitiesToArray,
22
+ supportsToolUse,
22
23
  } from "@llumiverse/core";
23
24
  import { asyncMap } from "@llumiverse/core/async";
24
- import { formatOpenAILikeMultimodalPrompt, noStructuredOutputModels } from "@llumiverse/core/formatters";
25
+ import { formatOpenAILikeMultimodalPrompt } from "@llumiverse/core/formatters";
25
26
  import OpenAI, { AzureOpenAI } from "openai";
26
27
  import { Stream } from "openai/streaming";
27
28
 
@@ -87,7 +88,7 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
87
88
  }
88
89
 
89
90
  const toolDefs = getToolDefinitions(options.tools);
90
- const useTools: boolean = toolDefs ? supportsTools(options.model) : false;
91
+ const useTools: boolean = toolDefs ? supportsToolUse(options.model, "openai", true) : false;
91
92
 
92
93
  const mapFn = (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
93
94
  let result = undefined
@@ -167,7 +168,7 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
167
168
  insert_image_detail(prompt, model_options?.image_detail ?? "auto");
168
169
 
169
170
  const toolDefs = getToolDefinitions(options.tools);
170
- const useTools: boolean = toolDefs ? supportsTools(options.model) : false;
171
+ const useTools: boolean = toolDefs ? supportsToolUse(options.model, "openai") : false;
171
172
 
172
173
  let conversation = updateConversation(options.conversation as OpenAIMessageBlock[], prompt);
173
174
 
@@ -289,7 +290,8 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
289
290
 
290
291
  //Some of these use the completions API instead of the chat completions API.
291
292
  //Others are for non-text input modalities. Therefore common to both.
292
- const wordBlacklist = ["embed", "whisper", "transcribe", "audio", "moderation", "tts", "realtime", "dall-e", "babbage", "davinci"];
293
+ const wordBlacklist = ["embed", "whisper", "transcribe", "audio", "moderation", "tts",
294
+ "realtime", "dall-e", "babbage", "davinci", "codex", "o1-pro"];
293
295
 
294
296
  if (this.provider === "azure_openai") {
295
297
  //Azure OpenAI has additional information about the models
@@ -415,20 +417,14 @@ function convertRoles(messages: OpenAIMessageBlock[], model: string): OpenAIMess
415
417
  return messages
416
418
  }
417
419
 
418
- function supportsTools(model: string): boolean {
419
- const list_check = !noStructuredOutputModels.some((m) => model.includes(m));
420
- if (!list_check && model.includes("gpt-4o") && !model.includes("gpt-4o-2024-05-13")) {
421
- return true;
422
- }
423
- return list_check
424
- }
425
-
420
+ //Structured output support is typically aligned with tool use support
421
+ //Not true for realtime models, which do not support structured output, but do support tool use.
426
422
  function supportsSchema(model: string): boolean {
427
- const list_check = !noStructuredOutputModels.some((m) => model.includes(m));
428
- if (!list_check && model.includes("gpt-4o") && !model.includes("gpt-4o-2024-05-13")) {
429
- return true;
423
+ const realtimeModel = model.includes("realtime");
424
+ if (realtimeModel) {
425
+ return false;
430
426
  }
431
- return list_check
427
+ return supportsToolUse(model, "openai");
432
428
  }
433
429
 
434
430
  function getToolDefinitions(tools: ToolDefinition[] | undefined | null): OpenAI.ChatCompletionTool[] | undefined {
@@ -1,6 +1,6 @@
1
1
  import { AIModel, AbstractDriver, Completion, CompletionChunk, DriverOptions, EmbeddingsResult, ExecutionOptions, TextFallbackOptions } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
- import { FetchClient } from "api-fetch-client";
3
+ import { FetchClient } from "@vertesia/api-fetch-client";
4
4
  import { TextCompletion, TogetherModelInfo } from "./interfaces.js";
5
5
 
6
6
  interface TogetherAIDriverOptions extends DriverOptions {
@@ -1,9 +1,8 @@
1
- import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
2
1
  import {
3
2
  AIModel,
4
3
  AbstractDriver,
5
4
  Completion,
6
- CompletionChunkObject,
5
+ CompletionChunk,
7
6
  DriverOptions,
8
7
  EmbeddingsResult,
9
8
  ExecutionOptions,
@@ -14,7 +13,7 @@ import {
14
13
  getModelCapabilities,
15
14
  modelModalitiesToArray,
16
15
  } from "@llumiverse/core";
17
- import { FetchClient } from "api-fetch-client";
16
+ import { FetchClient } from "@vertesia/api-fetch-client";
18
17
  import { GoogleAuth, GoogleAuthOptions } from "google-auth-library";
19
18
  import { JSONClient } from "google-auth-library/build/src/auth/googleauth.js";
20
19
  import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
@@ -24,6 +23,7 @@ import { getEmbeddingsForImages } from "./embeddings/embeddings-image.js";
24
23
  import { v1beta1 } from "@google-cloud/aiplatform";
25
24
  import { AnthropicVertex } from "@anthropic-ai/vertex-sdk";
26
25
  import { ImagenModelDefinition, ImagenPrompt } from "./models/imagen.js";
26
+ import { GoogleGenAI, Content, Tool } from "@google/genai";
27
27
 
28
28
  export interface VertexAIDriverOptions extends DriverOptions {
29
29
  project: string;
@@ -31,8 +31,14 @@ export interface VertexAIDriverOptions extends DriverOptions {
31
31
  googleAuthOptions?: GoogleAuthOptions;
32
32
  }
33
33
 
34
+ export interface GenerateContentPrompt {
35
+ contents: Content[];
36
+ system?: string;
37
+ tools?: Tool[];
38
+ }
39
+
34
40
  //General Prompt type for VertexAI
35
- export type VertexAIPrompt = GenerateContentRequest | ImagenPrompt;
41
+ export type VertexAIPrompt = ImagenPrompt | GenerateContentPrompt;
36
42
 
37
43
  export function trimModelName(model: string) {
38
44
  const i = model.lastIndexOf("@");
@@ -46,8 +52,9 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
46
52
  aiplatform: v1beta1.ModelServiceClient | undefined;
47
53
  anthropicClient: AnthropicVertex | undefined;
48
54
  fetchClient: FetchClient | undefined;
55
+ googleGenAI: GoogleGenAI | undefined;
56
+ llamaClient: FetchClient & { region?: string } | undefined;
49
57
  modelGarden: v1beta1.ModelGardenServiceClient | undefined;
50
- vertexai: VertexAI | undefined;
51
58
 
52
59
  authClient: JSONClient | GoogleAuth<JSONClient>;
53
60
 
@@ -57,12 +64,28 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
57
64
  this.aiplatform = undefined;
58
65
  this.anthropicClient = undefined;
59
66
  this.fetchClient = undefined
67
+ this.googleGenAI = undefined;
60
68
  this.modelGarden = undefined;
61
- this.vertexai = undefined;
69
+ this.llamaClient = undefined;
62
70
 
63
71
  this.authClient = options.googleAuthOptions?.authClient ?? new GoogleAuth(options.googleAuthOptions);
64
72
  }
65
73
 
74
+ public getGoogleGenAIClient(): GoogleGenAI {
75
+ //Lazy initialisation
76
+ if (!this.googleGenAI) {
77
+ this.googleGenAI = new GoogleGenAI({
78
+ project: this.options.project,
79
+ location: this.options.region,
80
+ vertexai: true,
81
+ googleAuthOptions: {
82
+ authClient: this.authClient as JSONClient,
83
+ }
84
+ });
85
+ }
86
+ return this.googleGenAI;
87
+ }
88
+
66
89
  public getFetchClient(): FetchClient {
67
90
  //Lazy initialisation
68
91
  if (!this.fetchClient) {
@@ -78,6 +101,24 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
78
101
  return this.fetchClient;
79
102
  }
80
103
 
104
+ public getLLamaClient(region: string = "us-central1"): FetchClient {
105
+ //Lazy initialisation
106
+ if (!this.llamaClient || this.llamaClient["region"] !== region) {
107
+ this.llamaClient = createFetchClient({
108
+ region: region,
109
+ project: this.options.project,
110
+ apiVersion: "v1beta1",
111
+ }).withAuthCallback(async () => {
112
+ const accessTokenResponse = await this.authClient.getAccessToken();
113
+ const token = typeof accessTokenResponse === 'string' ? accessTokenResponse : accessTokenResponse?.token;
114
+ return `Bearer ${token}`;
115
+ });
116
+ // Store the region for potential client reuse
117
+ this.llamaClient["region"] = region;
118
+ }
119
+ return this.llamaClient;
120
+ }
121
+
81
122
  public getAnthropicClient(): AnthropicVertex {
82
123
  //Lazy initialisation
83
124
  if (!this.anthropicClient) {
@@ -89,18 +130,6 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
89
130
  return this.anthropicClient;
90
131
  }
91
132
 
92
- public getVertexAIClient(): VertexAI {
93
- //Lazy initialisation
94
- if (!this.vertexai) {
95
- this.vertexai = new VertexAI({
96
- project: this.options.project,
97
- location: this.options.region,
98
- googleAuthOptions: this.options.googleAuthOptions,
99
- });
100
- }
101
- return this.vertexai;
102
- }
103
-
104
133
  public getAIPlatformClient(): v1beta1.ModelServiceClient {
105
134
  //Lazy initialisation
106
135
  if (!this.aiplatform) {
@@ -125,6 +154,18 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
125
154
  return this.modelGarden;
126
155
  }
127
156
 
157
+ validateResult(result: Completion, options: ExecutionOptions) {
158
+ // Optionally preprocess the result before validation
159
+ const modelDef = getModelDefinition(options.model);
160
+ if (typeof modelDef.preValidationProcessing === "function") {
161
+ const processed = modelDef.preValidationProcessing(result, options);
162
+ result = processed.result;
163
+ options = processed.options;
164
+ }
165
+
166
+ super.validateResult(result, options);
167
+ }
168
+
128
169
  protected canStream(options: ExecutionOptions): Promise<boolean> {
129
170
  if (options.output_modality == Modalities.image) {
130
171
  return Promise.resolve(false);
@@ -145,7 +186,7 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
145
186
  async requestTextCompletionStream(
146
187
  prompt: VertexAIPrompt,
147
188
  options: ExecutionOptions,
148
- ): Promise<AsyncIterable<CompletionChunkObject>> {
189
+ ): Promise<AsyncIterable<CompletionChunk>> {
149
190
  return getModelDefinition(options.model).requestTextCompletionStream(this, prompt, options);
150
191
  }
151
192
 
@@ -178,14 +219,31 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
178
219
  );
179
220
 
180
221
  //Model Garden Publisher models - Pretrained models
181
- const publishers = ["google", "anthropic"];
182
- const supportedModels = { google: ["gemini", "imagen"], anthropic: ["claude"] };
222
+ const publishers = ["google", "anthropic", "meta"];
223
+ // Meta "maas" models are LLama Models-As-A-Service. Non-maas models are not pre-deployed.
224
+ const supportedModels = { google: ["gemini", "imagen"], anthropic: ["claude"], meta: ["maas"] };
225
+ // Additional models not in the listings, but we want to include
226
+ // TODO: Remove once the models are available in the listing API, or no longer needed
227
+ const additionalModels = {
228
+ google: ["imagen-3.0-fast-generate-001"],
229
+ anthropic: [],
230
+ meta: [
231
+ "llama-4-maverick-17b-128e-instruct-maas",
232
+ "llama-4-scout-17b-16e-instruct-maas",
233
+ "llama-3.3-70b-instruct-maas",
234
+ "llama-3.2-90b-vision-instruct-maas",
235
+ "llama-3.1-405b-instruct-maas",
236
+ "llama-3.1-70b-instruct-maas",
237
+ "llama-3.1-8b-instruct-maas",
238
+ ],
239
+ }
183
240
 
184
241
  //Used to exclude retired models that are still in the listing API but not available for use.
185
242
  //Or models we do not support yet
186
243
  const unsupportedModelsByPublisher = {
187
244
  google: ["gemini-pro", "gemini-ultra"],
188
245
  anthropic: [],
246
+ meta: [],
189
247
  };
190
248
 
191
249
  for (const publisher of publishers) {
@@ -228,13 +286,28 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
228
286
  tool_support: modelCapability.tool_support,
229
287
  } satisfies AIModel<string>;
230
288
  }));
289
+
290
+ // Add additional models that are not in the listing
291
+ for (const additionalModel of additionalModels[publisher as keyof typeof additionalModels]) {
292
+ const publisherModelName = `publishers/${publisher}/models/${additionalModel}`;
293
+ const modelCapability = getModelCapabilities(additionalModel, "vertexai");
294
+ models.push({
295
+ id: publisherModelName,
296
+ name: additionalModel,
297
+ provider: 'vertexai',
298
+ owner: publisher,
299
+ input_modalities: modelModalitiesToArray(modelCapability.input),
300
+ output_modalities: modelModalitiesToArray(modelCapability.output),
301
+ tool_support: modelCapability.tool_support,
302
+ } satisfies AIModel<string>);
303
+ }
231
304
  }
232
305
 
233
306
  //Remove duplicates
234
307
  const uniqueModels = Array.from(new Set(models.map(a => a.id)))
235
308
  .map(id => {
236
309
  return models.find(a => a.id === id) ?? {} as AIModel<string>;
237
- });
310
+ }).sort((a, b) => a.id.localeCompare(b.id));
238
311
 
239
312
  return uniqueModels;
240
313
  }
@@ -1,5 +1,5 @@
1
1
  import * as AnthropicAPI from '@anthropic-ai/sdk';
2
- import { ContentBlock, ContentBlockParam, Message, TextBlockParam } from "@anthropic-ai/sdk/resources/index.js";
2
+ import { ContentBlock, ContentBlockParam, ImageBlockParam, Message, TextBlockParam } from "@anthropic-ai/sdk/resources/index.js";
3
3
  import {
4
4
  AIModel, Completion, CompletionChunkObject, ExecutionOptions, JSONObject, ModelType,
5
5
  PromptOptions, PromptRole, PromptSegment, readStreamAsBase64, ToolUse, VertexAIClaudeOptions
@@ -15,23 +15,6 @@ interface ClaudePrompt {
15
15
  system: TextBlockParam[];
16
16
  }
17
17
 
18
- function getFullModelName(model: string): string {
19
- if (model.includes("claude-3-5-sonnet-v2")) {
20
- return "claude-3-5-sonnet-v2@20241022"
21
- } else if (model.includes("claude-3-5-sonnet")) {
22
- return "claude-3-5-sonnet@20240620"
23
- } else if (model.includes("claude-3-5-haiku")) {
24
- return "claude-3-5-haiku@20241022"
25
- } else if (model.includes("claude-3-opus")) {
26
- return "claude-3-opus@20240229"
27
- } else if (model.includes("claude-3-sonnet")) {
28
- return "claude-3-sonnet@20240229"
29
- } else if (model.includes("claude-3-haiku")) {
30
- return "claude-3-haiku@20240307"
31
- }
32
- return model;
33
- }
34
-
35
18
  function claudeFinishReason(reason: string | undefined) {
36
19
  if (!reason) return undefined;
37
20
  switch (reason) {
@@ -63,6 +46,36 @@ function maxToken(max_tokens: number | undefined, model: string): number {
63
46
  }
64
47
  }
65
48
 
49
+ async function collectImageBlocks(segment: PromptSegment, contentBlocks: ContentBlockParam[]) {
50
+ for (const file of segment.files || []) {
51
+ if (file.mime_type?.startsWith("image/")) {
52
+ const allowedTypes = ["image/png", "image/jpeg", "image/gif", "image/webp"];
53
+ if (!allowedTypes.includes(file.mime_type)) {
54
+ throw new Error(`Unsupported image type: ${file.mime_type}`);
55
+ }
56
+ const mimeType = String(file.mime_type) as "image/png" | "image/jpeg" | "image/gif" | "image/webp";
57
+
58
+ contentBlocks.push({
59
+ type: 'image',
60
+ source: {
61
+ type: 'base64',
62
+ data: await readStreamAsBase64(await file.getStream()),
63
+ media_type: mimeType
64
+ }
65
+ });
66
+ } else if (file.mime_type?.startsWith("text/")) {
67
+ contentBlocks.push({
68
+ source: {
69
+ type: 'text',
70
+ data: await readStreamAsBase64(await file.getStream()),
71
+ media_type: 'text/plain'
72
+ },
73
+ type: 'document'
74
+ });
75
+ }
76
+ }
77
+ }
78
+
66
79
  export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
67
80
 
68
81
  model: AIModel
@@ -111,60 +124,38 @@ export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
111
124
  if (!segment.tool_use_id) {
112
125
  throw new Error("Tool prompt segment must have a tool_use_id");
113
126
  }
127
+
128
+ const imageBlocks: ImageBlockParam[] = [];
129
+ await collectImageBlocks(segment, imageBlocks);
130
+
114
131
  messages.push({
115
132
  role: 'user',
116
- content: [
117
- {
118
- type: 'tool_result',
119
- tool_use_id: segment.tool_use_id,
120
- content: segment.content || undefined
121
- }
122
- ]
133
+ content: [{
134
+ type: 'tool_result',
135
+ tool_use_id: segment.tool_use_id,
136
+ content: [{
137
+ type: 'text',
138
+ text: segment.content || ''
139
+ }, ...imageBlocks]
140
+ }]
123
141
  });
142
+
124
143
  } else {
125
144
  const contentBlocks: ContentBlockParam[] = [];
126
- for (const file of segment.files || []) {
127
- if (file.mime_type?.startsWith("image/")) {
128
- const allowedTypes = ["image/png", "image/jpeg", "image/gif", "image/webp"];
129
- if (!allowedTypes.includes(file.mime_type)) {
130
- throw new Error(`Unsupported image type: ${file.mime_type}`);
131
- }
132
-
133
- contentBlocks.push({
134
- type: 'image',
135
- source: {
136
- type: 'base64',
137
- data: await readStreamAsBase64(await file.getStream()),
138
- media_type: file.mime_type as any
139
- }
140
- });
141
- } else if (file.mime_type?.startsWith("text/")) {
142
- contentBlocks.push({
143
- source: {
144
- type: 'text',
145
- data: await readStreamAsBase64(await file.getStream()),
146
- media_type: 'text/plain'
147
- },
148
- type : 'document'
149
- });
150
- }
151
- }
152
-
145
+ collectImageBlocks(segment, contentBlocks);
153
146
  if (segment.content) {
154
147
  contentBlocks.push({
155
148
  type: 'text',
156
149
  text: segment.content
157
150
  });
158
151
  }
159
-
160
-
161
152
  messages.push({
162
153
  role: segment.role === PromptRole.assistant ? 'assistant' : 'user',
163
154
  content: contentBlocks
164
155
  });
165
156
  }
166
157
  }
167
-
158
+
168
159
  const system = systemSegments.concat(safetySegments);
169
160
 
170
161
  return {
@@ -253,23 +244,17 @@ export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
253
244
  }
254
245
  });
255
246
 
256
- //Streaming does not give information on the input tokens,
257
- //So we use a separate call to get the input tokens.
258
- //Non-critical and model name sensitive so we put it in a try catch block
259
- let count_tokens = { input_tokens: 0 };
260
- try {
261
- count_tokens = await client.messages.countTokens({
262
- ...prompt, // messages, system
263
- model: getFullModelName(modelName),
264
- });
265
- } catch (e) {
266
- driver.logger.warn("Failed to get token count for model " + modelName);
267
- }
268
-
269
247
  const stream = asyncMap(response_stream, async (item: any) => {
248
+ if (item.type == "message_start") {
249
+ return {
250
+ result: '',
251
+ token_usage: { prompt: item?.message?.usage?.input_tokens, result: item?.message?.usage?.output_tokens },
252
+ finish_reason: undefined,
253
+ }
254
+ }
270
255
  return {
271
256
  result: item?.delta?.text ?? '',
272
- token_usage: { prompt: count_tokens.input_tokens, result: item?.usage?.output_tokens },
257
+ token_usage: { result: item?.usage?.output_tokens },
273
258
  finish_reason: claudeFinishReason(item?.delta?.stop_reason ?? ''),
274
259
  }
275
260
  });