@llumiverse/drivers 0.15.0 → 0.17.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -3
- package/lib/cjs/adobe/firefly.js +119 -0
- package/lib/cjs/adobe/firefly.js.map +1 -0
- package/lib/cjs/bedrock/converse.js +177 -0
- package/lib/cjs/bedrock/converse.js.map +1 -0
- package/lib/cjs/bedrock/index.js +338 -234
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/bedrock/nova-image-payload.js +207 -0
- package/lib/cjs/bedrock/nova-image-payload.js.map +1 -0
- package/lib/cjs/groq/index.js +34 -9
- package/lib/cjs/groq/index.js.map +1 -1
- package/lib/cjs/huggingface_ie.js +28 -12
- package/lib/cjs/huggingface_ie.js.map +1 -1
- package/lib/cjs/index.js +1 -0
- package/lib/cjs/index.js.map +1 -1
- package/lib/cjs/mistral/index.js +32 -13
- package/lib/cjs/mistral/index.js.map +1 -1
- package/lib/cjs/mistral/types.js.map +1 -1
- package/lib/cjs/openai/index.js +164 -29
- package/lib/cjs/openai/index.js.map +1 -1
- package/lib/cjs/replicate.js +19 -34
- package/lib/cjs/replicate.js.map +1 -1
- package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -1
- package/lib/cjs/test/index.js.map +1 -1
- package/lib/cjs/togetherai/index.js +40 -10
- package/lib/cjs/togetherai/index.js.map +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-image.js +26 -0
- package/lib/cjs/vertexai/embeddings/embeddings-image.js.map +1 -0
- package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
- package/lib/cjs/vertexai/index.js +134 -35
- package/lib/cjs/vertexai/index.js.map +1 -1
- package/lib/cjs/vertexai/models/claude.js +252 -0
- package/lib/cjs/vertexai/models/claude.js.map +1 -0
- package/lib/cjs/vertexai/models/gemini.js +172 -25
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/vertexai/models/imagen.js +317 -0
- package/lib/cjs/vertexai/models/imagen.js.map +1 -0
- package/lib/cjs/vertexai/models.js +12 -64
- package/lib/cjs/vertexai/models.js.map +1 -1
- package/lib/cjs/watsonx/index.js +47 -10
- package/lib/cjs/watsonx/index.js.map +1 -1
- package/lib/cjs/xai/index.js +71 -0
- package/lib/cjs/xai/index.js.map +1 -0
- package/lib/esm/adobe/firefly.js +115 -0
- package/lib/esm/adobe/firefly.js.map +1 -0
- package/lib/esm/bedrock/converse.js +171 -0
- package/lib/esm/bedrock/converse.js.map +1 -0
- package/lib/esm/bedrock/index.js +339 -232
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/bedrock/nova-image-payload.js +203 -0
- package/lib/esm/bedrock/nova-image-payload.js.map +1 -0
- package/lib/esm/groq/index.js +34 -9
- package/lib/esm/groq/index.js.map +1 -1
- package/lib/esm/huggingface_ie.js +29 -13
- package/lib/esm/huggingface_ie.js.map +1 -1
- package/lib/esm/index.js +1 -0
- package/lib/esm/index.js.map +1 -1
- package/lib/esm/mistral/index.js +32 -13
- package/lib/esm/mistral/index.js.map +1 -1
- package/lib/esm/mistral/types.js.map +1 -1
- package/lib/esm/openai/index.js +165 -30
- package/lib/esm/openai/index.js.map +1 -1
- package/lib/esm/replicate.js +19 -34
- package/lib/esm/replicate.js.map +1 -1
- package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -1
- package/lib/esm/test/index.js.map +1 -1
- package/lib/esm/togetherai/index.js +40 -10
- package/lib/esm/togetherai/index.js.map +1 -1
- package/lib/esm/vertexai/embeddings/embeddings-image.js +23 -0
- package/lib/esm/vertexai/embeddings/embeddings-image.js.map +1 -0
- package/lib/esm/vertexai/embeddings/embeddings-text.js +1 -1
- package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -1
- package/lib/esm/vertexai/index.js +135 -37
- package/lib/esm/vertexai/index.js.map +1 -1
- package/lib/esm/vertexai/models/claude.js +247 -0
- package/lib/esm/vertexai/models/claude.js.map +1 -0
- package/lib/esm/vertexai/models/gemini.js +173 -26
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/vertexai/models/imagen.js +310 -0
- package/lib/esm/vertexai/models/imagen.js.map +1 -0
- package/lib/esm/vertexai/models.js +12 -61
- package/lib/esm/vertexai/models.js.map +1 -1
- package/lib/esm/watsonx/index.js +47 -10
- package/lib/esm/watsonx/index.js.map +1 -1
- package/lib/esm/xai/index.js +64 -0
- package/lib/esm/xai/index.js.map +1 -0
- package/lib/types/adobe/firefly.d.ts +30 -0
- package/lib/types/adobe/firefly.d.ts.map +1 -0
- package/lib/types/bedrock/converse.d.ts +8 -0
- package/lib/types/bedrock/converse.d.ts.map +1 -0
- package/lib/types/bedrock/index.d.ts +27 -12
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/bedrock/nova-image-payload.d.ts +74 -0
- package/lib/types/bedrock/nova-image-payload.d.ts.map +1 -0
- package/lib/types/bedrock/payloads.d.ts +9 -65
- package/lib/types/bedrock/payloads.d.ts.map +1 -1
- package/lib/types/groq/index.d.ts +3 -3
- package/lib/types/groq/index.d.ts.map +1 -1
- package/lib/types/huggingface_ie.d.ts +5 -7
- package/lib/types/huggingface_ie.d.ts.map +1 -1
- package/lib/types/index.d.ts +1 -0
- package/lib/types/index.d.ts.map +1 -1
- package/lib/types/mistral/index.d.ts +4 -4
- package/lib/types/mistral/index.d.ts.map +1 -1
- package/lib/types/mistral/types.d.ts +1 -0
- package/lib/types/mistral/types.d.ts.map +1 -1
- package/lib/types/openai/index.d.ts +5 -4
- package/lib/types/openai/index.d.ts.map +1 -1
- package/lib/types/replicate.d.ts +4 -9
- package/lib/types/replicate.d.ts.map +1 -1
- package/lib/types/test/index.d.ts +2 -2
- package/lib/types/test/index.d.ts.map +1 -1
- package/lib/types/togetherai/index.d.ts +4 -4
- package/lib/types/togetherai/index.d.ts.map +1 -1
- package/lib/types/vertexai/embeddings/embeddings-image.d.ts +11 -0
- package/lib/types/vertexai/embeddings/embeddings-image.d.ts.map +1 -0
- package/lib/types/vertexai/index.d.ts +21 -8
- package/lib/types/vertexai/index.d.ts.map +1 -1
- package/lib/types/vertexai/models/claude.d.ts +20 -0
- package/lib/types/vertexai/models/claude.d.ts.map +1 -0
- package/lib/types/vertexai/models/gemini.d.ts +4 -4
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
- package/lib/types/vertexai/models/imagen.d.ts +75 -0
- package/lib/types/vertexai/models/imagen.d.ts.map +1 -0
- package/lib/types/vertexai/models.d.ts +3 -6
- package/lib/types/vertexai/models.d.ts.map +1 -1
- package/lib/types/watsonx/index.d.ts +3 -3
- package/lib/types/watsonx/index.d.ts.map +1 -1
- package/lib/types/watsonx/interfaces.d.ts +4 -0
- package/lib/types/watsonx/interfaces.d.ts.map +1 -1
- package/lib/types/xai/index.d.ts +19 -0
- package/lib/types/xai/index.d.ts.map +1 -0
- package/package.json +25 -26
- package/src/adobe/firefly.ts +207 -0
- package/src/bedrock/converse.ts +194 -0
- package/src/bedrock/index.ts +359 -240
- package/src/bedrock/nova-image-payload.ts +309 -0
- package/src/bedrock/payloads.ts +12 -66
- package/src/groq/index.ts +35 -12
- package/src/huggingface_ie.ts +34 -13
- package/src/index.ts +1 -0
- package/src/mistral/index.ts +35 -13
- package/src/mistral/types.ts +2 -1
- package/src/openai/index.ts +186 -35
- package/src/replicate.ts +24 -35
- package/src/test/TestValidationErrorCompletionStream.ts +2 -2
- package/src/test/index.ts +3 -2
- package/src/togetherai/index.ts +44 -12
- package/src/vertexai/embeddings/embeddings-image.ts +50 -0
- package/src/vertexai/embeddings/embeddings-text.ts +1 -1
- package/src/vertexai/index.ts +186 -46
- package/src/vertexai/models/claude.ts +281 -0
- package/src/vertexai/models/gemini.ts +186 -29
- package/src/vertexai/models/imagen.ts +401 -0
- package/src/vertexai/models.ts +16 -78
- package/src/watsonx/index.ts +50 -12
- package/src/watsonx/interfaces.ts +4 -0
- package/src/xai/index.ts +110 -0
package/lib/esm/bedrock/index.js
CHANGED
|
@@ -1,18 +1,39 @@
|
|
|
1
1
|
import { Bedrock, CreateModelCustomizationJobCommand, GetModelCustomizationJobCommand, ModelCustomizationJobStatus, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock";
|
|
2
2
|
import { BedrockRuntime } from "@aws-sdk/client-bedrock-runtime";
|
|
3
3
|
import { S3Client } from "@aws-sdk/client-s3";
|
|
4
|
-
import { AbstractDriver, TrainingJobStatus } from "@llumiverse/core";
|
|
4
|
+
import { AbstractDriver, Modalities, TrainingJobStatus } from "@llumiverse/core";
|
|
5
5
|
import { transformAsyncIterator } from "@llumiverse/core/async";
|
|
6
|
-
import {
|
|
7
|
-
import
|
|
6
|
+
import { formatNovaPrompt } from "@llumiverse/core/formatters";
|
|
7
|
+
import { LRUCache } from "mnemonist";
|
|
8
|
+
import { converseConcatMessages, converseRemoveJSONprefill, converseSystemToMessages, fortmatConversePrompt } from "./converse.js";
|
|
9
|
+
import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js";
|
|
8
10
|
import { forceUploadFile } from "./s3.js";
|
|
9
|
-
const { LRUCache } = mnemonist;
|
|
10
11
|
const supportStreamingCache = new LRUCache(4096);
|
|
12
|
+
var BedrockModelType;
|
|
13
|
+
(function (BedrockModelType) {
|
|
14
|
+
BedrockModelType["FoundationModel"] = "foundation-model";
|
|
15
|
+
BedrockModelType["InferenceProfile"] = "inference-profile";
|
|
16
|
+
BedrockModelType["CustomModel"] = "custom-model";
|
|
17
|
+
BedrockModelType["Unknown"] = "unknown";
|
|
18
|
+
})(BedrockModelType || (BedrockModelType = {}));
|
|
19
|
+
;
|
|
20
|
+
function converseFinishReason(reason) {
|
|
21
|
+
//Possible values:
|
|
22
|
+
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
|
|
23
|
+
if (!reason)
|
|
24
|
+
return undefined;
|
|
25
|
+
switch (reason) {
|
|
26
|
+
case 'end_turn': return "stop";
|
|
27
|
+
case 'max_tokens': return "length";
|
|
28
|
+
default: return reason;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
11
31
|
export class BedrockDriver extends AbstractDriver {
|
|
12
32
|
static PROVIDER = "bedrock";
|
|
13
33
|
provider = BedrockDriver.PROVIDER;
|
|
14
34
|
_executor;
|
|
15
35
|
_service;
|
|
36
|
+
_service_region;
|
|
16
37
|
constructor(options) {
|
|
17
38
|
super(options);
|
|
18
39
|
if (!options.region) {
|
|
@@ -28,241 +49,334 @@ export class BedrockDriver extends AbstractDriver {
|
|
|
28
49
|
}
|
|
29
50
|
return this._executor;
|
|
30
51
|
}
|
|
31
|
-
getService() {
|
|
32
|
-
if (!this._service) {
|
|
52
|
+
getService(region = this.options.region) {
|
|
53
|
+
if (!this._service || this._service_region != region) {
|
|
33
54
|
this._service = new Bedrock({
|
|
34
|
-
region:
|
|
55
|
+
region: region,
|
|
35
56
|
credentials: this.options.credentials,
|
|
36
57
|
});
|
|
58
|
+
this._service_region = region;
|
|
37
59
|
}
|
|
38
60
|
return this._service;
|
|
39
61
|
}
|
|
40
62
|
async formatPrompt(segments, opts) {
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
//TODO: need to type better the types aren't checked properly by TS
|
|
44
|
-
return await formatClaudePrompt(segments, opts.result_schema);
|
|
45
|
-
}
|
|
46
|
-
else {
|
|
47
|
-
return await super.formatPrompt(segments, opts);
|
|
63
|
+
if (opts.model.includes("canvas")) {
|
|
64
|
+
return await formatNovaPrompt(segments, opts.result_schema);
|
|
48
65
|
}
|
|
66
|
+
return await fortmatConversePrompt(segments, opts.result_schema);
|
|
49
67
|
}
|
|
50
|
-
|
|
51
|
-
const decoder = new TextDecoder();
|
|
52
|
-
const body = decoder.decode(response.body);
|
|
53
|
-
const result = JSON.parse(body);
|
|
54
|
-
const getTextAnsStopReason = () => {
|
|
55
|
-
if (result.generation) {
|
|
56
|
-
// LLAMA2
|
|
57
|
-
return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
|
|
58
|
-
}
|
|
59
|
-
else if (result.generations) {
|
|
60
|
-
// Cohere
|
|
61
|
-
return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
|
|
62
|
-
}
|
|
63
|
-
else if (result.chat_history) {
|
|
64
|
-
//Cohere Command R
|
|
65
|
-
return [result.text, cohereFinishReason(result.finish_reason)];
|
|
66
|
-
}
|
|
67
|
-
else if (result.completions) {
|
|
68
|
-
//A21
|
|
69
|
-
return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
|
|
70
|
-
}
|
|
71
|
-
else if (result.content) {
|
|
72
|
-
// Claude
|
|
73
|
-
//if last prompt.messages is {, add { to the response
|
|
74
|
-
const p = prompt;
|
|
75
|
-
const lastMessage = p.messages[p.messages.length - 1];
|
|
76
|
-
const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text;
|
|
77
|
-
return [res, claudeFinishReason(result.stop_reason)];
|
|
78
|
-
}
|
|
79
|
-
else if (result.outputs) {
|
|
80
|
-
// mistral
|
|
81
|
-
return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
|
|
82
|
-
}
|
|
83
|
-
else if (result.results) {
|
|
84
|
-
// Amazon Titan
|
|
85
|
-
return [result.results[0]?.outputText ?? '', titanFinishReason(result.results[0]?.completionReason)];
|
|
86
|
-
}
|
|
87
|
-
else if (result.completion) { // TODO: who uses this?
|
|
88
|
-
return [result.completion];
|
|
89
|
-
}
|
|
90
|
-
else {
|
|
91
|
-
return [result.toString()];
|
|
92
|
-
}
|
|
93
|
-
};
|
|
94
|
-
const [text, finish_reason] = getTextAnsStopReason();
|
|
95
|
-
const promptLength = typeof prompt === 'string' ? prompt.length :
|
|
96
|
-
(prompt.system || '').length + prompt.messages.reduce((acc, m) => acc + m.content.length, 0);
|
|
68
|
+
static getExtractedExecuton(result, _prompt) {
|
|
97
69
|
return {
|
|
98
|
-
result: text,
|
|
70
|
+
result: result.output?.message?.content?.map(c => c.text).join("\n") ?? "",
|
|
99
71
|
token_usage: {
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
total:
|
|
72
|
+
prompt: result.usage?.inputTokens,
|
|
73
|
+
result: result.usage?.outputTokens,
|
|
74
|
+
total: result.usage?.totalTokens,
|
|
103
75
|
},
|
|
104
|
-
finish_reason
|
|
76
|
+
finish_reason: converseFinishReason(result.stopReason),
|
|
77
|
+
};
|
|
78
|
+
}
|
|
79
|
+
;
|
|
80
|
+
static getExtractedStream(result, _prompt) {
|
|
81
|
+
let output = "";
|
|
82
|
+
let stop_reason = "";
|
|
83
|
+
let token_usage;
|
|
84
|
+
if (result.contentBlockDelta) {
|
|
85
|
+
output = result.contentBlockDelta.delta?.text ?? "";
|
|
86
|
+
}
|
|
87
|
+
if (result.messageStop) {
|
|
88
|
+
stop_reason = result.messageStop.stopReason ?? "";
|
|
89
|
+
}
|
|
90
|
+
if (result.metadata) {
|
|
91
|
+
token_usage = {
|
|
92
|
+
prompt: result.metadata.usage?.inputTokens,
|
|
93
|
+
result: result.metadata.usage?.outputTokens,
|
|
94
|
+
total: result.metadata.usage?.totalTokens,
|
|
95
|
+
};
|
|
96
|
+
}
|
|
97
|
+
return {
|
|
98
|
+
result: output,
|
|
99
|
+
token_usage: token_usage,
|
|
100
|
+
finish_reason: converseFinishReason(stop_reason),
|
|
105
101
|
};
|
|
106
102
|
}
|
|
107
|
-
|
|
103
|
+
;
|
|
104
|
+
async requestTextCompletion(prompt, options) {
|
|
108
105
|
const payload = this.preparePayload(prompt, options);
|
|
109
106
|
const executor = this.getExecutor();
|
|
110
|
-
const res = await executor.
|
|
111
|
-
|
|
112
|
-
contentType: "application/json",
|
|
113
|
-
body: JSON.stringify(payload),
|
|
107
|
+
const res = await executor.converse({
|
|
108
|
+
...payload,
|
|
114
109
|
});
|
|
115
|
-
const completion =
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
}
|
|
110
|
+
const completion = {
|
|
111
|
+
...BedrockDriver.getExtractedExecuton(res, prompt),
|
|
112
|
+
original_response: options.include_original_response ? res : undefined,
|
|
113
|
+
};
|
|
119
114
|
return completion;
|
|
120
115
|
}
|
|
116
|
+
extractRegion(modelString, defaultRegion) {
|
|
117
|
+
// Match region in full ARN pattern
|
|
118
|
+
const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
|
|
119
|
+
if (arnMatch) {
|
|
120
|
+
return arnMatch[1];
|
|
121
|
+
}
|
|
122
|
+
// Match common AWS regions directly in string
|
|
123
|
+
const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
|
|
124
|
+
if (regionMatch) {
|
|
125
|
+
return regionMatch[0];
|
|
126
|
+
}
|
|
127
|
+
return defaultRegion;
|
|
128
|
+
}
|
|
129
|
+
async getCanStream(model, type) {
|
|
130
|
+
let canStream = false;
|
|
131
|
+
let error = null;
|
|
132
|
+
const region = this.extractRegion(model, this.options.region);
|
|
133
|
+
if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) {
|
|
134
|
+
try {
|
|
135
|
+
const response = await this.getService(region).getFoundationModel({
|
|
136
|
+
modelIdentifier: model
|
|
137
|
+
});
|
|
138
|
+
canStream = response.modelDetails?.responseStreamingSupported ?? false;
|
|
139
|
+
return canStream;
|
|
140
|
+
}
|
|
141
|
+
catch (e) {
|
|
142
|
+
error = e;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) {
|
|
146
|
+
try {
|
|
147
|
+
const response = await this.getService(region).getInferenceProfile({
|
|
148
|
+
inferenceProfileIdentifier: model
|
|
149
|
+
});
|
|
150
|
+
canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
|
|
151
|
+
return canStream;
|
|
152
|
+
}
|
|
153
|
+
catch (e) {
|
|
154
|
+
error = e;
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) {
|
|
158
|
+
try {
|
|
159
|
+
const response = await this.getService(region).getCustomModel({
|
|
160
|
+
modelIdentifier: model
|
|
161
|
+
});
|
|
162
|
+
canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
|
|
163
|
+
return canStream;
|
|
164
|
+
}
|
|
165
|
+
catch (e) {
|
|
166
|
+
error = e;
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
if (error) {
|
|
170
|
+
console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
|
|
171
|
+
}
|
|
172
|
+
return canStream;
|
|
173
|
+
}
|
|
121
174
|
async canStream(options) {
|
|
122
175
|
let canStream = supportStreamingCache.get(options.model);
|
|
123
176
|
if (canStream == null) {
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
177
|
+
let type = BedrockModelType.Unknown;
|
|
178
|
+
if (options.model.includes("foundation-model")) {
|
|
179
|
+
type = BedrockModelType.FoundationModel;
|
|
180
|
+
}
|
|
181
|
+
else if (options.model.includes("inference-profile")) {
|
|
182
|
+
type = BedrockModelType.InferenceProfile;
|
|
183
|
+
}
|
|
184
|
+
else if (options.model.includes("custom-model")) {
|
|
185
|
+
type = BedrockModelType.CustomModel;
|
|
186
|
+
}
|
|
187
|
+
canStream = await this.getCanStream(options.model, type);
|
|
128
188
|
supportStreamingCache.set(options.model, canStream);
|
|
129
189
|
}
|
|
130
190
|
return canStream;
|
|
131
191
|
}
|
|
132
|
-
async
|
|
192
|
+
async requestTextCompletionStream(prompt, options) {
|
|
133
193
|
const payload = this.preparePayload(prompt, options);
|
|
134
194
|
const executor = this.getExecutor();
|
|
135
|
-
return executor.
|
|
136
|
-
|
|
137
|
-
contentType: "application/json",
|
|
138
|
-
body: JSON.stringify(payload),
|
|
195
|
+
return executor.converseStream({
|
|
196
|
+
...payload,
|
|
139
197
|
}).then((res) => {
|
|
140
|
-
|
|
141
|
-
|
|
198
|
+
const stream = res.stream;
|
|
199
|
+
if (!stream) {
|
|
200
|
+
throw new Error("[Bedrock] Stream not found in response");
|
|
142
201
|
}
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
if (typeof prompt === 'object' && prompt.messages) {
|
|
146
|
-
const p = prompt;
|
|
147
|
-
const lastMessage = p.messages[p.messages.length - 1];
|
|
148
|
-
return lastMessage.content[0].text === '{';
|
|
149
|
-
}
|
|
150
|
-
return false;
|
|
151
|
-
};
|
|
152
|
-
return transformAsyncIterator(res.body, (stream) => {
|
|
153
|
-
const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
|
|
202
|
+
return transformAsyncIterator(stream, (stream) => {
|
|
203
|
+
//const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
|
|
154
204
|
//console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
}
|
|
158
|
-
else if (segment.completion) { // who is this?
|
|
159
|
-
return segment.completion;
|
|
160
|
-
}
|
|
161
|
-
else if (segment.text) { //cohere
|
|
162
|
-
return segment.text;
|
|
163
|
-
}
|
|
164
|
-
else if (segment.completions) {
|
|
165
|
-
return segment.completions[0].data?.text;
|
|
166
|
-
}
|
|
167
|
-
else if (segment.generation) {
|
|
168
|
-
return segment.generation;
|
|
169
|
-
}
|
|
170
|
-
else if (segment.generations) {
|
|
171
|
-
return segment.generations[0].text;
|
|
172
|
-
}
|
|
173
|
-
else if (segment.outputs) {
|
|
174
|
-
// mistral.mixtral-8x7b-instruct-v0:1
|
|
175
|
-
return segment.outputs[0].text;
|
|
176
|
-
//segment.outputs[0].stop_reason;
|
|
177
|
-
}
|
|
178
|
-
else if (segment.outputText) {
|
|
179
|
-
// Amazon Titan
|
|
180
|
-
return segment.outputText;
|
|
181
|
-
//completionReason
|
|
182
|
-
// token count too
|
|
183
|
-
}
|
|
184
|
-
else {
|
|
185
|
-
segment.toString();
|
|
186
|
-
}
|
|
187
|
-
}, () => addBracket() ? '{' : '');
|
|
205
|
+
return BedrockDriver.getExtractedStream(stream, prompt);
|
|
206
|
+
});
|
|
188
207
|
}).catch((err) => {
|
|
189
208
|
this.logger.error("[Bedrock] Failed to stream", err);
|
|
190
209
|
throw err;
|
|
191
210
|
});
|
|
192
211
|
}
|
|
193
212
|
preparePayload(prompt, options) {
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
temperature: options.temperature,
|
|
202
|
-
max_gen_len: options.max_tokens,
|
|
203
|
-
};
|
|
213
|
+
const model_options = options.model_options;
|
|
214
|
+
let additionalField = {};
|
|
215
|
+
if (options.model.includes("amazon")) {
|
|
216
|
+
//Titan models also exists but does not support any additional options
|
|
217
|
+
if (options.model.includes("nova")) {
|
|
218
|
+
additionalField = { inferenceConfig: { topK: model_options?.top_k } };
|
|
219
|
+
}
|
|
204
220
|
}
|
|
205
|
-
else if (
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
221
|
+
else if (options.model.includes("claude")) {
|
|
222
|
+
if (options.model.includes("claude-3-7")) {
|
|
223
|
+
const thinking_options = options.model_options;
|
|
224
|
+
const thinking = thinking_options?.thinking_mode ?? false;
|
|
225
|
+
if (!model_options?.max_tokens) {
|
|
226
|
+
model_options.max_tokens = thinking ? 128000 : 8192;
|
|
209
227
|
}
|
|
210
|
-
|
|
211
|
-
|
|
228
|
+
additionalField = {
|
|
229
|
+
top_k: model_options?.top_k,
|
|
230
|
+
reasoning_config: {
|
|
231
|
+
type: thinking ? "enabled" : "disabled",
|
|
232
|
+
budget_tokens: thinking_options?.thinking_budget_tokens,
|
|
233
|
+
}
|
|
234
|
+
};
|
|
235
|
+
if (thinking && (thinking_options?.thinking_budget_tokens ?? 0) > 64000) {
|
|
236
|
+
additionalField = {
|
|
237
|
+
...additionalField,
|
|
238
|
+
anthorpic_beta: ["output-128k-2025-02-19"]
|
|
239
|
+
};
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
//Needs max_tokens to be set
|
|
243
|
+
if (!model_options?.max_tokens) {
|
|
244
|
+
if (options.model.includes("claude-3-5")) {
|
|
245
|
+
model_options.max_tokens = 8192;
|
|
246
|
+
//Bug with AWS Converse Sonnet 3.5, does not effect Haiku.
|
|
247
|
+
//See https://github.com/boto/boto3/issues/4279
|
|
248
|
+
if (options.model.includes("claude-3-5-sonnet")) {
|
|
249
|
+
model_options.max_tokens = 4096;
|
|
250
|
+
}
|
|
212
251
|
}
|
|
213
252
|
else {
|
|
214
|
-
|
|
253
|
+
model_options.max_tokens = 4096;
|
|
215
254
|
}
|
|
216
|
-
}
|
|
217
|
-
|
|
218
|
-
anthropic_version: "bedrock-2023-05-31",
|
|
219
|
-
...prompt,
|
|
220
|
-
temperature: options.temperature,
|
|
221
|
-
max_tokens: maxToken(),
|
|
222
|
-
};
|
|
255
|
+
}
|
|
256
|
+
additionalField = { top_k: model_options?.top_k };
|
|
223
257
|
}
|
|
224
|
-
else if (
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
temperature: options.temperature,
|
|
228
|
-
maxTokens: options.max_tokens,
|
|
229
|
-
};
|
|
258
|
+
else if (options.model.includes("meta")) {
|
|
259
|
+
//If last message is "```json", remove it. Model requires the final message to be a user message
|
|
260
|
+
prompt.messages = converseRemoveJSONprefill(prompt.messages);
|
|
230
261
|
}
|
|
231
|
-
else if (
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
262
|
+
else if (options.model.includes("mistral")) {
|
|
263
|
+
//7B instruct and 8x7B instruct
|
|
264
|
+
if (options.model.includes("7b")) {
|
|
265
|
+
additionalField = { top_k: model_options?.top_k };
|
|
266
|
+
//Does not support system messages
|
|
267
|
+
if (prompt.system && prompt.system?.length != 0) {
|
|
268
|
+
prompt.messages?.push(converseSystemToMessages(prompt.system));
|
|
269
|
+
prompt.system = undefined;
|
|
270
|
+
prompt.messages = converseConcatMessages(prompt.messages);
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
else {
|
|
274
|
+
//Other models such as Mistral Small,Large and Large 2
|
|
275
|
+
//Support no additional fields.
|
|
276
|
+
prompt.messages = converseRemoveJSONprefill(prompt.messages);
|
|
277
|
+
}
|
|
237
278
|
}
|
|
238
|
-
else if (
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
279
|
+
else if (options.model.includes("ai21")) {
|
|
280
|
+
//If last message is "```json", remove it. Model requires the final message to be a user message
|
|
281
|
+
prompt.messages = converseRemoveJSONprefill(prompt.messages);
|
|
282
|
+
//Jamba models support no additional options
|
|
283
|
+
//Jurassic 2 models do.
|
|
284
|
+
if (options.model.includes("j2")) {
|
|
285
|
+
additionalField = {
|
|
286
|
+
presencePenalty: { scale: model_options?.presence_penalty },
|
|
287
|
+
frequencyPenalty: { scale: model_options?.frequency_penalty },
|
|
288
|
+
};
|
|
289
|
+
//Does not support system messages
|
|
290
|
+
if (prompt.system && prompt.system?.length != 0) {
|
|
291
|
+
prompt.messages?.push(converseSystemToMessages(prompt.system));
|
|
292
|
+
prompt.system = undefined;
|
|
293
|
+
prompt.messages = converseConcatMessages(prompt.messages);
|
|
294
|
+
}
|
|
295
|
+
}
|
|
244
296
|
}
|
|
245
|
-
else if (
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
297
|
+
else if (options.model.includes("cohere.command")) {
|
|
298
|
+
// If last message is "```json", remove it.
|
|
299
|
+
// Model requires the final message to be a user message or does not support assistant messages
|
|
300
|
+
prompt.messages = converseRemoveJSONprefill(prompt.messages);
|
|
301
|
+
//Command R and R plus
|
|
302
|
+
if (options.model.includes("cohere.command-r")) {
|
|
303
|
+
additionalField = {
|
|
304
|
+
k: model_options?.top_k,
|
|
305
|
+
frequency_penalty: model_options?.frequency_penalty,
|
|
306
|
+
presence_penalty: model_options?.presence_penalty,
|
|
307
|
+
};
|
|
308
|
+
}
|
|
309
|
+
else {
|
|
310
|
+
// Command non-R
|
|
311
|
+
additionalField = { k: model_options?.top_k };
|
|
312
|
+
//Does not support system messages
|
|
313
|
+
if (prompt.system && prompt.system?.length != 0) {
|
|
314
|
+
prompt.messages?.push(converseSystemToMessages(prompt.system));
|
|
315
|
+
prompt.system = undefined;
|
|
316
|
+
prompt.messages = converseConcatMessages(prompt.messages);
|
|
317
|
+
}
|
|
318
|
+
}
|
|
255
319
|
}
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
320
|
+
//If last message is "```json", add corresponding ``` as a stop sequence.
|
|
321
|
+
if (prompt.messages && prompt.messages.length > 0) {
|
|
322
|
+
if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
|
|
323
|
+
let stopSeq = model_options?.stop_sequence;
|
|
324
|
+
if (!stopSeq) {
|
|
325
|
+
model_options.stop_sequence = ["```"];
|
|
326
|
+
}
|
|
327
|
+
else if (!stopSeq.includes("```")) {
|
|
328
|
+
stopSeq.push("```");
|
|
329
|
+
model_options.stop_sequence = stopSeq;
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
return {
|
|
334
|
+
messages: prompt.messages,
|
|
335
|
+
system: prompt.system,
|
|
336
|
+
modelId: options.model,
|
|
337
|
+
inferenceConfig: {
|
|
338
|
+
maxTokens: model_options?.max_tokens,
|
|
339
|
+
temperature: model_options?.temperature,
|
|
340
|
+
topP: model_options?.top_p,
|
|
341
|
+
stopSequences: model_options?.stop_sequence,
|
|
342
|
+
},
|
|
343
|
+
additionalModelRequestFields: {
|
|
344
|
+
...additionalField,
|
|
345
|
+
},
|
|
346
|
+
};
|
|
347
|
+
}
|
|
348
|
+
async requestImageGeneration(prompt, options) {
|
|
349
|
+
if (options.output_modality !== Modalities.image) {
|
|
350
|
+
throw new Error(`Image generation requires image output_modality`);
|
|
262
351
|
}
|
|
263
|
-
|
|
264
|
-
|
|
352
|
+
if (options.model_options?._option_id !== "bedrock-nova-canvas") {
|
|
353
|
+
this.logger.warn("Invalid model options", { options: options.model_options });
|
|
354
|
+
}
|
|
355
|
+
const model_options = options.model_options;
|
|
356
|
+
const executor = this.getExecutor();
|
|
357
|
+
const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE;
|
|
358
|
+
this.logger.info("Task type: " + taskType);
|
|
359
|
+
if (typeof prompt === "string") {
|
|
360
|
+
throw new Error("Bad prompt format");
|
|
265
361
|
}
|
|
362
|
+
const payload = await formatNovaImageGenerationPayload(taskType, prompt, options);
|
|
363
|
+
const res = await executor.invokeModel({
|
|
364
|
+
modelId: options.model,
|
|
365
|
+
contentType: "application/json",
|
|
366
|
+
accept: "application/json",
|
|
367
|
+
body: JSON.stringify(payload),
|
|
368
|
+
}, {
|
|
369
|
+
requestTimeout: 60000 * 5
|
|
370
|
+
});
|
|
371
|
+
const decoder = new TextDecoder();
|
|
372
|
+
const body = decoder.decode(res.body);
|
|
373
|
+
const result = JSON.parse(body);
|
|
374
|
+
return {
|
|
375
|
+
error: result.error,
|
|
376
|
+
result: {
|
|
377
|
+
images: result.images,
|
|
378
|
+
}
|
|
379
|
+
};
|
|
266
380
|
}
|
|
267
381
|
async startTraining(dataset, options) {
|
|
268
382
|
//convert options.params to Record<string, string>
|
|
@@ -327,12 +441,13 @@ export class BedrockDriver extends AbstractDriver {
|
|
|
327
441
|
async listModels() {
|
|
328
442
|
this.logger.debug("[Bedrock] listing models");
|
|
329
443
|
// exclude trainable models since they are not executable
|
|
330
|
-
|
|
444
|
+
// exclude embedding models, not to be used for typical completions.
|
|
445
|
+
const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false;
|
|
331
446
|
return this._listModels(filter);
|
|
332
447
|
}
|
|
333
448
|
async _listModels(foundationFilter) {
|
|
334
449
|
const service = this.getService();
|
|
335
|
-
const [foundationals, customs] = await Promise.all([
|
|
450
|
+
const [foundationals, customs, inferenceProfiles] = await Promise.all([
|
|
336
451
|
service.listFoundationModels({}).catch(() => {
|
|
337
452
|
this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
|
|
338
453
|
return undefined;
|
|
@@ -341,6 +456,10 @@ export class BedrockDriver extends AbstractDriver {
|
|
|
341
456
|
this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
|
|
342
457
|
return undefined;
|
|
343
458
|
}),
|
|
459
|
+
service.listInferenceProfiles({}).catch(() => {
|
|
460
|
+
this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions.");
|
|
461
|
+
return undefined;
|
|
462
|
+
}),
|
|
344
463
|
]);
|
|
345
464
|
if (!foundationals?.modelSummaries) {
|
|
346
465
|
throw new Error("Foundation models not found");
|
|
@@ -349,6 +468,12 @@ export class BedrockDriver extends AbstractDriver {
|
|
|
349
468
|
if (foundationFilter) {
|
|
350
469
|
fmodels = fmodels.filter(foundationFilter);
|
|
351
470
|
}
|
|
471
|
+
const supportedProviders = ["amazon", "anthropic", "cohere", "ai21", "mistral", "meta", "deepseek"];
|
|
472
|
+
fmodels = fmodels.filter((m) => {
|
|
473
|
+
supportedProviders.some((provider) => {
|
|
474
|
+
m.providerName?.includes(provider) ?? false;
|
|
475
|
+
});
|
|
476
|
+
});
|
|
352
477
|
const aimodels = fmodels.map((m) => {
|
|
353
478
|
if (!m.modelId) {
|
|
354
479
|
throw new Error("modelId not found");
|
|
@@ -357,6 +482,7 @@ export class BedrockDriver extends AbstractDriver {
|
|
|
357
482
|
id: m.modelArn ?? m.modelId,
|
|
358
483
|
name: `${m.providerName} ${m.modelName}`,
|
|
359
484
|
provider: this.provider,
|
|
485
|
+
input_modalities: m.inputModalities ?? [],
|
|
360
486
|
//description: ``,
|
|
361
487
|
owner: m.providerName,
|
|
362
488
|
can_stream: m.responseStreamingSupported ?? false,
|
|
@@ -382,16 +508,33 @@ export class BedrockDriver extends AbstractDriver {
|
|
|
382
508
|
this.validateConnection;
|
|
383
509
|
});
|
|
384
510
|
}
|
|
511
|
+
//add inference profiles
|
|
512
|
+
if (inferenceProfiles?.inferenceProfileSummaries) {
|
|
513
|
+
inferenceProfiles.inferenceProfileSummaries.forEach((p) => {
|
|
514
|
+
if (!p.inferenceProfileArn) {
|
|
515
|
+
throw new Error("Profile ARN not found");
|
|
516
|
+
}
|
|
517
|
+
const model = {
|
|
518
|
+
id: p.inferenceProfileArn ?? p.inferenceProfileId,
|
|
519
|
+
name: p.inferenceProfileName ?? p.inferenceProfileArn,
|
|
520
|
+
provider: this.provider,
|
|
521
|
+
};
|
|
522
|
+
aimodels.push(model);
|
|
523
|
+
});
|
|
524
|
+
}
|
|
385
525
|
return aimodels;
|
|
386
526
|
}
|
|
387
|
-
async generateEmbeddings({
|
|
527
|
+
async generateEmbeddings({ text, image, model }) {
|
|
388
528
|
this.logger.info("[Bedrock] Generating embeddings with model " + model);
|
|
529
|
+
const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0";
|
|
530
|
+
const modelID = model ?? defaultModel;
|
|
389
531
|
const invokeBody = {
|
|
390
|
-
inputText:
|
|
532
|
+
inputText: text,
|
|
533
|
+
inputImage: image
|
|
391
534
|
};
|
|
392
535
|
const executor = this.getExecutor();
|
|
393
536
|
const res = await executor.invokeModel({
|
|
394
|
-
modelId:
|
|
537
|
+
modelId: modelID,
|
|
395
538
|
contentType: "application/json",
|
|
396
539
|
body: JSON.stringify(invokeBody),
|
|
397
540
|
});
|
|
@@ -403,7 +546,7 @@ export class BedrockDriver extends AbstractDriver {
|
|
|
403
546
|
}
|
|
404
547
|
return {
|
|
405
548
|
values: result.embedding,
|
|
406
|
-
model:
|
|
549
|
+
model: modelID,
|
|
407
550
|
token_count: result.inputTextTokenCount
|
|
408
551
|
};
|
|
409
552
|
}
|
|
@@ -434,40 +577,4 @@ function jobInfo(job, jobId) {
|
|
|
434
577
|
details
|
|
435
578
|
};
|
|
436
579
|
}
|
|
437
|
-
function claudeFinishReason(reason) {
|
|
438
|
-
if (!reason)
|
|
439
|
-
return undefined;
|
|
440
|
-
switch (reason) {
|
|
441
|
-
case 'end_turn': return "stop";
|
|
442
|
-
case 'max_tokens': return "length";
|
|
443
|
-
default: return reason; //stop_sequence
|
|
444
|
-
}
|
|
445
|
-
}
|
|
446
|
-
function cohereFinishReason(reason) {
|
|
447
|
-
if (!reason)
|
|
448
|
-
return undefined;
|
|
449
|
-
switch (reason) {
|
|
450
|
-
case 'COMPLETE': return "stop";
|
|
451
|
-
case 'MAX_TOKENS': return "length";
|
|
452
|
-
default: return reason;
|
|
453
|
-
}
|
|
454
|
-
}
|
|
455
|
-
function a21FinishReason(reason) {
|
|
456
|
-
if (!reason)
|
|
457
|
-
return undefined;
|
|
458
|
-
switch (reason) {
|
|
459
|
-
case 'endoftext': return "stop";
|
|
460
|
-
case 'length': return "length";
|
|
461
|
-
default: return reason;
|
|
462
|
-
}
|
|
463
|
-
}
|
|
464
|
-
function titanFinishReason(reason) {
|
|
465
|
-
if (!reason)
|
|
466
|
-
return undefined;
|
|
467
|
-
switch (reason) {
|
|
468
|
-
case 'FINISH': return "stop";
|
|
469
|
-
case 'LENGTH': return "length";
|
|
470
|
-
default: return reason;
|
|
471
|
-
}
|
|
472
|
-
}
|
|
473
580
|
//# sourceMappingURL=index.js.map
|