@llumiverse/drivers 0.14.0 → 0.16.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -3
- package/lib/cjs/adobe/firefly.js +119 -0
- package/lib/cjs/adobe/firefly.js.map +1 -0
- package/lib/cjs/bedrock/converse.js +177 -0
- package/lib/cjs/bedrock/converse.js.map +1 -0
- package/lib/cjs/bedrock/index.js +329 -228
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/bedrock/nova-image-payload.js +207 -0
- package/lib/cjs/bedrock/nova-image-payload.js.map +1 -0
- package/lib/cjs/groq/index.js +34 -9
- package/lib/cjs/groq/index.js.map +1 -1
- package/lib/cjs/huggingface_ie.js +28 -12
- package/lib/cjs/huggingface_ie.js.map +1 -1
- package/lib/cjs/index.js +1 -0
- package/lib/cjs/index.js.map +1 -1
- package/lib/cjs/mistral/index.js +31 -12
- package/lib/cjs/mistral/index.js.map +1 -1
- package/lib/cjs/mistral/types.js.map +1 -1
- package/lib/cjs/openai/index.js +149 -27
- package/lib/cjs/openai/index.js.map +1 -1
- package/lib/cjs/replicate.js +16 -18
- package/lib/cjs/replicate.js.map +1 -1
- package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -1
- package/lib/cjs/test/index.js.map +1 -1
- package/lib/cjs/togetherai/index.js +40 -10
- package/lib/cjs/togetherai/index.js.map +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-image.js +26 -0
- package/lib/cjs/vertexai/embeddings/embeddings-image.js.map +1 -0
- package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
- package/lib/cjs/vertexai/index.js +92 -25
- package/lib/cjs/vertexai/index.js.map +1 -1
- package/lib/cjs/vertexai/models/claude.js +252 -0
- package/lib/cjs/vertexai/models/claude.js.map +1 -0
- package/lib/cjs/vertexai/models/gemini.js +169 -27
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/vertexai/models/imagen.js +317 -0
- package/lib/cjs/vertexai/models/imagen.js.map +1 -0
- package/lib/cjs/vertexai/models.js +12 -107
- package/lib/cjs/vertexai/models.js.map +1 -1
- package/lib/cjs/watsonx/index.js +39 -8
- package/lib/cjs/watsonx/index.js.map +1 -1
- package/lib/cjs/xai/index.js +71 -0
- package/lib/cjs/xai/index.js.map +1 -0
- package/lib/esm/adobe/firefly.js +115 -0
- package/lib/esm/adobe/firefly.js.map +1 -0
- package/lib/esm/bedrock/converse.js +171 -0
- package/lib/esm/bedrock/converse.js.map +1 -0
- package/lib/esm/bedrock/index.js +331 -230
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/bedrock/nova-image-payload.js +203 -0
- package/lib/esm/bedrock/nova-image-payload.js.map +1 -0
- package/lib/esm/groq/index.js +34 -9
- package/lib/esm/groq/index.js.map +1 -1
- package/lib/esm/huggingface_ie.js +29 -13
- package/lib/esm/huggingface_ie.js.map +1 -1
- package/lib/esm/index.js +1 -0
- package/lib/esm/index.js.map +1 -1
- package/lib/esm/mistral/index.js +31 -12
- package/lib/esm/mistral/index.js.map +1 -1
- package/lib/esm/mistral/types.js.map +1 -1
- package/lib/esm/openai/index.js +150 -28
- package/lib/esm/openai/index.js.map +1 -1
- package/lib/esm/replicate.js +17 -19
- package/lib/esm/replicate.js.map +1 -1
- package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -1
- package/lib/esm/test/index.js.map +1 -1
- package/lib/esm/togetherai/index.js +40 -10
- package/lib/esm/togetherai/index.js.map +1 -1
- package/lib/esm/vertexai/embeddings/embeddings-image.js +23 -0
- package/lib/esm/vertexai/embeddings/embeddings-image.js.map +1 -0
- package/lib/esm/vertexai/embeddings/embeddings-text.js +1 -1
- package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -1
- package/lib/esm/vertexai/index.js +93 -27
- package/lib/esm/vertexai/index.js.map +1 -1
- package/lib/esm/vertexai/models/claude.js +247 -0
- package/lib/esm/vertexai/models/claude.js.map +1 -0
- package/lib/esm/vertexai/models/gemini.js +170 -28
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/vertexai/models/imagen.js +310 -0
- package/lib/esm/vertexai/models/imagen.js.map +1 -0
- package/lib/esm/vertexai/models.js +12 -104
- package/lib/esm/vertexai/models.js.map +1 -1
- package/lib/esm/watsonx/index.js +39 -8
- package/lib/esm/watsonx/index.js.map +1 -1
- package/lib/esm/xai/index.js +64 -0
- package/lib/esm/xai/index.js.map +1 -0
- package/lib/types/adobe/firefly.d.ts +30 -0
- package/lib/types/adobe/firefly.d.ts.map +1 -0
- package/lib/types/bedrock/converse.d.ts +8 -0
- package/lib/types/bedrock/converse.d.ts.map +1 -0
- package/lib/types/bedrock/index.d.ts +26 -11
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/bedrock/nova-image-payload.d.ts +74 -0
- package/lib/types/bedrock/nova-image-payload.d.ts.map +1 -0
- package/lib/types/bedrock/payloads.d.ts +9 -65
- package/lib/types/bedrock/payloads.d.ts.map +1 -1
- package/lib/types/groq/index.d.ts +3 -3
- package/lib/types/groq/index.d.ts.map +1 -1
- package/lib/types/huggingface_ie.d.ts +5 -7
- package/lib/types/huggingface_ie.d.ts.map +1 -1
- package/lib/types/index.d.ts +1 -0
- package/lib/types/index.d.ts.map +1 -1
- package/lib/types/mistral/index.d.ts +4 -4
- package/lib/types/mistral/index.d.ts.map +1 -1
- package/lib/types/mistral/types.d.ts +1 -0
- package/lib/types/mistral/types.d.ts.map +1 -1
- package/lib/types/openai/index.d.ts +5 -4
- package/lib/types/openai/index.d.ts.map +1 -1
- package/lib/types/replicate.d.ts +4 -9
- package/lib/types/replicate.d.ts.map +1 -1
- package/lib/types/test/index.d.ts +2 -2
- package/lib/types/test/index.d.ts.map +1 -1
- package/lib/types/togetherai/index.d.ts +4 -4
- package/lib/types/togetherai/index.d.ts.map +1 -1
- package/lib/types/vertexai/embeddings/embeddings-image.d.ts +11 -0
- package/lib/types/vertexai/embeddings/embeddings-image.d.ts.map +1 -0
- package/lib/types/vertexai/index.d.ts +19 -8
- package/lib/types/vertexai/index.d.ts.map +1 -1
- package/lib/types/vertexai/models/claude.d.ts +20 -0
- package/lib/types/vertexai/models/claude.d.ts.map +1 -0
- package/lib/types/vertexai/models/gemini.d.ts +4 -4
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
- package/lib/types/vertexai/models/imagen.d.ts +75 -0
- package/lib/types/vertexai/models/imagen.d.ts.map +1 -0
- package/lib/types/vertexai/models.d.ts +3 -6
- package/lib/types/vertexai/models.d.ts.map +1 -1
- package/lib/types/watsonx/index.d.ts +3 -3
- package/lib/types/watsonx/index.d.ts.map +1 -1
- package/lib/types/xai/index.d.ts +19 -0
- package/lib/types/xai/index.d.ts.map +1 -0
- package/package.json +24 -23
- package/src/adobe/firefly.ts +207 -0
- package/src/bedrock/converse.ts +194 -0
- package/src/bedrock/index.ts +349 -237
- package/src/bedrock/nova-image-payload.ts +309 -0
- package/src/bedrock/payloads.ts +12 -66
- package/src/groq/index.ts +35 -12
- package/src/huggingface_ie.ts +34 -13
- package/src/index.ts +1 -0
- package/src/mistral/index.ts +34 -12
- package/src/mistral/types.ts +2 -1
- package/src/openai/index.ts +167 -33
- package/src/replicate.ts +21 -20
- package/src/test/TestValidationErrorCompletionStream.ts +2 -2
- package/src/test/index.ts +3 -2
- package/src/togetherai/index.ts +44 -12
- package/src/vertexai/embeddings/embeddings-image.ts +50 -0
- package/src/vertexai/embeddings/embeddings-text.ts +1 -1
- package/src/vertexai/index.ts +114 -37
- package/src/vertexai/models/claude.ts +281 -0
- package/src/vertexai/models/gemini.ts +181 -31
- package/src/vertexai/models/imagen.ts +401 -0
- package/src/vertexai/models.ts +16 -120
- package/src/watsonx/index.ts +42 -10
- package/src/xai/index.ts +110 -0
- package/lib/cjs/vertexai/models/codey-chat.js +0 -65
- package/lib/cjs/vertexai/models/codey-chat.js.map +0 -1
- package/lib/cjs/vertexai/models/codey-text.js +0 -35
- package/lib/cjs/vertexai/models/codey-text.js.map +0 -1
- package/lib/cjs/vertexai/models/palm-model-base.js +0 -59
- package/lib/cjs/vertexai/models/palm-model-base.js.map +0 -1
- package/lib/cjs/vertexai/models/palm2-chat.js +0 -65
- package/lib/cjs/vertexai/models/palm2-chat.js.map +0 -1
- package/lib/cjs/vertexai/models/palm2-text.js +0 -35
- package/lib/cjs/vertexai/models/palm2-text.js.map +0 -1
- package/lib/cjs/vertexai/utils/tensor.js +0 -86
- package/lib/cjs/vertexai/utils/tensor.js.map +0 -1
- package/lib/esm/vertexai/models/codey-chat.js +0 -61
- package/lib/esm/vertexai/models/codey-chat.js.map +0 -1
- package/lib/esm/vertexai/models/codey-text.js +0 -31
- package/lib/esm/vertexai/models/codey-text.js.map +0 -1
- package/lib/esm/vertexai/models/palm-model-base.js +0 -55
- package/lib/esm/vertexai/models/palm-model-base.js.map +0 -1
- package/lib/esm/vertexai/models/palm2-chat.js +0 -61
- package/lib/esm/vertexai/models/palm2-chat.js.map +0 -1
- package/lib/esm/vertexai/models/palm2-text.js +0 -31
- package/lib/esm/vertexai/models/palm2-text.js.map +0 -1
- package/lib/esm/vertexai/utils/tensor.js +0 -82
- package/lib/esm/vertexai/utils/tensor.js.map +0 -1
- package/lib/types/vertexai/models/codey-chat.d.ts +0 -51
- package/lib/types/vertexai/models/codey-chat.d.ts.map +0 -1
- package/lib/types/vertexai/models/codey-text.d.ts +0 -39
- package/lib/types/vertexai/models/codey-text.d.ts.map +0 -1
- package/lib/types/vertexai/models/palm-model-base.d.ts +0 -61
- package/lib/types/vertexai/models/palm-model-base.d.ts.map +0 -1
- package/lib/types/vertexai/models/palm2-chat.d.ts +0 -61
- package/lib/types/vertexai/models/palm2-chat.d.ts.map +0 -1
- package/lib/types/vertexai/models/palm2-text.d.ts +0 -39
- package/lib/types/vertexai/models/palm2-text.d.ts.map +0 -1
- package/lib/types/vertexai/utils/tensor.d.ts +0 -6
- package/lib/types/vertexai/utils/tensor.d.ts.map +0 -1
- package/src/vertexai/models/codey-chat.ts +0 -115
- package/src/vertexai/models/codey-text.ts +0 -69
- package/src/vertexai/models/palm-model-base.ts +0 -128
- package/src/vertexai/models/palm2-chat.ts +0 -119
- package/src/vertexai/models/palm2-text.ts +0 -69
- package/src/vertexai/utils/tensor.ts +0 -82
package/src/vertexai/index.ts
CHANGED
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
import { GenerateContentRequest, VertexAI } from "@google-cloud/vertexai";
|
|
2
|
-
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsResult, ExecutionOptions,
|
|
2
|
+
import { AIModel, AbstractDriver, Completion, CompletionChunkObject, DriverOptions, EmbeddingsResult, ExecutionOptions, ImageGeneration, Modalities, ModelSearchPayload, PromptSegment } from "@llumiverse/core";
|
|
3
3
|
import { FetchClient } from "api-fetch-client";
|
|
4
|
-
import { GoogleAuthOptions } from "google-auth-library";
|
|
4
|
+
import { GoogleAuth, GoogleAuthOptions } from "google-auth-library";
|
|
5
|
+
import { JSONClient } from "google-auth-library/build/src/auth/googleauth.js";
|
|
5
6
|
import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
|
|
6
|
-
import {
|
|
7
|
+
import { getModelDefinition } from "./models.js";
|
|
8
|
+
import { EmbeddingsOptions } from "@llumiverse/core";
|
|
9
|
+
import { getEmbeddingsForImages } from "./embeddings/embeddings-image.js";
|
|
10
|
+
import { v1beta1 } from '@google-cloud/aiplatform';
|
|
11
|
+
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
|
|
12
|
+
import { ImagenModelDefinition, ImagenPrompt } from "./models/imagen.js";
|
|
7
13
|
|
|
8
14
|
|
|
9
15
|
export interface VertexAIDriverOptions extends DriverOptions {
|
|
@@ -12,19 +18,31 @@ export interface VertexAIDriverOptions extends DriverOptions {
|
|
|
12
18
|
googleAuthOptions?: GoogleAuthOptions;
|
|
13
19
|
}
|
|
14
20
|
|
|
15
|
-
|
|
21
|
+
//General Prompt type for VertexAI
|
|
22
|
+
export type VertexAIPrompt = GenerateContentRequest | ImagenPrompt;
|
|
23
|
+
|
|
24
|
+
export function trimModelName(model: string) {
|
|
25
|
+
const i = model.lastIndexOf('@');
|
|
26
|
+
return i > -1 ? model.substring(0, i) : model;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, VertexAIPrompt> {
|
|
16
30
|
static PROVIDER = "vertexai";
|
|
17
31
|
provider = VertexAIDriver.PROVIDER;
|
|
18
32
|
|
|
19
|
-
|
|
33
|
+
aiplatform: v1beta1.ModelServiceClient;
|
|
20
34
|
vertexai: VertexAI;
|
|
21
35
|
fetchClient: FetchClient;
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
) {
|
|
36
|
+
authClient: JSONClient | GoogleAuth<JSONClient>;
|
|
37
|
+
anthropicClient: AnthropicVertex | undefined;
|
|
38
|
+
|
|
39
|
+
constructor( options: VertexAIDriverOptions) {
|
|
26
40
|
super(options);
|
|
27
|
-
|
|
41
|
+
|
|
42
|
+
this.anthropicClient = undefined;
|
|
43
|
+
|
|
44
|
+
this.authClient = options.googleAuthOptions?.authClient ?? new GoogleAuth(options.googleAuthOptions);
|
|
45
|
+
|
|
28
46
|
this.vertexai = new VertexAI({
|
|
29
47
|
project: this.options.project,
|
|
30
48
|
location: this.options.region,
|
|
@@ -35,54 +53,113 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Genera
|
|
|
35
53
|
project: this.options.project,
|
|
36
54
|
}).withAuthCallback(async () => {
|
|
37
55
|
//@ts-ignore
|
|
38
|
-
const token = await this.
|
|
56
|
+
const token = await this.authClient.getAccessToken();
|
|
39
57
|
return `Bearer ${token}`;
|
|
40
58
|
});
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
59
|
+
this.aiplatform = new v1beta1.ModelServiceClient({
|
|
60
|
+
projectId: this.options.project,
|
|
61
|
+
apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
|
|
62
|
+
});
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
public getAnthropicClient() : AnthropicVertex {
|
|
66
|
+
//Lazy initialisation
|
|
67
|
+
if (!this.anthropicClient) {
|
|
68
|
+
this.anthropicClient = new AnthropicVertex({region: "us-east5", projectId: process.env.GOOGLE_PROJECT_ID});
|
|
69
|
+
}
|
|
70
|
+
return this.anthropicClient;
|
|
45
71
|
}
|
|
46
72
|
|
|
47
73
|
protected canStream(options: ExecutionOptions): Promise<boolean> {
|
|
74
|
+
if (options.output_modality == Modalities.image) {
|
|
75
|
+
return Promise.resolve(false);
|
|
76
|
+
}
|
|
48
77
|
return Promise.resolve(getModelDefinition(options.model).model.can_stream === true);
|
|
49
78
|
}
|
|
50
79
|
|
|
51
|
-
public createPrompt(segments: PromptSegment[], options:
|
|
80
|
+
public createPrompt(segments: PromptSegment[], options: ExecutionOptions): Promise<VertexAIPrompt> {
|
|
81
|
+
if (options.model.includes("imagen")) {
|
|
82
|
+
return new ImagenModelDefinition(options.model).createPrompt(this, segments, options);
|
|
83
|
+
}
|
|
52
84
|
return getModelDefinition(options.model).createPrompt(this, segments, options);
|
|
53
85
|
}
|
|
54
86
|
|
|
55
|
-
async
|
|
56
|
-
return getModelDefinition(options.model).
|
|
87
|
+
async requestTextCompletion(prompt: VertexAIPrompt, options: ExecutionOptions): Promise<Completion<any>> {
|
|
88
|
+
return getModelDefinition(options.model).requestTextCompletion(this, prompt, options);
|
|
57
89
|
}
|
|
58
|
-
async
|
|
59
|
-
return getModelDefinition(options.model).
|
|
90
|
+
async requestTextCompletionStream(prompt: VertexAIPrompt, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
|
|
91
|
+
return getModelDefinition(options.model).requestTextCompletionStream(this, prompt, options);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
async requestImageGeneration(_prompt: ImagenPrompt, _options: ExecutionOptions): Promise <Completion<ImageGeneration>> {
|
|
95
|
+
const splits = _options.model.split("/");
|
|
96
|
+
const modelName = trimModelName(splits[splits.length - 1]);
|
|
97
|
+
return new ImagenModelDefinition(modelName).requestImageGeneration(this, _prompt, _options);
|
|
60
98
|
}
|
|
61
99
|
|
|
62
100
|
async listModels(_params?: ModelSearchPayload): Promise<AIModel<string>[]> {
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
//
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
101
|
+
let models: AIModel<string>[] = [];
|
|
102
|
+
const modelGarden = new v1beta1.ModelGardenServiceClient({
|
|
103
|
+
projectId: this.options.project,
|
|
104
|
+
apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
|
|
105
|
+
});
|
|
106
|
+
|
|
107
|
+
//Project specific deployed models
|
|
108
|
+
const [response] = await this.aiplatform.listModels({
|
|
109
|
+
parent: `projects/${this.options.project}/locations/${this.options.region}`,
|
|
110
|
+
});
|
|
111
|
+
models = models.concat(response.map(model => ({
|
|
112
|
+
id: model.name?.split('/').pop() ?? '',
|
|
113
|
+
name: model.displayName ?? '',
|
|
114
|
+
provider: 'vertexai',
|
|
115
|
+
})));
|
|
116
|
+
|
|
117
|
+
//Model Garden Publisher models - Pretrained models
|
|
118
|
+
const publishers = ['google', 'anthropic']
|
|
119
|
+
const supportedModels = {google: ['gemini','imagen'], anthropic: ['claude']}
|
|
120
|
+
|
|
121
|
+
for (const publisher of publishers) {
|
|
122
|
+
const [response] = await modelGarden.listPublisherModels({
|
|
123
|
+
parent: `publishers/${publisher}`,
|
|
124
|
+
orderBy: 'name',
|
|
125
|
+
//filter: `name eq name`,
|
|
126
|
+
listAllVersions: true,
|
|
127
|
+
});
|
|
128
|
+
|
|
129
|
+
models = models.concat(response.map(model => ({
|
|
130
|
+
id: model.name ?? '',
|
|
131
|
+
name: model.name?.split('/').pop() ?? '',
|
|
132
|
+
provider: 'vertexai',
|
|
133
|
+
owner: publisher,
|
|
134
|
+
})).filter(model => {
|
|
135
|
+
const modelFamily = supportedModels[publisher as keyof typeof supportedModels];
|
|
136
|
+
for (const family of modelFamily) {
|
|
137
|
+
if (model.name.includes(family)) {
|
|
138
|
+
return true;
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
}));
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
return models;
|
|
78
145
|
}
|
|
79
146
|
|
|
80
147
|
validateConnection(): Promise<boolean> {
|
|
81
148
|
throw new Error("Method not implemented.");
|
|
82
149
|
}
|
|
83
150
|
|
|
84
|
-
async generateEmbeddings(options:
|
|
85
|
-
|
|
151
|
+
async generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult> {
|
|
152
|
+
if (options.image || options.model?.includes("multimodal")) {
|
|
153
|
+
if (options.text && options.image) {
|
|
154
|
+
throw new Error("Text and Image simultaneous embedding not implemented. Submit seperately");
|
|
155
|
+
}
|
|
156
|
+
return getEmbeddingsForImages(this, options);
|
|
157
|
+
}
|
|
158
|
+
const text_options: TextEmbeddingsOptions = {
|
|
159
|
+
content: options.text ?? '',
|
|
160
|
+
model: options.model,
|
|
161
|
+
}
|
|
162
|
+
return getEmbeddingsForText(this, text_options);
|
|
86
163
|
}
|
|
87
164
|
|
|
88
165
|
}
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import * as AnthropicAPI from '@anthropic-ai/sdk';
|
|
2
|
+
import { ContentBlock, Message, TextBlockParam } from "@anthropic-ai/sdk/resources/index.js";
|
|
3
|
+
import { AIModel, Completion, CompletionChunkObject, ExecutionOptions, JSONObject, ModelType, PromptOptions, PromptRole, PromptSegment, ToolUse } from "@llumiverse/core";
|
|
4
|
+
import { asyncMap } from "@llumiverse/core/async";
|
|
5
|
+
import { VertexAIClaudeOptions } from "../../../../core/src/options/vertexai.js";
|
|
6
|
+
import { VertexAIDriver } from "../index.js";
|
|
7
|
+
import { ModelDefinition } from "../models.js";
|
|
8
|
+
|
|
9
|
+
type MessageParam = AnthropicAPI.Anthropic.MessageParam;
|
|
10
|
+
|
|
11
|
+
interface ClaudePrompt {
|
|
12
|
+
messages: MessageParam[];
|
|
13
|
+
system: TextBlockParam[];
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
function getFullModelName(model: string): string {
|
|
17
|
+
if (model.includes("claude-3-5-sonnet-v2")) {
|
|
18
|
+
return "claude-3-5-sonnet-v2@20241022"
|
|
19
|
+
} else if (model.includes("claude-3-5-sonnet")) {
|
|
20
|
+
return "claude-3-5-sonnet@20240620"
|
|
21
|
+
} else if (model.includes("claude-3-5-haiku")) {
|
|
22
|
+
return "claude-3-5-haiku@20241022"
|
|
23
|
+
} else if (model.includes("claude-3-opus")) {
|
|
24
|
+
return "claude-3-opus@20240229"
|
|
25
|
+
} else if (model.includes("claude-3-sonnet")) {
|
|
26
|
+
return "claude-3-sonnet@20240229"
|
|
27
|
+
} else if (model.includes("claude-3-haike")) {
|
|
28
|
+
return "claude-3-haiku@20240307"
|
|
29
|
+
}
|
|
30
|
+
return model;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
function claudeFinishReason(reason: string | undefined) {
|
|
34
|
+
if (!reason) return undefined;
|
|
35
|
+
switch (reason) {
|
|
36
|
+
case 'end_turn': return "stop";
|
|
37
|
+
case 'max_tokens': return "length";
|
|
38
|
+
default: return reason; //stop_sequence
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
function collectTextParts(content: any) {
|
|
43
|
+
const out = [];
|
|
44
|
+
|
|
45
|
+
for (const block of content) {
|
|
46
|
+
if (block?.text) {
|
|
47
|
+
out.push(block.text);
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
return out.join('\n');
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
function maxToken(max_tokens: number | undefined, model: string): number {
|
|
54
|
+
const contains = (str: string, substr: string) => str.indexOf(substr) !== -1;
|
|
55
|
+
if (max_tokens) {
|
|
56
|
+
return max_tokens;
|
|
57
|
+
} else if (contains(model, "claude-3-5")) {
|
|
58
|
+
return 8192;
|
|
59
|
+
} else {
|
|
60
|
+
return 4096
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
|
|
65
|
+
|
|
66
|
+
model: AIModel
|
|
67
|
+
|
|
68
|
+
constructor(modelId: string) {
|
|
69
|
+
this.model = {
|
|
70
|
+
id: modelId,
|
|
71
|
+
name: modelId,
|
|
72
|
+
provider: 'vertexai',
|
|
73
|
+
type: ModelType.Text,
|
|
74
|
+
can_stream: true,
|
|
75
|
+
} as AIModel;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
async createPrompt(_driver: VertexAIDriver, segments: PromptSegment[], options: PromptOptions): Promise<ClaudePrompt> {
|
|
79
|
+
// Convert the prompt to the format expected by the Claude API
|
|
80
|
+
const systemSegments: TextBlockParam[] = segments
|
|
81
|
+
.filter(segment => segment.role === PromptRole.system)
|
|
82
|
+
.map(segment => ({
|
|
83
|
+
text: segment.content,
|
|
84
|
+
type: 'text'
|
|
85
|
+
}));
|
|
86
|
+
|
|
87
|
+
const safetySegments: TextBlockParam[] = segments
|
|
88
|
+
.filter(segment => segment.role === PromptRole.safety)
|
|
89
|
+
.map(segment => ({
|
|
90
|
+
text: segment.content,
|
|
91
|
+
type: 'text'
|
|
92
|
+
}));
|
|
93
|
+
|
|
94
|
+
if (options.result_schema) {
|
|
95
|
+
const schemaSegments: TextBlockParam = {
|
|
96
|
+
text: "The answer must be a JSON object using the following JSON Schema:\n" + JSON.stringify(options.result_schema),
|
|
97
|
+
type: 'text'
|
|
98
|
+
}
|
|
99
|
+
safetySegments.push(schemaSegments);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
const messages: MessageParam[] = segments.filter(segment =>
|
|
103
|
+
segment.role == PromptRole.user
|
|
104
|
+
|| segment.role == PromptRole.assistant
|
|
105
|
+
|| segment.role === PromptRole.tool)
|
|
106
|
+
.map(segment => {
|
|
107
|
+
if (segment.role === PromptRole.tool) {
|
|
108
|
+
if (!segment.tool_use_id) {
|
|
109
|
+
throw new Error("Tool prompt segment must have a tool_use_id");
|
|
110
|
+
}
|
|
111
|
+
return {
|
|
112
|
+
role: 'user',
|
|
113
|
+
content: [
|
|
114
|
+
{
|
|
115
|
+
type: 'tool_result',
|
|
116
|
+
tool_use_id: segment.tool_use_id,
|
|
117
|
+
content: segment.content || undefined
|
|
118
|
+
}
|
|
119
|
+
]
|
|
120
|
+
}
|
|
121
|
+
} else {
|
|
122
|
+
return {
|
|
123
|
+
role: segment.role !== PromptRole.user ? 'assistant' : 'user',
|
|
124
|
+
content: segment.content
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
});
|
|
128
|
+
|
|
129
|
+
const system = systemSegments.concat(safetySegments);
|
|
130
|
+
|
|
131
|
+
return {
|
|
132
|
+
messages: messages,
|
|
133
|
+
system: system
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
async requestTextCompletion(driver: VertexAIDriver, prompt: ClaudePrompt, options: ExecutionOptions): Promise<Completion> {
|
|
138
|
+
const client = driver.getAnthropicClient();
|
|
139
|
+
const splits = options.model.split("/");
|
|
140
|
+
const modelName = splits[splits.length - 1];
|
|
141
|
+
options = { ...options, model: modelName };
|
|
142
|
+
options.model_options = options.model_options as VertexAIClaudeOptions;
|
|
143
|
+
|
|
144
|
+
if (options.model_options?._option_id !== "vertexai-claude") {
|
|
145
|
+
driver.logger.warn("Invalid model options", { options: options.model_options });
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
let conversation = updateConversation(options.conversation as ClaudePrompt, prompt);
|
|
149
|
+
|
|
150
|
+
const result = await client.messages.create({
|
|
151
|
+
...conversation, // messages, system,
|
|
152
|
+
tools: options.tools, // we are using the same shape as claude for tools
|
|
153
|
+
temperature: options.model_options?.temperature,
|
|
154
|
+
model: modelName,
|
|
155
|
+
max_tokens: maxToken(options.model_options?.max_tokens, modelName),
|
|
156
|
+
top_p: options.model_options?.top_p,
|
|
157
|
+
top_k: options.model_options?.top_k,
|
|
158
|
+
stop_sequences: options.model_options?.stop_sequence,
|
|
159
|
+
thinking: options.model_options?.thinking_mode ?
|
|
160
|
+
{
|
|
161
|
+
budget_tokens: options.model_options?.thinking_budget_tokens ?? 1024,
|
|
162
|
+
type: "enabled"
|
|
163
|
+
} : {
|
|
164
|
+
type: "disabled"
|
|
165
|
+
}
|
|
166
|
+
}) as Message;
|
|
167
|
+
|
|
168
|
+
const text = collectTextParts(result.content);
|
|
169
|
+
const tool_use = collectTools(result.content);
|
|
170
|
+
|
|
171
|
+
conversation = updateConversation(options.conversation as ClaudePrompt, createPromptFromResponse(result));
|
|
172
|
+
|
|
173
|
+
return {
|
|
174
|
+
chat: [prompt, { role: result.role, content: result.content }],
|
|
175
|
+
result: text ?? '',
|
|
176
|
+
tool_use,
|
|
177
|
+
token_usage: {
|
|
178
|
+
prompt: result?.usage.input_tokens,
|
|
179
|
+
result: result?.usage.output_tokens,
|
|
180
|
+
total: result?.usage.input_tokens + result?.usage.output_tokens
|
|
181
|
+
},
|
|
182
|
+
// make sure we set finish_reason to the correct value (claude is normally setting this by itself)
|
|
183
|
+
finish_reason: tool_use ? "tool_use" : claudeFinishReason(result?.stop_reason ?? ''),
|
|
184
|
+
conversation
|
|
185
|
+
} as Completion;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
async requestTextCompletionStream(driver: VertexAIDriver, prompt: ClaudePrompt, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
|
|
189
|
+
const client = driver.getAnthropicClient();
|
|
190
|
+
const splits = options.model.split("/");
|
|
191
|
+
const modelName = splits[splits.length - 1];
|
|
192
|
+
options = { ...options, model: modelName };
|
|
193
|
+
options.model_options = options.model_options as VertexAIClaudeOptions;
|
|
194
|
+
|
|
195
|
+
if (options.model_options?._option_id !== "vertexai-claude") {
|
|
196
|
+
driver.logger.warn("Invalid model options", { options: options.model_options });
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
const response_stream = await client.messages.stream({
|
|
200
|
+
...prompt, // messages, system,
|
|
201
|
+
tools: options.tools, // we are using the same shape as claude for tools
|
|
202
|
+
temperature: options.model_options?.temperature,
|
|
203
|
+
model: modelName,
|
|
204
|
+
max_tokens: maxToken(options.model_options?.max_tokens, modelName),
|
|
205
|
+
top_p: options.model_options?.top_p,
|
|
206
|
+
top_k: options.model_options?.top_k,
|
|
207
|
+
stop_sequences: options.model_options?.stop_sequence,
|
|
208
|
+
thinking: options.model_options?.thinking_mode ?
|
|
209
|
+
{
|
|
210
|
+
budget_tokens: options.model_options?.thinking_budget_tokens ?? 1024,
|
|
211
|
+
type: "enabled"
|
|
212
|
+
} : {
|
|
213
|
+
type: "disabled"
|
|
214
|
+
}
|
|
215
|
+
});
|
|
216
|
+
|
|
217
|
+
//Streaming does not give information on the input tokens,
|
|
218
|
+
//So we use a seperate call to get the input tokens.
|
|
219
|
+
//Non-critical and model name sensitive so we put it in a try catch block
|
|
220
|
+
let count_tokens = { input_tokens: 0 };
|
|
221
|
+
try {
|
|
222
|
+
count_tokens = await client.messages.countTokens({
|
|
223
|
+
...prompt, // messages, system
|
|
224
|
+
model: getFullModelName(modelName),
|
|
225
|
+
});
|
|
226
|
+
} catch (e) {
|
|
227
|
+
driver.logger.warn("Failed to get token count for model " + modelName);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
const stream = asyncMap(response_stream, async (item: any) => {
|
|
231
|
+
return {
|
|
232
|
+
result: item?.delta?.text ?? '',
|
|
233
|
+
token_usage: { prompt: count_tokens.input_tokens, result: item?.usage?.output_tokens },
|
|
234
|
+
finish_reason: claudeFinishReason(item?.delta?.stop_reason ?? ''),
|
|
235
|
+
}
|
|
236
|
+
});
|
|
237
|
+
|
|
238
|
+
return stream;
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
export function collectTools(content: ContentBlock[]): ToolUse[] | undefined {
|
|
243
|
+
const out: ToolUse[] = [];
|
|
244
|
+
|
|
245
|
+
for (const block of content) {
|
|
246
|
+
if (block?.type === "tool_use") {
|
|
247
|
+
out.push({
|
|
248
|
+
id: block.id,
|
|
249
|
+
name: block.name,
|
|
250
|
+
input: block.input as JSONObject,
|
|
251
|
+
});
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
return out.length > 0 ? out : undefined;
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
function createPromptFromResponse(response: Message): ClaudePrompt {
|
|
259
|
+
return {
|
|
260
|
+
messages: [{
|
|
261
|
+
role: PromptRole.assistant,
|
|
262
|
+
content: response.content,
|
|
263
|
+
}],
|
|
264
|
+
system: []
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
/**
|
|
269
|
+
* Update the converatation messages
|
|
270
|
+
* @param prompt
|
|
271
|
+
* @param response
|
|
272
|
+
* @returns
|
|
273
|
+
*/
|
|
274
|
+
function updateConversation(conversation: ClaudePrompt | undefined | null, prompt: ClaudePrompt): ClaudePrompt {
|
|
275
|
+
const baseSystemMessages = conversation ? conversation.system : [];
|
|
276
|
+
const baseMessages = conversation ? conversation.messages : []
|
|
277
|
+
return {
|
|
278
|
+
messages: baseMessages.concat(prompt.messages || []),
|
|
279
|
+
system: baseSystemMessages.concat(prompt.system || [])
|
|
280
|
+
};
|
|
281
|
+
}
|