@llumiverse/drivers 0.15.0 → 0.17.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. package/README.md +3 -3
  2. package/lib/cjs/adobe/firefly.js +119 -0
  3. package/lib/cjs/adobe/firefly.js.map +1 -0
  4. package/lib/cjs/bedrock/converse.js +177 -0
  5. package/lib/cjs/bedrock/converse.js.map +1 -0
  6. package/lib/cjs/bedrock/index.js +338 -234
  7. package/lib/cjs/bedrock/index.js.map +1 -1
  8. package/lib/cjs/bedrock/nova-image-payload.js +207 -0
  9. package/lib/cjs/bedrock/nova-image-payload.js.map +1 -0
  10. package/lib/cjs/groq/index.js +34 -9
  11. package/lib/cjs/groq/index.js.map +1 -1
  12. package/lib/cjs/huggingface_ie.js +28 -12
  13. package/lib/cjs/huggingface_ie.js.map +1 -1
  14. package/lib/cjs/index.js +1 -0
  15. package/lib/cjs/index.js.map +1 -1
  16. package/lib/cjs/mistral/index.js +32 -13
  17. package/lib/cjs/mistral/index.js.map +1 -1
  18. package/lib/cjs/mistral/types.js.map +1 -1
  19. package/lib/cjs/openai/index.js +164 -29
  20. package/lib/cjs/openai/index.js.map +1 -1
  21. package/lib/cjs/replicate.js +19 -34
  22. package/lib/cjs/replicate.js.map +1 -1
  23. package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -1
  24. package/lib/cjs/test/index.js.map +1 -1
  25. package/lib/cjs/togetherai/index.js +40 -10
  26. package/lib/cjs/togetherai/index.js.map +1 -1
  27. package/lib/cjs/vertexai/embeddings/embeddings-image.js +26 -0
  28. package/lib/cjs/vertexai/embeddings/embeddings-image.js.map +1 -0
  29. package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -1
  30. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
  31. package/lib/cjs/vertexai/index.js +134 -35
  32. package/lib/cjs/vertexai/index.js.map +1 -1
  33. package/lib/cjs/vertexai/models/claude.js +252 -0
  34. package/lib/cjs/vertexai/models/claude.js.map +1 -0
  35. package/lib/cjs/vertexai/models/gemini.js +172 -25
  36. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  37. package/lib/cjs/vertexai/models/imagen.js +317 -0
  38. package/lib/cjs/vertexai/models/imagen.js.map +1 -0
  39. package/lib/cjs/vertexai/models.js +12 -64
  40. package/lib/cjs/vertexai/models.js.map +1 -1
  41. package/lib/cjs/watsonx/index.js +47 -10
  42. package/lib/cjs/watsonx/index.js.map +1 -1
  43. package/lib/cjs/xai/index.js +71 -0
  44. package/lib/cjs/xai/index.js.map +1 -0
  45. package/lib/esm/adobe/firefly.js +115 -0
  46. package/lib/esm/adobe/firefly.js.map +1 -0
  47. package/lib/esm/bedrock/converse.js +171 -0
  48. package/lib/esm/bedrock/converse.js.map +1 -0
  49. package/lib/esm/bedrock/index.js +339 -232
  50. package/lib/esm/bedrock/index.js.map +1 -1
  51. package/lib/esm/bedrock/nova-image-payload.js +203 -0
  52. package/lib/esm/bedrock/nova-image-payload.js.map +1 -0
  53. package/lib/esm/groq/index.js +34 -9
  54. package/lib/esm/groq/index.js.map +1 -1
  55. package/lib/esm/huggingface_ie.js +29 -13
  56. package/lib/esm/huggingface_ie.js.map +1 -1
  57. package/lib/esm/index.js +1 -0
  58. package/lib/esm/index.js.map +1 -1
  59. package/lib/esm/mistral/index.js +32 -13
  60. package/lib/esm/mistral/index.js.map +1 -1
  61. package/lib/esm/mistral/types.js.map +1 -1
  62. package/lib/esm/openai/index.js +165 -30
  63. package/lib/esm/openai/index.js.map +1 -1
  64. package/lib/esm/replicate.js +19 -34
  65. package/lib/esm/replicate.js.map +1 -1
  66. package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -1
  67. package/lib/esm/test/index.js.map +1 -1
  68. package/lib/esm/togetherai/index.js +40 -10
  69. package/lib/esm/togetherai/index.js.map +1 -1
  70. package/lib/esm/vertexai/embeddings/embeddings-image.js +23 -0
  71. package/lib/esm/vertexai/embeddings/embeddings-image.js.map +1 -0
  72. package/lib/esm/vertexai/embeddings/embeddings-text.js +1 -1
  73. package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -1
  74. package/lib/esm/vertexai/index.js +135 -37
  75. package/lib/esm/vertexai/index.js.map +1 -1
  76. package/lib/esm/vertexai/models/claude.js +247 -0
  77. package/lib/esm/vertexai/models/claude.js.map +1 -0
  78. package/lib/esm/vertexai/models/gemini.js +173 -26
  79. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  80. package/lib/esm/vertexai/models/imagen.js +310 -0
  81. package/lib/esm/vertexai/models/imagen.js.map +1 -0
  82. package/lib/esm/vertexai/models.js +12 -61
  83. package/lib/esm/vertexai/models.js.map +1 -1
  84. package/lib/esm/watsonx/index.js +47 -10
  85. package/lib/esm/watsonx/index.js.map +1 -1
  86. package/lib/esm/xai/index.js +64 -0
  87. package/lib/esm/xai/index.js.map +1 -0
  88. package/lib/types/adobe/firefly.d.ts +30 -0
  89. package/lib/types/adobe/firefly.d.ts.map +1 -0
  90. package/lib/types/bedrock/converse.d.ts +8 -0
  91. package/lib/types/bedrock/converse.d.ts.map +1 -0
  92. package/lib/types/bedrock/index.d.ts +27 -12
  93. package/lib/types/bedrock/index.d.ts.map +1 -1
  94. package/lib/types/bedrock/nova-image-payload.d.ts +74 -0
  95. package/lib/types/bedrock/nova-image-payload.d.ts.map +1 -0
  96. package/lib/types/bedrock/payloads.d.ts +9 -65
  97. package/lib/types/bedrock/payloads.d.ts.map +1 -1
  98. package/lib/types/groq/index.d.ts +3 -3
  99. package/lib/types/groq/index.d.ts.map +1 -1
  100. package/lib/types/huggingface_ie.d.ts +5 -7
  101. package/lib/types/huggingface_ie.d.ts.map +1 -1
  102. package/lib/types/index.d.ts +1 -0
  103. package/lib/types/index.d.ts.map +1 -1
  104. package/lib/types/mistral/index.d.ts +4 -4
  105. package/lib/types/mistral/index.d.ts.map +1 -1
  106. package/lib/types/mistral/types.d.ts +1 -0
  107. package/lib/types/mistral/types.d.ts.map +1 -1
  108. package/lib/types/openai/index.d.ts +5 -4
  109. package/lib/types/openai/index.d.ts.map +1 -1
  110. package/lib/types/replicate.d.ts +4 -9
  111. package/lib/types/replicate.d.ts.map +1 -1
  112. package/lib/types/test/index.d.ts +2 -2
  113. package/lib/types/test/index.d.ts.map +1 -1
  114. package/lib/types/togetherai/index.d.ts +4 -4
  115. package/lib/types/togetherai/index.d.ts.map +1 -1
  116. package/lib/types/vertexai/embeddings/embeddings-image.d.ts +11 -0
  117. package/lib/types/vertexai/embeddings/embeddings-image.d.ts.map +1 -0
  118. package/lib/types/vertexai/index.d.ts +21 -8
  119. package/lib/types/vertexai/index.d.ts.map +1 -1
  120. package/lib/types/vertexai/models/claude.d.ts +20 -0
  121. package/lib/types/vertexai/models/claude.d.ts.map +1 -0
  122. package/lib/types/vertexai/models/gemini.d.ts +4 -4
  123. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  124. package/lib/types/vertexai/models/imagen.d.ts +75 -0
  125. package/lib/types/vertexai/models/imagen.d.ts.map +1 -0
  126. package/lib/types/vertexai/models.d.ts +3 -6
  127. package/lib/types/vertexai/models.d.ts.map +1 -1
  128. package/lib/types/watsonx/index.d.ts +3 -3
  129. package/lib/types/watsonx/index.d.ts.map +1 -1
  130. package/lib/types/watsonx/interfaces.d.ts +4 -0
  131. package/lib/types/watsonx/interfaces.d.ts.map +1 -1
  132. package/lib/types/xai/index.d.ts +19 -0
  133. package/lib/types/xai/index.d.ts.map +1 -0
  134. package/package.json +25 -26
  135. package/src/adobe/firefly.ts +207 -0
  136. package/src/bedrock/converse.ts +194 -0
  137. package/src/bedrock/index.ts +359 -240
  138. package/src/bedrock/nova-image-payload.ts +309 -0
  139. package/src/bedrock/payloads.ts +12 -66
  140. package/src/groq/index.ts +35 -12
  141. package/src/huggingface_ie.ts +34 -13
  142. package/src/index.ts +1 -0
  143. package/src/mistral/index.ts +35 -13
  144. package/src/mistral/types.ts +2 -1
  145. package/src/openai/index.ts +186 -35
  146. package/src/replicate.ts +24 -35
  147. package/src/test/TestValidationErrorCompletionStream.ts +2 -2
  148. package/src/test/index.ts +3 -2
  149. package/src/togetherai/index.ts +44 -12
  150. package/src/vertexai/embeddings/embeddings-image.ts +50 -0
  151. package/src/vertexai/embeddings/embeddings-text.ts +1 -1
  152. package/src/vertexai/index.ts +186 -46
  153. package/src/vertexai/models/claude.ts +281 -0
  154. package/src/vertexai/models/gemini.ts +186 -29
  155. package/src/vertexai/models/imagen.ts +401 -0
  156. package/src/vertexai/models.ts +16 -78
  157. package/src/watsonx/index.ts +50 -12
  158. package/src/watsonx/interfaces.ts +4 -0
  159. package/src/xai/index.ts +110 -0
