@llumiverse/drivers 0.22.1 → 0.23.0-dev-20251118
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/package.json +23 -18
- package/src/adobe/firefly.ts +2 -2
- package/src/azure/azure_foundry.ts +11 -11
- package/src/bedrock/index.ts +4 -4
- package/src/groq/index.ts +2 -2
- package/src/huggingface_ie.ts +9 -10
- package/src/index.ts +2 -2
- package/src/mistral/index.ts +2 -2
- package/src/openai/azure_openai.ts +5 -5
- package/src/openai/index.ts +3 -3
- package/src/replicate.ts +7 -7
- package/src/togetherai/index.ts +2 -2
- package/src/vertexai/index.ts +41 -37
- package/src/vertexai/models/claude.ts +4 -4
- package/src/vertexai/models/imagen.ts +4 -4
- package/src/watsonx/index.ts +3 -3
- /package/src/{test → test-driver}/TestErrorCompletionStream.ts +0 -0
- /package/src/{test → test-driver}/TestValidationErrorCompletionStream.ts +0 -0
- /package/src/{test → test-driver}/index.ts +0 -0
- /package/src/{test → test-driver}/utils.ts +0 -0
package/README.md
CHANGED
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@llumiverse/drivers",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.23.0-dev-20251118",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"description": "LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.",
|
|
6
6
|
"files": [
|
|
@@ -47,41 +47,46 @@
|
|
|
47
47
|
],
|
|
48
48
|
"devDependencies": {
|
|
49
49
|
"dotenv": "^16.6.1",
|
|
50
|
-
"rimraf": "^6.0
|
|
50
|
+
"rimraf": "^6.1.0",
|
|
51
51
|
"ts-dual-module": "^0.6.3",
|
|
52
|
-
"typescript": "^5.9.
|
|
52
|
+
"typescript": "^5.9.3",
|
|
53
53
|
"vitest": "^3.2.4"
|
|
54
54
|
},
|
|
55
55
|
"dependencies": {
|
|
56
|
-
"@anthropic-ai/sdk": "^0.
|
|
56
|
+
"@anthropic-ai/sdk": "^0.68.0",
|
|
57
57
|
"@anthropic-ai/vertex-sdk": "^0.14.0",
|
|
58
|
-
"@aws-sdk/client-bedrock": "^3.
|
|
59
|
-
"@aws-sdk/client-bedrock-runtime": "^3.
|
|
60
|
-
"@aws-sdk/client-s3": "^3.
|
|
61
|
-
"@aws-sdk/credential-providers": "^3.
|
|
62
|
-
"@aws-sdk/lib-storage": "^3.
|
|
63
|
-
"@aws-sdk/types": "^3.
|
|
58
|
+
"@aws-sdk/client-bedrock": "^3.922.0",
|
|
59
|
+
"@aws-sdk/client-bedrock-runtime": "^3.922.0",
|
|
60
|
+
"@aws-sdk/client-s3": "^3.922.0",
|
|
61
|
+
"@aws-sdk/credential-providers": "^3.922.0",
|
|
62
|
+
"@aws-sdk/lib-storage": "^3.922.0",
|
|
63
|
+
"@aws-sdk/types": "^3.922.0",
|
|
64
64
|
"@azure-rest/ai-inference": "1.0.0-beta.6",
|
|
65
65
|
"@azure/ai-projects": "1.0.0-beta.10",
|
|
66
|
-
"@azure/core-auth": "^1.10.
|
|
66
|
+
"@azure/core-auth": "^1.10.1",
|
|
67
67
|
"@azure/core-sse": "^2.3.0",
|
|
68
|
-
"@azure/identity": "^4.
|
|
68
|
+
"@azure/identity": "^4.13.0",
|
|
69
69
|
"@azure/openai": "2.0.0",
|
|
70
|
-
"@google-cloud/aiplatform": "^
|
|
71
|
-
"@google/genai": "^1.
|
|
72
|
-
"@huggingface/inference": "
|
|
70
|
+
"@google-cloud/aiplatform": "^5.12.0",
|
|
71
|
+
"@google/genai": "^1.28.0",
|
|
72
|
+
"@huggingface/inference": "4.13.0",
|
|
73
73
|
"@llumiverse/common": "workspace:*",
|
|
74
74
|
"@llumiverse/core": "workspace:*",
|
|
75
|
-
"@vertesia/api-fetch-client": "^0.
|
|
75
|
+
"@vertesia/api-fetch-client": "^0.79.0",
|
|
76
76
|
"eventsource": "^4.0.0",
|
|
77
|
-
"google-auth-library": "^
|
|
77
|
+
"google-auth-library": "^10.5.0",
|
|
78
78
|
"groq-sdk": "^0.34.0",
|
|
79
79
|
"mnemonist": "^0.40.3",
|
|
80
80
|
"node-web-stream-adapters": "^0.2.1",
|
|
81
81
|
"openai": "^4.104.0",
|
|
82
|
-
"replicate": "^1.
|
|
82
|
+
"replicate": "^1.3.1"
|
|
83
83
|
},
|
|
84
84
|
"ts_dual_module": {
|
|
85
85
|
"outDir": "lib"
|
|
86
|
+
},
|
|
87
|
+
"pnpm": {
|
|
88
|
+
"overrides": {
|
|
89
|
+
"google-auth-library": "^10.5.0"
|
|
90
|
+
}
|
|
86
91
|
}
|
|
87
92
|
}
|
package/src/adobe/firefly.ts
CHANGED
|
@@ -134,7 +134,7 @@ export class FireflyDriver extends AbstractDriver<FireflyDriverOptions> {
|
|
|
134
134
|
};
|
|
135
135
|
|
|
136
136
|
} catch (error: any) {
|
|
137
|
-
this.logger.error("[Firefly] Image generation failed"
|
|
137
|
+
this.logger.error({ error }, "[Firefly] Image generation failed");
|
|
138
138
|
return {
|
|
139
139
|
result: [],
|
|
140
140
|
error: {
|
|
@@ -188,7 +188,7 @@ export class FireflyDriver extends AbstractDriver<FireflyDriverOptions> {
|
|
|
188
188
|
});
|
|
189
189
|
return response.ok;
|
|
190
190
|
} catch (error) {
|
|
191
|
-
this.logger.error("[Firefly] Connection validation failed"
|
|
191
|
+
this.logger.error({ error }, "[Firefly] Connection validation failed");
|
|
192
192
|
return false;
|
|
193
193
|
}
|
|
194
194
|
}
|
|
@@ -54,7 +54,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
54
54
|
opts.azureADTokenProvider = new DefaultAzureCredential();
|
|
55
55
|
}
|
|
56
56
|
} catch (error) {
|
|
57
|
-
this.logger.error("Failed to initialize Azure AD token provider:"
|
|
57
|
+
this.logger.error({ error }, "Failed to initialize Azure AD token provider:");
|
|
58
58
|
throw new Error("Failed to initialize Azure AD token provider");
|
|
59
59
|
}
|
|
60
60
|
|
|
@@ -89,7 +89,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
89
89
|
deployment = await this.service.deployments.get(deploymentName);
|
|
90
90
|
this.logger.debug(`[Azure Foundry] Deployment ${deploymentName} found`);
|
|
91
91
|
} catch (deploymentError) {
|
|
92
|
-
this.logger.error(`[Azure Foundry] Deployment ${deploymentName} not found
|
|
92
|
+
this.logger.error({ deploymentError }, `[Azure Foundry] Deployment ${deploymentName} not found:`);
|
|
93
93
|
}
|
|
94
94
|
|
|
95
95
|
return (deployment as ModelDeployment).modelPublisher == "OpenAI";
|
|
@@ -131,7 +131,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
131
131
|
}
|
|
132
132
|
});
|
|
133
133
|
if (response.status !== "200") {
|
|
134
|
-
this.logger.error(`[Azure Foundry] Chat completion request failed
|
|
134
|
+
this.logger.error({ response }, `[Azure Foundry] Chat completion request failed:`);
|
|
135
135
|
throw new Error(`Chat completion request failed with status ${response.status}: ${response.body}`);
|
|
136
136
|
}
|
|
137
137
|
|
|
@@ -213,12 +213,12 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
213
213
|
|
|
214
214
|
yield chunk;
|
|
215
215
|
} catch (parseError) {
|
|
216
|
-
this.logger.warn(`[Azure Foundry] Failed to parse streaming response
|
|
216
|
+
this.logger.warn({ parseError }, `[Azure Foundry] Failed to parse streaming response:`);
|
|
217
217
|
continue;
|
|
218
218
|
}
|
|
219
219
|
}
|
|
220
220
|
} catch (error) {
|
|
221
|
-
this.logger.error(`[Azure Foundry] Streaming error
|
|
221
|
+
this.logger.error({ error }, `[Azure Foundry] Streaming error:`);
|
|
222
222
|
throw error;
|
|
223
223
|
}
|
|
224
224
|
}
|
|
@@ -233,7 +233,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
233
233
|
|
|
234
234
|
const choice = result.choices?.[0];
|
|
235
235
|
if (!choice) {
|
|
236
|
-
this.logger
|
|
236
|
+
this.logger.error({ result }, "[Azure Foundry] No choices in response");
|
|
237
237
|
throw new Error("No choices in response");
|
|
238
238
|
}
|
|
239
239
|
|
|
@@ -241,7 +241,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
241
241
|
const toolCalls = choice.message?.tool_calls;
|
|
242
242
|
|
|
243
243
|
if (!data && !toolCalls) {
|
|
244
|
-
this.logger
|
|
244
|
+
this.logger.error({ result }, "[Azure Foundry] Response is not valid");
|
|
245
245
|
throw new Error("Response is not valid: no content or tool calls");
|
|
246
246
|
}
|
|
247
247
|
|
|
@@ -291,7 +291,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
291
291
|
|
|
292
292
|
return true;
|
|
293
293
|
} catch (error) {
|
|
294
|
-
this.logger.error("Azure Foundry connection validation failed:"
|
|
294
|
+
this.logger.error({ error }, "Azure Foundry connection validation failed:");
|
|
295
295
|
return false;
|
|
296
296
|
}
|
|
297
297
|
}
|
|
@@ -328,7 +328,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
328
328
|
}
|
|
329
329
|
});
|
|
330
330
|
} catch (error) {
|
|
331
|
-
this.logger.error("Azure Foundry text embeddings error:"
|
|
331
|
+
this.logger.error({ error }, "Azure Foundry text embeddings error:");
|
|
332
332
|
throw error;
|
|
333
333
|
}
|
|
334
334
|
|
|
@@ -365,7 +365,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
365
365
|
}
|
|
366
366
|
});
|
|
367
367
|
} catch (error) {
|
|
368
|
-
this.logger.error("Azure Foundry image embeddings error:"
|
|
368
|
+
this.logger.error({ error }, "Azure Foundry image embeddings error:");
|
|
369
369
|
throw error;
|
|
370
370
|
}
|
|
371
371
|
if (isUnexpected(response)) {
|
|
@@ -395,7 +395,7 @@ export class AzureFoundryDriver extends AbstractDriver<AzureFoundryDriverOptions
|
|
|
395
395
|
// List all deployments in the Azure AI Foundry project
|
|
396
396
|
deploymentsIterable = this.service.deployments.list();
|
|
397
397
|
} catch (error) {
|
|
398
|
-
this.logger.error("Failed to list deployments:"
|
|
398
|
+
this.logger.error({ error }, "Failed to list deployments:");
|
|
399
399
|
throw new Error("Failed to list deployments in Azure AI Foundry project");
|
|
400
400
|
}
|
|
401
401
|
const deployments: DeploymentUnion[] = [];
|
package/src/bedrock/index.ts
CHANGED
|
@@ -169,7 +169,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
169
169
|
const type = Object.keys(content).find(
|
|
170
170
|
key => key !== '$unknown' && content[key as keyof typeof content] !== undefined
|
|
171
171
|
);
|
|
172
|
-
this.logger.info("[Bedrock] Unsupported content response type:"
|
|
172
|
+
this.logger.info({ type }, "[Bedrock] Unsupported content response type:");
|
|
173
173
|
}
|
|
174
174
|
}
|
|
175
175
|
|
|
@@ -235,7 +235,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
235
235
|
const type = Object.keys(delta).find(
|
|
236
236
|
key => key !== '$unknown' && (delta as any)[key] !== undefined
|
|
237
237
|
);
|
|
238
|
-
this.logger.info("[Bedrock] Unsupported content response type:"
|
|
238
|
+
this.logger.info({ type }, "[Bedrock] Unsupported content response type:");
|
|
239
239
|
}
|
|
240
240
|
}
|
|
241
241
|
|
|
@@ -512,7 +512,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
512
512
|
});
|
|
513
513
|
|
|
514
514
|
}).catch((err) => {
|
|
515
|
-
this.logger.error("[Bedrock] Failed to stream"
|
|
515
|
+
this.logger.error({ error: err }, "[Bedrock] Failed to stream");
|
|
516
516
|
throw err;
|
|
517
517
|
});
|
|
518
518
|
}
|
|
@@ -673,7 +673,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
|
|
|
673
673
|
throw new Error(`Image generation requires image output_modality`);
|
|
674
674
|
}
|
|
675
675
|
if (options.model_options?._option_id !== "bedrock-nova-canvas") {
|
|
676
|
-
this.logger.warn(
|
|
676
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
677
677
|
}
|
|
678
678
|
const model_options = options.model_options as NovaCanvasOptions;
|
|
679
679
|
|
package/src/groq/index.ts
CHANGED
|
@@ -195,7 +195,7 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, ChatCompletion
|
|
|
195
195
|
|
|
196
196
|
async requestTextCompletion(messages: ChatCompletionMessageParam[], options: ExecutionOptions): Promise<Completion> {
|
|
197
197
|
if (options.model_options?._option_id !== "text-fallback" && options.model_options?._option_id !== "groq-deepseek-thinking") {
|
|
198
|
-
this.logger.warn(
|
|
198
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
199
199
|
}
|
|
200
200
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
201
201
|
|
|
@@ -251,7 +251,7 @@ export class GroqDriver extends AbstractDriver<GroqDriverOptions, ChatCompletion
|
|
|
251
251
|
|
|
252
252
|
async requestTextCompletionStream(messages: ChatCompletionMessageParam[], options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
|
|
253
253
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
254
|
-
this.logger.warn(
|
|
254
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
255
255
|
}
|
|
256
256
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
257
257
|
|
package/src/huggingface_ie.ts
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import {
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
TextGenerationStreamOutput
|
|
2
|
+
InferenceClient,
|
|
3
|
+
TextGenerationStreamOutput,
|
|
5
4
|
} from "@huggingface/inference";
|
|
6
5
|
import {
|
|
7
6
|
AIModel,
|
|
@@ -25,7 +24,7 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
|
|
|
25
24
|
static PROVIDER = "huggingface_ie";
|
|
26
25
|
provider = HuggingFaceIEDriver.PROVIDER;
|
|
27
26
|
service: FetchClient;
|
|
28
|
-
_executor?:
|
|
27
|
+
_executor?: InferenceClient;
|
|
29
28
|
|
|
30
29
|
constructor(
|
|
31
30
|
options: HuggingFaceIEDriverOptions
|
|
@@ -60,7 +59,8 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
|
|
|
60
59
|
`Endpoint ${model} is not running - current status: ${endpoint.status}`
|
|
61
60
|
);
|
|
62
61
|
|
|
63
|
-
|
|
62
|
+
// Use the new InferenceClient and bind it to the endpoint URL
|
|
63
|
+
this._executor = new InferenceClient(this.options.apiKey).endpoint(
|
|
64
64
|
endpoint.url
|
|
65
65
|
);
|
|
66
66
|
}
|
|
@@ -69,7 +69,7 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
|
|
|
69
69
|
|
|
70
70
|
async requestTextCompletionStream(prompt: string, options: ExecutionOptions) {
|
|
71
71
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
72
|
-
this.logger.warn(
|
|
72
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
73
73
|
}
|
|
74
74
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
75
75
|
|
|
@@ -82,8 +82,7 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
|
|
|
82
82
|
},
|
|
83
83
|
});
|
|
84
84
|
|
|
85
|
-
|
|
86
|
-
return transformAsyncIterator(req, (val: TextGenerationStreamOutput) => {
|
|
85
|
+
return transformAsyncIterator(req, (val: TextGenerationStreamOutput): CompletionChunkObject => {
|
|
87
86
|
//special like <s> are not part of the result
|
|
88
87
|
if (val.token.special) return { result: [] };
|
|
89
88
|
let finish_reason = val.details?.finish_reason as string;
|
|
@@ -96,13 +95,13 @@ export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptio
|
|
|
96
95
|
token_usage: {
|
|
97
96
|
result: val.details?.generated_tokens ?? 0,
|
|
98
97
|
}
|
|
99
|
-
}
|
|
98
|
+
};
|
|
100
99
|
});
|
|
101
100
|
}
|
|
102
101
|
|
|
103
102
|
async requestTextCompletion(prompt: string, options: ExecutionOptions) {
|
|
104
103
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
105
|
-
this.logger.warn(
|
|
104
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
106
105
|
}
|
|
107
106
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
108
107
|
|
package/src/index.ts
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
|
+
export * from "./azure/azure_foundry.js";
|
|
1
2
|
export * from "./bedrock/index.js";
|
|
2
3
|
export * from "./groq/index.js";
|
|
3
4
|
export * from "./huggingface_ie.js";
|
|
4
5
|
export * from "./mistral/index.js";
|
|
5
6
|
export * from "./openai/azure_openai.js";
|
|
6
|
-
export * from "./azure/azure_foundry.js";
|
|
7
7
|
export * from "./openai/openai.js";
|
|
8
8
|
export * from "./replicate.js";
|
|
9
|
-
export * from "./test/index.js";
|
|
9
|
+
export * from "./test-driver/index.js";
|
|
10
10
|
export * from "./togetherai/index.js";
|
|
11
11
|
export * from "./vertexai/index.js";
|
|
12
12
|
export * from "./watsonx/index.js";
|
package/src/mistral/index.ts
CHANGED
|
@@ -64,7 +64,7 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
|
|
|
64
64
|
|
|
65
65
|
async requestTextCompletion(messages: OpenAITextMessage[], options: ExecutionOptions): Promise<Completion> {
|
|
66
66
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
67
|
-
this.logger.warn(
|
|
67
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
68
68
|
}
|
|
69
69
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
70
70
|
|
|
@@ -95,7 +95,7 @@ export class MistralAIDriver extends AbstractDriver<MistralAIDriverOptions, Open
|
|
|
95
95
|
|
|
96
96
|
async requestTextCompletionStream(messages: OpenAITextMessage[], options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
|
|
97
97
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
98
|
-
this.logger.warn(
|
|
98
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
99
99
|
}
|
|
100
100
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
101
101
|
|
|
@@ -9,7 +9,7 @@ export interface AzureOpenAIDriverOptions extends DriverOptions {
|
|
|
9
9
|
* The credentials to use to access Azure OpenAI
|
|
10
10
|
*/
|
|
11
11
|
azureADTokenProvider?: any; //type with azure credentials
|
|
12
|
-
|
|
12
|
+
|
|
13
13
|
apiKey?: string;
|
|
14
14
|
|
|
15
15
|
endpoint?: string;
|
|
@@ -41,7 +41,7 @@ export class AzureOpenAIDriver extends BaseOpenAIDriver {
|
|
|
41
41
|
|
|
42
42
|
this.service = new AzureOpenAI({
|
|
43
43
|
apiKey: opts.apiKey,
|
|
44
|
-
azureADTokenProvider: opts.azureADTokenProvider,
|
|
44
|
+
azureADTokenProvider: opts.azureADTokenProvider,
|
|
45
45
|
endpoint: opts.endpoint,
|
|
46
46
|
apiVersion: opts.apiVersion ?? "2024-10-21",
|
|
47
47
|
deployment: opts.deployment
|
|
@@ -56,7 +56,7 @@ export class AzureOpenAIDriver extends BaseOpenAIDriver {
|
|
|
56
56
|
const azureADTokenProvider = getBearerTokenProvider(new DefaultAzureCredential(), scope);
|
|
57
57
|
return azureADTokenProvider;
|
|
58
58
|
}
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
async listModels(): Promise<AIModel[]> {
|
|
61
61
|
return this._listModels();
|
|
62
62
|
}
|
|
@@ -65,7 +65,7 @@ export class AzureOpenAIDriver extends BaseOpenAIDriver {
|
|
|
65
65
|
if (!this.service.deploymentName) {
|
|
66
66
|
throw new Error("A specific deployment is not set. Azure OpenAI cannot list deployments. Update your endpoint URL to include the deployment name, e.g., https://your-resource.openai.azure.com/openai/deployments/your-deployment/chat/completions");
|
|
67
67
|
}
|
|
68
|
-
|
|
68
|
+
|
|
69
69
|
//Do a test execution to check if the model works and to get the model ID.
|
|
70
70
|
let modelID = this.service.deploymentName;
|
|
71
71
|
try {
|
|
@@ -76,7 +76,7 @@ export class AzureOpenAIDriver extends BaseOpenAIDriver {
|
|
|
76
76
|
});
|
|
77
77
|
modelID = testResponse.model;
|
|
78
78
|
} catch (error) {
|
|
79
|
-
this.logger.error("Failed to test model for Azure OpenAI listing :"
|
|
79
|
+
this.logger.error({ error }, "Failed to test model for Azure OpenAI listing :");
|
|
80
80
|
}
|
|
81
81
|
const modelCapability = getModelCapabilities(modelID, "openai");
|
|
82
82
|
return [{
|
package/src/openai/index.ts
CHANGED
|
@@ -76,7 +76,7 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
|
|
|
76
76
|
const data = choice.message.content ?? undefined;
|
|
77
77
|
|
|
78
78
|
if (!data && !tools) {
|
|
79
|
-
this.logger
|
|
79
|
+
this.logger.error({ result }, "[OpenAI] Response is not valid");
|
|
80
80
|
throw new Error("Response is not valid: no data");
|
|
81
81
|
}
|
|
82
82
|
|
|
@@ -90,7 +90,7 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
|
|
|
90
90
|
|
|
91
91
|
async requestTextCompletionStream(prompt: ChatCompletionMessageParam[], options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
|
|
92
92
|
if (options.model_options?._option_id !== "openai-text" && options.model_options?._option_id !== "openai-thinking") {
|
|
93
|
-
this.logger.warn(
|
|
93
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
94
94
|
}
|
|
95
95
|
|
|
96
96
|
const toolDefs = getToolDefinitions(options.tools);
|
|
@@ -163,7 +163,7 @@ export abstract class BaseOpenAIDriver extends AbstractDriver<
|
|
|
163
163
|
|
|
164
164
|
async requestTextCompletion(prompt: ChatCompletionMessageParam[], options: ExecutionOptions): Promise<Completion> {
|
|
165
165
|
if (options.model_options?._option_id !== "openai-text" && options.model_options?._option_id !== "openai-thinking") {
|
|
166
|
-
this.logger.warn(
|
|
166
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
167
167
|
}
|
|
168
168
|
|
|
169
169
|
convertRoles(prompt, options.model);
|
package/src/replicate.ts
CHANGED
|
@@ -15,7 +15,7 @@ import {
|
|
|
15
15
|
} from "@llumiverse/core";
|
|
16
16
|
import { EventStream } from "@llumiverse/core/async";
|
|
17
17
|
import { EventSource } from "eventsource";
|
|
18
|
-
import Replicate, { Prediction } from "replicate";
|
|
18
|
+
import Replicate, { Prediction, Training } from "replicate";
|
|
19
19
|
|
|
20
20
|
let cachedTrainableModels: AIModel[] | undefined;
|
|
21
21
|
let cachedTrainableModelsTimestamp: number = 0;
|
|
@@ -66,7 +66,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
66
66
|
|
|
67
67
|
async requestTextCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
|
|
68
68
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
69
|
-
this.logger.warn(
|
|
69
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
70
70
|
}
|
|
71
71
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
72
72
|
|
|
@@ -88,7 +88,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
88
88
|
|
|
89
89
|
const source = new EventSource(prediction.urls.stream!);
|
|
90
90
|
source.addEventListener("output", (e: any) => {
|
|
91
|
-
stream.push({result: [{ type: "text", value: e.data }] });
|
|
91
|
+
stream.push({ result: [{ type: "text", value: e.data }] });
|
|
92
92
|
});
|
|
93
93
|
source.addEventListener("error", (e: any) => {
|
|
94
94
|
let error: any;
|
|
@@ -97,7 +97,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
97
97
|
} catch (error) {
|
|
98
98
|
error = JSON.stringify(e);
|
|
99
99
|
}
|
|
100
|
-
this.logger
|
|
100
|
+
this.logger.error({ e, error }, "Error in SSE stream");
|
|
101
101
|
});
|
|
102
102
|
source.addEventListener("done", () => {
|
|
103
103
|
try {
|
|
@@ -111,7 +111,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
111
111
|
|
|
112
112
|
async requestTextCompletion(prompt: string, options: ExecutionOptions) {
|
|
113
113
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
114
|
-
this.logger.warn(
|
|
114
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
115
115
|
}
|
|
116
116
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
117
117
|
const model = ReplicateDriver.parseModelId(options.model);
|
|
@@ -236,7 +236,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
236
236
|
this.service.models.versions.list(owner, model),
|
|
237
237
|
]);
|
|
238
238
|
|
|
239
|
-
if (!rModel || !versions || versions.length === 0) {
|
|
239
|
+
if (!rModel || !versions || (versions as any).results?.length === 0) {
|
|
240
240
|
throw new Error("Model not found or no versions available");
|
|
241
241
|
}
|
|
242
242
|
|
|
@@ -289,7 +289,7 @@ export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
|
|
|
289
289
|
|
|
290
290
|
}
|
|
291
291
|
|
|
292
|
-
function jobInfo(job: Prediction, modelName?: string): TrainingJob {
|
|
292
|
+
function jobInfo(job: Prediction | Training, modelName?: string): TrainingJob {
|
|
293
293
|
// 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled'
|
|
294
294
|
const jobStatus = job.status;
|
|
295
295
|
let details: string | undefined;
|
package/src/togetherai/index.ts
CHANGED
|
@@ -31,7 +31,7 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
|
|
|
31
31
|
|
|
32
32
|
async requestTextCompletion(prompt: string, options: ExecutionOptions): Promise<Completion> {
|
|
33
33
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
34
|
-
this.logger.warn(
|
|
34
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
35
35
|
}
|
|
36
36
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
37
37
|
|
|
@@ -73,7 +73,7 @@ export class TogetherAIDriver extends AbstractDriver<TogetherAIDriverOptions, st
|
|
|
73
73
|
|
|
74
74
|
async requestTextCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
|
|
75
75
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
76
|
-
this.logger.warn(
|
|
76
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
77
77
|
}
|
|
78
78
|
options.model_options = options.model_options as TextFallbackOptions;
|
|
79
79
|
|
package/src/vertexai/index.ts
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
|
+
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk";
|
|
2
|
+
import { PredictionServiceClient, v1beta1 } from "@google-cloud/aiplatform";
|
|
3
|
+
import { Content, GoogleGenAI, Model } from "@google/genai";
|
|
1
4
|
import {
|
|
2
5
|
AIModel,
|
|
3
6
|
AbstractDriver,
|
|
4
7
|
Completion,
|
|
5
8
|
CompletionChunkObject,
|
|
6
9
|
DriverOptions,
|
|
10
|
+
EmbeddingsOptions,
|
|
7
11
|
EmbeddingsResult,
|
|
8
12
|
ExecutionOptions,
|
|
9
13
|
Modalities,
|
|
@@ -13,17 +17,12 @@ import {
|
|
|
13
17
|
modelModalitiesToArray,
|
|
14
18
|
} from "@llumiverse/core";
|
|
15
19
|
import { FetchClient } from "@vertesia/api-fetch-client";
|
|
16
|
-
import { GoogleAuth, GoogleAuthOptions } from "google-auth-library";
|
|
17
|
-
import {
|
|
20
|
+
import { GoogleAuth, GoogleAuthOptions, AuthClient } from "google-auth-library";
|
|
21
|
+
import { getEmbeddingsForImages } from "./embeddings/embeddings-image.js";
|
|
18
22
|
import { TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js";
|
|
19
23
|
import { getModelDefinition } from "./models.js";
|
|
20
|
-
import {
|
|
21
|
-
import { getEmbeddingsForImages } from "./embeddings/embeddings-image.js";
|
|
22
|
-
import { PredictionServiceClient, v1beta1 } from "@google-cloud/aiplatform";
|
|
23
|
-
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk";
|
|
24
|
+
import { ANTHROPIC_REGIONS, NON_GLOBAL_ANTHROPIC_MODELS } from "./models/claude.js";
|
|
24
25
|
import { ImagenModelDefinition, ImagenPrompt } from "./models/imagen.js";
|
|
25
|
-
import { GoogleGenAI, Content, Model } from "@google/genai";
|
|
26
|
-
import { NON_GLOBAL_ANTHROPIC_MODELS, ANTHROPIC_REGIONS } from "./models/claude.js";
|
|
27
26
|
|
|
28
27
|
export interface VertexAIDriverOptions extends DriverOptions {
|
|
29
28
|
project: string;
|
|
@@ -56,7 +55,8 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
56
55
|
modelGarden: v1beta1.ModelGardenServiceClient | undefined;
|
|
57
56
|
imagenClient: PredictionServiceClient | undefined;
|
|
58
57
|
|
|
59
|
-
|
|
58
|
+
googleAuth: GoogleAuth<any>;
|
|
59
|
+
private authClientPromise: Promise<AuthClient> | undefined;
|
|
60
60
|
|
|
61
61
|
constructor(options: VertexAIDriverOptions) {
|
|
62
62
|
super(options);
|
|
@@ -69,7 +69,15 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
69
69
|
this.llamaClient = undefined;
|
|
70
70
|
this.imagenClient = undefined;
|
|
71
71
|
|
|
72
|
-
this.
|
|
72
|
+
this.googleAuth = new GoogleAuth(options.googleAuthOptions) as GoogleAuth<any>;
|
|
73
|
+
this.authClientPromise = undefined;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
private async getAuthClient(): Promise<AuthClient> {
|
|
77
|
+
if (!this.authClientPromise) {
|
|
78
|
+
this.authClientPromise = this.googleAuth.getClient();
|
|
79
|
+
}
|
|
80
|
+
return this.authClientPromise;
|
|
73
81
|
}
|
|
74
82
|
|
|
75
83
|
public getGoogleGenAIClient(region: string = this.options.region): GoogleGenAI {
|
|
@@ -80,8 +88,8 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
80
88
|
project: this.options.project,
|
|
81
89
|
location: region,
|
|
82
90
|
vertexai: true,
|
|
83
|
-
googleAuthOptions: {
|
|
84
|
-
|
|
91
|
+
googleAuthOptions: this.options.googleAuthOptions || {
|
|
92
|
+
scopes: ["https://www.googleapis.com/auth/cloud-platform"],
|
|
85
93
|
}
|
|
86
94
|
});
|
|
87
95
|
}
|
|
@@ -90,8 +98,8 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
90
98
|
project: this.options.project,
|
|
91
99
|
location: region,
|
|
92
100
|
vertexai: true,
|
|
93
|
-
googleAuthOptions: {
|
|
94
|
-
|
|
101
|
+
googleAuthOptions: this.options.googleAuthOptions || {
|
|
102
|
+
scopes: ["https://www.googleapis.com/auth/cloud-platform"],
|
|
95
103
|
}
|
|
96
104
|
});
|
|
97
105
|
}
|
|
@@ -105,8 +113,7 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
105
113
|
region: this.options.region,
|
|
106
114
|
project: this.options.project,
|
|
107
115
|
}).withAuthCallback(async () => {
|
|
108
|
-
const
|
|
109
|
-
const token = typeof accessTokenResponse === 'string' ? accessTokenResponse : accessTokenResponse?.token;
|
|
116
|
+
const token = await this.googleAuth.getAccessToken();
|
|
110
117
|
return `Bearer ${token}`;
|
|
111
118
|
});
|
|
112
119
|
}
|
|
@@ -121,8 +128,7 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
121
128
|
project: this.options.project,
|
|
122
129
|
apiVersion: "v1beta1",
|
|
123
130
|
}).withAuthCallback(async () => {
|
|
124
|
-
const
|
|
125
|
-
const token = typeof accessTokenResponse === 'string' ? accessTokenResponse : accessTokenResponse?.token;
|
|
131
|
+
const token = await this.googleAuth.getAccessToken();
|
|
126
132
|
return `Bearer ${token}`;
|
|
127
133
|
});
|
|
128
134
|
// Store the region for potential client reuse
|
|
@@ -131,7 +137,7 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
131
137
|
return this.llamaClient;
|
|
132
138
|
}
|
|
133
139
|
|
|
134
|
-
public getAnthropicClient(region: string = this.options.region): AnthropicVertex {
|
|
140
|
+
public async getAnthropicClient(region: string = this.options.region): Promise<AnthropicVertex> {
|
|
135
141
|
// Extract region prefix and map if it exists in ANTHROPIC_REGIONS, otherwise use as-is
|
|
136
142
|
const getRegionPrefix = (r: string) => r.split('-')[0];
|
|
137
143
|
const regionPrefix = getRegionPrefix(region);
|
|
@@ -140,17 +146,16 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
140
146
|
const defaultRegionPrefix = getRegionPrefix(this.options.region);
|
|
141
147
|
const defaultMappedRegion = ANTHROPIC_REGIONS[defaultRegionPrefix] || this.options.region;
|
|
142
148
|
|
|
149
|
+
// Get auth client to avoid version mismatch with GoogleAuth generic types
|
|
150
|
+
const authClient = await this.getAuthClient();
|
|
151
|
+
|
|
143
152
|
// If mapped region is different from default mapped region, create one-off client
|
|
144
153
|
if (mappedRegion !== defaultMappedRegion) {
|
|
145
154
|
return new AnthropicVertex({
|
|
146
155
|
timeout: 20 * 60 * 10000, // Set to 20 minutes, 10 minute default, setting this disables long request error: https://github.com/anthropics/anthropic-sdk-typescript?#long-requests
|
|
147
156
|
region: mappedRegion,
|
|
148
157
|
projectId: this.options.project,
|
|
149
|
-
|
|
150
|
-
scopes: ["https://www.googleapis.com/auth/cloud-platform"],
|
|
151
|
-
authClient: this.authClient as JSONClient,
|
|
152
|
-
projectId: this.options.project,
|
|
153
|
-
}),
|
|
158
|
+
authClient,
|
|
154
159
|
});
|
|
155
160
|
}
|
|
156
161
|
|
|
@@ -160,48 +165,47 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
160
165
|
timeout: 20 * 60 * 10000, // Set to 20 minutes, 10 minute default, setting this disables long request error: https://github.com/anthropics/anthropic-sdk-typescript?#long-requests
|
|
161
166
|
region: mappedRegion,
|
|
162
167
|
projectId: this.options.project,
|
|
163
|
-
|
|
164
|
-
scopes: ["https://www.googleapis.com/auth/cloud-platform"],
|
|
165
|
-
authClient: this.authClient as JSONClient,
|
|
166
|
-
projectId: this.options.project,
|
|
167
|
-
}),
|
|
168
|
+
authClient,
|
|
168
169
|
});
|
|
169
170
|
}
|
|
170
171
|
return this.anthropicClient;
|
|
171
172
|
}
|
|
172
173
|
|
|
173
|
-
public getAIPlatformClient(): v1beta1.ModelServiceClient {
|
|
174
|
+
public async getAIPlatformClient(): Promise<v1beta1.ModelServiceClient> {
|
|
174
175
|
//Lazy initialization
|
|
175
176
|
if (!this.aiplatform) {
|
|
177
|
+
const authClient = await this.getAuthClient();
|
|
176
178
|
this.aiplatform = new v1beta1.ModelServiceClient({
|
|
177
179
|
projectId: this.options.project,
|
|
178
180
|
apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
|
|
179
|
-
authClient
|
|
181
|
+
authClient,
|
|
180
182
|
});
|
|
181
183
|
}
|
|
182
184
|
return this.aiplatform;
|
|
183
185
|
}
|
|
184
186
|
|
|
185
|
-
public getModelGardenClient(): v1beta1.ModelGardenServiceClient {
|
|
187
|
+
public async getModelGardenClient(): Promise<v1beta1.ModelGardenServiceClient> {
|
|
186
188
|
//Lazy initialization
|
|
187
189
|
if (!this.modelGarden) {
|
|
190
|
+
const authClient = await this.getAuthClient();
|
|
188
191
|
this.modelGarden = new v1beta1.ModelGardenServiceClient({
|
|
189
192
|
projectId: this.options.project,
|
|
190
193
|
apiEndpoint: `${this.options.region}-${API_BASE_PATH}`,
|
|
191
|
-
authClient
|
|
194
|
+
authClient,
|
|
192
195
|
});
|
|
193
196
|
}
|
|
194
197
|
return this.modelGarden;
|
|
195
198
|
}
|
|
196
199
|
|
|
197
|
-
public getImagenClient(): PredictionServiceClient {
|
|
200
|
+
public async getImagenClient(): Promise<PredictionServiceClient> {
|
|
198
201
|
//Lazy initialization
|
|
199
202
|
if (!this.imagenClient) {
|
|
200
203
|
// TODO: make location configurable, fixed to us-central1 for now
|
|
204
|
+
const authClient = await this.getAuthClient();
|
|
201
205
|
this.imagenClient = new PredictionServiceClient({
|
|
202
206
|
projectId: this.options.project,
|
|
203
207
|
apiEndpoint: `us-central1-${API_BASE_PATH}`,
|
|
204
|
-
authClient
|
|
208
|
+
authClient,
|
|
205
209
|
});
|
|
206
210
|
}
|
|
207
211
|
return this.imagenClient;
|
|
@@ -263,8 +267,8 @@ export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, Vertex
|
|
|
263
267
|
|
|
264
268
|
async listModels(_params?: ModelSearchPayload): Promise<AIModel<string>[]> {
|
|
265
269
|
// Get clients
|
|
266
|
-
const modelGarden = this.getModelGardenClient();
|
|
267
|
-
const aiplatform = this.getAIPlatformClient();
|
|
270
|
+
const modelGarden = await this.getModelGardenClient();
|
|
271
|
+
const aiplatform = await this.getAIPlatformClient();
|
|
268
272
|
const globalGenAiClient = this.getGoogleGenAIClient("global");
|
|
269
273
|
|
|
270
274
|
let models: AIModel<string>[] = [];
|
|
@@ -269,11 +269,11 @@ export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
|
|
|
269
269
|
const modelName = splits[splits.length - 1];
|
|
270
270
|
options = { ...options, model: modelName };
|
|
271
271
|
|
|
272
|
-
const client = driver.getAnthropicClient(region);
|
|
272
|
+
const client = await driver.getAnthropicClient(region);
|
|
273
273
|
options.model_options = options.model_options as VertexAIClaudeOptions;
|
|
274
274
|
|
|
275
275
|
if (options.model_options?._option_id !== "vertexai-claude") {
|
|
276
|
-
driver.logger.warn(
|
|
276
|
+
driver.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
277
277
|
}
|
|
278
278
|
|
|
279
279
|
let conversation = updateConversation(options.conversation as ClaudePrompt, prompt);
|
|
@@ -314,11 +314,11 @@ export class ClaudeModelDefinition implements ModelDefinition<ClaudePrompt> {
|
|
|
314
314
|
const modelName = splits[splits.length - 1];
|
|
315
315
|
options = { ...options, model: modelName };
|
|
316
316
|
|
|
317
|
-
const client = driver.getAnthropicClient(region);
|
|
317
|
+
const client = await driver.getAnthropicClient(region);
|
|
318
318
|
const model_options = options.model_options as VertexAIClaudeOptions | undefined;
|
|
319
319
|
|
|
320
320
|
if (model_options?._option_id !== "vertexai-claude") {
|
|
321
|
-
driver.logger.warn(
|
|
321
|
+
driver.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
322
322
|
}
|
|
323
323
|
|
|
324
324
|
const { payload, requestOptions } = getClaudePayload(options, prompt);
|
|
@@ -324,7 +324,7 @@ export class ImagenModelDefinition {
|
|
|
324
324
|
|
|
325
325
|
async requestImageGeneration(driver: VertexAIDriver, prompt: ImagenPrompt, options: ExecutionOptions): Promise<Completion> {
|
|
326
326
|
if (options.model_options?._option_id !== "vertexai-imagen") {
|
|
327
|
-
driver.logger.warn(
|
|
327
|
+
driver.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
328
328
|
}
|
|
329
329
|
options.model_options = options.model_options as ImagenOptions | undefined;
|
|
330
330
|
|
|
@@ -336,7 +336,7 @@ export class ImagenModelDefinition {
|
|
|
336
336
|
|
|
337
337
|
driver.logger.info("Task type: " + taskType);
|
|
338
338
|
|
|
339
|
-
|
|
339
|
+
const modelName = options.model.split("/").pop() ?? '';
|
|
340
340
|
|
|
341
341
|
// Configure the parent resource
|
|
342
342
|
// TODO: make location configurable, fixed to us-central1 for now
|
|
@@ -348,7 +348,7 @@ export class ImagenModelDefinition {
|
|
|
348
348
|
}
|
|
349
349
|
const instances = [instanceValue];
|
|
350
350
|
|
|
351
|
-
let parameter: any = getImagenParameters(taskType, options.model_options ?? {_option_id: "vertexai-imagen"});
|
|
351
|
+
let parameter: any = getImagenParameters(taskType, options.model_options ?? { _option_id: "vertexai-imagen" });
|
|
352
352
|
parameter.negativePrompt = prompt.negativePrompt ?? undefined;
|
|
353
353
|
|
|
354
354
|
const numberOfImages = options.model_options?.number_of_images ?? 1;
|
|
@@ -366,7 +366,7 @@ export class ImagenModelDefinition {
|
|
|
366
366
|
parameters,
|
|
367
367
|
};
|
|
368
368
|
|
|
369
|
-
const client = driver.getImagenClient();
|
|
369
|
+
const client = await driver.getImagenClient();
|
|
370
370
|
|
|
371
371
|
// Predict request
|
|
372
372
|
const [response] = await client.predict(request, { timeout: 120000 * numberOfImages }); //Extended timeout for image generation
|
package/src/watsonx/index.ts
CHANGED
|
@@ -31,7 +31,7 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
|
|
|
31
31
|
|
|
32
32
|
async requestTextCompletion(prompt: string, options: ExecutionOptions): Promise<Completion> {
|
|
33
33
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
34
|
-
this.logger.warn(
|
|
34
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
35
35
|
}
|
|
36
36
|
options.model_options = options.model_options as TextFallbackOptions | undefined;
|
|
37
37
|
|
|
@@ -66,7 +66,7 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
|
|
|
66
66
|
|
|
67
67
|
async requestTextCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
|
|
68
68
|
if (options.model_options?._option_id !== "text-fallback") {
|
|
69
|
-
this.logger.warn(
|
|
69
|
+
this.logger.warn({ options: options.model_options }, "Invalid model options");
|
|
70
70
|
}
|
|
71
71
|
options.model_options = options.model_options as TextFallbackOptions | undefined;
|
|
72
72
|
const payload: WatsonxTextGenerationPayload = {
|
|
@@ -152,7 +152,7 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
|
|
|
152
152
|
return this.listModels()
|
|
153
153
|
.then(() => true)
|
|
154
154
|
.catch((err) => {
|
|
155
|
-
this.logger.warn("Failed to connect to WatsonX"
|
|
155
|
+
this.logger.warn({ error: err }, "Failed to connect to WatsonX");
|
|
156
156
|
return false
|
|
157
157
|
});
|
|
158
158
|
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|