@llumiverse/drivers 0.15.0 → 0.17.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. package/README.md +3 -3
  2. package/lib/cjs/adobe/firefly.js +119 -0
  3. package/lib/cjs/adobe/firefly.js.map +1 -0
  4. package/lib/cjs/bedrock/converse.js +177 -0
  5. package/lib/cjs/bedrock/converse.js.map +1 -0
  6. package/lib/cjs/bedrock/index.js +338 -234
  7. package/lib/cjs/bedrock/index.js.map +1 -1
  8. package/lib/cjs/bedrock/nova-image-payload.js +207 -0
  9. package/lib/cjs/bedrock/nova-image-payload.js.map +1 -0
  10. package/lib/cjs/groq/index.js +34 -9
  11. package/lib/cjs/groq/index.js.map +1 -1
  12. package/lib/cjs/huggingface_ie.js +28 -12
  13. package/lib/cjs/huggingface_ie.js.map +1 -1
  14. package/lib/cjs/index.js +1 -0
  15. package/lib/cjs/index.js.map +1 -1
  16. package/lib/cjs/mistral/index.js +32 -13
  17. package/lib/cjs/mistral/index.js.map +1 -1
  18. package/lib/cjs/mistral/types.js.map +1 -1
  19. package/lib/cjs/openai/index.js +164 -29
  20. package/lib/cjs/openai/index.js.map +1 -1
  21. package/lib/cjs/replicate.js +19 -34
  22. package/lib/cjs/replicate.js.map +1 -1
  23. package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -1
  24. package/lib/cjs/test/index.js.map +1 -1
  25. package/lib/cjs/togetherai/index.js +40 -10
  26. package/lib/cjs/togetherai/index.js.map +1 -1
  27. package/lib/cjs/vertexai/embeddings/embeddings-image.js +26 -0
  28. package/lib/cjs/vertexai/embeddings/embeddings-image.js.map +1 -0
  29. package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -1
  30. package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
  31. package/lib/cjs/vertexai/index.js +134 -35
  32. package/lib/cjs/vertexai/index.js.map +1 -1
  33. package/lib/cjs/vertexai/models/claude.js +252 -0
  34. package/lib/cjs/vertexai/models/claude.js.map +1 -0
  35. package/lib/cjs/vertexai/models/gemini.js +172 -25
  36. package/lib/cjs/vertexai/models/gemini.js.map +1 -1
  37. package/lib/cjs/vertexai/models/imagen.js +317 -0
  38. package/lib/cjs/vertexai/models/imagen.js.map +1 -0
  39. package/lib/cjs/vertexai/models.js +12 -64
  40. package/lib/cjs/vertexai/models.js.map +1 -1
  41. package/lib/cjs/watsonx/index.js +47 -10
  42. package/lib/cjs/watsonx/index.js.map +1 -1
  43. package/lib/cjs/xai/index.js +71 -0
  44. package/lib/cjs/xai/index.js.map +1 -0
  45. package/lib/esm/adobe/firefly.js +115 -0
  46. package/lib/esm/adobe/firefly.js.map +1 -0
  47. package/lib/esm/bedrock/converse.js +171 -0
  48. package/lib/esm/bedrock/converse.js.map +1 -0
  49. package/lib/esm/bedrock/index.js +339 -232
  50. package/lib/esm/bedrock/index.js.map +1 -1
  51. package/lib/esm/bedrock/nova-image-payload.js +203 -0
  52. package/lib/esm/bedrock/nova-image-payload.js.map +1 -0
  53. package/lib/esm/groq/index.js +34 -9
  54. package/lib/esm/groq/index.js.map +1 -1
  55. package/lib/esm/huggingface_ie.js +29 -13
  56. package/lib/esm/huggingface_ie.js.map +1 -1
  57. package/lib/esm/index.js +1 -0
  58. package/lib/esm/index.js.map +1 -1
  59. package/lib/esm/mistral/index.js +32 -13
  60. package/lib/esm/mistral/index.js.map +1 -1
  61. package/lib/esm/mistral/types.js.map +1 -1
  62. package/lib/esm/openai/index.js +165 -30
  63. package/lib/esm/openai/index.js.map +1 -1
  64. package/lib/esm/replicate.js +19 -34
  65. package/lib/esm/replicate.js.map +1 -1
  66. package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -1
  67. package/lib/esm/test/index.js.map +1 -1
  68. package/lib/esm/togetherai/index.js +40 -10
  69. package/lib/esm/togetherai/index.js.map +1 -1
  70. package/lib/esm/vertexai/embeddings/embeddings-image.js +23 -0
  71. package/lib/esm/vertexai/embeddings/embeddings-image.js.map +1 -0
  72. package/lib/esm/vertexai/embeddings/embeddings-text.js +1 -1
  73. package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -1
  74. package/lib/esm/vertexai/index.js +135 -37
  75. package/lib/esm/vertexai/index.js.map +1 -1
  76. package/lib/esm/vertexai/models/claude.js +247 -0
  77. package/lib/esm/vertexai/models/claude.js.map +1 -0
  78. package/lib/esm/vertexai/models/gemini.js +173 -26
  79. package/lib/esm/vertexai/models/gemini.js.map +1 -1
  80. package/lib/esm/vertexai/models/imagen.js +310 -0
  81. package/lib/esm/vertexai/models/imagen.js.map +1 -0
  82. package/lib/esm/vertexai/models.js +12 -61
  83. package/lib/esm/vertexai/models.js.map +1 -1
  84. package/lib/esm/watsonx/index.js +47 -10
  85. package/lib/esm/watsonx/index.js.map +1 -1
  86. package/lib/esm/xai/index.js +64 -0
  87. package/lib/esm/xai/index.js.map +1 -0
  88. package/lib/types/adobe/firefly.d.ts +30 -0
  89. package/lib/types/adobe/firefly.d.ts.map +1 -0
  90. package/lib/types/bedrock/converse.d.ts +8 -0
  91. package/lib/types/bedrock/converse.d.ts.map +1 -0
  92. package/lib/types/bedrock/index.d.ts +27 -12
  93. package/lib/types/bedrock/index.d.ts.map +1 -1
  94. package/lib/types/bedrock/nova-image-payload.d.ts +74 -0
  95. package/lib/types/bedrock/nova-image-payload.d.ts.map +1 -0
  96. package/lib/types/bedrock/payloads.d.ts +9 -65
  97. package/lib/types/bedrock/payloads.d.ts.map +1 -1
  98. package/lib/types/groq/index.d.ts +3 -3
  99. package/lib/types/groq/index.d.ts.map +1 -1
  100. package/lib/types/huggingface_ie.d.ts +5 -7
  101. package/lib/types/huggingface_ie.d.ts.map +1 -1
  102. package/lib/types/index.d.ts +1 -0
  103. package/lib/types/index.d.ts.map +1 -1
  104. package/lib/types/mistral/index.d.ts +4 -4
  105. package/lib/types/mistral/index.d.ts.map +1 -1
  106. package/lib/types/mistral/types.d.ts +1 -0
  107. package/lib/types/mistral/types.d.ts.map +1 -1
  108. package/lib/types/openai/index.d.ts +5 -4
  109. package/lib/types/openai/index.d.ts.map +1 -1
  110. package/lib/types/replicate.d.ts +4 -9
  111. package/lib/types/replicate.d.ts.map +1 -1
  112. package/lib/types/test/index.d.ts +2 -2
  113. package/lib/types/test/index.d.ts.map +1 -1
  114. package/lib/types/togetherai/index.d.ts +4 -4
  115. package/lib/types/togetherai/index.d.ts.map +1 -1
  116. package/lib/types/vertexai/embeddings/embeddings-image.d.ts +11 -0
  117. package/lib/types/vertexai/embeddings/embeddings-image.d.ts.map +1 -0
  118. package/lib/types/vertexai/index.d.ts +21 -8
  119. package/lib/types/vertexai/index.d.ts.map +1 -1
  120. package/lib/types/vertexai/models/claude.d.ts +20 -0
  121. package/lib/types/vertexai/models/claude.d.ts.map +1 -0
  122. package/lib/types/vertexai/models/gemini.d.ts +4 -4
  123. package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
  124. package/lib/types/vertexai/models/imagen.d.ts +75 -0
  125. package/lib/types/vertexai/models/imagen.d.ts.map +1 -0
  126. package/lib/types/vertexai/models.d.ts +3 -6
  127. package/lib/types/vertexai/models.d.ts.map +1 -1
  128. package/lib/types/watsonx/index.d.ts +3 -3
  129. package/lib/types/watsonx/index.d.ts.map +1 -1
  130. package/lib/types/watsonx/interfaces.d.ts +4 -0
  131. package/lib/types/watsonx/interfaces.d.ts.map +1 -1
  132. package/lib/types/xai/index.d.ts +19 -0
  133. package/lib/types/xai/index.d.ts.map +1 -0
  134. package/package.json +25 -26
  135. package/src/adobe/firefly.ts +207 -0
  136. package/src/bedrock/converse.ts +194 -0
  137. package/src/bedrock/index.ts +359 -240
  138. package/src/bedrock/nova-image-payload.ts +309 -0
  139. package/src/bedrock/payloads.ts +12 -66
  140. package/src/groq/index.ts +35 -12
  141. package/src/huggingface_ie.ts +34 -13
  142. package/src/index.ts +1 -0
  143. package/src/mistral/index.ts +35 -13
  144. package/src/mistral/types.ts +2 -1
  145. package/src/openai/index.ts +186 -35
  146. package/src/replicate.ts +24 -35
  147. package/src/test/TestValidationErrorCompletionStream.ts +2 -2
  148. package/src/test/index.ts +3 -2
  149. package/src/togetherai/index.ts +44 -12
  150. package/src/vertexai/embeddings/embeddings-image.ts +50 -0
  151. package/src/vertexai/embeddings/embeddings-text.ts +1 -1
  152. package/src/vertexai/index.ts +186 -46
  153. package/src/vertexai/models/claude.ts +281 -0
  154. package/src/vertexai/models/gemini.ts +186 -29
  155. package/src/vertexai/models/imagen.ts +401 -0
  156. package/src/vertexai/models.ts +16 -78
  157. package/src/watsonx/index.ts +50 -12
  158. package/src/watsonx/interfaces.ts +4 -0
  159. package/src/xai/index.ts +110 -0
