@llumiverse/drivers 0.15.0 → 0.17.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. package/README.md +3 -3
  2. package/lib/cjs/adobe/firefly.js +119 -0
  3. package/lib/cjs/adobe/firefly.js.map +1 -0
  4. package/lib/cjs/bedrock/converse.js +177 -0
  5. package/lib/cjs/bedrock/converse.js.map +1 -0
  6. package/lib/cjs/bedrock/index.js +338 -234
  7. package/lib/cjs/bedrock/index.js.map +1 -1
  8. package/lib/cjs/bedrock/nova-image-payload.js +207 -0
  9. package/lib/cjs/bedrock/nova-image-payload.js.map +1 -0
  10. package/lib/cjs/groq/index.js +34 -9
  11. package/lib/cjs/groq/index.js.map +1 -1
  12. package/lib/cjs/huggingface_ie.js +28 -12
  13. package/lib/cjs/huggingface_ie.js.map +1 -1
  14. package/lib/cjs/index.js +1 -0
  15. package/lib/cjs/index.js.map +1 -1
  16. package/lib/cjs/mistral/index.js +32 -13
  17. package/lib/cjs/mistral/index.js.map +1 -1
  18. package/lib/cjs/mistral/types.js.map +1 -1
  19. package/lib/cjs/openai/index.js +164 -29
  20. package/lib/cjs/openai/index.js.map +1 -1
  21. package/lib/cjs/replicate.js +19 -34
  22. package/lib/cjs/replicate.js.map +1 -1
  23. package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -1
  24. package/lib/cjs/test/index.js.map +1 -1
  25. package/lib/cjs/togetherai/index.js +40 -10
  26. package/lib/cjs/togetherai/index.js.map +1 -1
  27. package/lib/cjs/vertexai/embeddings/embeddings-image.js +26 -0
  28. package/lib/cjs/vertexai/embeddings/embeddings-image.js.map +1 -0
  29. package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -1
  30. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
  31. package/lib/cjs/vertexai/index.js +134 -35
  32. package/lib/cjs/vertexai/index.js.map +1 -1
  33. package/lib/cjs/vertexai/models/claude.js +252 -0
  34. package/lib/cjs/vertexai/models/claude.js.map +1 -0
  35. package/lib/cjs/vertexai/models/gemini.js +172 -25
  36. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  37. package/lib/cjs/vertexai/models/imagen.js +317 -0
  38. package/lib/cjs/vertexai/models/imagen.js.map +1 -0
  39. package/lib/cjs/vertexai/models.js +12 -64
  40. package/lib/cjs/vertexai/models.js.map +1 -1
  41. package/lib/cjs/watsonx/index.js +47 -10
  42. package/lib/cjs/watsonx/index.js.map +1 -1
  43. package/lib/cjs/xai/index.js +71 -0
  44. package/lib/cjs/xai/index.js.map +1 -0
  45. package/lib/esm/adobe/firefly.js +115 -0
  46. package/lib/esm/adobe/firefly.js.map +1 -0
  47. package/lib/esm/bedrock/converse.js +171 -0
  48. package/lib/esm/bedrock/converse.js.map +1 -0
  49. package/lib/esm/bedrock/index.js +339 -232
  50. package/lib/esm/bedrock/index.js.map +1 -1
  51. package/lib/esm/bedrock/nova-image-payload.js +203 -0
  52. package/lib/esm/bedrock/nova-image-payload.js.map +1 -0
  53. package/lib/esm/groq/index.js +34 -9
  54. package/lib/esm/groq/index.js.map +1 -1
  55. package/lib/esm/huggingface_ie.js +29 -13
  56. package/lib/esm/huggingface_ie.js.map +1 -1
  57. package/lib/esm/index.js +1 -0
  58. package/lib/esm/index.js.map +1 -1
  59. package/lib/esm/mistral/index.js +32 -13
  60. package/lib/esm/mistral/index.js.map +1 -1
  61. package/lib/esm/mistral/types.js.map +1 -1
  62. package/lib/esm/openai/index.js +165 -30
  63. package/lib/esm/openai/index.js.map +1 -1
  64. package/lib/esm/replicate.js +19 -34
  65. package/lib/esm/replicate.js.map +1 -1
  66. package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -1
  67. package/lib/esm/test/index.js.map +1 -1
  68. package/lib/esm/togetherai/index.js +40 -10
  69. package/lib/esm/togetherai/index.js.map +1 -1
  70. package/lib/esm/vertexai/embeddings/embeddings-image.js +23 -0
  71. package/lib/esm/vertexai/embeddings/embeddings-image.js.map +1 -0
  72. package/lib/esm/vertexai/embeddings/embeddings-text.js +1 -1
  73. package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -1
  74. package/lib/esm/vertexai/index.js +135 -37
  75. package/lib/esm/vertexai/index.js.map +1 -1
  76. package/lib/esm/vertexai/models/claude.js +247 -0
  77. package/lib/esm/vertexai/models/claude.js.map +1 -0
  78. package/lib/esm/vertexai/models/gemini.js +173 -26
  79. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  80. package/lib/esm/vertexai/models/imagen.js +310 -0
  81. package/lib/esm/vertexai/models/imagen.js.map +1 -0
  82. package/lib/esm/vertexai/models.js +12 -61
  83. package/lib/esm/vertexai/models.js.map +1 -1
  84. package/lib/esm/watsonx/index.js +47 -10
  85. package/lib/esm/watsonx/index.js.map +1 -1
  86. package/lib/esm/xai/index.js +64 -0
  87. package/lib/esm/xai/index.js.map +1 -0
  88. package/lib/types/adobe/firefly.d.ts +30 -0
  89. package/lib/types/adobe/firefly.d.ts.map +1 -0
  90. package/lib/types/bedrock/converse.d.ts +8 -0
  91. package/lib/types/bedrock/converse.d.ts.map +1 -0
  92. package/lib/types/bedrock/index.d.ts +27 -12
  93. package/lib/types/bedrock/index.d.ts.map +1 -1
  94. package/lib/types/bedrock/nova-image-payload.d.ts +74 -0
  95. package/lib/types/bedrock/nova-image-payload.d.ts.map +1 -0
  96. package/lib/types/bedrock/payloads.d.ts +9 -65
  97. package/lib/types/bedrock/payloads.d.ts.map +1 -1
  98. package/lib/types/groq/index.d.ts +3 -3
  99. package/lib/types/groq/index.d.ts.map +1 -1
  100. package/lib/types/huggingface_ie.d.ts +5 -7
  101. package/lib/types/huggingface_ie.d.ts.map +1 -1
  102. package/lib/types/index.d.ts +1 -0
  103. package/lib/types/index.d.ts.map +1 -1
  104. package/lib/types/mistral/index.d.ts +4 -4
  105. package/lib/types/mistral/index.d.ts.map +1 -1
  106. package/lib/types/mistral/types.d.ts +1 -0
  107. package/lib/types/mistral/types.d.ts.map +1 -1
  108. package/lib/types/openai/index.d.ts +5 -4
  109. package/lib/types/openai/index.d.ts.map +1 -1
  110. package/lib/types/replicate.d.ts +4 -9
  111. package/lib/types/replicate.d.ts.map +1 -1
  112. package/lib/types/test/index.d.ts +2 -2
  113. package/lib/types/test/index.d.ts.map +1 -1
  114. package/lib/types/togetherai/index.d.ts +4 -4
  115. package/lib/types/togetherai/index.d.ts.map +1 -1
  116. package/lib/types/vertexai/embeddings/embeddings-image.d.ts +11 -0
  117. package/lib/types/vertexai/embeddings/embeddings-image.d.ts.map +1 -0
  118. package/lib/types/vertexai/index.d.ts +21 -8
  119. package/lib/types/vertexai/index.d.ts.map +1 -1
  120. package/lib/types/vertexai/models/claude.d.ts +20 -0
  121. package/lib/types/vertexai/models/claude.d.ts.map +1 -0
  122. package/lib/types/vertexai/models/gemini.d.ts +4 -4
  123. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  124. package/lib/types/vertexai/models/imagen.d.ts +75 -0
  125. package/lib/types/vertexai/models/imagen.d.ts.map +1 -0
  126. package/lib/types/vertexai/models.d.ts +3 -6
  127. package/lib/types/vertexai/models.d.ts.map +1 -1
  128. package/lib/types/watsonx/index.d.ts +3 -3
  129. package/lib/types/watsonx/index.d.ts.map +1 -1
  130. package/lib/types/watsonx/interfaces.d.ts +4 -0
  131. package/lib/types/watsonx/interfaces.d.ts.map +1 -1
  132. package/lib/types/xai/index.d.ts +19 -0
  133. package/lib/types/xai/index.d.ts.map +1 -0
  134. package/package.json +25 -26
  135. package/src/adobe/firefly.ts +207 -0
  136. package/src/bedrock/converse.ts +194 -0
  137. package/src/bedrock/index.ts +359 -240
  138. package/src/bedrock/nova-image-payload.ts +309 -0
  139. package/src/bedrock/payloads.ts +12 -66
  140. package/src/groq/index.ts +35 -12
  141. package/src/huggingface_ie.ts +34 -13
  142. package/src/index.ts +1 -0
  143. package/src/mistral/index.ts +35 -13
  144. package/src/mistral/types.ts +2 -1
  145. package/src/openai/index.ts +186 -35
  146. package/src/replicate.ts +24 -35
  147. package/src/test/TestValidationErrorCompletionStream.ts +2 -2
  148. package/src/test/index.ts +3 -2
  149. package/src/togetherai/index.ts +44 -12
  150. package/src/vertexai/embeddings/embeddings-image.ts +50 -0
  151. package/src/vertexai/embeddings/embeddings-text.ts +1 -1
  152. package/src/vertexai/index.ts +186 -46
  153. package/src/vertexai/models/claude.ts +281 -0
  154. package/src/vertexai/models/gemini.ts +186 -29
  155. package/src/vertexai/models/imagen.ts +401 -0
  156. package/src/vertexai/models.ts +16 -78
  157. package/src/watsonx/index.ts +50 -12
  158. package/src/watsonx/interfaces.ts +4 -0
  159. package/src/xai/index.ts +110 -0
