@llumiverse/drivers 0.15.0 → 0.17.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -3
- package/lib/cjs/adobe/firefly.js +119 -0
- package/lib/cjs/adobe/firefly.js.map +1 -0
- package/lib/cjs/bedrock/converse.js +177 -0
- package/lib/cjs/bedrock/converse.js.map +1 -0
- package/lib/cjs/bedrock/index.js +338 -234
- package/lib/cjs/bedrock/index.js.map +1 -1
- package/lib/cjs/bedrock/nova-image-payload.js +207 -0
- package/lib/cjs/bedrock/nova-image-payload.js.map +1 -0
- package/lib/cjs/groq/index.js +34 -9
- package/lib/cjs/groq/index.js.map +1 -1
- package/lib/cjs/huggingface_ie.js +28 -12
- package/lib/cjs/huggingface_ie.js.map +1 -1
- package/lib/cjs/index.js +1 -0
- package/lib/cjs/index.js.map +1 -1
- package/lib/cjs/mistral/index.js +32 -13
- package/lib/cjs/mistral/index.js.map +1 -1
- package/lib/cjs/mistral/types.js.map +1 -1
- package/lib/cjs/openai/index.js +164 -29
- package/lib/cjs/openai/index.js.map +1 -1
- package/lib/cjs/replicate.js +19 -34
- package/lib/cjs/replicate.js.map +1 -1
- package/lib/cjs/test/TestValidationErrorCompletionStream.js.map +1 -1
- package/lib/cjs/test/index.js.map +1 -1
- package/lib/cjs/togetherai/index.js +40 -10
- package/lib/cjs/togetherai/index.js.map +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-image.js +26 -0
- package/lib/cjs/vertexai/embeddings/embeddings-image.js.map +1 -0
- package/lib/cjs/vertexai/embeddings/embeddings-text.js +1 -1
- package/lib/cjs/vertexai/embeddings/embeddings-text.js.map +1 -1
- package/lib/cjs/vertexai/index.js +134 -35
- package/lib/cjs/vertexai/index.js.map +1 -1
- package/lib/cjs/vertexai/models/claude.js +252 -0
- package/lib/cjs/vertexai/models/claude.js.map +1 -0
- package/lib/cjs/vertexai/models/gemini.js +172 -25
- package/lib/cjs/vertexai/models/gemini.js.map +1 -1
- package/lib/cjs/vertexai/models/imagen.js +317 -0
- package/lib/cjs/vertexai/models/imagen.js.map +1 -0
- package/lib/cjs/vertexai/models.js +12 -64
- package/lib/cjs/vertexai/models.js.map +1 -1
- package/lib/cjs/watsonx/index.js +47 -10
- package/lib/cjs/watsonx/index.js.map +1 -1
- package/lib/cjs/xai/index.js +71 -0
- package/lib/cjs/xai/index.js.map +1 -0
- package/lib/esm/adobe/firefly.js +115 -0
- package/lib/esm/adobe/firefly.js.map +1 -0
- package/lib/esm/bedrock/converse.js +171 -0
- package/lib/esm/bedrock/converse.js.map +1 -0
- package/lib/esm/bedrock/index.js +339 -232
- package/lib/esm/bedrock/index.js.map +1 -1
- package/lib/esm/bedrock/nova-image-payload.js +203 -0
- package/lib/esm/bedrock/nova-image-payload.js.map +1 -0
- package/lib/esm/groq/index.js +34 -9
- package/lib/esm/groq/index.js.map +1 -1
- package/lib/esm/huggingface_ie.js +29 -13
- package/lib/esm/huggingface_ie.js.map +1 -1
- package/lib/esm/index.js +1 -0
- package/lib/esm/index.js.map +1 -1
- package/lib/esm/mistral/index.js +32 -13
- package/lib/esm/mistral/index.js.map +1 -1
- package/lib/esm/mistral/types.js.map +1 -1
- package/lib/esm/openai/index.js +165 -30
- package/lib/esm/openai/index.js.map +1 -1
- package/lib/esm/replicate.js +19 -34
- package/lib/esm/replicate.js.map +1 -1
- package/lib/esm/test/TestValidationErrorCompletionStream.js.map +1 -1
- package/lib/esm/test/index.js.map +1 -1
- package/lib/esm/togetherai/index.js +40 -10
- package/lib/esm/togetherai/index.js.map +1 -1
- package/lib/esm/vertexai/embeddings/embeddings-image.js +23 -0
- package/lib/esm/vertexai/embeddings/embeddings-image.js.map +1 -0
- package/lib/esm/vertexai/embeddings/embeddings-text.js +1 -1
- package/lib/esm/vertexai/embeddings/embeddings-text.js.map +1 -1
- package/lib/esm/vertexai/index.js +135 -37
- package/lib/esm/vertexai/index.js.map +1 -1
- package/lib/esm/vertexai/models/claude.js +247 -0
- package/lib/esm/vertexai/models/claude.js.map +1 -0
- package/lib/esm/vertexai/models/gemini.js +173 -26
- package/lib/esm/vertexai/models/gemini.js.map +1 -1
- package/lib/esm/vertexai/models/imagen.js +310 -0
- package/lib/esm/vertexai/models/imagen.js.map +1 -0
- package/lib/esm/vertexai/models.js +12 -61
- package/lib/esm/vertexai/models.js.map +1 -1
- package/lib/esm/watsonx/index.js +47 -10
- package/lib/esm/watsonx/index.js.map +1 -1
- package/lib/esm/xai/index.js +64 -0
- package/lib/esm/xai/index.js.map +1 -0
- package/lib/types/adobe/firefly.d.ts +30 -0
- package/lib/types/adobe/firefly.d.ts.map +1 -0
- package/lib/types/bedrock/converse.d.ts +8 -0
- package/lib/types/bedrock/converse.d.ts.map +1 -0
- package/lib/types/bedrock/index.d.ts +27 -12
- package/lib/types/bedrock/index.d.ts.map +1 -1
- package/lib/types/bedrock/nova-image-payload.d.ts +74 -0
- package/lib/types/bedrock/nova-image-payload.d.ts.map +1 -0
- package/lib/types/bedrock/payloads.d.ts +9 -65
- package/lib/types/bedrock/payloads.d.ts.map +1 -1
- package/lib/types/groq/index.d.ts +3 -3
- package/lib/types/groq/index.d.ts.map +1 -1
- package/lib/types/huggingface_ie.d.ts +5 -7
- package/lib/types/huggingface_ie.d.ts.map +1 -1
- package/lib/types/index.d.ts +1 -0
- package/lib/types/index.d.ts.map +1 -1
- package/lib/types/mistral/index.d.ts +4 -4
- package/lib/types/mistral/index.d.ts.map +1 -1
- package/lib/types/mistral/types.d.ts +1 -0
- package/lib/types/mistral/types.d.ts.map +1 -1
- package/lib/types/openai/index.d.ts +5 -4
- package/lib/types/openai/index.d.ts.map +1 -1
- package/lib/types/replicate.d.ts +4 -9
- package/lib/types/replicate.d.ts.map +1 -1
- package/lib/types/test/index.d.ts +2 -2
- package/lib/types/test/index.d.ts.map +1 -1
- package/lib/types/togetherai/index.d.ts +4 -4
- package/lib/types/togetherai/index.d.ts.map +1 -1
- package/lib/types/vertexai/embeddings/embeddings-image.d.ts +11 -0
- package/lib/types/vertexai/embeddings/embeddings-image.d.ts.map +1 -0
- package/lib/types/vertexai/index.d.ts +21 -8
- package/lib/types/vertexai/index.d.ts.map +1 -1
- package/lib/types/vertexai/models/claude.d.ts +20 -0
- package/lib/types/vertexai/models/claude.d.ts.map +1 -0
- package/lib/types/vertexai/models/gemini.d.ts +4 -4
- package/lib/types/vertexai/models/gemini.d.ts.map +1 -1
- package/lib/types/vertexai/models/imagen.d.ts +75 -0
- package/lib/types/vertexai/models/imagen.d.ts.map +1 -0
- package/lib/types/vertexai/models.d.ts +3 -6
- package/lib/types/vertexai/models.d.ts.map +1 -1
- package/lib/types/watsonx/index.d.ts +3 -3
- package/lib/types/watsonx/index.d.ts.map +1 -1
- package/lib/types/watsonx/interfaces.d.ts +4 -0
- package/lib/types/watsonx/interfaces.d.ts.map +1 -1
- package/lib/types/xai/index.d.ts +19 -0
- package/lib/types/xai/index.d.ts.map +1 -0
- package/package.json +25 -26
- package/src/adobe/firefly.ts +207 -0
- package/src/bedrock/converse.ts +194 -0
- package/src/bedrock/index.ts +359 -240
- package/src/bedrock/nova-image-payload.ts +309 -0
- package/src/bedrock/payloads.ts +12 -66
- package/src/groq/index.ts +35 -12
- package/src/huggingface_ie.ts +34 -13
- package/src/index.ts +1 -0
- package/src/mistral/index.ts +35 -13
- package/src/mistral/types.ts +2 -1
- package/src/openai/index.ts +186 -35
- package/src/replicate.ts +24 -35
- package/src/test/TestValidationErrorCompletionStream.ts +2 -2
- package/src/test/index.ts +3 -2
- package/src/togetherai/index.ts +44 -12
- package/src/vertexai/embeddings/embeddings-image.ts +50 -0
- package/src/vertexai/embeddings/embeddings-text.ts +1 -1
- package/src/vertexai/index.ts +186 -46
- package/src/vertexai/models/claude.ts +281 -0
- package/src/vertexai/models/gemini.ts +186 -29
- package/src/vertexai/models/imagen.ts +401 -0
- package/src/vertexai/models.ts +16 -78
- package/src/watsonx/index.ts +50 -12
- package/src/watsonx/interfaces.ts +4 -0
- package/src/xai/index.ts +110 -0
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
import { AIModel, Completion, ExecutionOptions, ImageGeneration, Modalities, ModelType, PromptRole, PromptSegment, readStreamAsBase64 } from "@llumiverse/core";
|
|
2
|
+
import { VertexAIDriver } from "../index.js";
|
|
3
|
+
|
|
4
|
+
const projectId = process.env.GOOGLE_PROJECT_ID;
|
|
5
|
+
const location = 'us-central1';
|
|
6
|
+
|
|
7
|
+
import aiplatform, { protos } from '@google-cloud/aiplatform';
|
|
8
|
+
|
|
9
|
+
// Imports the Google Cloud Prediction Service Client library
|
|
10
|
+
const { PredictionServiceClient } = aiplatform.v1;
|
|
11
|
+
|
|
12
|
+
// Import the helper module for converting arbitrary protobuf.Value objects
|
|
13
|
+
import { helpers } from '@google-cloud/aiplatform';
|
|
14
|
+
import { ImagenOptions } from "../../../../core/src/options/vertexai.js";
|
|
15
|
+
|
|
16
|
+
interface ImagenBaseReference {
|
|
17
|
+
referenceType: "REFERENCE_TYPE_RAW" | "REFERENCE_TYPE_MASK" | "REFERENCE_TYPE_SUBJECT" |
|
|
18
|
+
"REFERENCE_TYPE_CONTROL" | "REFERENCE_TYPE_STYLE";
|
|
19
|
+
referenceId: number;
|
|
20
|
+
referenceImage: {
|
|
21
|
+
bytesBase64Encoded: string; //10MB max
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
export enum ImagenTaskType {
|
|
26
|
+
TEXT_IMAGE = "TEXT_IMAGE",
|
|
27
|
+
EDIT_MODE_INPAINT_REMOVAL = "EDIT_MODE_INPAINT_REMOVAL",
|
|
28
|
+
EDIT_MODE_INPAINT_INSERTION = "EDIT_MODE_INPAINT_INSERTION",
|
|
29
|
+
EDIT_MODE_BGSWAP = "EDIT_MODE_BGSWAP",
|
|
30
|
+
EDIT_MODE_OUTPAINT = "EDIT_MODE_OUTPAINT",
|
|
31
|
+
CUSTOMIZATION_SUBJECT = "CUSTOMIZATION_SUBJECT",
|
|
32
|
+
CUSTOMIZATION_STYLE = "CUSTOMIZATION_STYLE",
|
|
33
|
+
CUSTOMIZATION_CONTROLLED = "CUSTOMIZATION_CONTROLLED",
|
|
34
|
+
CUSTOMIZATION_INSTRUCT = "CUSTOMIZATION_INSTRUCT",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
export enum ImagenMaskMode {
|
|
38
|
+
MASK_MODE_USER_PROVIDED = "MASK_MODE_USER_PROVIDED",
|
|
39
|
+
MASK_MODE_BACKGROUND = "MASK_MODE_BACKGROUND",
|
|
40
|
+
MASK_MODE_FOREGROUND = "MASK_MODE_FOREGROUND",
|
|
41
|
+
MASK_MODE_SEMANTIC = "MASK_MODE_SEMANTIC",
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
interface ImagenReferenceRaw extends ImagenBaseReference {
|
|
45
|
+
referenceType: "REFERENCE_TYPE_RAW";
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
interface ImagenReferenceMask extends Omit<ImagenBaseReference, "referenceImage"> {
|
|
49
|
+
referenceType: "REFERENCE_TYPE_MASK";
|
|
50
|
+
maskImageConfig: {
|
|
51
|
+
maskMode?: ImagenMaskMode;
|
|
52
|
+
maskClasses?: number[]; //Used for MASK_MODE_SEMANTIC, based on https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api-customization#segment-ids
|
|
53
|
+
dilation?: number; //Recommendation depends on mode: Inpaint: 0.01, BGSwap: 0.0, Outpaint: 0.01-0.03
|
|
54
|
+
}
|
|
55
|
+
referenceImage?: { //Only used for MASK_MODE_USER_PROVIDED
|
|
56
|
+
bytesBase64Encoded: string; //10MB max
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
interface ImagenReferenceSubject extends ImagenBaseReference {
|
|
61
|
+
referenceType: "REFERENCE_TYPE_SUBJECT";
|
|
62
|
+
subjectImageConfig: {
|
|
63
|
+
subjectDescription: string;
|
|
64
|
+
subjectType: "SUBJECT_TYPE_PERSON" | "SUBJECT_TYPE_ANIMAL" | "SUBJECT_TYPE_PRODUCT" | "SUBJECT_TYPE_DEFAULT";
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
interface ImagenReferenceControl extends ImagenBaseReference {
|
|
69
|
+
referenceType: "REFERENCE_TYPE_CONTROL";
|
|
70
|
+
controlImageConfig: {
|
|
71
|
+
controlType: "CONTROL_TYPE_FACE_MESH" | "CONTROL_TYPE_CANNY" | "CONTROL_TYPE_SCRIBBLE";
|
|
72
|
+
enableControlImageComputation?: boolean; //If true, the model will compute the control image
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
interface ImagenReferenceStyle extends ImagenBaseReference {
|
|
77
|
+
referenceType: "REFERENCE_TYPE_STYLE";
|
|
78
|
+
styleImageConfig: {
|
|
79
|
+
styleDescription?: string;
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
type ImagenMessage = ImagenReferenceRaw | ImagenReferenceMask | ImagenReferenceSubject | ImagenReferenceControl | ImagenReferenceStyle;
|
|
84
|
+
|
|
85
|
+
export interface ImagenPrompt {
|
|
86
|
+
prompt: string;
|
|
87
|
+
referenceImages?: ImagenMessage[];
|
|
88
|
+
subjectDescription?: string; //Used for image customization to describe in the reference image
|
|
89
|
+
negativePrompt?: string; //Used for negative prompts
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// Specifies the location of the api endpoint
|
|
93
|
+
const clientOptions = {
|
|
94
|
+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
|
|
95
|
+
};
|
|
96
|
+
|
|
97
|
+
// Instantiates a client
|
|
98
|
+
const predictionServiceClient = new PredictionServiceClient(clientOptions);
|
|
99
|
+
|
|
100
|
+
function getImagenParameters(taskType: string, options: ImagenOptions) {
|
|
101
|
+
const commonParameters = {
|
|
102
|
+
sampleCount: options?.number_of_images,
|
|
103
|
+
seed: options?.seed,
|
|
104
|
+
safetySetting: options?.safety_setting,
|
|
105
|
+
personGeneration: options?.person_generation,
|
|
106
|
+
negativePrompt: taskType ? undefined : "", //Filled in later from the prompt
|
|
107
|
+
//TODO: Add more safety and prompt rejection information
|
|
108
|
+
//includeSafetyAttributes: true,
|
|
109
|
+
//includeRaiReason: true,
|
|
110
|
+
};
|
|
111
|
+
switch (taskType) {
|
|
112
|
+
case ImagenTaskType.EDIT_MODE_INPAINT_REMOVAL:
|
|
113
|
+
return {
|
|
114
|
+
...commonParameters,
|
|
115
|
+
editMode: "EDIT_MODE_INPAINT_REMOVAL",
|
|
116
|
+
editConfig: {
|
|
117
|
+
baseSteps: options?.edit_steps,
|
|
118
|
+
},
|
|
119
|
+
}
|
|
120
|
+
case ImagenTaskType.EDIT_MODE_INPAINT_INSERTION:
|
|
121
|
+
return {
|
|
122
|
+
...commonParameters,
|
|
123
|
+
editMode: "EDIT_MODE_INPAINT_INSERTION",
|
|
124
|
+
editConfig: {
|
|
125
|
+
baseSteps: options?.edit_steps,
|
|
126
|
+
},
|
|
127
|
+
}
|
|
128
|
+
case ImagenTaskType.EDIT_MODE_BGSWAP:
|
|
129
|
+
return {
|
|
130
|
+
...commonParameters,
|
|
131
|
+
editMode: "EDIT_MODE_BGSWAP",
|
|
132
|
+
editConfig: {
|
|
133
|
+
baseSteps: options?.edit_steps,
|
|
134
|
+
},
|
|
135
|
+
}
|
|
136
|
+
case ImagenTaskType.EDIT_MODE_OUTPAINT:
|
|
137
|
+
return {
|
|
138
|
+
...commonParameters,
|
|
139
|
+
editMode: "EDIT_MODE_OUTPAINT",
|
|
140
|
+
editConfig: {
|
|
141
|
+
baseSteps: options?.edit_steps,
|
|
142
|
+
},
|
|
143
|
+
}
|
|
144
|
+
case ImagenTaskType.TEXT_IMAGE:
|
|
145
|
+
return {
|
|
146
|
+
...commonParameters,
|
|
147
|
+
// You can't use a seed value and watermark at the same time.
|
|
148
|
+
addWatermark: options?.add_watermark,
|
|
149
|
+
aspectRatio: options?.aspect_ratio,
|
|
150
|
+
enhancePrompt: options?.enhance_prompt,
|
|
151
|
+
};
|
|
152
|
+
case ImagenTaskType.CUSTOMIZATION_SUBJECT:
|
|
153
|
+
case ImagenTaskType.CUSTOMIZATION_CONTROLLED:
|
|
154
|
+
case ImagenTaskType.CUSTOMIZATION_INSTRUCT:
|
|
155
|
+
case ImagenTaskType.CUSTOMIZATION_STYLE:
|
|
156
|
+
return {
|
|
157
|
+
...commonParameters,
|
|
158
|
+
}
|
|
159
|
+
default:
|
|
160
|
+
throw new Error("Task type not supported");
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
export class ImagenModelDefinition {
|
|
165
|
+
|
|
166
|
+
model: AIModel
|
|
167
|
+
|
|
168
|
+
constructor(modelId: string) {
|
|
169
|
+
this.model = {
|
|
170
|
+
id: modelId,
|
|
171
|
+
name: modelId,
|
|
172
|
+
provider: 'vertexai',
|
|
173
|
+
type: ModelType.Image,
|
|
174
|
+
can_stream: false,
|
|
175
|
+
};
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
async createPrompt(_driver: VertexAIDriver, segments: PromptSegment[], options: ExecutionOptions): Promise<ImagenPrompt> {
|
|
179
|
+
const splits = options.model.split("/");
|
|
180
|
+
const modelName = splits[splits.length - 1];
|
|
181
|
+
options = { ...options, model: modelName };
|
|
182
|
+
|
|
183
|
+
const prompt: ImagenPrompt = {
|
|
184
|
+
prompt: "",
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
//Collect text prompts, Imagen does not support roles, so everything gets merged together
|
|
188
|
+
// however we still respect our typical pattern. System First, Safety Last.
|
|
189
|
+
const system: string[] = [];
|
|
190
|
+
const user: string[] = [];
|
|
191
|
+
const safety: string[] = [];
|
|
192
|
+
const negative: string[] = [];
|
|
193
|
+
|
|
194
|
+
const mask_mode = (options.model_options as ImagenOptions)?.mask_mode;
|
|
195
|
+
const imagenOptions = options.model_options as ImagenOptions;
|
|
196
|
+
|
|
197
|
+
for (const msg of segments) {
|
|
198
|
+
if (msg.role === PromptRole.safety) {
|
|
199
|
+
safety.push(msg.content);
|
|
200
|
+
} else if (msg.role === PromptRole.system) {
|
|
201
|
+
system.push(msg.content);
|
|
202
|
+
} else if (msg.role === PromptRole.negative) {
|
|
203
|
+
negative.push(msg.content);
|
|
204
|
+
} else {
|
|
205
|
+
//Everything else is assumed to be user or user adjacent.
|
|
206
|
+
user.push(msg.content);
|
|
207
|
+
}
|
|
208
|
+
if (msg.files) {
|
|
209
|
+
//Get images from messages
|
|
210
|
+
if (!prompt.referenceImages) {
|
|
211
|
+
prompt.referenceImages = [];
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
//Always required, but only used by customisation.
|
|
215
|
+
//Each ref ID refers to a single "reference", i.e. object. To provide multiple images of a single ref,
|
|
216
|
+
//include multiple images in one prompt.
|
|
217
|
+
const refId = prompt.referenceImages.length + 1;
|
|
218
|
+
for (const img of msg.files) {
|
|
219
|
+
if (img.mime_type?.includes("image")) {
|
|
220
|
+
if (msg.role !== PromptRole.mask) {
|
|
221
|
+
//Editing based mode requires a reference image
|
|
222
|
+
if (imagenOptions?.edit_mode?.includes("EDIT_MODE")) {
|
|
223
|
+
prompt.referenceImages.push({
|
|
224
|
+
referenceType: "REFERENCE_TYPE_RAW",
|
|
225
|
+
referenceId: refId,
|
|
226
|
+
referenceImage: {
|
|
227
|
+
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
|
|
228
|
+
}
|
|
229
|
+
});
|
|
230
|
+
//If mask is auto-generated, add a mask reference
|
|
231
|
+
if (mask_mode !== ImagenMaskMode.MASK_MODE_USER_PROVIDED) {
|
|
232
|
+
prompt.referenceImages.push({
|
|
233
|
+
referenceType: "REFERENCE_TYPE_MASK",
|
|
234
|
+
referenceId: refId,
|
|
235
|
+
maskImageConfig: {
|
|
236
|
+
maskMode: mask_mode,
|
|
237
|
+
dilation: imagenOptions?.mask_dilation,
|
|
238
|
+
}
|
|
239
|
+
});
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_SUBJECT) {
|
|
243
|
+
//First image is always the control image
|
|
244
|
+
if (refId == 1) {
|
|
245
|
+
//Customization subject mode requires a control image
|
|
246
|
+
prompt.referenceImages.push({
|
|
247
|
+
referenceType: "REFERENCE_TYPE_CONTROL",
|
|
248
|
+
referenceId: refId,
|
|
249
|
+
referenceImage: {
|
|
250
|
+
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
|
|
251
|
+
},
|
|
252
|
+
controlImageConfig: {
|
|
253
|
+
controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY",
|
|
254
|
+
enableControlImageComputation: imagenOptions?.controlImageComputation,
|
|
255
|
+
}
|
|
256
|
+
});
|
|
257
|
+
} else {
|
|
258
|
+
// Subject images
|
|
259
|
+
prompt.referenceImages.push({
|
|
260
|
+
referenceType: "REFERENCE_TYPE_SUBJECT",
|
|
261
|
+
referenceId: refId,
|
|
262
|
+
referenceImage: {
|
|
263
|
+
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
|
|
264
|
+
},
|
|
265
|
+
subjectImageConfig: {
|
|
266
|
+
subjectDescription: prompt.subjectDescription ?? msg.content,
|
|
267
|
+
subjectType: imagenOptions?.subjectType ?? "SUBJECT_TYPE_DEFAULT",
|
|
268
|
+
}
|
|
269
|
+
});
|
|
270
|
+
}
|
|
271
|
+
} else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_STYLE) {
|
|
272
|
+
// Style images
|
|
273
|
+
prompt.referenceImages.push({
|
|
274
|
+
referenceType: "REFERENCE_TYPE_STYLE",
|
|
275
|
+
referenceId: refId,
|
|
276
|
+
referenceImage: {
|
|
277
|
+
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
|
|
278
|
+
},
|
|
279
|
+
styleImageConfig: {
|
|
280
|
+
styleDescription: prompt.subjectDescription ?? msg.content,
|
|
281
|
+
}
|
|
282
|
+
});
|
|
283
|
+
} else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_CONTROLLED) {
|
|
284
|
+
// Control images
|
|
285
|
+
prompt.referenceImages.push({
|
|
286
|
+
referenceType: "REFERENCE_TYPE_CONTROL",
|
|
287
|
+
referenceId: refId,
|
|
288
|
+
referenceImage: {
|
|
289
|
+
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
|
|
290
|
+
},
|
|
291
|
+
controlImageConfig: {
|
|
292
|
+
controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY",
|
|
293
|
+
enableControlImageComputation: imagenOptions?.controlImageComputation,
|
|
294
|
+
}
|
|
295
|
+
});
|
|
296
|
+
} else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_INSTRUCT) {
|
|
297
|
+
// Control images
|
|
298
|
+
prompt.referenceImages.push({
|
|
299
|
+
referenceType: "REFERENCE_TYPE_RAW",
|
|
300
|
+
referenceId: refId,
|
|
301
|
+
referenceImage: {
|
|
302
|
+
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
|
|
303
|
+
},
|
|
304
|
+
});
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
//If mask is user-provided, add a mask reference
|
|
308
|
+
if (msg.role === PromptRole.mask && mask_mode === ImagenMaskMode.MASK_MODE_USER_PROVIDED) {
|
|
309
|
+
prompt.referenceImages.push({
|
|
310
|
+
referenceType: "REFERENCE_TYPE_MASK",
|
|
311
|
+
referenceId: refId,
|
|
312
|
+
referenceImage: {
|
|
313
|
+
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
|
|
314
|
+
},
|
|
315
|
+
maskImageConfig: {
|
|
316
|
+
maskMode: mask_mode,
|
|
317
|
+
dilation: imagenOptions?.mask_dilation,
|
|
318
|
+
}
|
|
319
|
+
});
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
//Extract the text from the segments
|
|
327
|
+
prompt.prompt += [system.join("\n\n"), user.join("\n\n"), safety.join("\n\n")].join("\n\n");
|
|
328
|
+
|
|
329
|
+
//Negative prompt
|
|
330
|
+
if (negative.length > 0) {
|
|
331
|
+
prompt.negativePrompt = negative.join(", ");
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
console.log(prompt);
|
|
335
|
+
|
|
336
|
+
return prompt
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
async requestImageGeneration(driver: VertexAIDriver, prompt: ImagenPrompt, options: ExecutionOptions): Promise<Completion<ImageGeneration>> {
|
|
340
|
+
if (options.model_options?._option_id !== "vertexai-imagen") {
|
|
341
|
+
driver.logger.warn("Invalid model options", {options: options.model_options });
|
|
342
|
+
}
|
|
343
|
+
options.model_options = options.model_options as ImagenOptions;
|
|
344
|
+
|
|
345
|
+
if (options.output_modality !== Modalities.image) {
|
|
346
|
+
throw new Error(`Image generation requires image output_modality`);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
const taskType: string = options.model_options.edit_mode ?? ImagenTaskType.TEXT_IMAGE;
|
|
350
|
+
|
|
351
|
+
driver.logger.info("Task type: " + taskType);
|
|
352
|
+
|
|
353
|
+
const modelName = options.model.split("/").pop() ?? '';
|
|
354
|
+
|
|
355
|
+
// Configure the parent resource
|
|
356
|
+
const endpoint = `projects/${projectId}/locations/${location}/publishers/google/models/${modelName}`;
|
|
357
|
+
|
|
358
|
+
const instanceValue = helpers.toValue(prompt);
|
|
359
|
+
if (!instanceValue) {
|
|
360
|
+
throw new Error('No instance value found');
|
|
361
|
+
}
|
|
362
|
+
const instances = [instanceValue];
|
|
363
|
+
|
|
364
|
+
let parameter: any = getImagenParameters(taskType, options.model_options);
|
|
365
|
+
parameter.negativePrompt = prompt.negativePrompt ?? undefined;
|
|
366
|
+
|
|
367
|
+
const numberOfImages = options.model_options?.number_of_images ?? 1;
|
|
368
|
+
|
|
369
|
+
// Remove all undefined values
|
|
370
|
+
parameter = Object.fromEntries(
|
|
371
|
+
Object.entries(parameter).filter(([_, v]) => v !== undefined)
|
|
372
|
+
) as any;
|
|
373
|
+
|
|
374
|
+
const parameters = helpers.toValue(parameter);
|
|
375
|
+
|
|
376
|
+
const request: protos.google.cloud.aiplatform.v1.IPredictRequest = {
|
|
377
|
+
endpoint,
|
|
378
|
+
instances,
|
|
379
|
+
parameters,
|
|
380
|
+
};
|
|
381
|
+
|
|
382
|
+
// Predict request
|
|
383
|
+
const [response] = await predictionServiceClient.predict(request, { timeout: 120000 * numberOfImages }); //Extended timeout for image generation
|
|
384
|
+
const predictions = response.predictions;
|
|
385
|
+
|
|
386
|
+
if (!predictions) {
|
|
387
|
+
throw new Error('No predictions found');
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
// Extract base64 encoded images from predictions
|
|
391
|
+
const images: string[] = predictions.map(prediction =>
|
|
392
|
+
prediction.structValue?.fields?.bytesBase64Encoded?.stringValue ?? ''
|
|
393
|
+
);
|
|
394
|
+
|
|
395
|
+
return {
|
|
396
|
+
result: {
|
|
397
|
+
images
|
|
398
|
+
},
|
|
399
|
+
};
|
|
400
|
+
}
|
|
401
|
+
}
|
package/src/vertexai/models.ts
CHANGED
|
@@ -1,89 +1,27 @@
|
|
|
1
|
-
import { AIModel, Completion,
|
|
2
|
-
import { VertexAIDriver } from "./index.js";
|
|
1
|
+
import { AIModel, Completion, CompletionChunkObject, PromptOptions, PromptSegment, ExecutionOptions } from "@llumiverse/core";
|
|
2
|
+
import { VertexAIDriver , trimModelName} from "./index.js";
|
|
3
3
|
import { GeminiModelDefinition } from "./models/gemini.js";
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
4
|
+
import { ClaudeModelDefinition } from "./models/claude.js";
|
|
7
5
|
|
|
8
6
|
export interface ModelDefinition<PromptT = any> {
|
|
9
7
|
model: AIModel;
|
|
10
8
|
versions?: string[]; // the versions of the model that are available. ex: ['001', '002']
|
|
11
9
|
createPrompt: (driver: VertexAIDriver, segments: PromptSegment[], options: PromptOptions) => Promise<PromptT>;
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
}
|
|
15
|
-
|
|
16
|
-
export function getModelName(model: string) {
|
|
17
|
-
const i = model.lastIndexOf('@');
|
|
18
|
-
return i > -1 ? model.substring(0, i) : model;
|
|
10
|
+
requestTextCompletion: (driver: VertexAIDriver, prompt: PromptT, options: ExecutionOptions) => Promise<Completion>;
|
|
11
|
+
requestTextCompletionStream: (driver: VertexAIDriver, promp: PromptT, options: ExecutionOptions) => Promise<AsyncIterable<CompletionChunkObject>>;
|
|
19
12
|
}
|
|
20
13
|
|
|
21
14
|
export function getModelDefinition(model: string): ModelDefinition {
|
|
22
|
-
const
|
|
23
|
-
const
|
|
24
|
-
|
|
25
|
-
|
|
15
|
+
const splits = model.split("/");
|
|
16
|
+
const publisher = splits[1];
|
|
17
|
+
const modelName = trimModelName(splits[splits.length - 1]);
|
|
18
|
+
|
|
19
|
+
if (publisher?.includes("anthropic")) {
|
|
20
|
+
return new ClaudeModelDefinition(modelName);
|
|
21
|
+
} else if (publisher?.includes("google")) {
|
|
22
|
+
return new GeminiModelDefinition(modelName);
|
|
26
23
|
}
|
|
27
|
-
return def;
|
|
28
|
-
}
|
|
29
|
-
|
|
30
|
-
export function getAIModels() {
|
|
31
|
-
return Object.values(Models).map(m => m.model);
|
|
32
|
-
}
|
|
33
|
-
|
|
34
|
-
// Builtin models. VertexAI doesn't provide an API to list models. so we have to hardcode them here.
|
|
35
|
-
export const BuiltinModels: AIModel<string>[] = [
|
|
36
|
-
{
|
|
37
|
-
id: "gemini-1.5-flash",
|
|
38
|
-
name: "Gemini Pro 1.5 Flash",
|
|
39
|
-
provider: "vertexai",
|
|
40
|
-
owner: "google",
|
|
41
|
-
type: ModelType.MultiModal,
|
|
42
|
-
can_stream: true,
|
|
43
|
-
is_multimodal: true
|
|
44
24
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
name: "Gemini Pro 1.5 Pro",
|
|
49
|
-
provider: "vertexai",
|
|
50
|
-
owner: "google",
|
|
51
|
-
type: ModelType.MultiModal,
|
|
52
|
-
can_stream: true,
|
|
53
|
-
is_multimodal: true
|
|
54
|
-
|
|
55
|
-
},
|
|
56
|
-
{
|
|
57
|
-
id: "gemini-1.0-pro",
|
|
58
|
-
name: "Gemini Pro 1.0",
|
|
59
|
-
provider: "vertexai",
|
|
60
|
-
owner: "google",
|
|
61
|
-
type: ModelType.Text,
|
|
62
|
-
can_stream: true,
|
|
63
|
-
},
|
|
64
|
-
{
|
|
65
|
-
id: "tablextembedding-gecko",
|
|
66
|
-
name: "Gecko Text Embeddings",
|
|
67
|
-
provider: "vertexai",
|
|
68
|
-
owner: "google",
|
|
69
|
-
type: ModelType.Embedding,
|
|
70
|
-
},
|
|
71
|
-
{
|
|
72
|
-
id: "textembedding-gecko-multilingual",
|
|
73
|
-
name: "Gecko Multilingual Text Embeddings",
|
|
74
|
-
provider: "vertexai",
|
|
75
|
-
owner: "google",
|
|
76
|
-
type: ModelType.Embedding,
|
|
77
|
-
},
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
]
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
const Models: Record<string, ModelDefinition> = {
|
|
86
|
-
"gemini-1.5-flash": new GeminiModelDefinition("gemini-1.5-flash"),
|
|
87
|
-
"gemini-1.5-pro": new GeminiModelDefinition("gemini-1.5-pro"),
|
|
88
|
-
"gemini-1.0-pro": new GeminiModelDefinition(),
|
|
89
|
-
}
|
|
25
|
+
//Fallback, assume it is Gemini.
|
|
26
|
+
return new GeminiModelDefinition(modelName);
|
|
27
|
+
}
|
package/src/watsonx/index.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AIModel, AbstractDriver, Completion, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions } from "@llumiverse/core";
|
|
1
|
+
import { AIModel, AbstractDriver, Completion, CompletionChunk, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, TextFallbackOptions } from "@llumiverse/core";
|
|
2
2
|
import { transformSSEStream } from "@llumiverse/core/async";
|
|
3
3
|
import { FetchClient } from "api-fetch-client";
|
|
4
4
|
import { GenerateEmbeddingPayload, GenerateEmbeddingResponse, WatsonAuthToken, WatsonxListModelResponse, WatsonxModelSpec, WatsonxTextGenerationPayload, WatsonxTextGenerationResponse } from "./interfaces.js";
|
|
@@ -29,13 +29,21 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
|
|
|
29
29
|
this.fetchClient = new FetchClient(this.endpoint_url).withAuthCallback(async () => this.getAuthToken().then(token => `Bearer ${token}`));
|
|
30
30
|
}
|
|
31
31
|
|
|
32
|
-
async
|
|
32
|
+
async requestTextCompletion(prompt: string, options: ExecutionOptions): Promise<Completion<any>> {
|
|
33
|
+
if (options.model_options?._option_id !== "text-fallback") {
|
|
34
|
+
this.logger.warn("Invalid model options", {options: options.model_options });
|
|
35
|
+
}
|
|
36
|
+
options.model_options = options.model_options as TextFallbackOptions;
|
|
37
|
+
|
|
33
38
|
const payload: WatsonxTextGenerationPayload = {
|
|
34
39
|
model_id: options.model,
|
|
35
40
|
input: prompt + "\n",
|
|
36
41
|
parameters: {
|
|
37
|
-
max_new_tokens: options.max_tokens,
|
|
38
|
-
|
|
42
|
+
max_new_tokens: options.model_options.max_tokens,
|
|
43
|
+
temperature: options.model_options.temperature,
|
|
44
|
+
top_k: options.model_options.top_k,
|
|
45
|
+
top_p: options.model_options.top_p,
|
|
46
|
+
stop_sequences: options.model_options.stop_sequence,
|
|
39
47
|
},
|
|
40
48
|
project_id: this.projectId,
|
|
41
49
|
}
|
|
@@ -51,19 +59,25 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
|
|
|
51
59
|
result: result.generated_token_count,
|
|
52
60
|
total: result.input_token_count + result.generated_token_count,
|
|
53
61
|
},
|
|
54
|
-
finish_reason: result.stop_reason,
|
|
62
|
+
finish_reason: watsonFinishReason(result.stop_reason),
|
|
55
63
|
original_response: options.include_original_response ? res : undefined,
|
|
56
64
|
}
|
|
57
65
|
}
|
|
58
66
|
|
|
59
|
-
async
|
|
60
|
-
|
|
67
|
+
async requestTextCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunk>> {
|
|
68
|
+
if (options.model_options?._option_id !== "text-fallback") {
|
|
69
|
+
this.logger.warn("Invalid model options", {options: options.model_options });
|
|
70
|
+
}
|
|
71
|
+
options.model_options = options.model_options as TextFallbackOptions;
|
|
61
72
|
const payload: WatsonxTextGenerationPayload = {
|
|
62
73
|
model_id: options.model,
|
|
63
74
|
input: prompt + "\n",
|
|
64
75
|
parameters: {
|
|
65
|
-
max_new_tokens: options.max_tokens,
|
|
66
|
-
|
|
76
|
+
max_new_tokens: options.model_options.max_tokens,
|
|
77
|
+
temperature: options.model_options.temperature,
|
|
78
|
+
top_k: options.model_options.top_k,
|
|
79
|
+
top_p: options.model_options.top_p,
|
|
80
|
+
stop_sequences: options.model_options.stop_sequence,
|
|
67
81
|
},
|
|
68
82
|
project_id: this.projectId,
|
|
69
83
|
}
|
|
@@ -75,7 +89,15 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
|
|
|
75
89
|
|
|
76
90
|
return transformSSEStream(stream, (data: string) => {
|
|
77
91
|
const json = JSON.parse(data) as WatsonxTextGenerationResponse;
|
|
78
|
-
return
|
|
92
|
+
return {
|
|
93
|
+
result: json.results[0]?.generated_text ?? '',
|
|
94
|
+
finish_reason: watsonFinishReason(json.results[0]?.stop_reason),
|
|
95
|
+
token_usage: {
|
|
96
|
+
prompt: json.results[0].input_token_count,
|
|
97
|
+
result: json.results[0].generated_token_count,
|
|
98
|
+
total: json.results[0].input_token_count + json.results[0].generated_token_count,
|
|
99
|
+
},
|
|
100
|
+
};
|
|
79
101
|
});
|
|
80
102
|
|
|
81
103
|
}
|
|
@@ -130,15 +152,22 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
|
|
|
130
152
|
return this.listModels()
|
|
131
153
|
.then(() => true)
|
|
132
154
|
.catch((err) => {
|
|
133
|
-
this.logger.warn("Failed to connect to WatsonX", err);
|
|
155
|
+
this.logger.warn("Failed to connect to WatsonX", { error: err });
|
|
134
156
|
return false
|
|
135
157
|
});
|
|
136
158
|
}
|
|
137
159
|
|
|
138
160
|
async generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult> {
|
|
161
|
+
if (options.image) {
|
|
162
|
+
throw new Error("Image embeddings not supported by Watsonx");
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
if (!options.text) {
|
|
166
|
+
throw new Error ("No text provided");
|
|
167
|
+
}
|
|
139
168
|
|
|
140
169
|
const payload: GenerateEmbeddingPayload = {
|
|
141
|
-
inputs: [options.
|
|
170
|
+
inputs: [options.text],
|
|
142
171
|
model_id: options.model ?? 'ibm/slate-125m-english-rtrvr',
|
|
143
172
|
project_id: this.projectId
|
|
144
173
|
}
|
|
@@ -154,6 +183,15 @@ export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string>
|
|
|
154
183
|
|
|
155
184
|
}
|
|
156
185
|
|
|
186
|
+
function watsonFinishReason(reason: string | undefined) {
|
|
187
|
+
if (!reason) return undefined;
|
|
188
|
+
switch (reason) {
|
|
189
|
+
case 'eos_token': return "stop";
|
|
190
|
+
case 'max_tokens': return "length";
|
|
191
|
+
default: return reason;
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
157
195
|
|
|
158
196
|
|
|
159
197
|
/*interface ListModelsParams extends ModelSearchPayload {
|