@@ -0,0 +1,401 @@
1
+ import { AIModel, Completion, ExecutionOptions, ImageGeneration, Modalities, ModelType, PromptRole, PromptSegment, readStreamAsBase64 } from "@llumiverse/core";
2
+ import { VertexAIDriver } from "../index.js";
3
+
4
+ const projectId = process.env.GOOGLE_PROJECT_ID;
5
+ const location = 'us-central1';
6
+
7
+ import aiplatform, { protos } from '@google-cloud/aiplatform';
8
+
9
+ // Imports the Google Cloud Prediction Service Client library
10
+ const { PredictionServiceClient } = aiplatform.v1;
11
+
12
+ // Import the helper module for converting arbitrary protobuf.Value objects
13
+ import { helpers } from '@google-cloud/aiplatform';
14
+ import { ImagenOptions } from "../../../../core/src/options/vertexai.js";
15
+
16
+ interface ImagenBaseReference {
17
+ referenceType: "REFERENCE_TYPE_RAW" | "REFERENCE_TYPE_MASK" | "REFERENCE_TYPE_SUBJECT" |
18
+ "REFERENCE_TYPE_CONTROL" | "REFERENCE_TYPE_STYLE";
19
+ referenceId: number;
20
+ referenceImage: {
21
+ bytesBase64Encoded: string; //10MB max
22
+ }
23
+ }
24
+
25
+ export enum ImagenTaskType {
26
+ TEXT_IMAGE = "TEXT_IMAGE",
27
+ EDIT_MODE_INPAINT_REMOVAL = "EDIT_MODE_INPAINT_REMOVAL",
28
+ EDIT_MODE_INPAINT_INSERTION = "EDIT_MODE_INPAINT_INSERTION",
29
+ EDIT_MODE_BGSWAP = "EDIT_MODE_BGSWAP",
30
+ EDIT_MODE_OUTPAINT = "EDIT_MODE_OUTPAINT",
31
+ CUSTOMIZATION_SUBJECT = "CUSTOMIZATION_SUBJECT",
32
+ CUSTOMIZATION_STYLE = "CUSTOMIZATION_STYLE",
33
+ CUSTOMIZATION_CONTROLLED = "CUSTOMIZATION_CONTROLLED",
34
+ CUSTOMIZATION_INSTRUCT = "CUSTOMIZATION_INSTRUCT",
35
+ }
36
+
37
+ export enum ImagenMaskMode {
38
+ MASK_MODE_USER_PROVIDED = "MASK_MODE_USER_PROVIDED",
39
+ MASK_MODE_BACKGROUND = "MASK_MODE_BACKGROUND",
40
+ MASK_MODE_FOREGROUND = "MASK_MODE_FOREGROUND",
41
+ MASK_MODE_SEMANTIC = "MASK_MODE_SEMANTIC",
42
+ }
43
+
44
+ interface ImagenReferenceRaw extends ImagenBaseReference {
45
+ referenceType: "REFERENCE_TYPE_RAW";
46
+ }
47
+
48
+ interface ImagenReferenceMask extends Omit<ImagenBaseReference, "referenceImage"> {
49
+ referenceType: "REFERENCE_TYPE_MASK";
50
+ maskImageConfig: {
51
+ maskMode?: ImagenMaskMode;
52
+ maskClasses?: number[]; //Used for MASK_MODE_SEMANTIC, based on https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api-customization#segment-ids
53
+ dilation?: number; //Recommendation depends on mode: Inpaint: 0.01, BGSwap: 0.0, Outpaint: 0.01-0.03
54
+ }
55
+ referenceImage?: { //Only used for MASK_MODE_USER_PROVIDED
56
+ bytesBase64Encoded: string; //10MB max
57
+ }
58
+ }
59
+
60
+ interface ImagenReferenceSubject extends ImagenBaseReference {
61
+ referenceType: "REFERENCE_TYPE_SUBJECT";
62
+ subjectImageConfig: {
63
+ subjectDescription: string;
64
+ subjectType: "SUBJECT_TYPE_PERSON" | "SUBJECT_TYPE_ANIMAL" | "SUBJECT_TYPE_PRODUCT" | "SUBJECT_TYPE_DEFAULT";
65
+ }
66
+ }
67
+
68
+ interface ImagenReferenceControl extends ImagenBaseReference {
69
+ referenceType: "REFERENCE_TYPE_CONTROL";
70
+ controlImageConfig: {
71
+ controlType: "CONTROL_TYPE_FACE_MESH" | "CONTROL_TYPE_CANNY" | "CONTROL_TYPE_SCRIBBLE";
72
+ enableControlImageComputation?: boolean; //If true, the model will compute the control image
73
+ }
74
+ }
75
+
76
+ interface ImagenReferenceStyle extends ImagenBaseReference {
77
+ referenceType: "REFERENCE_TYPE_STYLE";
78
+ styleImageConfig: {
79
+ styleDescription?: string;
80
+ }
81
+ }
82
+
83
+ type ImagenMessage = ImagenReferenceRaw | ImagenReferenceMask | ImagenReferenceSubject | ImagenReferenceControl | ImagenReferenceStyle;
84
+
85
+ export interface ImagenPrompt {
86
+ prompt: string;
87
+ referenceImages?: ImagenMessage[];
88
+ subjectDescription?: string; //Used for image customization to describe in the reference image
89
+ negativePrompt?: string; //Used for negative prompts
90
+ }
91
+
92
+ // Specifies the location of the api endpoint
93
+ const clientOptions = {
94
+ apiEndpoint: `${location}-aiplatform.googleapis.com`,
95
+ };
96
+
97
+ // Instantiates a client
98
+ const predictionServiceClient = new PredictionServiceClient(clientOptions);
99
+
100
+ function getImagenParameters(taskType: string, options: ImagenOptions) {
101
+ const commonParameters = {
102
+ sampleCount: options?.number_of_images,
103
+ seed: options?.seed,
104
+ safetySetting: options?.safety_setting,
105
+ personGeneration: options?.person_generation,
106
+ negativePrompt: taskType ? undefined : "", //Filled in later from the prompt
107
+ //TODO: Add more safety and prompt rejection information
108
+ //includeSafetyAttributes: true,
109
+ //includeRaiReason: true,
110
+ };
111
+ switch (taskType) {
112
+ case ImagenTaskType.EDIT_MODE_INPAINT_REMOVAL:
113
+ return {
114
+ ...commonParameters,
115
+ editMode: "EDIT_MODE_INPAINT_REMOVAL",
116
+ editConfig: {
117
+ baseSteps: options?.edit_steps,
118
+ },
119
+ }
120
+ case ImagenTaskType.EDIT_MODE_INPAINT_INSERTION:
121
+ return {
122
+ ...commonParameters,
123
+ editMode: "EDIT_MODE_INPAINT_INSERTION",
124
+ editConfig: {
125
+ baseSteps: options?.edit_steps,
126
+ },
127
+ }
128
+ case ImagenTaskType.EDIT_MODE_BGSWAP:
129
+ return {
130
+ ...commonParameters,
131
+ editMode: "EDIT_MODE_BGSWAP",
132
+ editConfig: {
133
+ baseSteps: options?.edit_steps,
134
+ },
135
+ }
136
+ case ImagenTaskType.EDIT_MODE_OUTPAINT:
137
+ return {
138
+ ...commonParameters,
139
+ editMode: "EDIT_MODE_OUTPAINT",
140
+ editConfig: {
141
+ baseSteps: options?.edit_steps,
142
+ },
143
+ }
144
+ case ImagenTaskType.TEXT_IMAGE:
145
+ return {
146
+ ...commonParameters,
147
+ // You can't use a seed value and watermark at the same time.
148
+ addWatermark: options?.add_watermark,
149
+ aspectRatio: options?.aspect_ratio,
150
+ enhancePrompt: options?.enhance_prompt,
151
+ };
152
+ case ImagenTaskType.CUSTOMIZATION_SUBJECT:
153
+ case ImagenTaskType.CUSTOMIZATION_CONTROLLED:
154
+ case ImagenTaskType.CUSTOMIZATION_INSTRUCT:
155
+ case ImagenTaskType.CUSTOMIZATION_STYLE:
156
+ return {
157
+ ...commonParameters,
158
+ }
159
+ default:
160
+ throw new Error("Task type not supported");
161
+ }
162
+ }
163
+
164
+ export class ImagenModelDefinition {
165
+
166
+ model: AIModel
167
+
168
+ constructor(modelId: string) {
169
+ this.model = {
170
+ id: modelId,
171
+ name: modelId,
172
+ provider: 'vertexai',
173
+ type: ModelType.Image,
174
+ can_stream: false,
175
+ };
176
+ }
177
+
178
+ async createPrompt(_driver: VertexAIDriver, segments: PromptSegment[], options: ExecutionOptions): Promise<ImagenPrompt> {
179
+ const splits = options.model.split("/");
180
+ const modelName = splits[splits.length - 1];
181
+ options = { ...options, model: modelName };
182
+
183
+ const prompt: ImagenPrompt = {
184
+ prompt: "",
185
+ }
186
+
187
+ //Collect text prompts, Imagen does not support roles, so everything gets merged together
188
+ // however we still respect our typical pattern. System First, Safety Last.
189
+ const system: string[] = [];
190
+ const user: string[] = [];
191
+ const safety: string[] = [];
192
+ const negative: string[] = [];
193
+
194
+ const mask_mode = (options.model_options as ImagenOptions)?.mask_mode;
195
+ const imagenOptions = options.model_options as ImagenOptions;
196
+
197
+ for (const msg of segments) {
198
+ if (msg.role === PromptRole.safety) {
199
+ safety.push(msg.content);
200
+ } else if (msg.role === PromptRole.system) {
201
+ system.push(msg.content);
202
+ } else if (msg.role === PromptRole.negative) {
203
+ negative.push(msg.content);
204
+ } else {
205
+ //Everything else is assumed to be user or user adjacent.
206
+ user.push(msg.content);
207
+ }
208
+ if (msg.files) {
209
+ //Get images from messages
210
+ if (!prompt.referenceImages) {
211
+ prompt.referenceImages = [];
212
+ }
213
+
214
+ //Always required, but only used by customisation.
215
+ //Each ref ID refers to a single "reference", i.e. object. To provide multiple images of a single ref,
216
+ //include multiple images in one prompt.
217
+ const refId = prompt.referenceImages.length + 1;
218
+ for (const img of msg.files) {
219
+ if (img.mime_type?.includes("image")) {
220
+ if (msg.role !== PromptRole.mask) {
221
+ //Editing based mode requires a reference image
222
+ if (imagenOptions?.edit_mode?.includes("EDIT_MODE")) {
223
+ prompt.referenceImages.push({
224
+ referenceType: "REFERENCE_TYPE_RAW",
225
+ referenceId: refId,
226
+ referenceImage: {
227
+ bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
228
+ }
229
+ });
230
+ //If mask is auto-generated, add a mask reference
231
+ if (mask_mode !== ImagenMaskMode.MASK_MODE_USER_PROVIDED) {
232
+ prompt.referenceImages.push({
233
+ referenceType: "REFERENCE_TYPE_MASK",
234
+ referenceId: refId,
235
+ maskImageConfig: {
236
+ maskMode: mask_mode,
237
+ dilation: imagenOptions?.mask_dilation,
238
+ }
239
+ });
240
+ }
241
+ }
242
+ else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_SUBJECT) {
243
+ //First image is always the control image
244
+ if (refId == 1) {
245
+ //Customization subject mode requires a control image
246
+ prompt.referenceImages.push({
247
+ referenceType: "REFERENCE_TYPE_CONTROL",
248
+ referenceId: refId,
249
+ referenceImage: {
250
+ bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
251
+ },
252
+ controlImageConfig: {
253
+ controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY",
254
+ enableControlImageComputation: imagenOptions?.controlImageComputation,
255
+ }
256
+ });
257
+ } else {
258
+ // Subject images
259
+ prompt.referenceImages.push({
260
+ referenceType: "REFERENCE_TYPE_SUBJECT",
261
+ referenceId: refId,
262
+ referenceImage: {
263
+ bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
264
+ },
265
+ subjectImageConfig: {
266
+ subjectDescription: prompt.subjectDescription ?? msg.content,
267
+ subjectType: imagenOptions?.subjectType ?? "SUBJECT_TYPE_DEFAULT",
268
+ }
269
+ });
270
+ }
271
+ } else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_STYLE) {
272
+ // Style images
273
+ prompt.referenceImages.push({
274
+ referenceType: "REFERENCE_TYPE_STYLE",
275
+ referenceId: refId,
276
+ referenceImage: {
277
+ bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
278
+ },
279
+ styleImageConfig: {
280
+ styleDescription: prompt.subjectDescription ?? msg.content,
281
+ }
282
+ });
283
+ } else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_CONTROLLED) {
284
+ // Control images
285
+ prompt.referenceImages.push({
286
+ referenceType: "REFERENCE_TYPE_CONTROL",
287
+ referenceId: refId,
288
+ referenceImage: {
289
+ bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
290
+ },
291
+ controlImageConfig: {
292
+ controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY",
293
+ enableControlImageComputation: imagenOptions?.controlImageComputation,
294
+ }
295
+ });
296
+ } else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_INSTRUCT) {
297
+ // Control images
298
+ prompt.referenceImages.push({
299
+ referenceType: "REFERENCE_TYPE_RAW",
300
+ referenceId: refId,
301
+ referenceImage: {
302
+ bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
303
+ },
304
+ });
305
+ }
306
+ }
307
+ //If mask is user-provided, add a mask reference
308
+ if (msg.role === PromptRole.mask && mask_mode === ImagenMaskMode.MASK_MODE_USER_PROVIDED) {
309
+ prompt.referenceImages.push({
310
+ referenceType: "REFERENCE_TYPE_MASK",
311
+ referenceId: refId,
312
+ referenceImage: {
313
+ bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
314
+ },
315
+ maskImageConfig: {
316
+ maskMode: mask_mode,
317
+ dilation: imagenOptions?.mask_dilation,
318
+ }
319
+ });
320
+ }
321
+ }
322
+ }
323
+ }
324
+ }
325
+
326
+ //Extract the text from the segments
327
+ prompt.prompt += [system.join("\n\n"), user.join("\n\n"), safety.join("\n\n")].join("\n\n");
328
+
329
+ //Negative prompt
330
+ if (negative.length > 0) {
331
+ prompt.negativePrompt = negative.join(", ");
332
+ }
333
+
334
+ console.log(prompt);
335
+
336
+ return prompt
337
+ }
338
+
339
+ async requestImageGeneration(driver: VertexAIDriver, prompt: ImagenPrompt, options: ExecutionOptions): Promise<Completion<ImageGeneration>> {
340
+ if (options.model_options?._option_id !== "vertexai-imagen") {
341
+ driver.logger.warn("Invalid model options", {options: options.model_options });
342
+ }
343
+ options.model_options = options.model_options as ImagenOptions;
344
+
345
+ if (options.output_modality !== Modalities.image) {
346
+ throw new Error(`Image generation requires image output_modality`);
347
+ }
348
+
349
+ const taskType: string = options.model_options.edit_mode ?? ImagenTaskType.TEXT_IMAGE;
350
+
351
+ driver.logger.info("Task type: " + taskType);
352
+
353
+ const modelName = options.model.split("/").pop() ?? '';
354
+
355
+ // Configure the parent resource
356
+ const endpoint = `projects/${projectId}/locations/${location}/publishers/google/models/${modelName}`;
357
+
358
+ const instanceValue = helpers.toValue(prompt);
359
+ if (!instanceValue) {
360
+ throw new Error('No instance value found');
361
+ }
362
+ const instances = [instanceValue];
363
+
364
+ let parameter: any = getImagenParameters(taskType, options.model_options);
365
+ parameter.negativePrompt = prompt.negativePrompt ?? undefined;
366
+
367
+ const numberOfImages = options.model_options?.number_of_images ?? 1;
368
+
369
+ // Remove all undefined values
370
+ parameter = Object.fromEntries(
371
+ Object.entries(parameter).filter(([_, v]) => v !== undefined)
372
+ ) as any;
373
+
374
+ const parameters = helpers.toValue(parameter);
375
+
376
+ const request: protos.google.cloud.aiplatform.v1.IPredictRequest = {
377
+ endpoint,
378
+ instances,
379
+ parameters,
380
+ };
381
+
382
+ // Predict request
383
+ const [response] = await predictionServiceClient.predict(request, { timeout: 120000 * numberOfImages }); //Extended timeout for image generation
384
+ const predictions = response.predictions;
385
+
386
+ if (!predictions) {
387
+ throw new Error('No predictions found');
388
+ }
389
+
390
+ // Extract base64 encoded images from predictions
391
+ const images: string[] = predictions.map(prediction =>
392
+ prediction.structValue?.fields?.bytesBase64Encoded?.stringValue ?? ''
393
+ );
394
+
395
+ return {
396
+ result: {
397
+ images
398
+ },
399
+ };
400
+ }
401
+ }
@@ -1,89 +1,27 @@
1
- import { AIModel, Completion, ExecutionOptions, ModelType, PromptOptions, PromptSegment } from "@llumiverse/core";
2
- import { VertexAIDriver } from "./index.js";
1
+ import { AIModel, Completion, CompletionChunkObject, PromptOptions, PromptSegment, ExecutionOptions } from "@llumiverse/core";
2
+ import { VertexAIDriver , trimModelName} from "./index.js";
3
3
  import { GeminiModelDefinition } from "./models/gemini.js";
