@llumiverse/drivers 0.14.0 → 0.16.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. package/README.md +3 -3
  2. package/lib/cjs/adobe/firefly.js +119 -0
  3. package/lib/cjs/adobe/firefly.js.map +1 -0
  4. package/lib/cjs/bedrock/converse.js +177 -0
  5. package/lib/cjs/bedrock/converse.js.map +1 -0
  6. package/lib/cjs/bedrock/index.js +329 -228
  7. package/lib/cjs/bedrock/index.js.map +1 -1
  8. package/lib/cjs/bedrock/nova-image-payload.js +207 -0
  9. package/lib/cjs/bedrock/nova-image-payload.js.map +1 -0
  10. package/lib/cjs/groq/index.js +34 -9
  11. package/lib/cjs/groq/index.js.map +1 -1
  12. package/lib/cjs/huggingface_ie.js +28 -12
  13. package/lib/cjs/huggingface_ie.js.map +1 -1
  14. package/lib/cjs/index.js +1 -0
  15. package/lib/cjs/index.js.map +1 -1
  16. package/lib/cjs/mistral/index.js +31 -12
  17. package/lib/cjs/mistral/index.js.map +1 -1
  18. package/lib/cjs/mistral/types.js.map +1 -1
  19. package/lib/cjs/openai/index.js +149 -27
  20. package/lib/cjs/openai/index.js.map +1 -1
  21. package/lib/cjs/replicate.js +16 -18
  22. package/lib/cjs/replicate.js.map +1 -1
  23. package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -1
  24. package/lib/cjs/test/index.js.map +1 -1
  25. package/lib/cjs/togetherai/index.js +40 -10
  26. package/lib/cjs/togetherai/index.js.map +1 -1
  27. package/lib/cjs/vertexai/embeddings/embeddings-image.js +26 -0
  28. package/lib/cjs/vertexai/embeddings/embeddings-image.js.map +1 -0
  29. package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -1
  30. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
  31. package/lib/cjs/vertexai/index.js +92 -25
  32. package/lib/cjs/vertexai/index.js.map +1 -1
  33. package/lib/cjs/vertexai/models/claude.js +252 -0
  34. package/lib/cjs/vertexai/models/claude.js.map +1 -0
  35. package/lib/cjs/vertexai/models/gemini.js +169 -27
  36. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  37. package/lib/cjs/vertexai/models/imagen.js +317 -0
  38. package/lib/cjs/vertexai/models/imagen.js.map +1 -0
  39. package/lib/cjs/vertexai/models.js +12 -107
  40. package/lib/cjs/vertexai/models.js.map +1 -1
  41. package/lib/cjs/watsonx/index.js +39 -8
  42. package/lib/cjs/watsonx/index.js.map +1 -1
  43. package/lib/cjs/xai/index.js +71 -0
  44. package/lib/cjs/xai/index.js.map +1 -0
  45. package/lib/esm/adobe/firefly.js +115 -0
  46. package/lib/esm/adobe/firefly.js.map +1 -0
  47. package/lib/esm/bedrock/converse.js +171 -0
  48. package/lib/esm/bedrock/converse.js.map +1 -0
  49. package/lib/esm/bedrock/index.js +331 -230
  50. package/lib/esm/bedrock/index.js.map +1 -1
  51. package/lib/esm/bedrock/nova-image-payload.js +203 -0
  52. package/lib/esm/bedrock/nova-image-payload.js.map +1 -0
  53. package/lib/esm/groq/index.js +34 -9
  54. package/lib/esm/groq/index.js.map +1 -1
  55. package/lib/esm/huggingface_ie.js +29 -13
  56. package/lib/esm/huggingface_ie.js.map +1 -1
  57. package/lib/esm/index.js +1 -0
  58. package/lib/esm/index.js.map +1 -1
  59. package/lib/esm/mistral/index.js +31 -12
  60. package/lib/esm/mistral/index.js.map +1 -1
  61. package/lib/esm/mistral/types.js.map +1 -1
  62. package/lib/esm/openai/index.js +150 -28
  63. package/lib/esm/openai/index.js.map +1 -1
  64. package/lib/esm/replicate.js +17 -19
  65. package/lib/esm/replicate.js.map +1 -1
  66. package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -1
  67. package/lib/esm/test/index.js.map +1 -1
  68. package/lib/esm/togetherai/index.js +40 -10
  69. package/lib/esm/togetherai/index.js.map +1 -1
  70. package/lib/esm/vertexai/embeddings/embeddings-image.js +23 -0
  71. package/lib/esm/vertexai/embeddings/embeddings-image.js.map +1 -0
  72. package/lib/esm/vertexai/embeddings/embeddings-text.js +1 -1
  73. package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -1
  74. package/lib/esm/vertexai/index.js +93 -27
  75. package/lib/esm/vertexai/index.js.map +1 -1
  76. package/lib/esm/vertexai/models/claude.js +247 -0
  77. package/lib/esm/vertexai/models/claude.js.map +1 -0
  78. package/lib/esm/vertexai/models/gemini.js +170 -28
  79. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  80. package/lib/esm/vertexai/models/imagen.js +310 -0
  81. package/lib/esm/vertexai/models/imagen.js.map +1 -0
  82. package/lib/esm/vertexai/models.js +12 -104
  83. package/lib/esm/vertexai/models.js.map +1 -1
  84. package/lib/esm/watsonx/index.js +39 -8
  85. package/lib/esm/watsonx/index.js.map +1 -1
  86. package/lib/esm/xai/index.js +64 -0
  87. package/lib/esm/xai/index.js.map +1 -0
  88. package/lib/types/adobe/firefly.d.ts +30 -0
  89. package/lib/types/adobe/firefly.d.ts.map +1 -0
  90. package/lib/types/bedrock/converse.d.ts +8 -0
  91. package/lib/types/bedrock/converse.d.ts.map +1 -0
  92. package/lib/types/bedrock/index.d.ts +26 -11
  93. package/lib/types/bedrock/index.d.ts.map +1 -1
  94. package/lib/types/bedrock/nova-image-payload.d.ts +74 -0
  95. package/lib/types/bedrock/nova-image-payload.d.ts.map +1 -0
  96. package/lib/types/bedrock/payloads.d.ts +9 -65
  97. package/lib/types/bedrock/payloads.d.ts.map +1 -1
  98. package/lib/types/groq/index.d.ts +3 -3
  99. package/lib/types/groq/index.d.ts.map +1 -1
  100. package/lib/types/huggingface_ie.d.ts +5 -7
  101. package/lib/types/huggingface_ie.d.ts.map +1 -1
  102. package/lib/types/index.d.ts +1 -0
  103. package/lib/types/index.d.ts.map +1 -1
  104. package/lib/types/mistral/index.d.ts +4 -4
  105. package/lib/types/mistral/index.d.ts.map +1 -1
  106. package/lib/types/mistral/types.d.ts +1 -0
  107. package/lib/types/mistral/types.d.ts.map +1 -1
  108. package/lib/types/openai/index.d.ts +5 -4
  109. package/lib/types/openai/index.d.ts.map +1 -1
  110. package/lib/types/replicate.d.ts +4 -9
  111. package/lib/types/replicate.d.ts.map +1 -1
  112. package/lib/types/test/index.d.ts +2 -2
  113. package/lib/types/test/index.d.ts.map +1 -1
  114. package/lib/types/togetherai/index.d.ts +4 -4
  115. package/lib/types/togetherai/index.d.ts.map +1 -1
  116. package/lib/types/vertexai/embeddings/embeddings-image.d.ts +11 -0
  117. package/lib/types/vertexai/embeddings/embeddings-image.d.ts.map +1 -0
  118. package/lib/types/vertexai/index.d.ts +19 -8
  119. package/lib/types/vertexai/index.d.ts.map +1 -1
  120. package/lib/types/vertexai/models/claude.d.ts +20 -0
  121. package/lib/types/vertexai/models/claude.d.ts.map +1 -0
  122. package/lib/types/vertexai/models/gemini.d.ts +4 -4
  123. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  124. package/lib/types/vertexai/models/imagen.d.ts +75 -0
  125. package/lib/types/vertexai/models/imagen.d.ts.map +1 -0
  126. package/lib/types/vertexai/models.d.ts +3 -6
  127. package/lib/types/vertexai/models.d.ts.map +1 -1
  128. package/lib/types/watsonx/index.d.ts +3 -3
  129. package/lib/types/watsonx/index.d.ts.map +1 -1
  130. package/lib/types/xai/index.d.ts +19 -0
  131. package/lib/types/xai/index.d.ts.map +1 -0
  132. package/package.json +24 -23
  133. package/src/adobe/firefly.ts +207 -0
  134. package/src/bedrock/converse.ts +194 -0
  135. package/src/bedrock/index.ts +349 -237
  136. package/src/bedrock/nova-image-payload.ts +309 -0
  137. package/src/bedrock/payloads.ts +12 -66
  138. package/src/groq/index.ts +35 -12
  139. package/src/huggingface_ie.ts +34 -13
  140. package/src/index.ts +1 -0
  141. package/src/mistral/index.ts +34 -12
  142. package/src/mistral/types.ts +2 -1
  143. package/src/openai/index.ts +167 -33
  144. package/src/replicate.ts +21 -20
  145. package/src/test/TestValidationErrorCompletionStream.ts +2 -2
  146. package/src/test/index.ts +3 -2
  147. package/src/togetherai/index.ts +44 -12
  148. package/src/vertexai/embeddings/embeddings-image.ts +50 -0
  149. package/src/vertexai/embeddings/embeddings-text.ts +1 -1
  150. package/src/vertexai/index.ts +114 -37
  151. package/src/vertexai/models/claude.ts +281 -0
  152. package/src/vertexai/models/gemini.ts +181 -31
  153. package/src/vertexai/models/imagen.ts +401 -0
  154. package/src/vertexai/models.ts +16 -120
  155. package/src/watsonx/index.ts +42 -10
  156. package/src/xai/index.ts +110 -0
  157. package/lib/cjs/vertexai/models/codey-chat.js +0 -65
  158. package/lib/cjs/vertexai/models/codey-chat.js.map +0 -1
  159. package/lib/cjs/vertexai/models/codey-text.js +0 -35
  160. package/lib/cjs/vertexai/models/codey-text.js.map +0 -1
  161. package/lib/cjs/vertexai/models/palm-model-base.js +0 -59
  162. package/lib/cjs/vertexai/models/palm-model-base.js.map +0 -1
  163. package/lib/cjs/vertexai/models/palm2-chat.js +0 -65
  164. package/lib/cjs/vertexai/models/palm2-chat.js.map +0 -1
  165. package/lib/cjs/vertexai/models/palm2-text.js +0 -35
  166. package/lib/cjs/vertexai/models/palm2-text.js.map +0 -1
  167. package/lib/cjs/vertexai/utils/tensor.js +0 -86
  168. package/lib/cjs/vertexai/utils/tensor.js.map +0 -1
  169. package/lib/esm/vertexai/models/codey-chat.js +0 -61
  170. package/lib/esm/vertexai/models/codey-chat.js.map +0 -1
  171. package/lib/esm/vertexai/models/codey-text.js +0 -31
  172. package/lib/esm/vertexai/models/codey-text.js.map +0 -1
  173. package/lib/esm/vertexai/models/palm-model-base.js +0 -55
  174. package/lib/esm/vertexai/models/palm-model-base.js.map +0 -1
  175. package/lib/esm/vertexai/models/palm2-chat.js +0 -61
  176. package/lib/esm/vertexai/models/palm2-chat.js.map +0 -1
  177. package/lib/esm/vertexai/models/palm2-text.js +0 -31
  178. package/lib/esm/vertexai/models/palm2-text.js.map +0 -1
  179. package/lib/esm/vertexai/utils/tensor.js +0 -82
  180. package/lib/esm/vertexai/utils/tensor.js.map +0 -1
  181. package/lib/types/vertexai/models/codey-chat.d.ts +0 -51
  182. package/lib/types/vertexai/models/codey-chat.d.ts.map +0 -1
  183. package/lib/types/vertexai/models/codey-text.d.ts +0 -39
  184. package/lib/types/vertexai/models/codey-text.d.ts.map +0 -1
  185. package/lib/types/vertexai/models/palm-model-base.d.ts +0 -61
  186. package/lib/types/vertexai/models/palm-model-base.d.ts.map +0 -1
  187. package/lib/types/vertexai/models/palm2-chat.d.ts +0 -61
  188. package/lib/types/vertexai/models/palm2-chat.d.ts.map +0 -1
  189. package/lib/types/vertexai/models/palm2-text.d.ts +0 -39
  190. package/lib/types/vertexai/models/palm2-text.d.ts.map +0 -1
  191. package/lib/types/vertexai/utils/tensor.d.ts +0 -6
  192. package/lib/types/vertexai/utils/tensor.d.ts.map +0 -1
  193. package/src/vertexai/models/codey-chat.ts +0 -115
  194. package/src/vertexai/models/codey-text.ts +0 -69
  195. package/src/vertexai/models/palm-model-base.ts +0 -128
  196. package/src/vertexai/models/palm2-chat.ts +0 -119
  197. package/src/vertexai/models/palm2-text.ts +0 -69
  198. package/src/vertexai/utils/tensor.ts +0 -82
