@llumiverse/drivers 0.14.0 → 0.16.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. package/README.md +3 -3
  2. package/lib/cjs/adobe/firefly.js +119 -0
  3. package/lib/cjs/adobe/firefly.js.map +1 -0
  4. package/lib/cjs/bedrock/converse.js +177 -0
  5. package/lib/cjs/bedrock/converse.js.map +1 -0
  6. package/lib/cjs/bedrock/index.js +329 -228
  7. package/lib/cjs/bedrock/index.js.map +1 -1
  8. package/lib/cjs/bedrock/nova-image-payload.js +207 -0
  9. package/lib/cjs/bedrock/nova-image-payload.js.map +1 -0
  10. package/lib/cjs/groq/index.js +34 -9
  11. package/lib/cjs/groq/index.js.map +1 -1
  12. package/lib/cjs/huggingface_ie.js +28 -12
  13. package/lib/cjs/huggingface_ie.js.map +1 -1
  14. package/lib/cjs/index.js +1 -0
  15. package/lib/cjs/index.js.map +1 -1
  16. package/lib/cjs/mistral/index.js +31 -12
  17. package/lib/cjs/mistral/index.js.map +1 -1
  18. package/lib/cjs/mistral/types.js.map +1 -1
  19. package/lib/cjs/openai/index.js +149 -27
  20. package/lib/cjs/openai/index.js.map +1 -1
  21. package/lib/cjs/replicate.js +16 -18
  22. package/lib/cjs/replicate.js.map +1 -1
  23. package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -1
  24. package/lib/cjs/test/index.js.map +1 -1
  25. package/lib/cjs/togetherai/index.js +40 -10
  26. package/lib/cjs/togetherai/index.js.map +1 -1
  27. package/lib/cjs/vertexai/embeddings/embeddings-image.js +26 -0
  28. package/lib/cjs/vertexai/embeddings/embeddings-image.js.map +1 -0
  29. package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -1
  30. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
  31. package/lib/cjs/vertexai/index.js +92 -25
  32. package/lib/cjs/vertexai/index.js.map +1 -1
  33. package/lib/cjs/vertexai/models/claude.js +252 -0
  34. package/lib/cjs/vertexai/models/claude.js.map +1 -0
  35. package/lib/cjs/vertexai/models/gemini.js +169 -27
  36. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  37. package/lib/cjs/vertexai/models/imagen.js +317 -0
  38. package/lib/cjs/vertexai/models/imagen.js.map +1 -0
  39. package/lib/cjs/vertexai/models.js +12 -107
  40. package/lib/cjs/vertexai/models.js.map +1 -1
  41. package/lib/cjs/watsonx/index.js +39 -8
  42. package/lib/cjs/watsonx/index.js.map +1 -1
  43. package/lib/cjs/xai/index.js +71 -0
  44. package/lib/cjs/xai/index.js.map +1 -0
  45. package/lib/esm/adobe/firefly.js +115 -0
  46. package/lib/esm/adobe/firefly.js.map +1 -0
  47. package/lib/esm/bedrock/converse.js +171 -0
  48. package/lib/esm/bedrock/converse.js.map +1 -0
  49. package/lib/esm/bedrock/index.js +331 -230
  50. package/lib/esm/bedrock/index.js.map +1 -1
  51. package/lib/esm/bedrock/nova-image-payload.js +203 -0
  52. package/lib/esm/bedrock/nova-image-payload.js.map +1 -0
  53. package/lib/esm/groq/index.js +34 -9
  54. package/lib/esm/groq/index.js.map +1 -1
  55. package/lib/esm/huggingface_ie.js +29 -13
  56. package/lib/esm/huggingface_ie.js.map +1 -1
  57. package/lib/esm/index.js +1 -0
  58. package/lib/esm/index.js.map +1 -1
  59. package/lib/esm/mistral/index.js +31 -12
  60. package/lib/esm/mistral/index.js.map +1 -1
  61. package/lib/esm/mistral/types.js.map +1 -1
  62. package/lib/esm/openai/index.js +150 -28
  63. package/lib/esm/openai/index.js.map +1 -1
  64. package/lib/esm/replicate.js +17 -19
  65. package/lib/esm/replicate.js.map +1 -1
  66. package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -1
  67. package/lib/esm/test/index.js.map +1 -1
  68. package/lib/esm/togetherai/index.js +40 -10
  69. package/lib/esm/togetherai/index.js.map +1 -1
  70. package/lib/esm/vertexai/embeddings/embeddings-image.js +23 -0
  71. package/lib/esm/vertexai/embeddings/embeddings-image.js.map +1 -0
  72. package/lib/esm/vertexai/embeddings/embeddings-text.js +1 -1
  73. package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -1
  74. package/lib/esm/vertexai/index.js +93 -27
  75. package/lib/esm/vertexai/index.js.map +1 -1
  76. package/lib/esm/vertexai/models/claude.js +247 -0
  77. package/lib/esm/vertexai/models/claude.js.map +1 -0
  78. package/lib/esm/vertexai/models/gemini.js +170 -28
  79. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  80. package/lib/esm/vertexai/models/imagen.js +310 -0
  81. package/lib/esm/vertexai/models/imagen.js.map +1 -0
  82. package/lib/esm/vertexai/models.js +12 -104
  83. package/lib/esm/vertexai/models.js.map +1 -1
  84. package/lib/esm/watsonx/index.js +39 -8
  85. package/lib/esm/watsonx/index.js.map +1 -1
  86. package/lib/esm/xai/index.js +64 -0
  87. package/lib/esm/xai/index.js.map +1 -0
  88. package/lib/types/adobe/firefly.d.ts +30 -0
  89. package/lib/types/adobe/firefly.d.ts.map +1 -0
  90. package/lib/types/bedrock/converse.d.ts +8 -0
  91. package/lib/types/bedrock/converse.d.ts.map +1 -0
  92. package/lib/types/bedrock/index.d.ts +26 -11
  93. package/lib/types/bedrock/index.d.ts.map +1 -1
  94. package/lib/types/bedrock/nova-image-payload.d.ts +74 -0
  95. package/lib/types/bedrock/nova-image-payload.d.ts.map +1 -0
  96. package/lib/types/bedrock/payloads.d.ts +9 -65
  97. package/lib/types/bedrock/payloads.d.ts.map +1 -1
  98. package/lib/types/groq/index.d.ts +3 -3
  99. package/lib/types/groq/index.d.ts.map +1 -1
  100. package/lib/types/huggingface_ie.d.ts +5 -7
  101. package/lib/types/huggingface_ie.d.ts.map +1 -1
  102. package/lib/types/index.d.ts +1 -0
  103. package/lib/types/index.d.ts.map +1 -1
  104. package/lib/types/mistral/index.d.ts +4 -4
  105. package/lib/types/mistral/index.d.ts.map +1 -1
  106. package/lib/types/mistral/types.d.ts +1 -0
  107. package/lib/types/mistral/types.d.ts.map +1 -1
  108. package/lib/types/openai/index.d.ts +5 -4
  109. package/lib/types/openai/index.d.ts.map +1 -1
  110. package/lib/types/replicate.d.ts +4 -9
  111. package/lib/types/replicate.d.ts.map +1 -1
  112. package/lib/types/test/index.d.ts +2 -2
  113. package/lib/types/test/index.d.ts.map +1 -1
  114. package/lib/types/togetherai/index.d.ts +4 -4
  115. package/lib/types/togetherai/index.d.ts.map +1 -1
  116. package/lib/types/vertexai/embeddings/embeddings-image.d.ts +11 -0
  117. package/lib/types/vertexai/embeddings/embeddings-image.d.ts.map +1 -0
  118. package/lib/types/vertexai/index.d.ts +19 -8
  119. package/lib/types/vertexai/index.d.ts.map +1 -1
  120. package/lib/types/vertexai/models/claude.d.ts +20 -0
  121. package/lib/types/vertexai/models/claude.d.ts.map +1 -0
  122. package/lib/types/vertexai/models/gemini.d.ts +4 -4
  123. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  124. package/lib/types/vertexai/models/imagen.d.ts +75 -0
  125. package/lib/types/vertexai/models/imagen.d.ts.map +1 -0
  126. package/lib/types/vertexai/models.d.ts +3 -6
  127. package/lib/types/vertexai/models.d.ts.map +1 -1
  128. package/lib/types/watsonx/index.d.ts +3 -3
  129. package/lib/types/watsonx/index.d.ts.map +1 -1
  130. package/lib/types/xai/index.d.ts +19 -0
  131. package/lib/types/xai/index.d.ts.map +1 -0
  132. package/package.json +24 -23
  133. package/src/adobe/firefly.ts +207 -0
  134. package/src/bedrock/converse.ts +194 -0
  135. package/src/bedrock/index.ts +349 -237
  136. package/src/bedrock/nova-image-payload.ts +309 -0
  137. package/src/bedrock/payloads.ts +12 -66
  138. package/src/groq/index.ts +35 -12
  139. package/src/huggingface_ie.ts +34 -13
  140. package/src/index.ts +1 -0
  141. package/src/mistral/index.ts +34 -12
  142. package/src/mistral/types.ts +2 -1
  143. package/src/openai/index.ts +167 -33
  144. package/src/replicate.ts +21 -20
  145. package/src/test/TestValidationErrorCompletionStream.ts +2 -2
  146. package/src/test/index.ts +3 -2
  147. package/src/togetherai/index.ts +44 -12
  148. package/src/vertexai/embeddings/embeddings-image.ts +50 -0
  149. package/src/vertexai/embeddings/embeddings-text.ts +1 -1
  150. package/src/vertexai/index.ts +114 -37
  151. package/src/vertexai/models/claude.ts +281 -0
  152. package/src/vertexai/models/gemini.ts +181 -31
  153. package/src/vertexai/models/imagen.ts +401 -0
  154. package/src/vertexai/models.ts +16 -120
  155. package/src/watsonx/index.ts +42 -10
  156. package/src/xai/index.ts +110 -0
  157. package/lib/cjs/vertexai/models/codey-chat.js +0 -65
  158. package/lib/cjs/vertexai/models/codey-chat.js.map +0 -1
  159. package/lib/cjs/vertexai/models/codey-text.js +0 -35
  160. package/lib/cjs/vertexai/models/codey-text.js.map +0 -1
  161. package/lib/cjs/vertexai/models/palm-model-base.js +0 -59
  162. package/lib/cjs/vertexai/models/palm-model-base.js.map +0 -1
  163. package/lib/cjs/vertexai/models/palm2-chat.js +0 -65
  164. package/lib/cjs/vertexai/models/palm2-chat.js.map +0 -1
  165. package/lib/cjs/vertexai/models/palm2-text.js +0 -35
  166. package/lib/cjs/vertexai/models/palm2-text.js.map +0 -1
  167. package/lib/cjs/vertexai/utils/tensor.js +0 -86
  168. package/lib/cjs/vertexai/utils/tensor.js.map +0 -1
  169. package/lib/esm/vertexai/models/codey-chat.js +0 -61
  170. package/lib/esm/vertexai/models/codey-chat.js.map +0 -1
  171. package/lib/esm/vertexai/models/codey-text.js +0 -31
  172. package/lib/esm/vertexai/models/codey-text.js.map +0 -1
  173. package/lib/esm/vertexai/models/palm-model-base.js +0 -55
  174. package/lib/esm/vertexai/models/palm-model-base.js.map +0 -1
  175. package/lib/esm/vertexai/models/palm2-chat.js +0 -61
  176. package/lib/esm/vertexai/models/palm2-chat.js.map +0 -1
  177. package/lib/esm/vertexai/models/palm2-text.js +0 -31
  178. package/lib/esm/vertexai/models/palm2-text.js.map +0 -1
  179. package/lib/esm/vertexai/utils/tensor.js +0 -82
  180. package/lib/esm/vertexai/utils/tensor.js.map +0 -1
  181. package/lib/types/vertexai/models/codey-chat.d.ts +0 -51
  182. package/lib/types/vertexai/models/codey-chat.d.ts.map +0 -1
  183. package/lib/types/vertexai/models/codey-text.d.ts +0 -39
  184. package/lib/types/vertexai/models/codey-text.d.ts.map +0 -1
  185. package/lib/types/vertexai/models/palm-model-base.d.ts +0 -61
  186. package/lib/types/vertexai/models/palm-model-base.d.ts.map +0 -1
  187. package/lib/types/vertexai/models/palm2-chat.d.ts +0 -61
  188. package/lib/types/vertexai/models/palm2-chat.d.ts.map +0 -1
  189. package/lib/types/vertexai/models/palm2-text.d.ts +0 -39
  190. package/lib/types/vertexai/models/palm2-text.d.ts.map +0 -1
  191. package/lib/types/vertexai/utils/tensor.d.ts +0 -6
  192. package/lib/types/vertexai/utils/tensor.d.ts.map +0 -1
  193. package/src/vertexai/models/codey-chat.ts +0 -115
  194. package/src/vertexai/models/codey-text.ts +0 -69
  195. package/src/vertexai/models/palm-model-base.ts +0 -128
  196. package/src/vertexai/models/palm2-chat.ts +0 -119
  197. package/src/vertexai/models/palm2-text.ts +0 -69
  198. package/src/vertexai/utils/tensor.ts +0 -82