4
-
5
-
6
-
4
+ import { ClaudeModelDefinition } from "./models/claude.js";
7
5
 
8
6
  export interface ModelDefinition<PromptT = any> {
9
7
  model: AIModel;
10
8
  versions?: string[]; // the versions of the model that are available. ex: ['001', '002']
11
9
  createPrompt: (driver: VertexAIDriver, segments: PromptSegment[], options: PromptOptions) => Promise<PromptT>;
12
- requestCompletion: (driver: VertexAIDriver, prompt: PromptT, options: ExecutionOptions) => Promise<Completion>;
13
- requestCompletionStream: (driver: VertexAIDriver, promp: PromptT, options: ExecutionOptions) => Promise<AsyncIterable<string>>;
14
- }
15
-
16
- export function getModelName(model: string) {
17
- const i = model.lastIndexOf('@');
18
- return i > -1 ? model.substring(0, i) : model;
10
+ requestTextCompletion: (driver: VertexAIDriver, prompt: PromptT, options: ExecutionOptions) => Promise<Completion>;
11
+ requestTextCompletionStream: (driver: VertexAIDriver, promp: PromptT, options: ExecutionOptions) => Promise<AsyncIterable<CompletionChunkObject>>;
19
12
  }
20
13
 
21
14
  export function getModelDefinition(model: string): ModelDefinition {
22
- const modelName = getModelName(model);
23
- const def = Models[modelName];
24
- if (!def) {
25
- throw new Error(`Unknown model ${model}`);
15
+ const splits = model.split("/");
16
+ const publisher = splits[1];
17
+ const modelName = trimModelName(splits[splits.length - 1]);
18
+
19
+ if (publisher?.includes("anthropic")) {
20
+ return new ClaudeModelDefinition(modelName);
21
+ } else if (publisher?.includes("google")) {
22
+ return new GeminiModelDefinition(modelName);
26
23
  }
27
- return def;
28
- }
29
-
30
- export function getAIModels() {
31
- return Object.values(Models).map(m => m.model);
32
- }
33
-
34
- // Builtin models. VertexAI doesn't provide an API to list models. so we have to hardcode them here.
35
- export const BuiltinModels: AIModel<string>[] = [
36
- {
37
- id: "gemini-1.5-flash",
38
- name: "Gemini Pro 1.5 Flash",
39
- provider: "vertexai",
40
- owner: "google",
41
- type: ModelType.MultiModal,
42
- can_stream: true,
43
- is_multimodal: true
44
24
 
45
- },
46
- {
47
- id: "gemini-1.5-pro",
48
- name: "Gemini Pro 1.5 Pro",
49
- provider: "vertexai",
50
- owner: "google",
51
- type: ModelType.MultiModal,
52
- can_stream: true,
53
- is_multimodal: true
54
-
55
- },
56
- {
57
- id: "gemini-1.0-pro",
58
- name: "Gemini Pro 1.0",
59
- provider: "vertexai",
60
- owner: "google",
61
- type: ModelType.Text,
62
- can_stream: true,
63
- },
64
- {
65
- id: "tablextembedding-gecko",
66
- name: "Gecko Text Embeddings",
67
- provider: "vertexai",
68
- owner: "google",
69
- type: ModelType.Embedding,
70
- },
71
- {
72
- id: "textembedding-gecko-multilingual",
73
- name: "Gecko Multilingual Text Embeddings",
74
- provider: "vertexai",
75
- owner: "google",
76
- type: ModelType.Embedding,
77
- },
78
-
79
-
80
-
81
- ]
82
-
83
-
84
-
85
- const Models: Record<string, ModelDefinition> = {
86
- "gemini-1.5-flash": new GeminiModelDefinition("gemini-1.5-flash"),
87
- "gemini-1.5-pro": new GeminiModelDefinition("gemini-1.5-pro"),
88
- "gemini-1.0-pro": new GeminiModelDefinition(),
89
- }
25
+ //Fallback, assume it is Gemini.
26
+ return new GeminiModelDefinition(modelName);
27
+ }
@@ -1,4 +1,4 @@
1
- import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions } from "@llumiverse/core";
1
+ import { AIModel, AbstractDriver, Completion, CompletionChunk, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, TextFallbackOptions } from "@llumiverse/core";
2
2
  import { transformSSEStream } from "@llumiverse/core/async";