@@ -1,9 +1,15 @@
1
1
  import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
2
- import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions, ModelSearchPayload, PromptOptions, PromptSegment } from "@llumiverse/core";
2
+ import { AIModel, AbstractDriver, Completion, CompletionChunkObject, DriverOptions, EmbeddingsResult, ExecutionOptions, ImageGeneration, Modalities, ModelSearchPayload, PromptSegment } from "@llumiverse/core";
3
3
  import { FetchClient } from "api-fetch-client";
4
- import { GoogleAuthOptions } from "google-auth-library";
4
+ import { GoogleAuth, GoogleAuthOptions } from "google-auth-library";
5
+ import { JSONClient } from "google-auth-library/build/src/auth/googleauth.js";
5
6
  import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
6
- import { BuiltinModels, getModelDefinition } from "./models.js";
7
+ import { getModelDefinition } from "./models.js";
8
+ import { EmbeddingsOptions } from "@llumiverse/core";
9
+ import { getEmbeddingsForImages } from "./embeddings/embeddings-image.js";
10
+ import { v1beta1 } from '@google-cloud/aiplatform';
11
+ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
12
+ import { ImagenModelDefinition, ImagenPrompt } from "./models/imagen.js";
7
13
 
8
14
 
9
15
  export interface VertexAIDriverOptions extends DriverOptions {
@@ -12,19 +18,31 @@ export interface VertexAIDriverOptions extends DriverOptions {
12
18
  googleAuthOptions?: GoogleAuthOptions;
13
19
  }
14
20
 
15
- export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, GenerateContentRequest> {
21
+ //General Prompt type for VertexAI
22
+ export type VertexAIPrompt = GenerateContentRequest | ImagenPrompt;
23
+
24
+ export function trimModelName(model: string) {
25
+ const i = model.lastIndexOf('@');
26
+ return i > -1 ? model.substring(0, i) : model;
27
+ }
28
+
29
+ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, VertexAIPrompt> {
16
30
  static PROVIDER = "vertexai";
17
31
  provider = VertexAIDriver.PROVIDER;
18
32
 
19
- //aiplatform: v1.ModelServiceClient;
33
+ aiplatform: v1beta1.ModelServiceClient;
20
34
  vertexai: VertexAI;
21
35
  fetchClient: FetchClient;
22
-
23
- constructor(
24
- options: VertexAIDriverOptions
25
- ) {
36
+ authClient: JSONClient | GoogleAuth<JSONClient>;
37
+ anthropicClient: AnthropicVertex | undefined;
38
+
39
+ constructor( options: VertexAIDriverOptions) {
26
40
  super(options);
27
- //this.aiplatform = new v1.ModelServiceClient();
41
+
42
+ this.anthropicClient = undefined;
43
+
44
+ this.authClient = options.googleAuthOptions?.authClient ?? new GoogleAuth(options.googleAuthOptions);
45
+
28
46
  this.vertexai = new VertexAI({
29
47
  project: this.options.project,
30
48
  location: this.options.region,
@@ -35,54 +53,113 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Genera
35
53
  project: this.options.project,
36
54
  }).withAuthCallback(async () => {
37
55
  //@ts-ignore
38
- const token = await this.options.googleAuthOptions?.authClient?.getAccessToken();
56
+ const token = await this.authClient.getAccessToken();
39
57
  return `Bearer ${token}`;
40
58
  });
41
- // this.aiplatform = new v1.ModelServiceClient({
42
- // projectId: this.options.project,
43
- // apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
44
- // });
59
+ this.aiplatform = new v1beta1.ModelServiceClient({
60
+ projectId: this.options.project,
61
+ apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
62
+ });
63
+ }
64
+
65
+ public getAnthropicClient() : AnthropicVertex {
66
+ //Lazy initialisation
67
+ if (!this.anthropicClient) {
68
+ this.anthropicClient = new AnthropicVertex({region: "us-east5", projectId: process.env.GOOGLE_PROJECT_ID});
69
+ }
70
+ return this.anthropicClient;
45
71
  }
46
72
 
47
73
  protected canStream(options: ExecutionOptions): Promise<boolean> {
74
+ if (options.output_modality == Modalities.image) {
75
+ return Promise.resolve(false);
76
+ }
48
77
  return Promise.resolve(getModelDefinition(options.model).model.can_stream === true);
49
78
  }
50
79
 
51
- public createPrompt(segments: PromptSegment[], options: PromptOptions): Promise<GenerateContentRequest> {
80
+ public createPrompt(segments: PromptSegment[], options: ExecutionOptions): Promise<VertexAIPrompt> {
81
+ if (options.model.includes("imagen")) {
82
+ return new ImagenModelDefinition(options.model).createPrompt(this, segments, options);
83
+ }
52
84
  return getModelDefinition(options.model).createPrompt(this, segments, options);
53
85
  }
54
86
 
55
- async requestCompletion(prompt: GenerateContentRequest, options: ExecutionOptions): Promise<Completion<any>> {
56
- return getModelDefinition(options.model).requestCompletion(this, prompt, options);
87
+ async requestTextCompletion(prompt: VertexAIPrompt, options: ExecutionOptions): Promise<Completion<any>> {
88
+ return getModelDefinition(options.model).requestTextCompletion(this, prompt, options);
57
89
  }
58
- async requestCompletionStream(prompt: GenerateContentRequest, options: ExecutionOptions): Promise<AsyncIterable<string>> {
59
- return getModelDefinition(options.model).requestCompletionStream(this, prompt, options);
90
+ async requestTextCompletionStream(prompt: VertexAIPrompt, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
91
+ return getModelDefinition(options.model).requestTextCompletionStream(this, prompt, options);
92
+ }
93
+
94
+ async requestImageGeneration(_prompt: ImagenPrompt, _options: ExecutionOptions): Promise <Completion<ImageGeneration>> {
95
+ const splits = _options.model.split("/");
96
+ const modelName = trimModelName(splits[splits.length - 1]);
97
+ return new ImagenModelDefinition(modelName).requestImageGeneration(this, _prompt, _options);
60
98
  }
61
99
 
62
100
  async listModels(_params?: ModelSearchPayload): Promise<AIModel<string>[]> {
63
- return BuiltinModels;
64
- // try {
65
- // const response = await this.fetchClient.get('/publishers/google/models/gemini-pro');
66
- // console.log('>>>>>>>>', response);
67
- // } catch (err: any) {
68
- // console.error('+++++VETREX ERROR++++++', err);
69
- // throw err;
70
- // }
71
-
72
- // TODO uncomment this to use apiplatform instead of the fetch client
73
- // const response = await this.aiplatform.listModels({
74
- // parent: `projects/${this.options.project}/locations/${this.options.region}`,
75
- // });
76
-
77
- return []; //TODO
101
+ let models: AIModel<string>[] = [];
102
+ const modelGarden = new v1beta1.ModelGardenServiceClient({
103
+ projectId: this.options.project,
104
+ apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
105
+ });
106
+
107
+ //Project specific deployed models
108
+ const [response] = await this.aiplatform.listModels({
109
+ parent: `projects/${this.options.project}/locations/${this.options.region}`,
110
+ });
111
+ models = models.concat(response.map(model => ({
112
+ id: model.name?.split('/').pop() ?? '',
113
+ name: model.displayName ?? '',
114
+ provider: 'vertexai',
115
+ })));
116
+
117
+ //Model Garden Publisher models - Pretrained models
118
+ const publishers = ['google', 'anthropic']
119
+ const supportedModels = {google: ['gemini','imagen'], anthropic: ['claude']}
120
+
121
+ for (const publisher of publishers) {
122
+ const [response] = await modelGarden.listPublisherModels({
123
+ parent: `publishers/${publisher}`,
124
+ orderBy: 'name',
125
+ //filter: `name eq name`,
126
+ listAllVersions: true,
127
+ });
128
+
129
+ models = models.concat(response.map(model => ({
130
+ id: model.name ?? '',
131
+ name: model.name?.split('/').pop() ?? '',
132
+ provider: 'vertexai',
133
+ owner: publisher,
134
+ })).filter(model => {
135
+ const modelFamily = supportedModels[publisher as keyof typeof supportedModels];
136
+ for (const family of modelFamily) {
137
+ if (model.name.includes(family)) {
138
+ return true;
139
+ }
140
+ }
141
+ }));
142
+ }
143
+
144
+ return models;
78
145
  }
79
146
 
80
147
  validateConnection(): Promise<boolean> {
81
148
  throw new Error("Method not implemented.");
82
149
  }
83
150
 
84
- async generateEmbeddings(options: TextEmbeddingsOptions): Promise<EmbeddingsResult> {
85
- return getEmbeddingsForText(this, options);
151
+ async generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult> {
152
+ if (options.image || options.model?.includes("multimodal")) {
153
+ if (options.text && options.image) {
154
+ throw new Error("Text and Image simultaneous embedding not implemented. Submit seperately");
155
+ }
156
+ return getEmbeddingsForImages(this, options);
157
+ }
158
+ const text_options: TextEmbeddingsOptions = {
159
+ content: options.text ?? '',
160
+ model: options.model,
161
+ }
162
+ return getEmbeddingsForText(this, text_options);
86
163
  }
87
164
 
88
165
  }
@@ -0,0 +1,281 @@
1
+ import * as AnthropicAPI from '@anthropic-ai/sdk';
2
+ import { ContentBlock, Message, TextBlockParam } from "@anthropic-ai/sdk/resources/index.js";
3
+ import { AIModel, Completion, CompletionChunkObject, ExecutionOptions, JSONObject, ModelType, PromptOptions, PromptRole, PromptSegment, ToolUse } from "@llumiverse/core";
4
+ import { asyncMap } from "@llumiverse/core/async";
5
+ import { VertexAIClaudeOptions } from "../../../../core/src/options/vertexai.js";
6
+ import { VertexAIDriver } from "../index.js";
7
+ import { ModelDefinition } from "../models.js";
8
+
9
+ type MessageParam = AnthropicAPI.Anthropic.MessageParam;
10
+
11
+ interface ClaudePrompt {
12
+ messages: MessageParam[];
13
+ system: TextBlockParam[];
14
+ }
15
+
16
+ function getFullModelName(model: string): string {
17
+ if (model.includes("claude-3-5-sonnet-v2")) {
18
+ return "claude-3-5-sonnet-v2@20241022"
19
+ } else if (model.includes("claude-3-5-sonnet")) {
20
+ return "claude-3-5-sonnet@20240620"
21
+ } else if (model.includes("claude-3-5-haiku")) {
22
+ return "claude-3-5-haiku@20241022"
23
+ } else if (model.includes("claude-3-opus")) {
24
+ return "claude-3-opus@20240229"
25
+ } else if (model.includes("claude-3-sonnet")) {
26
+ return "claude-3-sonnet@20240229"
27
+ } else if (model.includes("claude-3-haike")) {
28
+ return "claude-3-haiku@20240307"
29
+ }
30
+ return model;
31
+ }
32
+
33
+ function claudeFinishReason(reason: string | undefined) {
34
+ if (!reason) return undefined;
35
+ switch (reason) {
36
+ case 'end_turn': return "stop";
37
+ case 'max_tokens': return "length";
38
+ default: return reason; //stop_sequence
39
+ }
40
+ }
41
+
42
+ function collectTextParts(content: any) {
43
+ const out = [];
44
+
45
+ for (const block of content) {
46
+ if (block?.text) {
47
+ out.push(block.text);
48
+ }
49
+ }
50
+ return out.join('\n');
51
+ }
52
+
53
+ function maxToken(max_tokens: number | undefined, model: string): number {
54
+ const contains = (str: string, substr: string) => str.indexOf(substr) !== -1;
55
+ if (max_tokens) {
56
+ return max_tokens;
57
+ } else if (contains(model, "claude-3-5")) {
58
+ return 8192;
59
+ } else {
60
+ return 4096
61
+ }
62
+ }
63
+
64
+ export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
65
+
66
+ model: AIModel
67
+
68
+ constructor(modelId: string) {
69
+ this.model = {
70
+ id: modelId,
71
+ name: modelId,
72
+ provider: 'vertexai',
73
+ type: ModelType.Text,
74
+ can_stream: true,
75
+ } as AIModel;
76
+ }
77
+
78
+ async createPrompt(_driver: VertexAIDriver, segments: PromptSegment[], options: PromptOptions): Promise<ClaudePrompt> {
79
+ // Convert the prompt to the format expected by the Claude API
80
+ const systemSegments: TextBlockParam[] = segments
81
+ .filter(segment => segment.role === PromptRole.system)
82
+ .map(segment => ({
83
+ text: segment.content,
84
+ type: 'text'
85
+ }));
86
+
87
+ const safetySegments: TextBlockParam[] = segments
88
+ .filter(segment => segment.role === PromptRole.safety)
89
+ .map(segment => ({
90
+ text: segment.content,
91
+ type: 'text'
92
+ }));
93
+
94
+ if (options.result_schema) {
95
+ const schemaSegments: TextBlockParam = {
96
+ text: "The answer must be a JSON object using the following JSON Schema:\n" + JSON.stringify(options.result_schema),
97
+ type: 'text'
98
+ }
99
+ safetySegments.push(schemaSegments);
100
+ }
101
+
102
+ const messages: MessageParam[] = segments.filter(segment =>
103
+ segment.role == PromptRole.user
104
+ || segment.role == PromptRole.assistant
105
+ || segment.role === PromptRole.tool)
106
+ .map(segment => {
107
+ if (segment.role === PromptRole.tool) {
108
+ if (!segment.tool_use_id) {
109
+ throw new Error("Tool prompt segment must have a tool_use_id");
110
+ }
111
+ return {
112
+ role: 'user',
113
+ content: [
114
+ {
115
+ type: 'tool_result',
116
+ tool_use_id: segment.tool_use_id,
117
+ content: segment.content || undefined
118
+ }
119
+ ]
120
+ }
121
+ } else {
122
+ return {
123
+ role: segment.role !== PromptRole.user ? 'assistant' : 'user',
124
+ content: segment.content
125
+ }
126
+ }
127
+ });
128
+
129
+ const system = systemSegments.concat(safetySegments);
130
+
131
+ return {
132
+ messages: messages,
133
+ system: system
134
+ }
135
+ }
136
+
137
+ async requestTextCompletion(driver: VertexAIDriver, prompt: ClaudePrompt, options: ExecutionOptions): Promise<Completion> {
138
+ const client = driver.getAnthropicClient();
139
+ const splits = options.model.split("/");
140
+ const modelName = splits[splits.length - 1];
141
+ options = { ...options, model: modelName };
142
+ options.model_options = options.model_options as VertexAIClaudeOptions;
143
+
144
+ if (options.model_options?._option_id !== "vertexai-claude") {
145
+ driver.logger.warn("Invalid model options", { options: options.model_options });
146
+ }
147
+
148
+ let conversation = updateConversation(options.conversation as ClaudePrompt, prompt);
149
+
150
+ const result = await client.messages.create({
151
+ ...conversation, // messages, system,
152
+ tools: options.tools, // we are using the same shape as claude for tools
153
+ temperature: options.model_options?.temperature,
154
+ model: modelName,
155
+ max_tokens: maxToken(options.model_options?.max_tokens, modelName),
156
+ top_p: options.model_options?.top_p,
157
+ top_k: options.model_options?.top_k,
158
+ stop_sequences: options.model_options?.stop_sequence,
159
+ thinking: options.model_options?.thinking_mode ?
160
+ {
161
+ budget_tokens: options.model_options?.thinking_budget_tokens ?? 1024,
162
+ type: "enabled"
163
+ } : {
164
+ type: "disabled"
165
+ }
166
+ }) as Message;
167
+
168
+ const text = collectTextParts(result.content);
169
+ const tool_use = collectTools(result.content);
170
+
171
+ conversation = updateConversation(options.conversation as ClaudePrompt, createPromptFromResponse(result));
172
+
173
+ return {
174
+ chat: [prompt, { role: result.role, content: result.content }],
175
+ result: text ?? '',
176
+ tool_use,
177
+ token_usage: {
178
+ prompt: result?.usage.input_tokens,
179
+ result: result?.usage.output_tokens,
180
+ total: result?.usage.input_tokens + result?.usage.output_tokens
181
+ },
182
+ // make sure we set finish_reason to the correct value (claude is normally setting this by itself)
183
+ finish_reason: tool_use ? "tool_use" : claudeFinishReason(result?.stop_reason ?? ''),
184
+ conversation
185
+ } as Completion;
186
+ }
187
+
188
+ async requestTextCompletionStream(driver: VertexAIDriver, prompt: ClaudePrompt, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
189
+ const client = driver.getAnthropicClient();
190
+ const splits = options.model.split("/");
191
+ const modelName = splits[splits.length - 1];
192
+ options = { ...options, model: modelName };
193
+ options.model_options = options.model_options as VertexAIClaudeOptions;
194
+
195
+ if (options.model_options?._option_id !== "vertexai-claude") {
196
+ driver.logger.warn("Invalid model options", { options: options.model_options });
197
+ }
198
+
199
+ const response_stream = await client.messages.stream({
200
+ ...prompt, // messages, system,
201
+ tools: options.tools, // we are using the same shape as claude for tools
202
+ temperature: options.model_options?.temperature,
203
+ model: modelName,
204
+ max_tokens: maxToken(options.model_options?.max_tokens, modelName),
205
+ top_p: options.model_options?.top_p,
206
+ top_k: options.model_options?.top_k,
207
+ stop_sequences: options.model_options?.stop_sequence,
208
+ thinking: options.model_options?.thinking_mode ?
209
+ {
210
+ budget_tokens: options.model_options?.thinking_budget_tokens ?? 1024,
211
+ type: "enabled"
212
+ } : {
213
+ type: "disabled"
214
+ }
215
+ });
216
+
217
+ //Streaming does not give information on the input tokens,
218
+ //So we use a seperate call to get the input tokens.
219
+ //Non-critical and model name sensitive so we put it in a try catch block
220
+ let count_tokens = { input_tokens: 0 };
221
+ try {
222
+ count_tokens = await client.messages.countTokens({
223
+ ...prompt, // messages, system
224
+ model: getFullModelName(modelName),
225
+ });
226
+ } catch (e) {
227
+ driver.logger.warn("Failed to get token count for model " + modelName);
228
+ }
229
+
230
+ const stream = asyncMap(response_stream, async (item: any) => {
231
+ return {
232
+ result: item?.delta?.text ?? '',
233
+ token_usage: { prompt: count_tokens.input_tokens, result: item?.usage?.output_tokens },
234
+ finish_reason: claudeFinishReason(item?.delta?.stop_reason ?? ''),
235
+ }
236
+ });
237
+
238
+ return stream;
239
+ }
240
+ }
241
+
242
+ export function collectTools(content: ContentBlock[]): ToolUse[] | undefined {
243
+ const out: ToolUse[] = [];
244
+
245
+ for (const block of content) {
246
+ if (block?.type === "tool_use") {
247
+ out.push({
248
+ id: block.id,
249
+ name: block.name,
250
+ input: block.input as JSONObject,
251
+ });
252
+ }
253
+ }
254
+
255
+ return out.length > 0 ? out : undefined;
256
+ }
257
+
258
+ function createPromptFromResponse(response: Message): ClaudePrompt {
259
+ return {
260
+ messages: [{
261
+ role: PromptRole.assistant,
262
+ content: response.content,
263
+ }],
264
+ system: []
265
+ }
266
+ }
267
+
268
+ /**
269
+ * Update the converatation messages
270
+ * @param prompt
271
+ * @param response
272
+ * @returns
273
+ */
274
+ function updateConversation(conversation: ClaudePrompt | undefined | null, prompt: ClaudePrompt): ClaudePrompt {
275
+ const baseSystemMessages = conversation ? conversation.system : [];
276
+ const baseMessages = conversation ? conversation.messages : []
277
+ return {
278
+ messages: baseMessages.concat(prompt.messages || []),
279
+ system: baseSystemMessages.concat(prompt.system || [])
280
+ };
281
+ }