@@ -1,18 +1,39 @@
1
1
  import { Bedrock, CreateModelCustomizationJobCommand, GetModelCustomizationJobCommand, ModelCustomizationJobStatus, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock";
2
2
  import { BedrockRuntime } from "@aws-sdk/client-bedrock-runtime";
3
3
  import { S3Client } from "@aws-sdk/client-s3";
4
- import { AbstractDriver, TrainingJobStatus } from "@llumiverse/core";
4
+ import { AbstractDriver, Modalities, TrainingJobStatus } from "@llumiverse/core";
5
5
  import { transformAsyncIterator } from "@llumiverse/core/async";
6
- import { formatClaudePrompt } from "@llumiverse/core/formatters";
7
- import mnemonist from "mnemonist";
6
+ import { formatNovaPrompt } from "@llumiverse/core/formatters";
7
+ import { LRUCache } from "mnemonist";
8
+ import { converseConcatMessages, converseRemoveJSONprefill, converseSystemToMessages, fortmatConversePrompt } from "./converse.js";
9
+ import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js";
8
10
  import { forceUploadFile } from "./s3.js";
9
- const { LRUCache } = mnemonist;
10
11
  const supportStreamingCache = new LRUCache(4096);
12
+ var BedrockModelType;
13
+ (function (BedrockModelType) {
14
+ BedrockModelType["FoundationModel"] = "foundation-model";
15
+ BedrockModelType["InferenceProfile"] = "inference-profile";
16
+ BedrockModelType["CustomModel"] = "custom-model";
17
+ BedrockModelType["Unknown"] = "unknown";
18
+ })(BedrockModelType || (BedrockModelType = {}));
19
+ ;
20
+ function converseFinishReason(reason) {
21
+ //Possible values:
22
+ //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
23
+ if (!reason)
24
+ return undefined;
25
+ switch (reason) {
26
+ case 'end_turn': return "stop";
27
+ case 'max_tokens': return "length";
28
+ default: return reason;
29
+ }
30
+ }
11
31
  export class BedrockDriver extends AbstractDriver {
12
32
  static PROVIDER = "bedrock";
13
33
  provider = BedrockDriver.PROVIDER;
14
34
  _executor;
15
35
  _service;
36
+ _service_region;
16
37
  constructor(options) {
17
38
  super(options);
18
39
  if (!options.region) {
@@ -28,241 +49,334 @@ export class BedrockDriver extends AbstractDriver {
28
49
  }
29
50
  return this._executor;
30
51
  }
31
- getService() {
32
- if (!this._service) {
52
+ getService(region = this.options.region) {
53
+ if (!this._service || this._service_region != region) {
33
54
  this._service = new Bedrock({
34
- region: this.options.region,
55
+ region: region,
35
56
  credentials: this.options.credentials,
36
57
  });
58
+ this._service_region = region;
37
59
  }
38
60
  return this._service;
39
61
  }
40
62
  async formatPrompt(segments, opts) {
41
- //TODO move the anthropic test in abstract driver?
42
- if (opts.model.includes('anthropic')) {
43
- //TODO: need to type better the types aren't checked properly by TS
44
- return await formatClaudePrompt(segments, opts.result_schema);
45
- }
46
- else {
47
- return await super.formatPrompt(segments, opts);
63
+ if (opts.model.includes("canvas")) {
64
+ return await formatNovaPrompt(segments, opts.result_schema);
48
65
  }
66
+ return await fortmatConversePrompt(segments, opts.result_schema);
49
67
  }
50
- extractDataFromResponse(prompt, response) {
51
- const decoder = new TextDecoder();
52
- const body = decoder.decode(response.body);
53
- const result = JSON.parse(body);
54
- const getTextAnsStopReason = () => {
55
- if (result.generation) {
56
- // LLAMA2
57
- return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
58
- }
59
- else if (result.generations) {
60
- // Cohere
61
- return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
62
- }
63
- else if (result.chat_history) {
64
- //Cohere Command R
65
- return [result.text, cohereFinishReason(result.finish_reason)];
66
- }
67
- else if (result.completions) {
68
- //A21
69
- return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
70
- }
71
- else if (result.content) {
72
- // Claude
73
- //if last prompt.messages is {, add { to the response
74
- const p = prompt;
75
- const lastMessage = p.messages[p.messages.length - 1];
76
- const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text;
77
- return [res, claudeFinishReason(result.stop_reason)];
78
- }
79
- else if (result.outputs) {
80
- // mistral
81
- return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
82
- }
83
- else if (result.results) {
84
- // Amazon Titan
85
- return [result.results[0]?.outputText ?? '', titanFinishReason(result.results[0]?.completionReason)];
86
- }
87
- else if (result.completion) { // TODO: who uses this?
88
- return [result.completion];
89
- }
90
- else {
91
- return [result.toString()];
92
- }
93
- };
94
- const [text, finish_reason] = getTextAnsStopReason();
95
- const promptLength = typeof prompt === 'string' ? prompt.length :
96
- (prompt.system || '').length + prompt.messages.reduce((acc, m) => acc + m.content.length, 0);
68
+ static getExtractedExecuton(result, _prompt) {
97
69
  return {
98
- result: text,
70
+ result: result.output?.message?.content?.map(c => c.text).join("\n") ?? "",
99
71
  token_usage: {
100
- result: text?.length,
101
- prompt: promptLength,
102
- total: text?.length + promptLength,
72
+ prompt: result.usage?.inputTokens,
73
+ result: result.usage?.outputTokens,
74
+ total: result.usage?.totalTokens,
103
75
  },
104
- finish_reason
76
+ finish_reason: converseFinishReason(result.stopReason),
77
+ };
78
+ }
79
+ ;
80
+ static getExtractedStream(result, _prompt) {
81
+ let output = "";
82
+ let stop_reason = "";
83
+ let token_usage;
84
+ if (result.contentBlockDelta) {
85
+ output = result.contentBlockDelta.delta?.text ?? "";
86
+ }
87
+ if (result.messageStop) {
88
+ stop_reason = result.messageStop.stopReason ?? "";
89
+ }
90
+ if (result.metadata) {
91
+ token_usage = {
92
+ prompt: result.metadata.usage?.inputTokens,
93
+ result: result.metadata.usage?.outputTokens,
94
+ total: result.metadata.usage?.totalTokens,
95
+ };
96
+ }
97
+ return {
98
+ result: output,
99
+ token_usage: token_usage,
100
+ finish_reason: converseFinishReason(stop_reason),
105
101
  };
106
102
  }
107
- async requestCompletion(prompt, options) {
103
+ ;
104
+ async requestTextCompletion(prompt, options) {
108
105
  const payload = this.preparePayload(prompt, options);
109
106
  const executor = this.getExecutor();
110
- const res = await executor.invokeModel({
111
- modelId: options.model,
112
- contentType: "application/json",
113
- body: JSON.stringify(payload),
107
+ const res = await executor.converse({
108
+ ...payload,
114
109
  });
115
- const completion = this.extractDataFromResponse(prompt, res);
116
- if (options.include_original_response) {
117
- completion.original_response = res;
118
- }
110
+ const completion = {
111
+ ...BedrockDriver.getExtractedExecuton(res, prompt),
112
+ original_response: options.include_original_response ? res : undefined,
113
+ };
119
114
  return completion;
120
115
  }
116
+ extractRegion(modelString, defaultRegion) {
117
+ // Match region in full ARN pattern
118
+ const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
119
+ if (arnMatch) {
120
+ return arnMatch[1];
121
+ }
122
+ // Match common AWS regions directly in string
123
+ const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
124
+ if (regionMatch) {
125
+ return regionMatch[0];
126
+ }
127
+ return defaultRegion;
128
+ }
129
+ async getCanStream(model, type) {
130
+ let canStream = false;
131
+ let error = null;
132
+ const region = this.extractRegion(model, this.options.region);
133
+ if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) {
134
+ try {
135
+ const response = await this.getService(region).getFoundationModel({
136
+ modelIdentifier: model
137
+ });
138
+ canStream = response.modelDetails?.responseStreamingSupported ?? false;
139
+ return canStream;
140
+ }
141
+ catch (e) {
142
+ error = e;
143
+ }
144
+ }
145
+ if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) {
146
+ try {
147
+ const response = await this.getService(region).getInferenceProfile({
148
+ inferenceProfileIdentifier: model
149
+ });
150
+ canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
151
+ return canStream;
152
+ }
153
+ catch (e) {
154
+ error = e;
155
+ }
156
+ }
157
+ if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) {
158
+ try {
159
+ const response = await this.getService(region).getCustomModel({
160
+ modelIdentifier: model
161
+ });
162
+ canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
163
+ return canStream;
164
+ }
165
+ catch (e) {
166
+ error = e;
167
+ }
168
+ }
169
+ if (error) {
170
+ console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
171
+ }
172
+ return canStream;
173
+ }
121
174
  async canStream(options) {
122
175
  let canStream = supportStreamingCache.get(options.model);
123
176
  if (canStream == null) {
124
- const response = await this.getService().getFoundationModel({
125
- modelIdentifier: options.model
126
- });
127
- canStream = response.modelDetails?.responseStreamingSupported ?? false;
177
+ let type = BedrockModelType.Unknown;
178
+ if (options.model.includes("foundation-model")) {
179
+ type = BedrockModelType.FoundationModel;
180
+ }
181
+ else if (options.model.includes("inference-profile")) {
182
+ type = BedrockModelType.InferenceProfile;
183
+ }
184
+ else if (options.model.includes("custom-model")) {
185
+ type = BedrockModelType.CustomModel;
186
+ }
187
+ canStream = await this.getCanStream(options.model, type);
128
188
  supportStreamingCache.set(options.model, canStream);
129
189
  }
130
190
  return canStream;
131
191
  }
132
- async requestCompletionStream(prompt, options) {
192
+ async requestTextCompletionStream(prompt, options) {
133
193
  const payload = this.preparePayload(prompt, options);
134
194
  const executor = this.getExecutor();
135
- return executor.invokeModelWithResponseStream({
136
- modelId: options.model,
137
- contentType: "application/json",
138
- body: JSON.stringify(payload),
195
+ return executor.converseStream({
196
+ ...payload,
139
197
  }).then((res) => {
140
- if (!res.body) {
141
- throw new Error("Body not found");
198
+ const stream = res.stream;
199
+ if (!stream) {
200
+ throw new Error("[Bedrock] Stream not found in response");
142
201
  }
143
- const decoder = new TextDecoder();
144
- const addBracket = () => {
145
- if (typeof prompt === 'object' && prompt.messages) {
146
- const p = prompt;
147
- const lastMessage = p.messages[p.messages.length - 1];
148
- return lastMessage.content[0].text === '{';
149
- }
150
- return false;
151
- };
152
- return transformAsyncIterator(res.body, (stream) => {
153
- const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
202
+ return transformAsyncIterator(stream, (stream) => {
203
+ //const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
154
204
  //console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
155
- if (segment.delta) { // who is this?
156
- return segment.delta.text || '';
157
- }
158
- else if (segment.completion) { // who is this?
159
- return segment.completion;
160
- }
161
- else if (segment.text) { //cohere
162
- return segment.text;
163
- }
164
- else if (segment.completions) {
165
- return segment.completions[0].data?.text;
166
- }
167
- else if (segment.generation) {
168
- return segment.generation;
169
- }
170
- else if (segment.generations) {
171
- return segment.generations[0].text;
172
- }
173
- else if (segment.outputs) {
174
- // mistral.mixtral-8x7b-instruct-v0:1
175
- return segment.outputs[0].text;
176
- //segment.outputs[0].stop_reason;
177
- }
178
- else if (segment.outputText) {
179
- // Amazon Titan
180
- return segment.outputText;
181
- //completionReason
182
- // token count too
183
- }
184
- else {
185
- segment.toString();
186
- }
187
- }, () => addBracket() ? '{' : '');
205
+ return BedrockDriver.getExtractedStream(stream, prompt);
206
+ });
188
207
  }).catch((err) => {
189
208
  this.logger.error("[Bedrock] Failed to stream", err);
190
209
  throw err;
191
210
  });
192
211
  }
193
212
  preparePayload(prompt, options) {
194
- //split arn on / should give provider
195
- //TODO: check if works with custom models
196
- //const provider = options.model.split("/")[0];
197
- const contains = (str, substr) => str.indexOf(substr) !== -1;
198
- if (contains(options.model, "meta")) {
199
- return {
200
- prompt,
201
- temperature: options.temperature,
202
- max_gen_len: options.max_tokens,
203
- };
213
+ const model_options = options.model_options;
214
+ let additionalField = {};
215
+ if (options.model.includes("amazon")) {
216
+ //Titan models also exists but does not support any additional options
217
+ if (options.model.includes("nova")) {
218
+ additionalField = { inferenceConfig: { topK: model_options?.top_k } };
219
+ }
204
220
  }
205
- else if (contains(options.model, "claude")) {
206
- const maxToken = () => {
207
- if (options.max_tokens) {
208
- return options.max_tokens;
221
+ else if (options.model.includes("claude")) {
222
+ if (options.model.includes("claude-3-7")) {
223
+ const thinking_options = options.model_options;
224
+ const thinking = thinking_options?.thinking_mode ?? false;
225
+ if (!model_options?.max_tokens) {
226
+ model_options.max_tokens = thinking ? 128000 : 8192;
209
227
  }
210
- else if (contains(options.model, "claude-3-5")) {
211
- return 8192;
228
+ additionalField = {
229
+ top_k: model_options?.top_k,
230
+ reasoning_config: {
231
+ type: thinking ? "enabled" : "disabled",
232
+ budget_tokens: thinking_options?.thinking_budget_tokens,
233
+ }
234
+ };
235
+ if (thinking && (thinking_options?.thinking_budget_tokens ?? 0) > 64000) {
236
+ additionalField = {
237
+ ...additionalField,
238
+ anthorpic_beta: ["output-128k-2025-02-19"]
239
+ };
240
+ }
241
+ }
242
+ //Needs max_tokens to be set
243
+ if (!model_options?.max_tokens) {
244
+ if (options.model.includes("claude-3-5")) {
245
+ model_options.max_tokens = 8192;
246
+ //Bug with AWS Converse Sonnet 3.5, does not effect Haiku.
247
+ //See https://github.com/boto/boto3/issues/4279
248
+ if (options.model.includes("claude-3-5-sonnet")) {
249
+ model_options.max_tokens = 4096;
250
+ }
212
251
  }
213
252
  else {
214
- return 4096;
253
+ model_options.max_tokens = 4096;
215
254
  }
216
- };
217
- return {
218
- anthropic_version: "bedrock-2023-05-31",
219
- ...prompt,
220
- temperature: options.temperature,
221
- max_tokens: maxToken(),
222
- };
255
+ }
256
+ additionalField = { top_k: model_options?.top_k };
223
257
  }
224
- else if (contains(options.model, "ai21")) {
225
- return {
226
- prompt: prompt,
227
- temperature: options.temperature,
228
- maxTokens: options.max_tokens,
229
- };
258
+ else if (options.model.includes("meta")) {
259
+ //If last message is "```json", remove it. Model requires the final message to be a user message
260
+ prompt.messages = converseRemoveJSONprefill(prompt.messages);
230
261
  }
231
- else if (contains(options.model, "command-r-plus")) {
232
- return {
233
- message: prompt,
234
- max_tokens: options.max_tokens,
235
- temperature: options.temperature,
236
- };
262
+ else if (options.model.includes("mistral")) {
263
+ //7B instruct and 8x7B instruct
264
+ if (options.model.includes("7b")) {
265
+ additionalField = { top_k: model_options?.top_k };
266
+ //Does not support system messages
267
+ if (prompt.system && prompt.system?.length != 0) {
268
+ prompt.messages?.push(converseSystemToMessages(prompt.system));
269
+ prompt.system = undefined;
270
+ prompt.messages = converseConcatMessages(prompt.messages);
271
+ }
272
+ }
273
+ else {
274
+ //Other models such as Mistral Small,Large and Large 2
275
+ //Support no additional fields.
276
+ prompt.messages = converseRemoveJSONprefill(prompt.messages);
277
+ }
237
278
  }
238
- else if (contains(options.model, "cohere")) {
239
- return {
240
- prompt: prompt,
241
- temperature: options.temperature,
242
- max_tokens: options.max_tokens,
243
- };
279
+ else if (options.model.includes("ai21")) {
280
+ //If last message is "```json", remove it. Model requires the final message to be a user message
281
+ prompt.messages = converseRemoveJSONprefill(prompt.messages);
282
+ //Jamba models support no additional options
283
+ //Jurassic 2 models do.
284
+ if (options.model.includes("j2")) {
285
+ additionalField = {
286
+ presencePenalty: { scale: model_options?.presence_penalty },
287
+ frequencyPenalty: { scale: model_options?.frequency_penalty },
288
+ };
289
+ //Does not support system messages
290
+ if (prompt.system && prompt.system?.length != 0) {
291
+ prompt.messages?.push(converseSystemToMessages(prompt.system));
292
+ prompt.system = undefined;
293
+ prompt.messages = converseConcatMessages(prompt.messages);
294
+ }
295
+ }
244
296
  }
245
- else if (contains(options.model, "amazon")) {
246
- return {
247
- inputText: "User: " + prompt + "\nBot:", // see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html#model-parameters-titan-request-response
248
- textGenerationConfig: {
249
- temperature: options.temperature,
250
- topP: options.top_p,
251
- maxTokenCount: options.max_tokens,
252
- //stopSequences: ["\n"],
253
- },
254
- };
297
+ else if (options.model.includes("cohere.command")) {
298
+ // If last message is "```json", remove it.
299
+ // Model requires the final message to be a user message or does not support assistant messages
300
+ prompt.messages = converseRemoveJSONprefill(prompt.messages);
301
+ //Command R and R plus
302
+ if (options.model.includes("cohere.command-r")) {
303
+ additionalField = {
304
+ k: model_options?.top_k,
305
+ frequency_penalty: model_options?.frequency_penalty,
306
+ presence_penalty: model_options?.presence_penalty,
307
+ };
308
+ }
309
+ else {
310
+ // Command non-R
311
+ additionalField = { k: model_options?.top_k };
312
+ //Does not support system messages
313
+ if (prompt.system && prompt.system?.length != 0) {
314
+ prompt.messages?.push(converseSystemToMessages(prompt.system));
315
+ prompt.system = undefined;
316
+ prompt.messages = converseConcatMessages(prompt.messages);
317
+ }
318
+ }
255
319
  }
256
- else if (contains(options.model, "mistral")) {
257
- return {
258
- prompt: prompt,
259
- temperature: options.temperature,
260
- max_tokens: options.max_tokens,
261
- };
320
+ //If last message is "```json", add corresponding ``` as a stop sequence.
321
+ if (prompt.messages && prompt.messages.length > 0) {
322
+ if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
323
+ let stopSeq = model_options?.stop_sequence;
324
+ if (!stopSeq) {
325
+ model_options.stop_sequence = ["```"];
326
+ }
327
+ else if (!stopSeq.includes("```")) {
328
+ stopSeq.push("```");
329
+ model_options.stop_sequence = stopSeq;
330
+ }
331
+ }
332
+ }
333
+ return {
334
+ messages: prompt.messages,
335
+ system: prompt.system,
336
+ modelId: options.model,
337
+ inferenceConfig: {
338
+ maxTokens: model_options?.max_tokens,
339
+ temperature: model_options?.temperature,
340
+ topP: model_options?.top_p,
341
+ stopSequences: model_options?.stop_sequence,
342
+ },
343
+ additionalModelRequestFields: {
344
+ ...additionalField,
345
+ },
346
+ };
347
+ }
348
+ async requestImageGeneration(prompt, options) {
349
+ if (options.output_modality !== Modalities.image) {
350
+ throw new Error(`Image generation requires image output_modality`);
262
351
  }
263
- else {
264
- throw new Error("Cannot prepare payload for unknown provider: " + options.model);
352
+ if (options.model_options?._option_id !== "bedrock-nova-canvas") {
353
+ this.logger.warn("Invalid model options", { options: options.model_options });
354
+ }
355
+ const model_options = options.model_options;
356
+ const executor = this.getExecutor();
357
+ const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE;
358
+ this.logger.info("Task type: " + taskType);
359
+ if (typeof prompt === "string") {
360
+ throw new Error("Bad prompt format");
265
361
  }
362
+ const payload = await formatNovaImageGenerationPayload(taskType, prompt, options);
363
+ const res = await executor.invokeModel({
364
+ modelId: options.model,
365
+ contentType: "application/json",
366
+ accept: "application/json",
367
+ body: JSON.stringify(payload),
368
+ }, {
369
+ requestTimeout: 60000 * 5
370
+ });
371
+ const decoder = new TextDecoder();
372
+ const body = decoder.decode(res.body);
373
+ const result = JSON.parse(body);
374
+ return {
375
+ error: result.error,
376
+ result: {
377
+ images: result.images,
378
+ }
379
+ };
266
380
  }
267
381
  async startTraining(dataset, options) {
268
382
  //convert options.params to Record<string, string>
@@ -327,12 +441,13 @@ export class BedrockDriver extends AbstractDriver {
327
441
  async listModels() {
328
442
  this.logger.debug("[Bedrock] listing models");
329
443
  // exclude trainable models since they are not executable
330
- const filter = (m) => m.inferenceTypesSupported?.includes("ON_DEMAND") ?? false;
444
+ // exclude embedding models, not to be used for typical completions.
445
+ const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false;
331
446
  return this._listModels(filter);
332
447
  }
333
448
  async _listModels(foundationFilter) {
334
449
  const service = this.getService();
335
- const [foundationals, customs] = await Promise.all([
450
+ const [foundationals, customs, inferenceProfiles] = await Promise.all([
336
451
  service.listFoundationModels({}).catch(() => {
337
452
  this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
338
453
  return undefined;
@@ -341,6 +456,10 @@ export class BedrockDriver extends AbstractDriver {
341
456
  this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
342
457
  return undefined;
343
458
  }),
459
+ service.listInferenceProfiles({}).catch(() => {
460
+ this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions.");
461
+ return undefined;
462
+ }),
344
463
  ]);
345
464
  if (!foundationals?.modelSummaries) {
346
465
  throw new Error("Foundation models not found");
@@ -349,6 +468,12 @@ export class BedrockDriver extends AbstractDriver {
349
468
  if (foundationFilter) {
350
469
  fmodels = fmodels.filter(foundationFilter);
351
470
  }
471
+ const supportedProviders = ["amazon", "anthropic", "cohere", "ai21", "mistral", "meta", "deepseek"];
472
+ fmodels = fmodels.filter((m) => {
473
+ supportedProviders.some((provider) => {
474
+ m.providerName?.includes(provider) ?? false;
475
+ });
476
+ });
352
477
  const aimodels = fmodels.map((m) => {
353
478
  if (!m.modelId) {
354
479
  throw new Error("modelId not found");
@@ -357,6 +482,7 @@ export class BedrockDriver extends AbstractDriver {
357
482
  id: m.modelArn ?? m.modelId,
358
483
  name: `${m.providerName} ${m.modelName}`,
359
484
  provider: this.provider,
485
+ input_modalities: m.inputModalities ?? [],
360
486
  //description: ``,
361
487
  owner: m.providerName,
362
488
  can_stream: m.responseStreamingSupported ?? false,
@@ -382,16 +508,33 @@ export class BedrockDriver extends AbstractDriver {
382
508
  this.validateConnection;
383
509
  });
384
510
  }
511
+ //add inference profiles
512
+ if (inferenceProfiles?.inferenceProfileSummaries) {
513
+ inferenceProfiles.inferenceProfileSummaries.forEach((p) => {
514
+ if (!p.inferenceProfileArn) {
515
+ throw new Error("Profile ARN not found");
516
+ }
517
+ const model = {
518
+ id: p.inferenceProfileArn ?? p.inferenceProfileId,
519
+ name: p.inferenceProfileName ?? p.inferenceProfileArn,
520
+ provider: this.provider,
521
+ };
522
+ aimodels.push(model);
523
+ });
524
+ }
385
525
  return aimodels;
386
526
  }
387
- async generateEmbeddings({ content, model = "amazon.titan-embed-text-v1" }) {
527
+ async generateEmbeddings({ text, image, model }) {
388
528
  this.logger.info("[Bedrock] Generating embeddings with model " + model);
529
+ const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0";
530
+ const modelID = model ?? defaultModel;
389
531
  const invokeBody = {
390
- inputText: content
532
+ inputText: text,
533
+ inputImage: image
391
534
  };
392
535
  const executor = this.getExecutor();
393
536
  const res = await executor.invokeModel({
394
- modelId: model,
537
+ modelId: modelID,
395
538
  contentType: "application/json",
396
539
  body: JSON.stringify(invokeBody),
397
540
  });
@@ -403,7 +546,7 @@ export class BedrockDriver extends AbstractDriver {
403
546
  }
404
547
  return {
405
548
  values: result.embedding,
406
- model: model,
549
+ model: modelID,
407
550
  token_count: result.inputTextTokenCount
408
551
  };
409
552
  }
@@ -434,40 +577,4 @@ function jobInfo(job, jobId) {
434
577
  details
435
578
  };
436
579
  }
437
- function claudeFinishReason(reason) {
438
- if (!reason)
439
- return undefined;
440
- switch (reason) {
441
- case 'end_turn': return "stop";
442
- case 'max_tokens': return "length";
443
- default: return reason; //stop_sequence
444
- }
445
- }
446
- function cohereFinishReason(reason) {
447
- if (!reason)
448
- return undefined;
449
- switch (reason) {
450
- case 'COMPLETE': return "stop";
451
- case 'MAX_TOKENS': return "length";
452
- default: return reason;
453
- }
454
- }
455
- function a21FinishReason(reason) {
456
- if (!reason)
457
- return undefined;
458
- switch (reason) {
459
- case 'endoftext': return "stop";
460
- case 'length': return "length";
461
- default: return reason;
462
- }
463
- }
464
- function titanFinishReason(reason) {
465
- if (!reason)
466
- return undefined;
467
- switch (reason) {
468
- case 'FINISH': return "stop";
469
- case 'LENGTH': return "length";
470
- default: return reason;
471
- }
472
- }
473
580
  //# sourceMappingURL=index.js.map