3
3
  import { FetchClient } from "api-fetch-client";
4
4
  import { GenerateEmbeddingPayload, GenerateEmbeddingResponse, WatsonAuthToken, WatsonxListModelResponse, WatsonxModelSpec, WatsonxTextGenerationPayload, WatsonxTextGenerationResponse } from "./interfaces.js";
@@ -29,13 +29,21 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
29
29
  this.fetchClient = new FetchClient(this.endpoint_url).withAuthCallback(async () => this.getAuthToken().then(token => `Bearer ${token}`));
30
30
  }
31
31
 
32
- async requestCompletion(prompt: string, options: ExecutionOptions): Promise<Completion<any>> {
32
+ async requestTextCompletion(prompt: string, options: ExecutionOptions): Promise<Completion<any>> {
33
+ if (options.model_options?._option_id !== "text-fallback") {
34
+ this.logger.warn("Invalid model options", {options: options.model_options });
35
+ }
36
+ options.model_options = options.model_options as TextFallbackOptions;
37
+
33
38
  const payload: WatsonxTextGenerationPayload = {
34
39
  model_id: options.model,
35
40
  input: prompt + "\n",
36
41
  parameters: {
37
- max_new_tokens: options.max_tokens,
38
- //time_limit: options.time_limit,
42
+ max_new_tokens: options.model_options.max_tokens,
43
+ temperature: options.model_options.temperature,
44
+ top_k: options.model_options.top_k,
45
+ top_p: options.model_options.top_p,
46
+ stop_sequences: options.model_options.stop_sequence,
39
47
  },
40
48
  project_id: this.projectId,
41
49
  }
@@ -51,19 +59,25 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
51
59
  result: result.generated_token_count,
52
60
  total: result.input_token_count + result.generated_token_count,
53
61
  },
54
- finish_reason: result.stop_reason,
62
+ finish_reason: watsonFinishReason(result.stop_reason),
55
63
  original_response: options.include_original_response ? res : undefined,
56
64
  }
