@llumiverse/drivers 0.19.0 → 0.21.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/cjs/azure/azure_foundry.js +379 -0
- package/lib/cjs/azure/azure_foundry.js.map +1 -0
- package/lib/cjs/bedrock/converse.js +181 -123
- package/lib/cjs/bedrock/converse.js.map +1 -1
- package/lib/cjs/bedrock/index.js +157 -72
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/groq/index.js +91 -10
- package/lib/cjs/groq/index.js.map +1 -1
- package/lib/cjs/index.js +2 -1
- package/lib/cjs/index.js.map +1 -1
- package/lib/cjs/mistral/index.js +2 -1
- package/lib/cjs/mistral/index.js.map +1 -1
- package/lib/cjs/openai/azure_openai.js +72 -0
- package/lib/cjs/openai/azure_openai.js.map +1 -0
- package/lib/cjs/openai/index.js +6 -9
- package/lib/cjs/openai/index.js.map +1 -1
- package/lib/cjs/openai/openai.js +2 -2
- package/lib/cjs/openai/openai.js.map +1 -1
- package/lib/cjs/openai/openai_format.js +138 -0
- package/lib/cjs/openai/openai_format.js.map +1 -0
- package/lib/cjs/vertexai/index.js +1 -0
- package/lib/cjs/vertexai/index.js.map +1 -1
- package/lib/cjs/vertexai/models/claude.js +229 -118
- package/lib/cjs/vertexai/models/claude.js.map +1 -1
- package/lib/cjs/vertexai/models/gemini.js +110 -70
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/vertexai/models/imagen.js +2 -2
- package/lib/cjs/vertexai/models/imagen.js.map +1 -1
- package/lib/cjs/watsonx/index.js +11 -11
- package/lib/cjs/watsonx/index.js.map +1 -1
- package/lib/cjs/xai/index.js +3 -3
- package/lib/cjs/xai/index.js.map +1 -1
- package/lib/esm/azure/azure_foundry.js +373 -0
- package/lib/esm/azure/azure_foundry.js.map +1 -0
- package/lib/esm/bedrock/converse.js +180 -122
- package/lib/esm/bedrock/converse.js.map +1 -1
- package/lib/esm/bedrock/index.js +158 -73
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/groq/index.js +91 -10
- package/lib/esm/groq/index.js.map +1 -1
- package/lib/esm/index.js +2 -1
- package/lib/esm/index.js.map +1 -1
- package/lib/esm/mistral/index.js +2 -1
- package/lib/esm/mistral/index.js.map +1 -1
- package/lib/esm/openai/azure_openai.js +68 -0
- package/lib/esm/openai/azure_openai.js.map +1 -0
- package/lib/esm/openai/index.js +5 -8
- package/lib/esm/openai/index.js.map +1 -1
- package/lib/esm/openai/openai.js +2 -2
- package/lib/esm/openai/openai.js.map +1 -1
- package/lib/esm/openai/openai_format.js +134 -0
- package/lib/esm/openai/openai_format.js.map +1 -0
- package/lib/esm/src/adobe/firefly.js +115 -0
- package/lib/esm/src/adobe/firefly.js.map +1 -0
- package/lib/esm/src/bedrock/converse.js +278 -0
- package/lib/esm/src/bedrock/converse.js.map +1 -0
- package/lib/esm/src/bedrock/index.js +797 -0
- package/lib/esm/src/bedrock/index.js.map +1 -0
- package/lib/esm/src/bedrock/nova-image-payload.js +203 -0
- package/lib/esm/src/bedrock/nova-image-payload.js.map +1 -0
- package/lib/esm/src/bedrock/payloads.js +2 -0
- package/lib/esm/src/bedrock/payloads.js.map +1 -0
- package/lib/esm/src/bedrock/s3.js +99 -0
- package/lib/esm/src/bedrock/s3.js.map +1 -0
- package/lib/esm/src/groq/index.js +130 -0
- package/lib/esm/src/groq/index.js.map +1 -0
- package/lib/esm/src/huggingface_ie.js +196 -0
- package/lib/esm/src/huggingface_ie.js.map +1 -0
- package/lib/esm/src/index.js +13 -0
- package/lib/esm/src/index.js.map +1 -0
- package/lib/esm/src/mistral/index.js +167 -0
- package/lib/esm/src/mistral/index.js.map +1 -0
- package/lib/esm/src/mistral/types.js +80 -0
- package/lib/esm/src/mistral/types.js.map +1 -0
- package/{src/openai/azure.ts → lib/esm/src/openai/azure.js} +7 -34
- package/lib/esm/src/openai/azure.js.map +1 -0
- package/lib/esm/src/openai/index.js +463 -0
- package/lib/esm/src/openai/index.js.map +1 -0
- package/lib/esm/src/openai/openai.js +14 -0
- package/lib/esm/src/openai/openai.js.map +1 -0
- package/lib/esm/src/replicate.js +268 -0
- package/lib/esm/src/replicate.js.map +1 -0
- package/lib/esm/src/test/TestErrorCompletionStream.js +16 -0
- package/lib/esm/src/test/TestErrorCompletionStream.js.map +1 -0
- package/lib/esm/src/test/TestValidationErrorCompletionStream.js +20 -0
- package/lib/esm/src/test/TestValidationErrorCompletionStream.js.map +1 -0
- package/lib/esm/src/test/index.js +91 -0
- package/lib/esm/src/test/index.js.map +1 -0
- package/lib/esm/src/test/utils.js +25 -0
- package/lib/esm/src/test/utils.js.map +1 -0
- package/lib/esm/src/togetherai/index.js +122 -0
- package/lib/esm/src/togetherai/index.js.map +1 -0
- package/lib/esm/src/togetherai/interfaces.js +2 -0
- package/lib/esm/src/togetherai/interfaces.js.map +1 -0
- package/lib/esm/src/vertexai/debug.js +6 -0
- package/lib/esm/src/vertexai/debug.js.map +1 -0
- package/lib/esm/src/vertexai/embeddings/embeddings-image.js +24 -0
- package/lib/esm/src/vertexai/embeddings/embeddings-image.js.map +1 -0
- package/lib/esm/src/vertexai/embeddings/embeddings-text.js +20 -0
- package/lib/esm/src/vertexai/embeddings/embeddings-text.js.map +1 -0
- package/lib/esm/src/vertexai/index.js +270 -0
- package/lib/esm/src/vertexai/index.js.map +1 -0
- package/lib/esm/src/vertexai/models/claude.js +370 -0
- package/lib/esm/src/vertexai/models/claude.js.map +1 -0
- package/lib/esm/src/vertexai/models/gemini.js +700 -0
- package/lib/esm/src/vertexai/models/gemini.js.map +1 -0
- package/lib/esm/src/vertexai/models/imagen.js +310 -0
- package/lib/esm/src/vertexai/models/imagen.js.map +1 -0
- package/lib/esm/src/vertexai/models/llama.js +178 -0
- package/lib/esm/src/vertexai/models/llama.js.map +1 -0
- package/lib/esm/src/vertexai/models.js +21 -0
- package/lib/esm/src/vertexai/models.js.map +1 -0
- package/lib/esm/src/watsonx/index.js +157 -0
- package/lib/esm/src/watsonx/index.js.map +1 -0
- package/lib/esm/src/watsonx/interfaces.js +2 -0
- package/lib/esm/src/watsonx/interfaces.js.map +1 -0
- package/lib/esm/src/xai/index.js +64 -0
- package/lib/esm/src/xai/index.js.map +1 -0
- package/lib/esm/tsconfig.tsbuildinfo +1 -0
- package/lib/esm/vertexai/index.js +1 -0
- package/lib/esm/vertexai/index.js.map +1 -1
- package/lib/esm/vertexai/models/claude.js +230 -119
- package/lib/esm/vertexai/models/claude.js.map +1 -1
- package/lib/esm/vertexai/models/gemini.js +109 -70
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/vertexai/models/imagen.js +2 -2
- package/lib/esm/vertexai/models/imagen.js.map +1 -1
- package/lib/esm/watsonx/index.js +11 -11
- package/lib/esm/watsonx/index.js.map +1 -1
- package/lib/esm/xai/index.js +2 -2
- package/lib/esm/xai/index.js.map +1 -1
- package/lib/types/azure/azure_foundry.d.ts +50 -0
- package/lib/types/azure/azure_foundry.d.ts.map +1 -0
- package/lib/types/bedrock/converse.d.ts +2 -2
- package/lib/types/bedrock/converse.d.ts.map +1 -1
- package/lib/types/bedrock/index.d.ts +5 -5
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/groq/index.d.ts +5 -5
- package/lib/types/groq/index.d.ts.map +1 -1
- package/lib/types/index.d.ts +2 -1
- package/lib/types/index.d.ts.map +1 -1
- package/lib/types/mistral/index.d.ts +2 -2
- package/lib/types/mistral/index.d.ts.map +1 -1
- package/lib/types/openai/azure_openai.d.ts +25 -0
- package/lib/types/openai/azure_openai.d.ts.map +1 -0
- package/lib/types/openai/index.d.ts +6 -7
- package/lib/types/openai/index.d.ts.map +1 -1
- package/lib/types/openai/openai.d.ts +2 -2
- package/lib/types/openai/openai.d.ts.map +1 -1
- package/lib/types/openai/openai_format.d.ts +19 -0
- package/lib/types/openai/openai_format.d.ts.map +1 -0
- package/lib/types/src/adobe/firefly.d.ts +29 -0
- package/lib/types/src/bedrock/converse.d.ts +8 -0
- package/lib/types/src/bedrock/index.d.ts +57 -0
- package/lib/types/src/bedrock/nova-image-payload.d.ts +73 -0
- package/lib/types/src/bedrock/payloads.d.ts +11 -0
- package/lib/types/src/bedrock/s3.d.ts +22 -0
- package/lib/types/src/groq/index.d.ts +23 -0
- package/lib/types/src/huggingface_ie.d.ts +31 -0
- package/lib/types/src/index.d.ts +12 -0
- package/lib/types/src/mistral/index.d.ts +24 -0
- package/lib/types/src/mistral/types.d.ts +131 -0
- package/lib/types/src/openai/azure.d.ts +19 -0
- package/lib/types/src/openai/index.d.ts +25 -0
- package/lib/types/src/openai/openai.d.ts +14 -0
- package/lib/types/src/replicate.d.ts +44 -0
- package/lib/types/src/test/TestErrorCompletionStream.d.ts +8 -0
- package/lib/types/src/test/TestValidationErrorCompletionStream.d.ts +8 -0
- package/lib/types/src/test/index.d.ts +23 -0
- package/lib/types/src/test/utils.d.ts +4 -0
- package/lib/types/src/togetherai/index.d.ts +22 -0
- package/lib/types/src/togetherai/interfaces.d.ts +95 -0
- package/lib/types/src/vertexai/debug.d.ts +1 -0
- package/lib/types/src/vertexai/embeddings/embeddings-image.d.ts +10 -0
- package/lib/types/src/vertexai/embeddings/embeddings-text.d.ts +9 -0
- package/lib/types/src/vertexai/index.d.ts +49 -0
- package/lib/types/src/vertexai/models/claude.d.ts +17 -0
- package/lib/types/src/vertexai/models/gemini.d.ts +16 -0
- package/lib/types/src/vertexai/models/imagen.d.ts +74 -0
- package/lib/types/src/vertexai/models/llama.d.ts +19 -0
- package/lib/types/src/vertexai/models.d.ts +14 -0
- package/lib/types/src/watsonx/index.d.ts +26 -0
- package/lib/types/src/watsonx/interfaces.d.ts +64 -0
- package/lib/types/src/xai/index.d.ts +18 -0
- package/lib/types/vertexai/index.d.ts +2 -3
- package/lib/types/vertexai/index.d.ts.map +1 -1
- package/lib/types/vertexai/models/claude.d.ts +5 -7
- package/lib/types/vertexai/models/claude.d.ts.map +1 -1
- package/lib/types/vertexai/models/gemini.d.ts +4 -2
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
- package/lib/types/vertexai/models.d.ts +2 -2
- package/lib/types/vertexai/models.d.ts.map +1 -1
- package/lib/types/xai/index.d.ts.map +1 -1
- package/package.json +20 -16
- package/src/azure/azure_foundry.ts +450 -0
- package/src/bedrock/converse.ts +194 -129
- package/src/bedrock/index.ts +182 -84
- package/src/groq/index.ts +107 -16
- package/src/index.ts +2 -1
- package/src/mistral/index.ts +3 -2
- package/src/openai/azure_openai.ts +92 -0
- package/src/openai/index.ts +19 -22
- package/src/openai/openai.ts +2 -5
- package/src/openai/openai_format.ts +165 -0
- package/src/vertexai/index.ts +3 -3
- package/src/vertexai/models/claude.ts +270 -138
- package/src/vertexai/models/gemini.ts +120 -77
- package/src/vertexai/models/imagen.ts +3 -3
- package/src/vertexai/models.ts +2 -2
- package/src/watsonx/index.ts +17 -17
- package/src/xai/index.ts +2 -3
|
@@ -0,0 +1,797 @@
|
|
|
1
|
+
import { Bedrock, CreateModelCustomizationJobCommand, GetModelCustomizationJobCommand, ModelCustomizationJobStatus, ModelModality, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock";
|
|
2
|
+
import { BedrockRuntime } from "@aws-sdk/client-bedrock-runtime";
|
|
3
|
+
import { S3Client } from "@aws-sdk/client-s3";
|
|
4
|
+
import { AbstractDriver, Modalities, TrainingJobStatus, getMaxTokensLimitBedrock, modelModalitiesToArray, getModelCapabilities } from "@llumiverse/core";
|
|
5
|
+
import { transformAsyncIterator } from "@llumiverse/core/async";
|
|
6
|
+
import { formatNovaPrompt } from "@llumiverse/core/formatters";
|
|
7
|
+
import { LRUCache } from "mnemonist";
|
|
8
|
+
import { converseConcatMessages, converseJSONprefill, converseSystemToMessages, formatConversePrompt } from "./converse.js";
|
|
9
|
+
import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js";
|
|
10
|
+
import { forceUploadFile } from "./s3.js";
|
|
11
|
+
const supportStreamingCache = new LRUCache(4096);
|
|
12
|
+
var BedrockModelType;
|
|
13
|
+
(function (BedrockModelType) {
|
|
14
|
+
BedrockModelType["FoundationModel"] = "foundation-model";
|
|
15
|
+
BedrockModelType["InferenceProfile"] = "inference-profile";
|
|
16
|
+
BedrockModelType["CustomModel"] = "custom-model";
|
|
17
|
+
BedrockModelType["Unknown"] = "unknown";
|
|
18
|
+
})(BedrockModelType || (BedrockModelType = {}));
|
|
19
|
+
;
|
|
20
|
+
function converseFinishReason(reason) {
|
|
21
|
+
//Possible values:
|
|
22
|
+
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
|
|
23
|
+
if (!reason)
|
|
24
|
+
return undefined;
|
|
25
|
+
switch (reason) {
|
|
26
|
+
case 'end_turn': return "stop";
|
|
27
|
+
case 'max_tokens': return "length";
|
|
28
|
+
default: return reason;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
//Used to get a max_token value when not specified in the model options. Claude requires it to be set.
|
|
32
|
+
function maxTokenFallbackClaude(option) {
|
|
33
|
+
const modelOptions = option.model_options;
|
|
34
|
+
if (modelOptions && typeof modelOptions.max_tokens === "number") {
|
|
35
|
+
return modelOptions.max_tokens;
|
|
36
|
+
}
|
|
37
|
+
else {
|
|
38
|
+
// Fallback to the default max tokens limit for the model
|
|
39
|
+
if (option.model.includes('claude-3-7-sonnet') && (modelOptions?.thinking_budget_tokens ?? 0) < 64000) {
|
|
40
|
+
return 64000; // Claude 3.7 can go up to 128k with a beta header, but when no max tokens is specified, we default to 64k.
|
|
41
|
+
}
|
|
42
|
+
return getMaxTokensLimitBedrock(option.model) ?? 8192; // Should always return a number for claude, 8192 is to satisfy the TypeScript type checker
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
export class BedrockDriver extends AbstractDriver {
|
|
46
|
+
static PROVIDER = "bedrock";
|
|
47
|
+
provider = BedrockDriver.PROVIDER;
|
|
48
|
+
_executor;
|
|
49
|
+
_service;
|
|
50
|
+
_service_region;
|
|
51
|
+
constructor(options) {
|
|
52
|
+
super(options);
|
|
53
|
+
if (!options.region) {
|
|
54
|
+
throw new Error("No region found. Set the region in the environment's endpoint URL.");
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
getExecutor() {
|
|
58
|
+
if (!this._executor) {
|
|
59
|
+
this._executor = new BedrockRuntime({
|
|
60
|
+
region: this.options.region,
|
|
61
|
+
credentials: this.options.credentials,
|
|
62
|
+
});
|
|
63
|
+
}
|
|
64
|
+
return this._executor;
|
|
65
|
+
}
|
|
66
|
+
getService(region = this.options.region) {
|
|
67
|
+
if (!this._service || this._service_region != region) {
|
|
68
|
+
this._service = new Bedrock({
|
|
69
|
+
region: region,
|
|
70
|
+
credentials: this.options.credentials,
|
|
71
|
+
});
|
|
72
|
+
this._service_region = region;
|
|
73
|
+
}
|
|
74
|
+
return this._service;
|
|
75
|
+
}
|
|
76
|
+
async formatPrompt(segments, opts) {
|
|
77
|
+
if (opts.model.includes("canvas")) {
|
|
78
|
+
return await formatNovaPrompt(segments, opts.result_schema);
|
|
79
|
+
}
|
|
80
|
+
return await formatConversePrompt(segments, opts);
|
|
81
|
+
}
|
|
82
|
+
static getExtractedExecution(result, _prompt, options) {
|
|
83
|
+
let resultText = "";
|
|
84
|
+
let reasoning = "";
|
|
85
|
+
if (result.output?.message?.content) {
|
|
86
|
+
for (const content of result.output.message.content) {
|
|
87
|
+
// Get text output
|
|
88
|
+
if (content.text) {
|
|
89
|
+
resultText += content.text;
|
|
90
|
+
}
|
|
91
|
+
// Get reasoning content only if include_thoughts is true
|
|
92
|
+
if (content.reasoningContent && options) {
|
|
93
|
+
const claudeOptions = options.model_options;
|
|
94
|
+
if (claudeOptions?.include_thoughts) {
|
|
95
|
+
if (content.reasoningContent.reasoningText) {
|
|
96
|
+
reasoning += content.reasoningContent.reasoningText.text;
|
|
97
|
+
}
|
|
98
|
+
else if (content.reasoningContent.redactedContent) {
|
|
99
|
+
// Handle redacted thinking content
|
|
100
|
+
const redactedData = new TextDecoder().decode(content.reasoningContent.redactedContent);
|
|
101
|
+
reasoning += `[Redacted thinking: ${redactedData}]`;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
// Add spacing if we have reasoning content
|
|
107
|
+
if (reasoning) {
|
|
108
|
+
reasoning += '\n\n';
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
const completionResult = {
|
|
112
|
+
result: reasoning + resultText,
|
|
113
|
+
token_usage: {
|
|
114
|
+
prompt: result.usage?.inputTokens,
|
|
115
|
+
result: result.usage?.outputTokens,
|
|
116
|
+
total: result.usage?.totalTokens,
|
|
117
|
+
},
|
|
118
|
+
finish_reason: converseFinishReason(result.stopReason),
|
|
119
|
+
};
|
|
120
|
+
return completionResult;
|
|
121
|
+
}
|
|
122
|
+
;
|
|
123
|
+
static getExtractedStream(result, _prompt, options) {
|
|
124
|
+
let output = "";
|
|
125
|
+
let reasoning = "";
|
|
126
|
+
let stop_reason = "";
|
|
127
|
+
let token_usage;
|
|
128
|
+
// Check if we should include thoughts
|
|
129
|
+
const shouldIncludeThoughts = options && options.model_options?.include_thoughts;
|
|
130
|
+
// Handle content block start events (for reasoning blocks)
|
|
131
|
+
if (result.contentBlockStart) {
|
|
132
|
+
// Handle redacted content at block start
|
|
133
|
+
if (result.contentBlockStart.start && 'reasoningContent' in result.contentBlockStart.start && shouldIncludeThoughts) {
|
|
134
|
+
const reasoningStart = result.contentBlockStart.start;
|
|
135
|
+
if (reasoningStart.reasoningContent?.redactedContent) {
|
|
136
|
+
const redactedData = new TextDecoder().decode(reasoningStart.reasoningContent.redactedContent);
|
|
137
|
+
reasoning = `[Redacted thinking: ${redactedData}]`;
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
// Handle content block deltas (text and reasoning)
|
|
142
|
+
if (result.contentBlockDelta) {
|
|
143
|
+
const delta = result.contentBlockDelta.delta;
|
|
144
|
+
if (delta?.text) {
|
|
145
|
+
output = delta.text;
|
|
146
|
+
}
|
|
147
|
+
else if (delta?.reasoningContent && shouldIncludeThoughts) {
|
|
148
|
+
if (delta.reasoningContent.text) {
|
|
149
|
+
reasoning = delta.reasoningContent.text;
|
|
150
|
+
}
|
|
151
|
+
else if (delta.reasoningContent.redactedContent) {
|
|
152
|
+
const redactedData = new TextDecoder().decode(delta.reasoningContent.redactedContent);
|
|
153
|
+
reasoning = `[Redacted thinking: ${redactedData}]`;
|
|
154
|
+
}
|
|
155
|
+
else if (delta.reasoningContent.signature) {
|
|
156
|
+
// Handle signature updates for reasoning content - end of thinking
|
|
157
|
+
reasoning = "\n\n";
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
// Handle content block stop events
|
|
162
|
+
if (result.contentBlockStop) {
|
|
163
|
+
// Content block ended - could be end of reasoning or text block
|
|
164
|
+
// Add minimal spacing for reasoning blocks if not already present
|
|
165
|
+
if (reasoning && !reasoning.endsWith('\n\n') && shouldIncludeThoughts) {
|
|
166
|
+
reasoning += '\n\n';
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
if (result.messageStop) {
|
|
170
|
+
stop_reason = result.messageStop.stopReason ?? "";
|
|
171
|
+
}
|
|
172
|
+
if (result.metadata) {
|
|
173
|
+
token_usage = {
|
|
174
|
+
prompt: result.metadata.usage?.inputTokens,
|
|
175
|
+
result: result.metadata.usage?.outputTokens,
|
|
176
|
+
total: result.metadata.usage?.totalTokens,
|
|
177
|
+
};
|
|
178
|
+
}
|
|
179
|
+
const completionResult = {
|
|
180
|
+
result: reasoning + output,
|
|
181
|
+
token_usage: token_usage,
|
|
182
|
+
finish_reason: converseFinishReason(stop_reason),
|
|
183
|
+
};
|
|
184
|
+
return completionResult;
|
|
185
|
+
}
|
|
186
|
+
;
|
|
187
|
+
extractRegion(modelString, defaultRegion) {
|
|
188
|
+
// Match region in full ARN pattern
|
|
189
|
+
const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
|
|
190
|
+
if (arnMatch) {
|
|
191
|
+
return arnMatch[1];
|
|
192
|
+
}
|
|
193
|
+
// Match common AWS regions directly in string
|
|
194
|
+
const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
|
|
195
|
+
if (regionMatch) {
|
|
196
|
+
return regionMatch[0];
|
|
197
|
+
}
|
|
198
|
+
return defaultRegion;
|
|
199
|
+
}
|
|
200
|
+
async getCanStream(model, type) {
|
|
201
|
+
let canStream = false;
|
|
202
|
+
let error = null;
|
|
203
|
+
const region = this.extractRegion(model, this.options.region);
|
|
204
|
+
if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) {
|
|
205
|
+
try {
|
|
206
|
+
const response = await this.getService(region).getFoundationModel({
|
|
207
|
+
modelIdentifier: model
|
|
208
|
+
});
|
|
209
|
+
canStream = response.modelDetails?.responseStreamingSupported ?? false;
|
|
210
|
+
return canStream;
|
|
211
|
+
}
|
|
212
|
+
catch (e) {
|
|
213
|
+
error = e;
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) {
|
|
217
|
+
try {
|
|
218
|
+
const response = await this.getService(region).getInferenceProfile({
|
|
219
|
+
inferenceProfileIdentifier: model
|
|
220
|
+
});
|
|
221
|
+
canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
|
|
222
|
+
return canStream;
|
|
223
|
+
}
|
|
224
|
+
catch (e) {
|
|
225
|
+
error = e;
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) {
|
|
229
|
+
try {
|
|
230
|
+
const response = await this.getService(region).getCustomModel({
|
|
231
|
+
modelIdentifier: model
|
|
232
|
+
});
|
|
233
|
+
canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
|
|
234
|
+
return canStream;
|
|
235
|
+
}
|
|
236
|
+
catch (e) {
|
|
237
|
+
error = e;
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
if (error) {
|
|
241
|
+
console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
|
|
242
|
+
}
|
|
243
|
+
return canStream;
|
|
244
|
+
}
|
|
245
|
+
async canStream(options) {
|
|
246
|
+
let canStream = supportStreamingCache.get(options.model);
|
|
247
|
+
if (canStream == null) {
|
|
248
|
+
let type = BedrockModelType.Unknown;
|
|
249
|
+
if (options.model.includes("foundation-model")) {
|
|
250
|
+
type = BedrockModelType.FoundationModel;
|
|
251
|
+
}
|
|
252
|
+
else if (options.model.includes("inference-profile")) {
|
|
253
|
+
type = BedrockModelType.InferenceProfile;
|
|
254
|
+
}
|
|
255
|
+
else if (options.model.includes("custom-model")) {
|
|
256
|
+
type = BedrockModelType.CustomModel;
|
|
257
|
+
}
|
|
258
|
+
canStream = await this.getCanStream(options.model, type);
|
|
259
|
+
supportStreamingCache.set(options.model, canStream);
|
|
260
|
+
}
|
|
261
|
+
return canStream;
|
|
262
|
+
}
|
|
263
|
+
async requestTextCompletion(prompt, options) {
|
|
264
|
+
let conversation = updateConversation(options.conversation, prompt);
|
|
265
|
+
const payload = this.preparePayload(conversation, options);
|
|
266
|
+
const executor = this.getExecutor();
|
|
267
|
+
const res = await executor.converse({
|
|
268
|
+
...payload,
|
|
269
|
+
});
|
|
270
|
+
conversation = updateConversation(conversation, {
|
|
271
|
+
messages: [res.output?.message ?? { content: [{ text: "" }], role: "assistant" }],
|
|
272
|
+
modelId: prompt.modelId,
|
|
273
|
+
});
|
|
274
|
+
let tool_use = undefined;
|
|
275
|
+
//Get tool requests
|
|
276
|
+
if (res.stopReason == "tool_use") {
|
|
277
|
+
tool_use = res.output?.message?.content?.reduce((tools, c) => {
|
|
278
|
+
if (c.toolUse) {
|
|
279
|
+
tools.push({
|
|
280
|
+
tool_name: c.toolUse.name ?? "",
|
|
281
|
+
tool_input: c.toolUse.input,
|
|
282
|
+
id: c.toolUse.toolUseId ?? "",
|
|
283
|
+
});
|
|
284
|
+
}
|
|
285
|
+
return tools;
|
|
286
|
+
}, []);
|
|
287
|
+
//If no tools were used, set to undefined
|
|
288
|
+
if (tool_use && tool_use.length == 0) {
|
|
289
|
+
tool_use = undefined;
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
const completion = {
|
|
293
|
+
...BedrockDriver.getExtractedExecution(res, prompt, options),
|
|
294
|
+
original_response: options.include_original_response ? res : undefined,
|
|
295
|
+
conversation: conversation,
|
|
296
|
+
tool_use: tool_use,
|
|
297
|
+
};
|
|
298
|
+
return completion;
|
|
299
|
+
}
|
|
300
|
+
async requestTextCompletionStream(prompt, options) {
|
|
301
|
+
const payload = this.preparePayload(prompt, options);
|
|
302
|
+
const executor = this.getExecutor();
|
|
303
|
+
return executor.converseStream({
|
|
304
|
+
...payload,
|
|
305
|
+
}).then((res) => {
|
|
306
|
+
const stream = res.stream;
|
|
307
|
+
if (!stream) {
|
|
308
|
+
throw new Error("[Bedrock] Stream not found in response");
|
|
309
|
+
}
|
|
310
|
+
return transformAsyncIterator(stream, (streamSegment) => {
|
|
311
|
+
return BedrockDriver.getExtractedStream(streamSegment, prompt, options);
|
|
312
|
+
});
|
|
313
|
+
}).catch((err) => {
|
|
314
|
+
this.logger.error("[Bedrock] Failed to stream", err);
|
|
315
|
+
throw err;
|
|
316
|
+
});
|
|
317
|
+
}
|
|
318
|
+
preparePayload(prompt, options) {
|
|
319
|
+
const model_options = options.model_options ?? { _option_id: "text-fallback" };
|
|
320
|
+
let additionalField = {};
|
|
321
|
+
if (options.model.includes("amazon")) {
|
|
322
|
+
if (options.result_schema) {
|
|
323
|
+
prompt.messages = converseJSONprefill(prompt.messages);
|
|
324
|
+
}
|
|
325
|
+
//Titan models also exists but does not support any additional options
|
|
326
|
+
if (options.model.includes("nova")) {
|
|
327
|
+
additionalField = { inferenceConfig: { topK: model_options.top_k } };
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
else if (options.model.includes("claude")) {
|
|
331
|
+
const claude_options = options.model_options;
|
|
332
|
+
const thinking = claude_options.thinking_mode ?? false;
|
|
333
|
+
if (options.result_schema && !thinking) {
|
|
334
|
+
prompt.messages = converseJSONprefill(prompt.messages);
|
|
335
|
+
}
|
|
336
|
+
if (options.model.includes("claude-3-7") || options.model.includes("-4-")) {
|
|
337
|
+
additionalField = {
|
|
338
|
+
...additionalField,
|
|
339
|
+
reasoning_config: {
|
|
340
|
+
type: thinking ? "enabled" : "disabled",
|
|
341
|
+
budget_tokens: thinking ? (claude_options.thinking_budget_tokens ?? 1024) : undefined,
|
|
342
|
+
}
|
|
343
|
+
};
|
|
344
|
+
if (thinking && options.model.includes("claude-3-7-sonnet") &&
|
|
345
|
+
((claude_options.max_tokens ?? 0) > 64000 || (claude_options.thinking_budget_tokens ?? 0) > 64000)) {
|
|
346
|
+
additionalField = {
|
|
347
|
+
...additionalField,
|
|
348
|
+
anthropic_beta: ["output-128k-2025-02-19"]
|
|
349
|
+
};
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
//Needs max_tokens to be set
|
|
353
|
+
if (!model_options.max_tokens) {
|
|
354
|
+
model_options.max_tokens = maxTokenFallbackClaude(options);
|
|
355
|
+
}
|
|
356
|
+
additionalField = { ...additionalField, top_k: model_options.top_k };
|
|
357
|
+
}
|
|
358
|
+
else if (options.model.includes("meta")) {
|
|
359
|
+
//LLaMA models support no additional options
|
|
360
|
+
}
|
|
361
|
+
else if (options.model.includes("mistral")) {
|
|
362
|
+
//7B instruct and 8x7B instruct
|
|
363
|
+
if (options.model.includes("7b")) {
|
|
364
|
+
additionalField = { top_k: model_options.top_k };
|
|
365
|
+
//Does not support system messages
|
|
366
|
+
if (prompt.system && prompt.system?.length != 0) {
|
|
367
|
+
prompt.messages?.push(converseSystemToMessages(prompt.system));
|
|
368
|
+
prompt.system = undefined;
|
|
369
|
+
prompt.messages = converseConcatMessages(prompt.messages);
|
|
370
|
+
}
|
|
371
|
+
if (options.result_schema) {
|
|
372
|
+
prompt.messages = converseJSONprefill(prompt.messages);
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
else {
|
|
376
|
+
//Other models such as Mistral Small,Large and Large 2
|
|
377
|
+
//Support no additional fields.
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
else if (options.model.includes("ai21")) {
|
|
381
|
+
//Jamba models support no additional options
|
|
382
|
+
//Jurassic 2 models do.
|
|
383
|
+
if (options.model.includes("j2")) {
|
|
384
|
+
additionalField = {
|
|
385
|
+
presencePenalty: { scale: model_options.presence_penalty },
|
|
386
|
+
frequencyPenalty: { scale: model_options.frequency_penalty },
|
|
387
|
+
};
|
|
388
|
+
//Does not support system messages
|
|
389
|
+
if (prompt.system && prompt.system?.length != 0) {
|
|
390
|
+
prompt.messages?.push(converseSystemToMessages(prompt.system));
|
|
391
|
+
prompt.system = undefined;
|
|
392
|
+
prompt.messages = converseConcatMessages(prompt.messages);
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
else if (options.model.includes("cohere.command")) {
|
|
397
|
+
// If last message is "```json", remove it.
|
|
398
|
+
//Command R and R plus
|
|
399
|
+
if (options.model.includes("cohere.command-r")) {
|
|
400
|
+
additionalField = {
|
|
401
|
+
k: model_options.top_k,
|
|
402
|
+
frequency_penalty: model_options.frequency_penalty,
|
|
403
|
+
presence_penalty: model_options.presence_penalty,
|
|
404
|
+
};
|
|
405
|
+
}
|
|
406
|
+
else {
|
|
407
|
+
// Command non-R
|
|
408
|
+
additionalField = { k: model_options.top_k };
|
|
409
|
+
//Does not support system messages
|
|
410
|
+
if (prompt.system && prompt.system?.length != 0) {
|
|
411
|
+
prompt.messages?.push(converseSystemToMessages(prompt.system));
|
|
412
|
+
prompt.system = undefined;
|
|
413
|
+
prompt.messages = converseConcatMessages(prompt.messages);
|
|
414
|
+
}
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
else if (options.model.includes("palmyra")) {
|
|
418
|
+
const palmyraOptions = options.model_options;
|
|
419
|
+
additionalField = {
|
|
420
|
+
seed: palmyraOptions?.seed,
|
|
421
|
+
presence_penalty: palmyraOptions?.presence_penalty,
|
|
422
|
+
frequency_penalty: palmyraOptions?.frequency_penalty,
|
|
423
|
+
min_tokens: palmyraOptions?.min_tokens,
|
|
424
|
+
};
|
|
425
|
+
}
|
|
426
|
+
else if (options.model.includes("deepseek")) {
|
|
427
|
+
//DeepSeek models support no additional options
|
|
428
|
+
}
|
|
429
|
+
//If last message is "```json", add corresponding ``` as a stop sequence.
|
|
430
|
+
if (prompt.messages && prompt.messages.length > 0) {
|
|
431
|
+
if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
|
|
432
|
+
let stopSeq = model_options.stop_sequence;
|
|
433
|
+
if (!stopSeq) {
|
|
434
|
+
model_options.stop_sequence = ["```"];
|
|
435
|
+
}
|
|
436
|
+
else if (!stopSeq.includes("```")) {
|
|
437
|
+
stopSeq.push("```");
|
|
438
|
+
model_options.stop_sequence = stopSeq;
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
const tool_defs = getToolDefinitions(options.tools);
|
|
443
|
+
const request = {
|
|
444
|
+
messages: prompt.messages,
|
|
445
|
+
system: prompt.system,
|
|
446
|
+
modelId: options.model,
|
|
447
|
+
inferenceConfig: {
|
|
448
|
+
maxTokens: model_options.max_tokens,
|
|
449
|
+
temperature: model_options.temperature,
|
|
450
|
+
topP: model_options.top_p,
|
|
451
|
+
stopSequences: model_options.stop_sequence,
|
|
452
|
+
},
|
|
453
|
+
additionalModelRequestFields: {
|
|
454
|
+
...additionalField,
|
|
455
|
+
}
|
|
456
|
+
};
|
|
457
|
+
//Only add tools if they are defined
|
|
458
|
+
if (tool_defs) {
|
|
459
|
+
request.toolConfig = {
|
|
460
|
+
tools: tool_defs,
|
|
461
|
+
};
|
|
462
|
+
}
|
|
463
|
+
return request;
|
|
464
|
+
}
|
|
465
|
+
async requestImageGeneration(prompt, options) {
|
|
466
|
+
if (options.output_modality !== Modalities.image) {
|
|
467
|
+
throw new Error(`Image generation requires image output_modality`);
|
|
468
|
+
}
|
|
469
|
+
if (options.model_options?._option_id !== "bedrock-nova-canvas") {
|
|
470
|
+
this.logger.warn("Invalid model options", { options: options.model_options });
|
|
471
|
+
}
|
|
472
|
+
const model_options = options.model_options;
|
|
473
|
+
const executor = this.getExecutor();
|
|
474
|
+
const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE;
|
|
475
|
+
this.logger.info("Task type: " + taskType);
|
|
476
|
+
if (typeof prompt === "string") {
|
|
477
|
+
throw new Error("Bad prompt format");
|
|
478
|
+
}
|
|
479
|
+
const payload = await formatNovaImageGenerationPayload(taskType, prompt, options);
|
|
480
|
+
const res = await executor.invokeModel({
|
|
481
|
+
modelId: options.model,
|
|
482
|
+
contentType: "application/json",
|
|
483
|
+
accept: "application/json",
|
|
484
|
+
body: JSON.stringify(payload),
|
|
485
|
+
}, {
|
|
486
|
+
requestTimeout: 60000 * 5
|
|
487
|
+
});
|
|
488
|
+
const decoder = new TextDecoder();
|
|
489
|
+
const body = decoder.decode(res.body);
|
|
490
|
+
const result = JSON.parse(body);
|
|
491
|
+
return {
|
|
492
|
+
error: result.error,
|
|
493
|
+
result: {
|
|
494
|
+
images: result.images,
|
|
495
|
+
}
|
|
496
|
+
};
|
|
497
|
+
}
|
|
498
|
+
async startTraining(dataset, options) {
|
|
499
|
+
//convert options.params to Record<string, string>
|
|
500
|
+
const params = {};
|
|
501
|
+
for (const [key, value] of Object.entries(options.params || {})) {
|
|
502
|
+
params[key] = String(value);
|
|
503
|
+
}
|
|
504
|
+
if (!this.options.training_bucket) {
|
|
505
|
+
throw new Error("Training cannot nbe used since the 'training_bucket' property was not specified in driver options");
|
|
506
|
+
}
|
|
507
|
+
const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
|
|
508
|
+
const stream = await dataset.getStream();
|
|
509
|
+
const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name);
|
|
510
|
+
const service = this.getService();
|
|
511
|
+
const response = await service.send(new CreateModelCustomizationJobCommand({
|
|
512
|
+
jobName: options.name + "-job",
|
|
513
|
+
customModelName: options.name,
|
|
514
|
+
roleArn: this.options.training_role_arn || undefined,
|
|
515
|
+
baseModelIdentifier: options.model,
|
|
516
|
+
clientRequestToken: "llumiverse-" + Date.now(),
|
|
517
|
+
trainingDataConfig: {
|
|
518
|
+
s3Uri: `s3://${upload.Bucket}/${upload.Key}`,
|
|
519
|
+
},
|
|
520
|
+
outputDataConfig: undefined,
|
|
521
|
+
hyperParameters: params,
|
|
522
|
+
//TODO not supported?
|
|
523
|
+
//customizationType: "FINE_TUNING",
|
|
524
|
+
}));
|
|
525
|
+
const job = await service.send(new GetModelCustomizationJobCommand({
|
|
526
|
+
jobIdentifier: response.jobArn
|
|
527
|
+
}));
|
|
528
|
+
return jobInfo(job, response.jobArn);
|
|
529
|
+
}
|
|
530
|
+
async cancelTraining(jobId) {
|
|
531
|
+
const service = this.getService();
|
|
532
|
+
await service.send(new StopModelCustomizationJobCommand({
|
|
533
|
+
jobIdentifier: jobId
|
|
534
|
+
}));
|
|
535
|
+
const job = await service.send(new GetModelCustomizationJobCommand({
|
|
536
|
+
jobIdentifier: jobId
|
|
537
|
+
}));
|
|
538
|
+
return jobInfo(job, jobId);
|
|
539
|
+
}
|
|
540
|
+
async getTrainingJob(jobId) {
|
|
541
|
+
const service = this.getService();
|
|
542
|
+
const job = await service.send(new GetModelCustomizationJobCommand({
|
|
543
|
+
jobIdentifier: jobId
|
|
544
|
+
}));
|
|
545
|
+
return jobInfo(job, jobId);
|
|
546
|
+
}
|
|
547
|
+
// ===================== management API ==================
|
|
548
|
+
async validateConnection() {
|
|
549
|
+
const service = this.getService();
|
|
550
|
+
this.logger.debug("[Bedrock] validating connection", service.config.credentials.name);
|
|
551
|
+
//return true as if the client has been initialized, it means the connection is valid
|
|
552
|
+
return true;
|
|
553
|
+
}
|
|
554
|
+
async listTrainableModels() {
|
|
555
|
+
this.logger.debug("[Bedrock] listing trainable models");
|
|
556
|
+
return this._listModels(m => m.customizationsSupported ? m.customizationsSupported.includes("FINE_TUNING") : false);
|
|
557
|
+
}
|
|
558
|
+
async listModels() {
|
|
559
|
+
this.logger.debug("[Bedrock] listing models");
|
|
560
|
+
// exclude trainable models since they are not executable
|
|
561
|
+
// exclude embedding models, not to be used for typical completions.
|
|
562
|
+
const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false;
|
|
563
|
+
return this._listModels(filter);
|
|
564
|
+
}
|
|
565
|
+
async _listModels(foundationFilter) {
|
|
566
|
+
const service = this.getService();
|
|
567
|
+
const [foundationModelsList, customModelsList, inferenceProfilesList] = await Promise.all([
|
|
568
|
+
service.listFoundationModels({}).catch(() => {
|
|
569
|
+
this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
|
|
570
|
+
return undefined;
|
|
571
|
+
}),
|
|
572
|
+
service.listCustomModels({}).catch(() => {
|
|
573
|
+
this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
|
|
574
|
+
return undefined;
|
|
575
|
+
}),
|
|
576
|
+
service.listInferenceProfiles({}).catch(() => {
|
|
577
|
+
this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions.");
|
|
578
|
+
return undefined;
|
|
579
|
+
}),
|
|
580
|
+
]);
|
|
581
|
+
if (!foundationModelsList?.modelSummaries) {
|
|
582
|
+
throw new Error("Foundation models not found");
|
|
583
|
+
}
|
|
584
|
+
let foundationModels = foundationModelsList.modelSummaries || [];
|
|
585
|
+
if (foundationFilter) {
|
|
586
|
+
foundationModels = foundationModels.filter(foundationFilter);
|
|
587
|
+
}
|
|
588
|
+
const supportedPublishers = ["amazon", "anthropic", "cohere", "ai21", "mistral", "meta", "deepseek", "writer"];
|
|
589
|
+
const unsupportedModelsByPublisher = {
|
|
590
|
+
amazon: ["titan-image-generator", "nova-reel", "nova-sonic", "rerank"],
|
|
591
|
+
anthropic: [],
|
|
592
|
+
cohere: ["rerank"],
|
|
593
|
+
ai21: [],
|
|
594
|
+
mistral: [],
|
|
595
|
+
meta: [],
|
|
596
|
+
deepseek: [],
|
|
597
|
+
writer: [],
|
|
598
|
+
};
|
|
599
|
+
// Helper function to check if model should be filtered out
|
|
600
|
+
const shouldIncludeModel = (modelId, providerName) => {
|
|
601
|
+
if (!modelId || !providerName)
|
|
602
|
+
return false;
|
|
603
|
+
const normalizedProvider = providerName.toLowerCase();
|
|
604
|
+
// Check if provider is supported
|
|
605
|
+
const isProviderSupported = supportedPublishers.some(provider => normalizedProvider.includes(provider));
|
|
606
|
+
if (!isProviderSupported)
|
|
607
|
+
return false;
|
|
608
|
+
// Check if model is in the unsupported list for its provider
|
|
609
|
+
for (const provider of supportedPublishers) {
|
|
610
|
+
if (normalizedProvider.includes(provider)) {
|
|
611
|
+
const unsupportedModels = unsupportedModelsByPublisher[provider] || [];
|
|
612
|
+
return !unsupportedModels.some(unsupported => modelId.toLowerCase().includes(unsupported));
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
return true;
|
|
616
|
+
};
|
|
617
|
+
foundationModels = foundationModels.filter(m => shouldIncludeModel(m.modelId, m.providerName));
|
|
618
|
+
const aiModels = foundationModels.map((m) => {
|
|
619
|
+
if (!m.modelId) {
|
|
620
|
+
throw new Error("modelId not found");
|
|
621
|
+
}
|
|
622
|
+
const modelCapability = getModelCapabilities(m.modelArn ?? m.modelId, this.provider);
|
|
623
|
+
const model = {
|
|
624
|
+
id: m.modelArn ?? m.modelId,
|
|
625
|
+
name: `${m.providerName} ${m.modelName}`,
|
|
626
|
+
provider: this.provider,
|
|
627
|
+
//description: ``,
|
|
628
|
+
owner: m.providerName,
|
|
629
|
+
can_stream: m.responseStreamingSupported ?? false,
|
|
630
|
+
input_modalities: m.inputModalities ? formatAmazonModalities(m.inputModalities) : modelModalitiesToArray(modelCapability.input),
|
|
631
|
+
output_modalities: m.outputModalities ? formatAmazonModalities(m.outputModalities) : modelModalitiesToArray(modelCapability.input),
|
|
632
|
+
tool_support: modelCapability.tool_support,
|
|
633
|
+
};
|
|
634
|
+
return model;
|
|
635
|
+
});
|
|
636
|
+
//add custom models
|
|
637
|
+
if (customModelsList?.modelSummaries) {
|
|
638
|
+
customModelsList.modelSummaries.forEach((m) => {
|
|
639
|
+
if (!m.modelArn) {
|
|
640
|
+
throw new Error("Model ID not found");
|
|
641
|
+
}
|
|
642
|
+
const modelCapability = getModelCapabilities(m.modelArn, this.provider);
|
|
643
|
+
const model = {
|
|
644
|
+
id: m.modelArn,
|
|
645
|
+
name: m.modelName ?? m.modelArn,
|
|
646
|
+
provider: this.provider,
|
|
647
|
+
description: `Custom model from ${m.baseModelName}`,
|
|
648
|
+
is_custom: true,
|
|
649
|
+
input_modalities: modelModalitiesToArray(modelCapability.input),
|
|
650
|
+
output_modalities: modelModalitiesToArray(modelCapability.output),
|
|
651
|
+
tool_support: modelCapability.tool_support,
|
|
652
|
+
};
|
|
653
|
+
aiModels.push(model);
|
|
654
|
+
this.validateConnection;
|
|
655
|
+
});
|
|
656
|
+
}
|
|
657
|
+
//add inference profiles
|
|
658
|
+
if (inferenceProfilesList?.inferenceProfileSummaries) {
|
|
659
|
+
inferenceProfilesList.inferenceProfileSummaries.forEach((p) => {
|
|
660
|
+
if (!p.inferenceProfileArn) {
|
|
661
|
+
throw new Error("Profile ARN not found");
|
|
662
|
+
}
|
|
663
|
+
// Apply the same filtering logic to inference profiles based on their name
|
|
664
|
+
const profileId = p.inferenceProfileId || "";
|
|
665
|
+
const profileName = p.inferenceProfileName || "";
|
|
666
|
+
// Extract provider name from profile name or ID
|
|
667
|
+
let providerName = "";
|
|
668
|
+
for (const provider of supportedPublishers) {
|
|
669
|
+
if (profileName.toLowerCase().includes(provider) || profileId.toLowerCase().includes(provider)) {
|
|
670
|
+
providerName = provider;
|
|
671
|
+
break;
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
const modelCapability = getModelCapabilities(p.inferenceProfileArn ?? p.inferenceProfileId, this.provider);
|
|
675
|
+
if (providerName && shouldIncludeModel(profileId, providerName)) {
|
|
676
|
+
const model = {
|
|
677
|
+
id: p.inferenceProfileArn ?? p.inferenceProfileId,
|
|
678
|
+
name: p.inferenceProfileName ?? p.inferenceProfileArn,
|
|
679
|
+
provider: this.provider,
|
|
680
|
+
input_modalities: modelModalitiesToArray(modelCapability.input),
|
|
681
|
+
output_modalities: modelModalitiesToArray(modelCapability.output),
|
|
682
|
+
tool_support: modelCapability.tool_support,
|
|
683
|
+
};
|
|
684
|
+
aiModels.push(model);
|
|
685
|
+
}
|
|
686
|
+
});
|
|
687
|
+
}
|
|
688
|
+
return aiModels;
|
|
689
|
+
}
|
|
690
|
+
async generateEmbeddings({ text, image, model }) {
|
|
691
|
+
this.logger.info("[Bedrock] Generating embeddings with model " + model);
|
|
692
|
+
const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0";
|
|
693
|
+
const modelID = model ?? defaultModel;
|
|
694
|
+
const invokeBody = {
|
|
695
|
+
inputText: text,
|
|
696
|
+
inputImage: image
|
|
697
|
+
};
|
|
698
|
+
const executor = this.getExecutor();
|
|
699
|
+
const res = await executor.invokeModel({
|
|
700
|
+
modelId: modelID,
|
|
701
|
+
contentType: "application/json",
|
|
702
|
+
body: JSON.stringify(invokeBody),
|
|
703
|
+
});
|
|
704
|
+
const decoder = new TextDecoder();
|
|
705
|
+
const body = decoder.decode(res.body);
|
|
706
|
+
const result = JSON.parse(body);
|
|
707
|
+
if (!result.embedding) {
|
|
708
|
+
throw new Error("Embeddings not found");
|
|
709
|
+
}
|
|
710
|
+
return {
|
|
711
|
+
values: result.embedding,
|
|
712
|
+
model: modelID,
|
|
713
|
+
token_count: result.inputTextTokenCount
|
|
714
|
+
};
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
function jobInfo(job, jobId) {
|
|
718
|
+
const jobStatus = job.status;
|
|
719
|
+
let status = TrainingJobStatus.running;
|
|
720
|
+
let details;
|
|
721
|
+
if (jobStatus === ModelCustomizationJobStatus.COMPLETED) {
|
|
722
|
+
status = TrainingJobStatus.succeeded;
|
|
723
|
+
}
|
|
724
|
+
else if (jobStatus === ModelCustomizationJobStatus.FAILED) {
|
|
725
|
+
status = TrainingJobStatus.failed;
|
|
726
|
+
details = job.failureMessage || "error";
|
|
727
|
+
}
|
|
728
|
+
else if (jobStatus === ModelCustomizationJobStatus.STOPPED) {
|
|
729
|
+
status = TrainingJobStatus.cancelled;
|
|
730
|
+
}
|
|
731
|
+
else {
|
|
732
|
+
status = TrainingJobStatus.running;
|
|
733
|
+
details = jobStatus;
|
|
734
|
+
}
|
|
735
|
+
job.baseModelArn;
|
|
736
|
+
return {
|
|
737
|
+
id: jobId,
|
|
738
|
+
model: job.outputModelArn,
|
|
739
|
+
status,
|
|
740
|
+
details
|
|
741
|
+
};
|
|
742
|
+
}
|
|
743
|
+
function getToolDefinitions(tools) {
|
|
744
|
+
return tools ? tools.map(getToolDefinition) : undefined;
|
|
745
|
+
}
|
|
746
|
+
function getToolDefinition(tool) {
|
|
747
|
+
return {
|
|
748
|
+
toolSpec: {
|
|
749
|
+
name: tool.name,
|
|
750
|
+
description: tool.description,
|
|
751
|
+
inputSchema: {
|
|
752
|
+
json: tool.input_schema,
|
|
753
|
+
}
|
|
754
|
+
}
|
|
755
|
+
};
|
|
756
|
+
}
|
|
757
|
+
/**
|
|
758
|
+
* Update the conversation messages
|
|
759
|
+
* @param prompt
|
|
760
|
+
* @param response
|
|
761
|
+
* @returns
|
|
762
|
+
*/
|
|
763
|
+
function updateConversation(conversation, prompt) {
|
|
764
|
+
const combinedMessages = [...(conversation?.messages || []), ...(prompt.messages || [])];
|
|
765
|
+
const combinedSystem = prompt.system || conversation?.system;
|
|
766
|
+
return {
|
|
767
|
+
modelId: prompt?.modelId || conversation?.modelId,
|
|
768
|
+
messages: combinedMessages.length > 0 ? combinedMessages : [],
|
|
769
|
+
system: combinedSystem && combinedSystem.length > 0 ? combinedSystem : undefined,
|
|
770
|
+
};
|
|
771
|
+
}
|
|
772
|
+
function formatAmazonModalities(modalities) {
|
|
773
|
+
const standardizedModalities = [];
|
|
774
|
+
for (const modality of modalities) {
|
|
775
|
+
if (modality === ModelModality.TEXT) {
|
|
776
|
+
standardizedModalities.push("text");
|
|
777
|
+
}
|
|
778
|
+
else if (modality === ModelModality.IMAGE) {
|
|
779
|
+
standardizedModalities.push("image");
|
|
780
|
+
}
|
|
781
|
+
else if (modality === ModelModality.EMBEDDING) {
|
|
782
|
+
standardizedModalities.push("embedding");
|
|
783
|
+
}
|
|
784
|
+
else if (modality == "SPEECH") {
|
|
785
|
+
standardizedModalities.push("audio");
|
|
786
|
+
}
|
|
787
|
+
else if (modality == "VIDEO") {
|
|
788
|
+
standardizedModalities.push("video");
|
|
789
|
+
}
|
|
790
|
+
else {
|
|
791
|
+
// Handle other modalities as needed
|
|
792
|
+
standardizedModalities.push(modality.toString().toLowerCase());
|
|
793
|
+
}
|
|
794
|
+
}
|
|
795
|
+
return standardizedModalities;
|
|
796
|
+
}
|
|
797
|
+
//# sourceMappingURL=index.js.map
|