@@ -1,7 +1,4 @@
1
1
  "use strict";
2
- var __importDefault = (this && this.__importDefault) || function (mod) {
3
- return (mod && mod.__esModule) ? mod : { "default": mod };
4
- };
5
2
  Object.defineProperty(exports, "__esModule", { value: true });
6
3
  exports.BedrockDriver = void 0;
7
4
  const client_bedrock_1 = require("@aws-sdk/client-bedrock");
@@ -10,15 +7,36 @@ const client_s3_1 = require("@aws-sdk/client-s3");
10
7
  const core_1 = require("@llumiverse/core");
11
8
  const async_1 = require("@llumiverse/core/async");
12
9
  const formatters_1 = require("@llumiverse/core/formatters");
13
- const mnemonist_1 = __importDefault(require("mnemonist"));
10
+ const mnemonist_1 = require("mnemonist");
11
+ const converse_js_1 = require("./converse.js");
12
+ const nova_image_payload_js_1 = require("./nova-image-payload.js");
14
13
  const s3_js_1 = require("./s3.js");
15
- const { LRUCache } = mnemonist_1.default;
16
- const supportStreamingCache = new LRUCache(4096);
14
+ const supportStreamingCache = new mnemonist_1.LRUCache(4096);
15
+ var BedrockModelType;
16
+ (function (BedrockModelType) {
17
+ BedrockModelType["FoundationModel"] = "foundation-model";
18
+ BedrockModelType["InferenceProfile"] = "inference-profile";
19
+ BedrockModelType["CustomModel"] = "custom-model";
20
+ BedrockModelType["Unknown"] = "unknown";
21
+ })(BedrockModelType || (BedrockModelType = {}));
22
+ ;
23
+ function converseFinishReason(reason) {
24
+ //Possible values:
25
+ //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
26
+ if (!reason)
27
+ return undefined;
28
+ switch (reason) {
29
+ case 'end_turn': return "stop";
30
+ case 'max_tokens': return "length";
31
+ default: return reason;
32
+ }
33
+ }
17
34
  class BedrockDriver extends core_1.AbstractDriver {
18
35
  static PROVIDER = "bedrock";
19
36
  provider = BedrockDriver.PROVIDER;
20
37
  _executor;
21
38
  _service;
39
+ _service_region;
22
40
  constructor(options) {
23
41
  super(options);
24
42
  if (!options.region) {
@@ -34,241 +52,334 @@ class BedrockDriver extends core_1.AbstractDriver {
34
52
  }
35
53
  return this._executor;
36
54
  }
37
- getService() {
38
- if (!this._service) {
55
+ getService(region = this.options.region) {
56
+ if (!this._service || this._service_region != region) {
39
57
  this._service = new client_bedrock_1.Bedrock({
40
- region: this.options.region,
58
+ region: region,
41
59
  credentials: this.options.credentials,
42
60
  });
61
+ this._service_region = region;
43
62
  }
44
63
  return this._service;
45
64
  }
46
65
  async formatPrompt(segments, opts) {
47
- //TODO move the anthropic test in abstract driver?
48
- if (opts.model.includes('anthropic')) {
49
- //TODO: need to type better the types aren't checked properly by TS
50
- return await (0, formatters_1.formatClaudePrompt)(segments, opts.result_schema);
51
- }
52
- else {
53
- return await super.formatPrompt(segments, opts);
66
+ if (opts.model.includes("canvas")) {
67
+ return await (0, formatters_1.formatNovaPrompt)(segments, opts.result_schema);
54
68
  }
69
+ return await (0, converse_js_1.fortmatConversePrompt)(segments, opts.result_schema);
55
70
  }
56
- extractDataFromResponse(prompt, response) {
57
- const decoder = new TextDecoder();
58
- const body = decoder.decode(response.body);
59
- const result = JSON.parse(body);
60
- const getTextAnsStopReason = () => {
61
- if (result.generation) {
62
- // LLAMA2
63
- return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
64
- }
65
- else if (result.generations) {
66
- // Cohere
67
- return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
68
- }
69
- else if (result.chat_history) {
70
- //Cohere Command R
71
- return [result.text, cohereFinishReason(result.finish_reason)];
72
- }
73
- else if (result.completions) {
74
- //A21
75
- return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
76
- }
77
- else if (result.content) {
78
- // Claude
79
- //if last prompt.messages is {, add { to the response
80
- const p = prompt;
81
- const lastMessage = p.messages[p.messages.length - 1];
82
- const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text;
83
- return [res, claudeFinishReason(result.stop_reason)];
84
- }
85
- else if (result.outputs) {
86
- // mistral
87
- return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
88
- }
89
- else if (result.results) {
90
- // Amazon Titan
91
- return [result.results[0]?.outputText ?? '', titanFinishReason(result.results[0]?.completionReason)];
92
- }
93
- else if (result.completion) { // TODO: who uses this?
94
- return [result.completion];
95
- }
96
- else {
97
- return [result.toString()];
98
- }
99
- };
100
- const [text, finish_reason] = getTextAnsStopReason();
101
- const promptLength = typeof prompt === 'string' ? prompt.length :
102
- (prompt.system || '').length + prompt.messages.reduce((acc, m) => acc + m.content.length, 0);
71
+ static getExtractedExecuton(result, _prompt) {
103
72
  return {
104
- result: text,
73
+ result: result.output?.message?.content?.map(c => c.text).join("\n") ?? "",
105
74
  token_usage: {
106
- result: text?.length,
107
- prompt: promptLength,
108
- total: text?.length + promptLength,
75
+ prompt: result.usage?.inputTokens,
76
+ result: result.usage?.outputTokens,
77
+ total: result.usage?.totalTokens,
109
78
  },
110
- finish_reason
79
+ finish_reason: converseFinishReason(result.stopReason),
80
+ };
81
+ }
82
+ ;
83
+ static getExtractedStream(result, _prompt) {
84
+ let output = "";
85
+ let stop_reason = "";
86
+ let token_usage;
87
+ if (result.contentBlockDelta) {
88
+ output = result.contentBlockDelta.delta?.text ?? "";
89
+ }
90
+ if (result.messageStop) {
91
+ stop_reason = result.messageStop.stopReason ?? "";
92
+ }
93
+ if (result.metadata) {
94
+ token_usage = {
95
+ prompt: result.metadata.usage?.inputTokens,
96
+ result: result.metadata.usage?.outputTokens,
97
+ total: result.metadata.usage?.totalTokens,
98
+ };
99
+ }
100
+ return {
101
+ result: output,
102
+ token_usage: token_usage,
103
+ finish_reason: converseFinishReason(stop_reason),
111
104
  };
112
105
  }
113
- async requestCompletion(prompt, options) {
106
+ ;
107
+ async requestTextCompletion(prompt, options) {
114
108
  const payload = this.preparePayload(prompt, options);
115
109
  const executor = this.getExecutor();
116
- const res = await executor.invokeModel({
117
- modelId: options.model,
118
- contentType: "application/json",
119
- body: JSON.stringify(payload),
110
+ const res = await executor.converse({
111
+ ...payload,
120
112
  });
121
- const completion = this.extractDataFromResponse(prompt, res);
122
- if (options.include_original_response) {
123
- completion.original_response = res;
124
- }
113
+ const completion = {
114
+ ...BedrockDriver.getExtractedExecuton(res, prompt),
115
+ original_response: options.include_original_response ? res : undefined,
116
+ };
125
117
  return completion;
126
118
  }
119
+ extractRegion(modelString, defaultRegion) {
120
+ // Match region in full ARN pattern
121
+ const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
122
+ if (arnMatch) {
123
+ return arnMatch[1];
124
+ }
125
+ // Match common AWS regions directly in string
126
+ const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
127
+ if (regionMatch) {
128
+ return regionMatch[0];
129
+ }
130
+ return defaultRegion;
131
+ }
132
+ async getCanStream(model, type) {
133
+ let canStream = false;
134
+ let error = null;
135
+ const region = this.extractRegion(model, this.options.region);
136
+ if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) {
137
+ try {
138
+ const response = await this.getService(region).getFoundationModel({
139
+ modelIdentifier: model
140
+ });
141
+ canStream = response.modelDetails?.responseStreamingSupported ?? false;
142
+ return canStream;
143
+ }
144
+ catch (e) {
145
+ error = e;
146
+ }
147
+ }
148
+ if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) {
149
+ try {
150
+ const response = await this.getService(region).getInferenceProfile({
151
+ inferenceProfileIdentifier: model
152
+ });
153
+ canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
154
+ return canStream;
155
+ }
156
+ catch (e) {
157
+ error = e;
158
+ }
159
+ }
160
+ if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) {
161
+ try {
162
+ const response = await this.getService(region).getCustomModel({
163
+ modelIdentifier: model
164
+ });
165
+ canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
166
+ return canStream;
167
+ }
168
+ catch (e) {
169
+ error = e;
170
+ }
171
+ }
172
+ if (error) {
173
+ console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
174
+ }
175
+ return canStream;
176
+ }
127
177
  async canStream(options) {
128
178
  let canStream = supportStreamingCache.get(options.model);
129
179
  if (canStream == null) {
130
- const response = await this.getService().getFoundationModel({
131
- modelIdentifier: options.model
132
- });
133
- canStream = response.modelDetails?.responseStreamingSupported ?? false;
180
+ let type = BedrockModelType.Unknown;
181
+ if (options.model.includes("foundation-model")) {
182
+ type = BedrockModelType.FoundationModel;
183
+ }
184
+ else if (options.model.includes("inference-profile")) {
185
+ type = BedrockModelType.InferenceProfile;
186
+ }
187
+ else if (options.model.includes("custom-model")) {
188
+ type = BedrockModelType.CustomModel;
189
+ }
190
+ canStream = await this.getCanStream(options.model, type);
134
191
  supportStreamingCache.set(options.model, canStream);
135
192
  }
136
193
  return canStream;
137
194
  }
138
- async requestCompletionStream(prompt, options) {
195
+ async requestTextCompletionStream(prompt, options) {
139
196
  const payload = this.preparePayload(prompt, options);
140
197
  const executor = this.getExecutor();
141
- return executor.invokeModelWithResponseStream({
142
- modelId: options.model,
143
- contentType: "application/json",
144
- body: JSON.stringify(payload),
198
+ return executor.converseStream({
199
+ ...payload,
145
200
  }).then((res) => {
146
- if (!res.body) {
147
- throw new Error("Body not found");
201
+ const stream = res.stream;
202
+ if (!stream) {
203
+ throw new Error("[Bedrock] Stream not found in response");
148
204
  }
149
- const decoder = new TextDecoder();
150
- const addBracket = () => {
151
- if (typeof prompt === 'object' && prompt.messages) {
152
- const p = prompt;
153
- const lastMessage = p.messages[p.messages.length - 1];
154
- return lastMessage.content[0].text === '{';
155
- }
156
- return false;
157
- };
158
- return (0, async_1.transformAsyncIterator)(res.body, (stream) => {
159
- const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
205
+ return (0, async_1.transformAsyncIterator)(stream, (stream) => {
206
+ //const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
160
207
  //console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
161
- if (segment.delta) { // who is this?
162
- return segment.delta.text || '';
163
- }
164
- else if (segment.completion) { // who is this?
165
- return segment.completion;
166
- }
167
- else if (segment.text) { //cohere
168
- return segment.text;
169
- }
170
- else if (segment.completions) {
171
- return segment.completions[0].data?.text;
172
- }
173
- else if (segment.generation) {
174
- return segment.generation;
175
- }
176
- else if (segment.generations) {
177
- return segment.generations[0].text;
178
- }
179
- else if (segment.outputs) {
180
- // mistral.mixtral-8x7b-instruct-v0:1
181
- return segment.outputs[0].text;
182
- //segment.outputs[0].stop_reason;
183
- }
184
- else if (segment.outputText) {
185
- // Amazon Titan
186
- return segment.outputText;
187
- //completionReason
188
- // token count too
189
- }
190
- else {
191
- segment.toString();
192
- }
193
- }, () => addBracket() ? '{' : '');
208
+ return BedrockDriver.getExtractedStream(stream, prompt);
209
+ });
194
210
  }).catch((err) => {
195
211
  this.logger.error("[Bedrock] Failed to stream", err);
196
212
  throw err;
197
213
  });
198
214
  }
199
215
  preparePayload(prompt, options) {
200
- //split arn on / should give provider
201
- //TODO: check if works with custom models
202
- //const provider = options.model.split("/")[0];
203
- const contains = (str, substr) => str.indexOf(substr) !== -1;
204
- if (contains(options.model, "meta")) {
205
- return {
206
- prompt,
207
- temperature: options.temperature,
208
- max_gen_len: options.max_tokens,
209
- };
216
+ const model_options = options.model_options;
217
+ let additionalField = {};
218
+ if (options.model.includes("amazon")) {
219
+ //Titan models also exists but does not support any additional options
220
+ if (options.model.includes("nova")) {
221
+ additionalField = { inferenceConfig: { topK: model_options?.top_k } };
222
+ }
210
223
  }
211
- else if (contains(options.model, "claude")) {
212
- const maxToken = () => {
213
- if (options.max_tokens) {
214
- return options.max_tokens;
224
+ else if (options.model.includes("claude")) {
225
+ if (options.model.includes("claude-3-7")) {
226
+ const thinking_options = options.model_options;
227
+ const thinking = thinking_options?.thinking_mode ?? false;
228
+ if (!model_options?.max_tokens) {
229
+ model_options.max_tokens = thinking ? 128000 : 8192;
215
230
  }
216
- else if (contains(options.model, "claude-3-5")) {
217
- return 8192;
231
+ additionalField = {
232
+ top_k: model_options?.top_k,
233
+ reasoning_config: {
234
+ type: thinking ? "enabled" : "disabled",
235
+ budget_tokens: thinking_options?.thinking_budget_tokens,
236
+ }
237
+ };
238
+ if (thinking && (thinking_options?.thinking_budget_tokens ?? 0) > 64000) {
239
+ additionalField = {
240
+ ...additionalField,
241
+ anthorpic_beta: ["output-128k-2025-02-19"]
242
+ };
243
+ }
244
+ }
245
+ //Needs max_tokens to be set
246
+ if (!model_options?.max_tokens) {
247
+ if (options.model.includes("claude-3-5")) {
248
+ model_options.max_tokens = 8192;
249
+ //Bug with AWS Converse Sonnet 3.5, does not effect Haiku.
250
+ //See https://github.com/boto/boto3/issues/4279
251
+ if (options.model.includes("claude-3-5-sonnet")) {
252
+ model_options.max_tokens = 4096;
253
+ }
218
254
  }
219
255
  else {
220
- return 4096;
256
+ model_options.max_tokens = 4096;
221
257
  }
222
- };
223
- return {
224
- anthropic_version: "bedrock-2023-05-31",
225
- ...prompt,
226
- temperature: options.temperature,
227
- max_tokens: maxToken(),
228
- };
258
+ }
259
+ additionalField = { top_k: model_options?.top_k };
229
260
  }
230
- else if (contains(options.model, "ai21")) {
231
- return {
232
- prompt: prompt,
233
- temperature: options.temperature,
234
- maxTokens: options.max_tokens,
235
- };
261
+ else if (options.model.includes("meta")) {
262
+ //If last message is "```json", remove it. Model requires the final message to be a user message
263
+ prompt.messages = (0, converse_js_1.converseRemoveJSONprefill)(prompt.messages);
236
264
  }
237
- else if (contains(options.model, "command-r-plus")) {
238
- return {
239
- message: prompt,
240
- max_tokens: options.max_tokens,
241
- temperature: options.temperature,
242
- };
265
+ else if (options.model.includes("mistral")) {
266
+ //7B instruct and 8x7B instruct
267
+ if (options.model.includes("7b")) {
268
+ additionalField = { top_k: model_options?.top_k };
269
+ //Does not support system messages
270
+ if (prompt.system && prompt.system?.length != 0) {
271
+ prompt.messages?.push((0, converse_js_1.converseSystemToMessages)(prompt.system));
272
+ prompt.system = undefined;
273
+ prompt.messages = (0, converse_js_1.converseConcatMessages)(prompt.messages);
274
+ }
275
+ }
276
+ else {
277
+ //Other models such as Mistral Small,Large and Large 2
278
+ //Support no additional fields.
279
+ prompt.messages = (0, converse_js_1.converseRemoveJSONprefill)(prompt.messages);
280
+ }
243
281
  }
244
- else if (contains(options.model, "cohere")) {
245
- return {
246
- prompt: prompt,
247
- temperature: options.temperature,
248
- max_tokens: options.max_tokens,
249
- };
282
+ else if (options.model.includes("ai21")) {
283
+ //If last message is "```json", remove it. Model requires the final message to be a user message
284
+ prompt.messages = (0, converse_js_1.converseRemoveJSONprefill)(prompt.messages);
285
+ //Jamba models support no additional options
286
+ //Jurassic 2 models do.
287
+ if (options.model.includes("j2")) {
288
+ additionalField = {
289
+ presencePenalty: { scale: model_options?.presence_penalty },
290
+ frequencyPenalty: { scale: model_options?.frequency_penalty },
291
+ };
292
+ //Does not support system messages
293
+ if (prompt.system && prompt.system?.length != 0) {
294
+ prompt.messages?.push((0, converse_js_1.converseSystemToMessages)(prompt.system));
295
+ prompt.system = undefined;
296
+ prompt.messages = (0, converse_js_1.converseConcatMessages)(prompt.messages);
297
+ }
298
+ }
250
299
  }
251
- else if (contains(options.model, "amazon")) {
252
- return {
253
- inputText: "User: " + prompt + "\nBot:", // see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html#model-parameters-titan-request-response
254
- textGenerationConfig: {
255
- temperature: options.temperature,
256
- topP: options.top_p,
257
- maxTokenCount: options.max_tokens,
258
- //stopSequences: ["\n"],
259
- },
260
- };
300
+ else if (options.model.includes("cohere.command")) {
301
+ // If last message is "```json", remove it.
302
+ // Model requires the final message to be a user message or does not support assistant messages
303
+ prompt.messages = (0, converse_js_1.converseRemoveJSONprefill)(prompt.messages);
304
+ //Command R and R plus
305
+ if (options.model.includes("cohere.command-r")) {
306
+ additionalField = {
307
+ k: model_options?.top_k,
308
+ frequency_penalty: model_options?.frequency_penalty,
309
+ presence_penalty: model_options?.presence_penalty,
310
+ };
311
+ }
312
+ else {
313
+ // Command non-R
314
+ additionalField = { k: model_options?.top_k };
315
+ //Does not support system messages
316
+ if (prompt.system && prompt.system?.length != 0) {
317
+ prompt.messages?.push((0, converse_js_1.converseSystemToMessages)(prompt.system));
318
+ prompt.system = undefined;
319
+ prompt.messages = (0, converse_js_1.converseConcatMessages)(prompt.messages);
320
+ }
321
+ }
261
322
  }
262
- else if (contains(options.model, "mistral")) {
263
- return {
264
- prompt: prompt,
265
- temperature: options.temperature,
266
- max_tokens: options.max_tokens,
267
- };
323
+ //If last message is "```json", add corresponding ``` as a stop sequence.
324
+ if (prompt.messages && prompt.messages.length > 0) {
325
+ if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
326
+ let stopSeq = model_options?.stop_sequence;
327
+ if (!stopSeq) {
328
+ model_options.stop_sequence = ["```"];
329
+ }
330
+ else if (!stopSeq.includes("```")) {
331
+ stopSeq.push("```");
332
+ model_options.stop_sequence = stopSeq;
333
+ }
334
+ }
335
+ }
336
+ return {
337
+ messages: prompt.messages,
338
+ system: prompt.system,
339
+ modelId: options.model,
340
+ inferenceConfig: {
341
+ maxTokens: model_options?.max_tokens,
342
+ temperature: model_options?.temperature,
343
+ topP: model_options?.top_p,
344
+ stopSequences: model_options?.stop_sequence,
345
+ },
346
+ additionalModelRequestFields: {
347
+ ...additionalField,
348
+ },
349
+ };
350
+ }
351
+ async requestImageGeneration(prompt, options) {
352
+ if (options.output_modality !== core_1.Modalities.image) {
353
+ throw new Error(`Image generation requires image output_modality`);
268
354
  }
269
- else {
270
- throw new Error("Cannot prepare payload for unknown provider: " + options.model);
355
+ if (options.model_options?._option_id !== "bedrock-nova-canvas") {
356
+ this.logger.warn("Invalid model options", { options: options.model_options });
357
+ }
358
+ const model_options = options.model_options;
359
+ const executor = this.getExecutor();
360
+ const taskType = model_options.taskType ?? nova_image_payload_js_1.NovaImageGenerationTaskType.TEXT_IMAGE;
361
+ this.logger.info("Task type: " + taskType);
362
+ if (typeof prompt === "string") {
363
+ throw new Error("Bad prompt format");
271
364
  }
365
+ const payload = await (0, nova_image_payload_js_1.formatNovaImageGenerationPayload)(taskType, prompt, options);
366
+ const res = await executor.invokeModel({
367
+ modelId: options.model,
368
+ contentType: "application/json",
369
+ accept: "application/json",
370
+ body: JSON.stringify(payload),
371
+ }, {
372
+ requestTimeout: 60000 * 5
373
+ });
374
+ const decoder = new TextDecoder();
375
+ const body = decoder.decode(res.body);
376
+ const result = JSON.parse(body);
377
+ return {
378
+ error: result.error,
379
+ result: {
380
+ images: result.images,
381
+ }
382
+ };
272
383
  }
273
384
  async startTraining(dataset, options) {
274
385
  //convert options.params to Record<string, string>
@@ -333,12 +444,13 @@ class BedrockDriver extends core_1.AbstractDriver {
333
444
  async listModels() {
334
445
  this.logger.debug("[Bedrock] listing models");
335
446
  // exclude trainable models since they are not executable
336
- const filter = (m) => m.inferenceTypesSupported?.includes("ON_DEMAND") ?? false;
447
+ // exclude embedding models, not to be used for typical completions.
448
+ const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false;
337
449
  return this._listModels(filter);
338
450
  }
339
451
  async _listModels(foundationFilter) {
340
452
  const service = this.getService();
341
- const [foundationals, customs] = await Promise.all([
453
+ const [foundationals, customs, inferenceProfiles] = await Promise.all([
342
454
  service.listFoundationModels({}).catch(() => {
343
455
  this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
344
456
  return undefined;
@@ -347,6 +459,10 @@ class BedrockDriver extends core_1.AbstractDriver {
347
459
  this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
348
460
  return undefined;
349
461
  }),
462
+ service.listInferenceProfiles({}).catch(() => {
463
+ this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions.");
464
+ return undefined;
465
+ }),
350
466
  ]);
351
467
  if (!foundationals?.modelSummaries) {
352
468
  throw new Error("Foundation models not found");
@@ -355,6 +471,12 @@ class BedrockDriver extends core_1.AbstractDriver {
355
471
  if (foundationFilter) {
356
472
  fmodels = fmodels.filter(foundationFilter);
357
473
  }
474
+ const supportedProviders = ["amazon", "anthropic", "cohere", "ai21", "mistral", "meta", "deepseek"];
475
+ fmodels = fmodels.filter((m) => {
476
+ supportedProviders.some((provider) => {
477
+ m.providerName?.includes(provider) ?? false;
478
+ });
479
+ });
358
480
  const aimodels = fmodels.map((m) => {
359
481
  if (!m.modelId) {
360
482
  throw new Error("modelId not found");
@@ -363,6 +485,7 @@ class BedrockDriver extends core_1.AbstractDriver {
363
485
  id: m.modelArn ?? m.modelId,
364
486
  name: `${m.providerName} ${m.modelName}`,
365
487
  provider: this.provider,
488
+ input_modalities: m.inputModalities ?? [],
366
489
  //description: ``,
367
490
  owner: m.providerName,
368
491
  can_stream: m.responseStreamingSupported ?? false,
@@ -388,16 +511,33 @@ class BedrockDriver extends core_1.AbstractDriver {
388
511
  this.validateConnection;
389
512
  });
390
513
  }
514
+ //add inference profiles
515
+ if (inferenceProfiles?.inferenceProfileSummaries) {
516
+ inferenceProfiles.inferenceProfileSummaries.forEach((p) => {
517
+ if (!p.inferenceProfileArn) {
518
+ throw new Error("Profile ARN not found");
519
+ }
520
+ const model = {
521
+ id: p.inferenceProfileArn ?? p.inferenceProfileId,
522
+ name: p.inferenceProfileName ?? p.inferenceProfileArn,
523
+ provider: this.provider,
524
+ };
525
+ aimodels.push(model);
526
+ });
527
+ }
391
528
  return aimodels;
392
529
  }
393
- async generateEmbeddings({ content, model = "amazon.titan-embed-text-v1" }) {
530
+ async generateEmbeddings({ text, image, model }) {
394
531
  this.logger.info("[Bedrock] Generating embeddings with model " + model);
532
+ const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0";
533
+ const modelID = model ?? defaultModel;
395
534
  const invokeBody = {
396
- inputText: content
535
+ inputText: text,
536
+ inputImage: image
397
537
  };
398
538
  const executor = this.getExecutor();
399
539
  const res = await executor.invokeModel({
400
- modelId: model,
540
+ modelId: modelID,
401
541
  contentType: "application/json",
402
542
  body: JSON.stringify(invokeBody),
403
543
  });
@@ -409,7 +549,7 @@ class BedrockDriver extends core_1.AbstractDriver {
409
549
  }
410
550
  return {
411
551
  values: result.embedding,
412
- model: model,
552
+ model: modelID,
413
553
  token_count: result.inputTextTokenCount
414
554
  };
415
555
  }
@@ -441,40 +581,4 @@ function jobInfo(job, jobId) {
441
581
  details
442
582
  };
443
583
  }
444
- function claudeFinishReason(reason) {
445
- if (!reason)
446
- return undefined;
447
- switch (reason) {
448
- case 'end_turn': return "stop";
449
- case 'max_tokens': return "length";
450
- default: return reason; //stop_sequence
451
- }
452
- }
453
- function cohereFinishReason(reason) {
454
- if (!reason)
455
- return undefined;
456
- switch (reason) {
457
- case 'COMPLETE': return "stop";
458
- case 'MAX_TOKENS': return "length";
459
- default: return reason;
460
- }
461
- }
462
- function a21FinishReason(reason) {
463
- if (!reason)
464
- return undefined;
465
- switch (reason) {
466
- case 'endoftext': return "stop";
467
- case 'length': return "length";
468
- default: return reason;
469
- }
470
- }
471
- function titanFinishReason(reason) {
472
- if (!reason)
473
- return undefined;
474
- switch (reason) {
475
- case 'FINISH': return "stop";
476
- case 'LENGTH': return "length";
477
- default: return reason;
478
- }
479
- }
480
584
  //# sourceMappingURL=index.js.map