57
65
  }
58
66
 
59
- async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<string>> {
60
-
67
+ async requestTextCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunk>> {
68
+ if (options.model_options?._option_id !== "text-fallback") {
69
+ this.logger.warn("Invalid model options", {options: options.model_options });
70
+ }
71
+ options.model_options = options.model_options as TextFallbackOptions;
61
72
  const payload: WatsonxTextGenerationPayload = {
62
73
  model_id: options.model,
63
74
  input: prompt + "\n",
64
75
  parameters: {
65
- max_new_tokens: options.max_tokens,
66
- //time_limit: options.time_limit,
76
+ max_new_tokens: options.model_options.max_tokens,
77
+ temperature: options.model_options.temperature,
78
+ top_k: options.model_options.top_k,
79
+ top_p: options.model_options.top_p,
80
+ stop_sequences: options.model_options.stop_sequence,
67
81
  },
68
82
  project_id: this.projectId,
69
83
  }
@@ -75,7 +89,15 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
75
89
 
76
90
  return transformSSEStream(stream, (data: string) => {
77
91
  const json = JSON.parse(data) as WatsonxTextGenerationResponse;
78
- return json.results[0]?.generated_text ?? '';
92
+ return {
93
+ result: json.results[0]?.generated_text ?? '',
94
+ finish_reason: watsonFinishReason(json.results[0]?.stop_reason),
95
+ token_usage: {
96
+ prompt: json.results[0].input_token_count,
97
+ result: json.results[0].generated_token_count,
98
+ total: json.results[0].input_token_count + json.results[0].generated_token_count,
99
+ },
100
+ };
79
101
  });
80
102
 
81
103
  }
@@ -130,15 +152,22 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
130
152
  return this.listModels()
131
153
  .then(() => true)
132
154
  .catch((err) => {
133
- this.logger.warn("Failed to connect to WatsonX", err);
155
+ this.logger.warn("Failed to connect to WatsonX", { error: err });
134
156
  return false
135
157
  });
136
158
  }