@@ -1,18 +1,38 @@
1
1
  import { Bedrock, CreateModelCustomizationJobCommand, FoundationModelSummary, GetModelCustomizationJobCommand, GetModelCustomizationJobCommandOutput, ModelCustomizationJobStatus, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock";
2
- import { BedrockRuntime, InvokeModelCommandOutput, ResponseStream } from "@aws-sdk/client-bedrock-runtime";
2
+ import { BedrockRuntime, ConverseRequest, ConverseResponse, ConverseStreamOutput, InferenceConfiguration } from "@aws-sdk/client-bedrock-runtime";
3
3
  import { S3Client } from "@aws-sdk/client-s3";
4
- import { AIModel, AbstractDriver, Completion, DataSource, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptOptions, PromptSegment, TrainingJob, TrainingJobStatus, TrainingOptions } from "@llumiverse/core";
4
+ import { AbstractDriver, AIModel, Completion, CompletionChunkObject, DataSource, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, ExecutionTokenUsage, ImageGeneration, Modalities, PromptOptions, PromptSegment, TextFallbackOptions, TrainingJob, TrainingJobStatus, TrainingOptions } from "@llumiverse/core";
5
5
  import { transformAsyncIterator } from "@llumiverse/core/async";
6
- import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters";
6
+ import { formatNovaPrompt, NovaMessagesPrompt } from "@llumiverse/core/formatters";
7
7
  import { AwsCredentialIdentity, Provider } from "@smithy/types";
8
8
  import mnemonist from "mnemonist";
9
- import { AI21RequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama2RequestPayload, MistralPayload } from "./payloads.js";
9
+ import { BedrockClaudeOptions, NovaCanvasOptions } from "../../../core/src/options/bedrock.js";
10
+ import { converseConcatMessages, converseRemoveJSONprefill, converseSystemToMessages, fortmatConversePrompt } from "./converse.js";
11
+ import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js";
10
12
  import { forceUploadFile } from "./s3.js";
11
13
 
12
14
  const { LRUCache } = mnemonist;
13
15
 
14
16
  const supportStreamingCache = new LRUCache<string, boolean>(4096);
15
17
 
18
+ enum BedrockModelType {
19
+ FoundationModel = "foundation-model",
20
+ InferenceProfile = "inference-profile",
21
+ CustomModel = "custom-model",
22
+ Unknown = "unknown",
23
+ };
24
+
25
+ function converseFinishReason(reason: string | undefined) {
26
+ //Possible values:
27
+ //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
28
+ if (!reason) return undefined;
29
+ switch (reason) {
30
+ case 'end_turn': return "stop";
31
+ case 'max_tokens': return "length";
32
+ default: return reason;
33
+ }
34
+ }
35
+
16
36
  export interface BedrockModelCapabilities {
17
37
  name: string;
18
38
  canStream: boolean;
@@ -40,7 +60,7 @@ export interface BedrockDriverOptions extends DriverOptions {
40
60
  credentials?: AwsCredentialIdentity | Provider<AwsCredentialIdentity>;
41
61
  }
42
62
 
43
- export type BedrockPrompt = string | ClaudeMessagesPrompt;
63
+ export type BedrockPrompt = NovaMessagesPrompt | ConverseRequest;
44
64
 
45
65
  export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockPrompt> {
46
66
 
@@ -50,6 +70,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
50
70
 
51
71
  private _executor?: BedrockRuntime;
52
72
  private _service?: Bedrock;
73
+ private _service_region?: string;
53
74
 
54
75
  constructor(options: BedrockDriverOptions) {
55
76
  super(options);
@@ -69,164 +90,170 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
69
90
  return this._executor;
70
91
  }
71
92
 
72
- getService() {
73
- if (!this._service) {
93
+ getService(region: string = this.options.region) {
94
+ if (!this._service || this._service_region != region) {
74
95
  this._service = new Bedrock({
75
- region: this.options.region,
96
+ region: region,
76
97
  credentials: this.options.credentials,
77
98
  });
99
+ this._service_region = region;
78
100
  }
79
101
  return this._service;
80
102
  }
81
103
 
82
104
  protected async formatPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<BedrockPrompt> {
83
- //TODO move the anthropic test in abstract driver?
84
- if (opts.model.includes('anthropic')) {
85
- //TODO: need to type better the types aren't checked properly by TS
86
- return await formatClaudePrompt(segments, opts.result_schema);
87
- } else {
88
- return await super.formatPrompt(segments, opts) as string;
105
+ if (opts.model.includes("canvas")) {
106
+ return await formatNovaPrompt(segments, opts.result_schema);
89
107
  }
108
+ return await fortmatConversePrompt(segments, opts.result_schema);
90
109
  }
91
110
 
92
- extractDataFromResponse(prompt: BedrockPrompt, response: InvokeModelCommandOutput): Completion {
93
-
94
- const decoder = new TextDecoder();
95
- const body = decoder.decode(response.body);
96
- const result = JSON.parse(body);
97
-
98
- const getTextAnsStopReason = (): string[] => {
99
- if (result.generation) {
100
- // LLAMA2
101
- return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
102
- } else if (result.generations) {
103
- // Cohere
104
- return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
105
- } else if (result.chat_history) {
106
- //Cohere Command R
107
- return [result.text, cohereFinishReason(result.finish_reason)];
108
- } else if (result.completions) {
109
- //A21
110
- return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
111
- } else if (result.content) {
112
- // Claude
113
- //if last prompt.messages is {, add { to the response
114
- const p = prompt as ClaudeMessagesPrompt;
115
- const lastMessage = (p as ClaudeMessagesPrompt).messages[p.messages.length - 1];
116
- const res = lastMessage.content[0].text === '{' ? '{' + result.content[0]?.text : result.content[0]?.text;
117
-
118
- return [res, claudeFinishReason(result.stop_reason)];
119
-
120
- } else if (result.outputs) {
121
- // mistral
122
- return [result.outputs[0]?.text, result.outputs[0]?.stop_reason]; // the stop reason is in the expected format ("stop" and "length")
123
- } else if (result.results) {
124
- // Amazon Titan
125
- return [result.results[0]?.outputText ?? '', titanFinishReason(result.results[0]?.completionReason)];
126
- } else if (result.completion) { // TODO: who uses this?
127
- return [result.completion];
128
- } else {
129
- return [result.toString()];
130
- }
131
- };
132
-
133
- const [text, finish_reason] = getTextAnsStopReason();
134
-
135
- const promptLength = typeof prompt === 'string' ? prompt.length :
136
- (prompt.system || '').length + prompt.messages.reduce((acc, m) => acc + m.content.length, 0);
111
+ static getExtractedExecuton(result: ConverseResponse, _prompt?: BedrockPrompt): CompletionChunkObject {
137
112
  return {
138
- result: text,
113
+ result: result.output?.message?.content?.map(c => c.text).join("\n") ?? "",
139
114
  token_usage: {
140
- result: text?.length,
141
- prompt: promptLength,
142
- total: text?.length + promptLength,
115
+ prompt: result.usage?.inputTokens,
116
+ result: result.usage?.outputTokens,
117
+ total: result.usage?.totalTokens,
143
118
  },
144
- finish_reason
119
+ finish_reason: converseFinishReason(result.stopReason),
145
120
  }
146
- }
121
+ };
122
+
123
+ static getExtractedStream(result: ConverseStreamOutput, _prompt?: BedrockPrompt): CompletionChunkObject {
124
+ let output: string = "";
125
+ let stop_reason = "";
126
+ let token_usage: ExecutionTokenUsage | undefined;
127
+ if (result.contentBlockDelta) {
128
+ output = result.contentBlockDelta.delta?.text ?? "";
129
+ }
130
+ if (result.messageStop) {
131
+ stop_reason = result.messageStop.stopReason ?? "";
132
+ }
133
+ if (result.metadata) {
134
+ token_usage = {
135
+ prompt: result.metadata.usage?.inputTokens,
136
+ result: result.metadata.usage?.outputTokens,
137
+ total: result.metadata.usage?.totalTokens,
138
+ }
139
+ }
140
+ return {
141
+ result: output,
142
+ token_usage: token_usage,
143
+ finish_reason: converseFinishReason(stop_reason),
144
+ }
145
+ };
147
146
 
148
- async requestCompletion(prompt: BedrockPrompt, options: ExecutionOptions): Promise<Completion> {
147
+ async requestTextCompletion(prompt: ConverseRequest, options: ExecutionOptions): Promise<Completion> {
149
148
 
150
149
  const payload = this.preparePayload(prompt, options);
151
150
  const executor = this.getExecutor();
152
- const res = await executor.invokeModel({
153
- modelId: options.model,
154
- contentType: "application/json",
155
- body: JSON.stringify(payload),
151
+
152
+ const res = await executor.converse({
153
+ ...payload,
156
154
  });
157
- const completion = this.extractDataFromResponse(prompt, res);
158
- if (options.include_original_response) {
159
- completion.original_response = res;
160
- }
155
+
156
+ const completion = {
157
+ ...BedrockDriver.getExtractedExecuton(res, prompt),
158
+ original_response: options.include_original_response ? res : undefined,
159
+ } satisfies Completion;
160
+
161
161
  return completion;
162
162
  }
163
163
 
164
+ extractRegion(modelString: string, defaultRegion: string): string {
165
+ // Match region in full ARN pattern
166
+ const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
167
+ if (arnMatch) {
168
+ return arnMatch[1];
169
+ }
170
+
171
+ // Match common AWS regions directly in string
172
+ const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
173
+ if (regionMatch) {
174
+ return regionMatch[0];
175
+ }
176
+
177
+ return defaultRegion;
178
+ }
179
+
180
+ private async getCanStream(model: string, type: BedrockModelType): Promise<boolean> {
181
+ let canStream: boolean = false;
182
+ let error: any = null;
183
+ const region = this.extractRegion(model, this.options.region);
184
+ if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) {
185
+ try {
186
+ const response = await this.getService(region).getFoundationModel({
187
+ modelIdentifier: model
188
+ });
189
+ canStream = response.modelDetails?.responseStreamingSupported ?? false;
190
+ return canStream;
191
+ } catch (e) {
192
+ error = e;
193
+ }
194
+ }
195
+ if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) {
196
+ try {
197
+ const response = await this.getService(region).getInferenceProfile({
198
+ inferenceProfileIdentifier: model
199
+ });
200
+ canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
201
+ return canStream;
202
+ } catch (e) {
203
+ error = e;
204
+ }
205
+ }
206
+ if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) {
207
+ try {
208
+ const response = await this.getService(region).getCustomModel({
209
+ modelIdentifier: model
210
+ });
211
+ canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
212
+ return canStream;
213
+ } catch (e) {
214
+ error = e;
215
+ }
216
+ }
217
+ if (error) {
218
+ console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
219
+ }
220
+ return canStream;
221
+ }
222
+
164
223
  protected async canStream(options: ExecutionOptions): Promise<boolean> {
165
224
  let canStream = supportStreamingCache.get(options.model);
166
225
  if (canStream == null) {
167
- const response = await this.getService().getFoundationModel({
168
- modelIdentifier: options.model
169
- });
170
- canStream = response.modelDetails?.responseStreamingSupported ?? false;
226
+ let type = BedrockModelType.Unknown;
227
+ if (options.model.includes("foundation-model")) {
228
+ type = BedrockModelType.FoundationModel;
229
+ } else if (options.model.includes("inference-profile")) {
230
+ type = BedrockModelType.InferenceProfile;
231
+ } else if (options.model.includes("custom-model")) {
232
+ type = BedrockModelType.CustomModel;
233
+ }
234
+ canStream = await this.getCanStream(options.model, type);
171
235
  supportStreamingCache.set(options.model, canStream);
172
236
  }
173
237
  return canStream;
174
238
  }
175
239
 
176
- async requestCompletionStream(prompt: BedrockPrompt, options: ExecutionOptions): Promise<AsyncIterable<string>> {
240
+ async requestTextCompletionStream(prompt: ConverseRequest, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
177
241
  const payload = this.preparePayload(prompt, options);
178
242
  const executor = this.getExecutor();
179
- return executor.invokeModelWithResponseStream({
180
- modelId: options.model,
181
- contentType: "application/json",
182
- body: JSON.stringify(payload),
243
+ return executor.converseStream({
244
+ ...payload,
183
245
  }).then((res) => {
246
+ const stream = res.stream;
184
247
 
185
- if (!res.body) {
186
- throw new Error("Body not found");
248
+ if (!stream) {
249
+ throw new Error("[Bedrock] Stream not found in response");
187
250
  }
188
- const decoder = new TextDecoder();
189
251
 
190
- const addBracket = () => {
191
- if (typeof prompt === 'object' && (prompt as ClaudeMessagesPrompt).messages) {
192
- const p = prompt as ClaudeMessagesPrompt;
193
- const lastMessage = p.messages[p.messages.length - 1];
194
- return lastMessage.content[0].text === '{';
195
- }
196
- return false;
197
- };
198
-
199
- return transformAsyncIterator(res.body, (stream: ResponseStream) => {
200
- const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
252
+ return transformAsyncIterator(stream, (stream: ConverseStreamOutput) => {
253
+ //const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
201
254
  //console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
202
- if (segment.delta) { // who is this?
203
- return segment.delta.text || '';
204
- } else if (segment.completion) { // who is this?
205
- return segment.completion;
206
- } else if (segment.text) { //cohere
207
- return segment.text;
208
- } else if (segment.completions) {
209
- return segment.completions[0].data?.text;
210
- } else if (segment.generation) {
211
- return segment.generation;
212
- } else if (segment.generations) {
213
- return segment.generations[0].text;
214
- } else if (segment.outputs) {
215
- // mistral.mixtral-8x7b-instruct-v0:1
216
- return segment.outputs[0].text;
217
- //segment.outputs[0].stop_reason;
218
- } else if (segment.outputText) {
219
- // Amazon Titan
220
- return segment.outputText;
221
- //completionReason
222
- // token count too
223
- } else {
224
- segment.toString();
225
- }
226
-
227
- },
228
- () => addBracket() ? '{' : ''
229
- );
255
+ return BedrockDriver.getExtractedStream(stream, prompt);
256
+ });
230
257
 
231
258
  }).catch((err) => {
232
259
  this.logger.error("[Bedrock] Failed to stream", err);
@@ -234,79 +261,180 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
234
261
  });
235
262
  }
236
263
 
264
+ preparePayload(prompt: ConverseRequest, options: ExecutionOptions) {
265
+ const model_options = options.model_options as TextFallbackOptions;
237
266
 
267
+ let additionalField = {};
238
268
 
239
- preparePayload(prompt: BedrockPrompt, options: ExecutionOptions) {
240
-
241
- //split arn on / should give provider
242
- //TODO: check if works with custom models
243
- //const provider = options.model.split("/")[0];
244
- const contains = (str: string, substr: string) => str.indexOf(substr) !== -1;
245
-
246
- if (contains(options.model, "meta")) {
247
- return {
248
- prompt,
249
- temperature: options.temperature,
250
- max_gen_len: options.max_tokens,
251
- } as LLama2RequestPayload
252
- } else if (contains(options.model, "claude")) {
253
-
254
- const maxToken = () => {
255
- if (options.max_tokens) {
256
- return options.max_tokens;
257
-
258
- } else if (contains(options.model, "claude-3-5")) {
259
- return 8192;
269
+ if (options.model.includes("amazon")) {
270
+ //Titan models also exists but does not support any additional options
271
+ if (options.model.includes("nova")) {
272
+ additionalField = { inferenceConfig: { topK: model_options?.top_k } };
273
+ }
274
+ } else if (options.model.includes("claude")) {
275
+ if (options.model.includes("claude-3-7")) {
276
+ const thinking_options = options.model_options as BedrockClaudeOptions;
277
+ const thinking = thinking_options?.thinking_mode ?? false;
278
+ if (!model_options?.max_tokens) {
279
+ model_options.max_tokens = thinking ? 128000 : 8192;
280
+ }
281
+ additionalField = {
282
+ top_k: model_options?.top_k,
283
+ reasoning_config: {
284
+ type: thinking ? "enabled" : "disabled",
285
+ budget_tokens: thinking_options?.thinking_budget_tokens,
286
+ }
287
+ };
288
+ if(thinking && (thinking_options?.thinking_budget_tokens ?? 0) > 64000){
289
+ additionalField = {
290
+ ...additionalField,
291
+ anthorpic_beta: ["output-128k-2025-02-19"]
292
+ };
293
+ }
294
+ }
295
+ //Needs max_tokens to be set
296
+ if (!model_options?.max_tokens) {
297
+ if (options.model.includes("claude-3-5")) {
298
+ model_options.max_tokens = 8192;
299
+
300
+ //Bug with AWS Converse Sonnet 3.5, does not effect Haiku.
301
+ //See https://github.com/boto/boto3/issues/4279
302
+ if (options.model.includes("claude-3-5-sonnet")) {
303
+ model_options.max_tokens = 4096;
304
+ }
260
305
  } else {
261
- return 4096
306
+ model_options.max_tokens = 4096;
262
307
  }
263
308
  }
264
- return {
265
- anthropic_version: "bedrock-2023-05-31",
266
- ...(prompt as ClaudeMessagesPrompt),
267
- temperature: options.temperature,
268
- max_tokens: maxToken(),
269
- } as ClaudeRequestPayload;
270
- } else if (contains(options.model, "ai21")) {
271
- return {
272
- prompt: prompt,
273
- temperature: options.temperature,
274
- maxTokens: options.max_tokens,
275
- } as AI21RequestPayload;
276
- } else if (contains(options.model, "command-r-plus")) {
277
- return {
278
- message: prompt as string,
279
- max_tokens: options.max_tokens,
280
- temperature: options.temperature,
281
- } as CohereCommandRPayload;
309
+ additionalField = { top_k: model_options?.top_k };
310
+ } else if (options.model.includes("meta")) {
311
+ //If last message is "```json", remove it. Model requires the final message to be a user message
312
+ prompt.messages = converseRemoveJSONprefill(prompt.messages);
313
+ } else if (options.model.includes("mistral")) {
314
+ //7B instruct and 8x7B instruct
315
+ if (options.model.includes("7b")) {
316
+ additionalField = { top_k: model_options?.top_k };
317
+ //Does not support system messages
318
+ if (prompt.system && prompt.system?.length != 0) {
319
+ prompt.messages?.push(converseSystemToMessages(prompt.system));
320
+ prompt.system = undefined;
321
+ prompt.messages = converseConcatMessages(prompt.messages);
322
+ }
323
+ } else {
324
+ //Other models such as Mistral Small,Large and Large 2
325
+ //Support no additional fields.
326
+ prompt.messages = converseRemoveJSONprefill(prompt.messages);
327
+ }
328
+ } else if (options.model.includes("ai21")) {
329
+ //If last message is "```json", remove it. Model requires the final message to be a user message
330
+ prompt.messages = converseRemoveJSONprefill(prompt.messages);
331
+ //Jamba models support no additional options
332
+ //Jurassic 2 models do.
333
+ if (options.model.includes("j2")) {
334
+ additionalField = {
335
+ presencePenalty: { scale: model_options?.presence_penalty },
336
+ frequencyPenalty: { scale: model_options?.frequency_penalty },
337
+ };
338
+ //Does not support system messages
339
+ if (prompt.system && prompt.system?.length != 0) {
340
+ prompt.messages?.push(converseSystemToMessages(prompt.system));
341
+ prompt.system = undefined;
342
+ prompt.messages = converseConcatMessages(prompt.messages);
343
+ }
344
+ }
345
+ } else if (options.model.includes("cohere.command")) {
346
+ // If last message is "```json", remove it.
347
+ // Model requires the final message to be a user message or does not support assistant messages
348
+ prompt.messages = converseRemoveJSONprefill(prompt.messages);
349
+ //Command R and R plus
350
+ if (options.model.includes("cohere.command-r")) {
351
+ additionalField = {
352
+ k: model_options?.top_k,
353
+ frequency_penalty: model_options?.frequency_penalty,
354
+ presence_penalty: model_options?.presence_penalty,
355
+ };
356
+ } else {
357
+ // Command non-R
358
+ additionalField = { k: model_options?.top_k };
359
+ //Does not support system messages
360
+ if (prompt.system && prompt.system?.length != 0) {
361
+ prompt.messages?.push(converseSystemToMessages(prompt.system));
362
+ prompt.system = undefined;
363
+ prompt.messages = converseConcatMessages(prompt.messages);
364
+ }
365
+ }
366
+ }
282
367
 
368
+ //If last message is "```json", add corresponding ``` as a stop sequence.
369
+ if (prompt.messages && prompt.messages.length > 0) {
370
+ if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
371
+ let stopSeq = model_options?.stop_sequence;
372
+ if (!stopSeq) {
373
+ model_options.stop_sequence = ["```"];
374
+ } else if (!stopSeq.includes("```")) {
375
+ stopSeq.push("```");
376
+ model_options.stop_sequence = stopSeq;
377
+ }
378
+ }
283
379
  }
284
- else if (contains(options.model, "cohere")) {
285
- return {
286
- prompt: prompt,
287
- temperature: options.temperature,
288
- max_tokens: options.max_tokens,
289
- } as CohereRequestPayload;
290
- } else if (contains(options.model, "amazon")) {
291
- return {
292
- inputText: "User: " + (prompt as string) + "\nBot:", // see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html#model-parameters-titan-request-response
293
- textGenerationConfig: {
294
- temperature: options.temperature,
295
- topP: options.top_p,
296
- maxTokenCount: options.max_tokens,
297
- //stopSequences: ["\n"],
298
- },
299
- } as AmazonRequestPayload;
300
- } else if (contains(options.model, "mistral")) {
301
- return {
302
- prompt: prompt,
303
- temperature: options.temperature,
304
- max_tokens: options.max_tokens,
305
- } as MistralPayload;
306
- } else {
307
- throw new Error("Cannot prepare payload for unknown provider: " + options.model);
380
+
381
+ return {
382
+ messages: prompt.messages,
383
+ system: prompt.system,
384
+ modelId: options.model,
385
+ inferenceConfig: {
386
+ maxTokens: model_options?.max_tokens,
387
+ temperature: model_options?.temperature,
388
+ topP: model_options?.top_p,
389
+ stopSequences: model_options?.stop_sequence,
390
+ } satisfies InferenceConfiguration,
391
+ additionalModelRequestFields: {
392
+ ...additionalField,
393
+ },
394
+ } satisfies ConverseRequest;
395
+ }
396
+
397
+
398
+ async requestImageGeneration(prompt: NovaMessagesPrompt, options: ExecutionOptions): Promise<Completion<ImageGeneration>> {
399
+ if (options.output_modality !== Modalities.image) {
400
+ throw new Error(`Image generation requires image output_modality`);
401
+ }
402
+ if (options.model_options?._option_id !== "bedrock-nova-canvas") {
403
+ this.logger.warn("Invalid model options", {options: options.model_options });
308
404
  }
405
+ const model_options = options.model_options as NovaCanvasOptions;
406
+
407
+ const executor = this.getExecutor();
408
+ const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE;
409
+
410
+ this.logger.info("Task type: " + taskType);
411
+
412
+ if (typeof prompt === "string") {
413
+ throw new Error("Bad prompt format");
414
+ }
415
+
416
+ const payload = await formatNovaImageGenerationPayload(taskType, prompt, options);
309
417
 
418
+ const res = await executor.invokeModel({
419
+ modelId: options.model,
420
+ contentType: "application/json",
421
+ accept: "application/json",
422
+ body: JSON.stringify(payload),
423
+ },
424
+ {
425
+ requestTimeout: 60000 * 5
426
+ });
427
+
428
+ const decoder = new TextDecoder();
429
+ const body = decoder.decode(res.body);
430
+ const result = JSON.parse(body);
431
+
432
+ return {
433
+ error: result.error,
434
+ result: {
435
+ images: result.images,
436
+ }
437
+ }
310
438
  }
311
439
 
312
440
  async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {
@@ -387,13 +515,14 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
387
515
  async listModels(): Promise<AIModel[]> {
388
516
  this.logger.debug("[Bedrock] listing models");
389
517
  // exclude trainable models since they are not executable
390
- const filter = (m: FoundationModelSummary) => m.inferenceTypesSupported?.includes("ON_DEMAND") ?? false;
518
+ // exclude embedding models, not to be used for typical completions.
519
+ const filter = (m: FoundationModelSummary) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false;
391
520
  return this._listModels(filter);
392
521
  }
393
522
 
394
523
  async _listModels(foundationFilter?: (m: FoundationModelSummary) => boolean): Promise<AIModel[]> {
395
524
  const service = this.getService();
396
- const [foundationals, customs] = await Promise.all([
525
+ const [foundationals, customs, inferenceProfiles] = await Promise.all([
397
526
  service.listFoundationModels({}).catch(() => {
398
527
  this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
399
528
  return undefined
@@ -402,6 +531,10 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
402
531
  this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
403
532
  return undefined
404
533
  }),
534
+ service.listInferenceProfiles({}).catch(() => {
535
+ this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions.");
536
+ return undefined
537
+ }),
405
538
  ]);
406
539
 
407
540
  if (!foundationals?.modelSummaries) {
@@ -454,21 +587,41 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
454
587
  });
455
588
  }
456
589
 
590
+ //add inference profiles
591
+ if (inferenceProfiles?.inferenceProfileSummaries) {
592
+ inferenceProfiles.inferenceProfileSummaries.forEach((p) => {
593
+ if (!p.inferenceProfileArn) {
594
+ throw new Error("Profile ARN not found");
595
+ }
596
+
597
+ const model: AIModel = {
598
+ id: p.inferenceProfileArn ?? p.inferenceProfileId,
599
+ name: p.inferenceProfileName ?? p.inferenceProfileArn,
600
+ provider: this.provider,
601
+ };
602
+
603
+ aimodels.push(model);
604
+ });
605
+ }
606
+
457
607
  return aimodels;
458
608
  }
459
609
 
460
- async generateEmbeddings({ content, model = "amazon.titan-embed-text-v1" }: EmbeddingsOptions): Promise<EmbeddingsResult> {
610
+ async generateEmbeddings({ text, image, model }: EmbeddingsOptions): Promise<EmbeddingsResult> {
461
611
 
462
612
  this.logger.info("[Bedrock] Generating embeddings with model " + model);
613
+ const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0";
614
+ const modelID = model ?? defaultModel;
463
615
 
464
616
  const invokeBody = {
465
- inputText: content
617
+ inputText: text,
618
+ inputImage: image
466
619
  }
467
620
 
468
621
  const executor = this.getExecutor();
469
622
  const res = await executor.invokeModel(
470
623
  {
471
- modelId: model,
624
+ modelId: modelID,
472
625
  contentType: "application/json",
473
626
  body: JSON.stringify(invokeBody),
474
627
  }
@@ -485,16 +638,12 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
485
638
 
486
639
  return {
487
640
  values: result.embedding,
488
- model: model,
641
+ model: modelID,
489
642
  token_count: result.inputTextTokenCount
490
643
  };
491
-
492
644
  }
493
-
494
645
  }
495
646
 
496
-
497
-
498
647
  function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): TrainingJob {
499
648
  const jobStatus = job.status;
500
649
  let status = TrainingJobStatus.running;
@@ -517,41 +666,4 @@ function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): Tra
517
666
  status,
518
667
  details
519
668
  }
520
- }
521
-
522
-
523
- function claudeFinishReason(reason: string | undefined) {
524
- if (!reason) return undefined;
525
- switch (reason) {
526
- case 'end_turn': return "stop";
527
- case 'max_tokens': return "length";
528
- default: return reason; //stop_sequence
529
- }
530
- }
531
-
532
- function cohereFinishReason(reason: string | undefined) {
533
- if (!reason) return undefined;
534
- switch (reason) {
535
- case 'COMPLETE': return "stop";
536
- case 'MAX_TOKENS': return "length";
537
- default: return reason;
538
- }
539
- }
540
-
541
- function a21FinishReason(reason: string | undefined) {
542
- if (!reason) return undefined;
543
- switch (reason) {
544
- case 'endoftext': return "stop";
545
- case 'length': return "length";
546
- default: return reason;
547
- }
548
- }
549
-
550
- function titanFinishReason(reason: string | undefined) {
551
- if (!reason) return undefined;
552
- switch (reason) {
553
- case 'FINISH': return "stop";
554
- case 'LENGTH': return "length";
555
- default: return reason;
556
- }
557
- }
669
+ }