@drax/ai-back 3.31.0 → 3.33.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,489 @@
1
+ import {GoogleGenAI} from "@google/genai";
2
+ import type {
3
+ Content,
4
+ FunctionCall,
5
+ FunctionDeclaration,
6
+ GenerateContentConfig,
7
+ Part
8
+ } from "@google/genai";
9
+ import {toJSONSchema} from "zod";
10
+ import type {
11
+ IAIProvider,
12
+ IPromptContentPart,
13
+ IPromptMessage,
14
+ IPromptParams,
15
+ IPromptResponse,
16
+ IPromptTool
17
+ } from "../interfaces/IAIProvider";
18
+ import type {AILogService} from "../services/AILogService";
19
+ import type {IAILogBase} from "@drax/ai-share";
20
+
21
+ class GoogleAiProvider implements IAIProvider{
22
+ protected _apiKey: string
23
+ protected _model: any
24
+ protected _visionModel?: string
25
+ protected _client: GoogleGenAI | undefined
26
+ protected _aiLogService?: AILogService
27
+
28
+ constructor(apiKey: string, model: string, visionModel?: string, aiLogService?: AILogService) {
29
+
30
+ if (!apiKey) {
31
+ throw new Error("Google AI apiKey required")
32
+ }
33
+ if (!model) {
34
+ throw new Error("Google AI model required")
35
+ }
36
+
37
+ this._apiKey = apiKey
38
+ this._model = model
39
+ this._visionModel = visionModel
40
+ this._aiLogService = aiLogService
41
+ }
42
+
43
+ get model(){
44
+ if(!this._model){
45
+ throw new Error("Google AI model not found")
46
+ }
47
+ return this._model;
48
+ }
49
+
50
+ get client(){
51
+ if(!this._client){
52
+ this._client = new GoogleGenAI({
53
+ apiKey: this._apiKey,
54
+ });
55
+ }
56
+
57
+ return this._client
58
+ }
59
+
60
+ protected get visionModel(){
61
+ return this._visionModel
62
+ }
63
+
64
+ protected buildUserContent(input: IPromptParams): Part[] {
65
+ if(input.userContent && input.userContent.length > 0){
66
+ return this.mapContentParts(input.userContent)
67
+ }
68
+
69
+ if(input.userImages && input.userImages.length > 0){
70
+ const content: Part[] = []
71
+
72
+ if(input.userInput){
73
+ content.push({text: input.userInput})
74
+ }
75
+
76
+ content.push(...input.userImages.map(image => this.mapImageUrl(image.url)))
77
+
78
+ return content
79
+ }
80
+
81
+ return input.userInput ? [{text: input.userInput}] : [{text: ""}]
82
+ }
83
+
84
+ protected mapContentParts(content: IPromptContentPart[]): Part[]{
85
+ return content.map(part => {
86
+ if(part.type === 'text'){
87
+ return {
88
+ text: part.text
89
+ }
90
+ }
91
+
92
+ return this.mapImageUrl(part.imageUrl)
93
+ })
94
+ }
95
+
96
+ protected mapImageUrl(url: string): Part {
97
+ const dataUrlMatch = url.match(/^data:([^;,]+);base64,(.+)$/)
98
+
99
+ if(dataUrlMatch){
100
+ return {
101
+ inlineData: {
102
+ mimeType: dataUrlMatch[1],
103
+ data: dataUrlMatch[2],
104
+ }
105
+ }
106
+ }
107
+
108
+ return {
109
+ fileData: {
110
+ fileUri: url,
111
+ mimeType: this.inferImageMimeType(url),
112
+ }
113
+ }
114
+ }
115
+
116
+ protected inferImageMimeType(url: string){
117
+ const normalizedUrl = url.split("?")[0].toLowerCase()
118
+
119
+ if(normalizedUrl.endsWith(".png")){
120
+ return "image/png"
121
+ }
122
+ if(normalizedUrl.endsWith(".webp")){
123
+ return "image/webp"
124
+ }
125
+ if(normalizedUrl.endsWith(".gif")){
126
+ return "image/gif"
127
+ }
128
+ if(normalizedUrl.endsWith(".bmp")){
129
+ return "image/bmp"
130
+ }
131
+ if(normalizedUrl.endsWith(".heic")){
132
+ return "image/heic"
133
+ }
134
+ if(normalizedUrl.endsWith(".heif")){
135
+ return "image/heif"
136
+ }
137
+
138
+ return "image/jpeg"
139
+ }
140
+
141
+ protected mapHistory(history: IPromptMessage[] = []): Content[]{
142
+ return history.map(message => {
143
+ const parts = typeof message.content === 'string'
144
+ ? [{text: message.content}]
145
+ : this.mapContentParts(message.content)
146
+
147
+ if(message.role === "assistant"){
148
+ return {
149
+ role: "model",
150
+ parts,
151
+ }
152
+ }
153
+
154
+ if(message.role === "system"){
155
+ return {
156
+ role: "user",
157
+ parts: [
158
+ {text: "[SYSTEM]"},
159
+ ...parts,
160
+ ],
161
+ }
162
+ }
163
+
164
+ return {
165
+ role: "user",
166
+ parts,
167
+ }
168
+ })
169
+ }
170
+
171
+ protected hasImageInput(input: IPromptParams){
172
+ if(input.userImages && input.userImages.length > 0){
173
+ return true
174
+ }
175
+
176
+ if(input.userContent?.some(part => part.type === 'image')){
177
+ return true
178
+ }
179
+
180
+ return input.history?.some(message =>
181
+ Array.isArray(message.content) && message.content.some(part => part.type === 'image')
182
+ ) ?? false
183
+ }
184
+
185
+ protected serializePromptInput(input: IPromptParams, systemPrompt: string){
186
+ return JSON.stringify({
187
+ systemPrompt,
188
+ history: input.history,
189
+ userInput: input.userInput,
190
+ userContent: input.userContent,
191
+ memory: input.memory,
192
+ knowledgeBase: input.knowledgeBase,
193
+ tools: input.tools?.map(tool => ({
194
+ name: tool.name,
195
+ description: tool.description,
196
+ parameters: tool.parameters,
197
+ })),
198
+ })
199
+ }
200
+
201
+ protected serializePromptOutput(output: unknown){
202
+ if (typeof output === "string") {
203
+ return output
204
+ }
205
+
206
+ if (output === null || output === undefined) {
207
+ return undefined
208
+ }
209
+
210
+ return JSON.stringify(output)
211
+ }
212
+
213
+ protected buildLogPayload(input: IPromptParams, params: {
214
+ model: string,
215
+ systemPrompt: string,
216
+ startedAt: Date,
217
+ endedAt?: Date,
218
+ inputTokens?: number,
219
+ outputTokens?: number,
220
+ tokens?: number,
221
+ output?: unknown,
222
+ success: boolean,
223
+ errorMessage?: string,
224
+ }): IAILogBase {
225
+ return {
226
+ provider: "googleai",
227
+ model: params.model,
228
+ operationTitle: input.operationTitle,
229
+ operationGroup: input.operationGroup,
230
+ ip: input.ip,
231
+ userAgent: input.userAgent,
232
+ input: this.serializePromptInput(input, params.systemPrompt),
233
+ inputImages: input.userImages?.map(image => ({
234
+ url: image.url,
235
+ })) ?? input.userContent
236
+ ?.filter(part => part.type === "image")
237
+ .map(part => ({
238
+ url: part.imageUrl,
239
+ })),
240
+ inputFiles: input.inputFiles,
241
+ inputTokens: params.inputTokens,
242
+ outputTokens: params.outputTokens,
243
+ tokens: params.tokens,
244
+ startedAt: params.startedAt,
245
+ endedAt: params.endedAt,
246
+ responseTime: params.endedAt ? `${params.endedAt.getTime() - params.startedAt.getTime()}ms` : undefined,
247
+ output: this.serializePromptOutput(params.output),
248
+ success: params.success,
249
+ errorMessage: params.errorMessage,
250
+ tenant: input.tenant,
251
+ user: input.user,
252
+ }
253
+ }
254
+
255
+ protected async registerPromptLog(input: IPromptParams, params: {
256
+ model: string,
257
+ systemPrompt: string,
258
+ startedAt: Date,
259
+ endedAt?: Date,
260
+ inputTokens?: number,
261
+ outputTokens?: number,
262
+ tokens?: number,
263
+ output?: unknown,
264
+ success: boolean,
265
+ errorMessage?: string,
266
+ }){
267
+ if(!this._aiLogService){
268
+ return
269
+ }
270
+
271
+ try{
272
+ await this._aiLogService.create(this.buildLogPayload(input, params))
273
+ }catch(e: any){
274
+ console.error("Error registerPromptLog", {
275
+ name: e?.name,
276
+ message: e?.message,
277
+ stack: e?.stack,
278
+ })
279
+ }
280
+ }
281
+
282
+ async generateEmbedding({text, model="text-embedding-004"}: {text:string,model?:string }): Promise<number[]> {
283
+ const response = await this.client.models.embedContent({
284
+ model,
285
+ contents: text,
286
+ });
287
+ return response.embeddings?.[0]?.values ?? [];
288
+ }
289
+
290
+ protected mapTools(tools: IPromptTool[] = []): Array<{functionDeclarations: FunctionDeclaration[]}> {
291
+ if(tools.length === 0){
292
+ return []
293
+ }
294
+
295
+ return [{
296
+ functionDeclarations: tools.map(tool => ({
297
+ name: tool.name,
298
+ description: tool.description,
299
+ parametersJsonSchema: tool.parameters ?? {
300
+ type: "object",
301
+ properties: {},
302
+ additionalProperties: false,
303
+ },
304
+ }))
305
+ }]
306
+ }
307
+
308
+ protected buildResponseConfig(input: IPromptParams, systemPrompt: string): GenerateContentConfig {
309
+ const config: GenerateContentConfig = {
310
+ systemInstruction: systemPrompt,
311
+ }
312
+
313
+ const responseJsonSchema = this.normalizeResponseJsonSchema(input)
314
+
315
+ if(responseJsonSchema){
316
+ config.responseMimeType = "application/json"
317
+ config.responseJsonSchema = responseJsonSchema
318
+ }
319
+
320
+ if(input.tools && input.tools.length > 0){
321
+ config.tools = this.mapTools(input.tools)
322
+ }
323
+
324
+ return config
325
+ }
326
+
327
+ protected normalizeResponseJsonSchema(input: IPromptParams){
328
+ if(input.zodSchema){
329
+ return toJSONSchema(input.zodSchema, {
330
+ target: "draft-7",
331
+ })
332
+ }
333
+
334
+ if(!input.jsonSchema){
335
+ return undefined
336
+ }
337
+
338
+ const jsonSchema: any = input.jsonSchema
339
+
340
+ if(jsonSchema.type === "json_schema" && jsonSchema.json_schema?.schema){
341
+ return jsonSchema.json_schema.schema
342
+ }
343
+
344
+ return jsonSchema
345
+ }
346
+
347
+ protected async buildToolResponseParts(functionCalls: FunctionCall[] = [], tools: IPromptTool[] = []){
348
+ const parts: Part[] = []
349
+
350
+ for(const functionCall of functionCalls){
351
+ const toolName = functionCall.name
352
+ const tool = tools.find(t => t.name === toolName)
353
+
354
+ if(!tool){
355
+ throw new Error(`Tool not found: ${toolName}`)
356
+ }
357
+
358
+ const output = await tool.execute(functionCall.args ?? {})
359
+
360
+ parts.push({
361
+ functionResponse: {
362
+ id: functionCall.id,
363
+ name: toolName,
364
+ response: {
365
+ output,
366
+ },
367
+ }
368
+ })
369
+ }
370
+
371
+ return parts
372
+ }
373
+
374
+ protected buildModelFunctionCallContent(functionCalls: FunctionCall[] = []): Content {
375
+ return {
376
+ role: "model",
377
+ parts: functionCalls.map(functionCall => ({
378
+ functionCall,
379
+ }))
380
+ }
381
+ }
382
+
383
+ async prompt(input: IPromptParams): Promise<IPromptResponse> {
384
+
385
+ if(!input.systemPrompt){
386
+ throw new Error("systemPrompt required")
387
+ }
388
+
389
+ let systemPrompt = input.systemPrompt
390
+
391
+ if(input.memory && input.memory.length > 0){
392
+ systemPrompt += `\n\n ${input.memoryHeader ?? '[MEMORIA]'}\n ${input.memory.map(m => `${m.key}: ${m.value}`).join('\n')}`
393
+ }
394
+
395
+ if(input.knowledgeBase && input.knowledgeBase.length > 0){
396
+ systemPrompt += `\n\n${input.knowledgeBaseHeader ?? '[BASE DE CONOCIMIENTO]'}\n ${input.knowledgeBase.join('\n')}`
397
+ }
398
+
399
+
400
+ const userInput = this.buildUserContent(input)
401
+ const model = input.model ?? (this.hasImageInput(input) ? this.visionModel ?? this.model : this.model)
402
+ const startedAt = new Date()
403
+ const startTime = performance.now()
404
+ let tokens = 0
405
+ let inputTokens = 0
406
+ let outputTokens = 0
407
+
408
+ try {
409
+ const contents: Content[] = [
410
+ ...this.mapHistory(input.history),
411
+ {role: 'user', parts: userInput},
412
+ ]
413
+ const tools = input.tools ?? []
414
+ const maxIterations = input.toolMaxIterations ?? 5
415
+ let output: any
416
+
417
+ for(let iteration = 0; iteration < maxIterations; iteration++){
418
+ const response = await this.client.models.generateContent({
419
+ model,
420
+ contents,
421
+ config: this.buildResponseConfig(input, systemPrompt),
422
+ });
423
+
424
+ tokens += response.usageMetadata?.totalTokenCount ?? 0
425
+ inputTokens += response.usageMetadata?.promptTokenCount ?? 0
426
+ outputTokens += response.usageMetadata?.candidatesTokenCount ?? 0
427
+
428
+ const functionCalls = response.functionCalls ?? []
429
+
430
+ if(functionCalls.length === 0){
431
+ output = response.text
432
+ break
433
+ }
434
+
435
+ contents.push(response.candidates?.[0]?.content ?? this.buildModelFunctionCallContent(functionCalls))
436
+ contents.push({
437
+ role: "user",
438
+ parts: await this.buildToolResponseParts(functionCalls, tools),
439
+ })
440
+ }
441
+
442
+ if(output === undefined){
443
+ throw new Error(`Tool max iterations reached: ${maxIterations}`)
444
+ }
445
+
446
+ const endTime = performance.now()
447
+ const time = endTime - startTime
448
+ const endedAt = new Date()
449
+
450
+ await this.registerPromptLog(input, {
451
+ model,
452
+ systemPrompt,
453
+ startedAt,
454
+ endedAt,
455
+ inputTokens,
456
+ outputTokens,
457
+ tokens,
458
+ output,
459
+ success: true,
460
+ })
461
+
462
+ return {
463
+ output,
464
+ tokens,
465
+ inputTokens,
466
+ outputTokens,
467
+ time
468
+ }
469
+ } catch (e: any) {
470
+ const endedAt = new Date()
471
+
472
+ await this.registerPromptLog(input, {
473
+ model,
474
+ systemPrompt,
475
+ startedAt,
476
+ endedAt,
477
+ success: false,
478
+ errorMessage: e?.message,
479
+ })
480
+
481
+ throw e
482
+ }
483
+ }
484
+
485
+ }
486
+
487
+
488
+ export default GoogleAiProvider
489
+ export {GoogleAiProvider}
@@ -258,6 +258,18 @@ class BuilderTool<T = any, C = any, U = any> {
258
258
  return this.fieldAdapter(f.unwrap()).nullable();
259
259
  }
260
260
 
261
+ if (typeof f?.unwrap === "function" && typeName === "ZodDefault") {
262
+ return this.fieldAdapter(f.unwrap()).default(f.def.defaultValue);
263
+ }
264
+
265
+ if (typeof f?.unwrap === "function" && typeName === "ZodCatch") {
266
+ return this.fieldAdapter(f.unwrap()).catch(f.def.catchValue);
267
+ }
268
+
269
+ if (typeof f?.unwrap === "function" && typeName === "ZodReadonly") {
270
+ return this.fieldAdapter(f.unwrap()).readonly();
271
+ }
272
+
261
273
  if (typeName === "ZodArray" && f?.element) {
262
274
  return z.array(this.fieldAdapter(f.element));
263
275
  }