137
159
 
138
160
  async generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult> {
161
+ if (options.image) {
162
+ throw new Error("Image embeddings not supported by Watsonx");
163
+ }
164
+
165
+ if (!options.text) {
166
+ throw new Error ("No text provided");
167
+ }
139
168
 
140
169
  const payload: GenerateEmbeddingPayload = {
141
- inputs: [options.content],
170
+ inputs: [options.text],
142
171
  model_id: options.model ?? 'ibm/slate-125m-english-rtrvr',
143
172
  project_id: this.projectId
144
173
  }
@@ -154,6 +183,15 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
154
183
 
155
184
  }
156
185
 
186
+ function watsonFinishReason(reason: string | undefined) {
187
+ if (!reason) return undefined;
188
+ switch (reason) {
189
+ case 'eos_token': return "stop";
190
+ case 'max_tokens': return "length";
191
+ default: return reason;
192
+ }
193
+ }
194
+
157
195
 
158
196
 
159
197
  /*interface ListModelsParams extends ModelSearchPayload {
@@ -6,6 +6,10 @@ export interface WatsonxTextGenerationPayload {
6
6
  parameters: {
7
7
  max_new_tokens?: number;
8
8
  time_limit?: number;
9
+ stop_sequences?: string[];
10
+ temperature?: number;
11
+ top_k?: number;
12
+ top_p?: number;
9
13
  },
10
14
  project_id: string;
11
15
  }