@drax/ai-back 3.32.0 → 3.35.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/config/GoogleAiConfig.js +8 -0
- package/dist/config/OllamaAiConfig.js +9 -0
- package/dist/factory/AiProviderFactory.js +12 -3
- package/dist/factory/GoogleAiProviderFactory.js +14 -0
- package/dist/factory/OllamaAiProviderFactory.js +14 -0
- package/dist/index.js +7 -1
- package/dist/providers/GoogleAiProvider.js +367 -0
- package/dist/providers/OllamaAiProvider.js +342 -0
- package/dist/tools/BuilderTool.js +9 -0
- package/package.json +4 -2
- package/src/config/GoogleAiConfig.ts +13 -0
- package/src/config/OllamaAiConfig.ts +14 -0
- package/src/factory/AiProviderFactory.ts +12 -5
- package/src/factory/GoogleAiProviderFactory.ts +26 -0
- package/src/factory/OllamaAiProviderFactory.ts +27 -0
- package/src/index.ts +12 -0
- package/src/providers/GoogleAiProvider.ts +489 -0
- package/src/providers/OllamaAiProvider.ts +469 -0
- package/src/tools/BuilderTool.ts +12 -0
- package/test/GoogleAiProvider.test.ts +211 -0
- package/test/ToolBuilder.test.ts +48 -0
- package/tsconfig.tsbuildinfo +1 -1
- package/types/config/GoogleAiConfig.d.ts +8 -0
- package/types/config/GoogleAiConfig.d.ts.map +1 -0
- package/types/config/OllamaAiConfig.d.ts +9 -0
- package/types/config/OllamaAiConfig.d.ts.map +1 -0
- package/types/factory/AiProviderFactory.d.ts.map +1 -1
- package/types/factory/GoogleAiProviderFactory.d.ts +8 -0
- package/types/factory/GoogleAiProviderFactory.d.ts.map +1 -0
- package/types/factory/OllamaAiProviderFactory.d.ts +8 -0
- package/types/factory/OllamaAiProviderFactory.d.ts.map +1 -0
- package/types/index.d.ts.map +1 -1
- package/types/providers/GoogleAiProvider.d.ts +63 -0
- package/types/providers/GoogleAiProvider.d.ts.map +1 -0
- package/types/providers/OllamaAiProvider.d.ts +78 -0
- package/types/providers/OllamaAiProvider.d.ts.map +1 -0
- package/types/tools/BuilderTool.d.ts.map +1 -1
- package/.env +0 -4
- package/dist/agents/ChatbotTaskService.js +0 -143
- package/dist/agents/ChatbotTaskTools.js +0 -756
- package/dist/controllers/AIController.js +0 -150
- package/dist/interfaces/IAILog.js +0 -1
- package/dist/routes/ChatbotTaskRoutes.js +0 -8
- package/dist/tools/ToolBuilder.js +0 -243
- package/dist/vectors/ChromaVector.js +0 -65
- package/types/agents/ChatbotTaskService.d.ts +0 -42
- package/types/agents/ChatbotTaskService.d.ts.map +0 -1
- package/types/agents/ChatbotTaskTools.d.ts +0 -54
- package/types/agents/ChatbotTaskTools.d.ts.map +0 -1
- package/types/controllers/AIController.d.ts +0 -25
- package/types/controllers/AIController.d.ts.map +0 -1
- package/types/interfaces/IAILog.d.ts +0 -77
- package/types/interfaces/IAILog.d.ts.map +0 -1
- package/types/routes/ChatbotTaskRoutes.d.ts +0 -4
- package/types/routes/ChatbotTaskRoutes.d.ts.map +0 -1
- package/types/tools/ToolBuilder.d.ts +0 -47
- package/types/tools/ToolBuilder.d.ts.map +0 -1
- package/types/vectors/ChromaVector.d.ts +0 -21
- package/types/vectors/ChromaVector.d.ts.map +0 -1
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
import {toJSONSchema} from "zod";
|
|
2
|
+
import type {
|
|
3
|
+
IAIProvider,
|
|
4
|
+
IPromptContentPart,
|
|
5
|
+
IPromptMessage,
|
|
6
|
+
IPromptParams,
|
|
7
|
+
IPromptResponse,
|
|
8
|
+
IPromptTool
|
|
9
|
+
} from "../interfaces/IAIProvider";
|
|
10
|
+
import type {AILogService} from "../services/AILogService";
|
|
11
|
+
import type {IAILogBase} from "@drax/ai-share";
|
|
12
|
+
|
|
13
|
+
type OllamaMessage = {
|
|
14
|
+
role: "system" | "user" | "assistant" | "tool",
|
|
15
|
+
content: string,
|
|
16
|
+
images?: string[],
|
|
17
|
+
name?: string,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
type OllamaToolCall = {
|
|
21
|
+
function?: {
|
|
22
|
+
name?: string,
|
|
23
|
+
arguments?: string | object,
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
class OllamaAiProvider implements IAIProvider{
|
|
28
|
+
protected _baseUrl: string
|
|
29
|
+
protected _model: string
|
|
30
|
+
protected _visionModel?: string
|
|
31
|
+
protected _embeddingModel?: string
|
|
32
|
+
protected _aiLogService?: AILogService
|
|
33
|
+
|
|
34
|
+
constructor(baseUrl: string, model: string, visionModel?: string, embeddingModel?: string, aiLogService?: AILogService) {
|
|
35
|
+
|
|
36
|
+
if (!baseUrl) {
|
|
37
|
+
throw new Error("Ollama AI baseUrl required")
|
|
38
|
+
}
|
|
39
|
+
if (!model) {
|
|
40
|
+
throw new Error("Ollama AI model required")
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
this._baseUrl = baseUrl.replace(/\/+$/, "")
|
|
44
|
+
this._model = model
|
|
45
|
+
this._visionModel = visionModel
|
|
46
|
+
this._embeddingModel = embeddingModel
|
|
47
|
+
this._aiLogService = aiLogService
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
get model(){
|
|
51
|
+
if(!this._model){
|
|
52
|
+
throw new Error("Ollama AI model not found")
|
|
53
|
+
}
|
|
54
|
+
return this._model;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
protected get visionModel(){
|
|
58
|
+
return this._visionModel
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
protected get embeddingModel(){
|
|
62
|
+
return this._embeddingModel ?? this.model
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
protected async post<T>(path: string, body: object): Promise<T> {
|
|
66
|
+
const response = await fetch(`${this._baseUrl}${path}`, {
|
|
67
|
+
method: "POST",
|
|
68
|
+
headers: {
|
|
69
|
+
"Content-Type": "application/json",
|
|
70
|
+
},
|
|
71
|
+
body: JSON.stringify(body),
|
|
72
|
+
})
|
|
73
|
+
|
|
74
|
+
if(!response.ok){
|
|
75
|
+
const errorText = await response.text()
|
|
76
|
+
throw new Error(`Ollama AI request failed (${response.status}): ${errorText}`)
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return await response.json() as T
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
protected async buildUserMessage(input: IPromptParams): Promise<OllamaMessage> {
|
|
83
|
+
if(input.userContent && input.userContent.length > 0){
|
|
84
|
+
return await this.mapContentPartsToMessage(input.userContent)
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if(input.userImages && input.userImages.length > 0){
|
|
88
|
+
return {
|
|
89
|
+
role: "user",
|
|
90
|
+
content: input.userInput ?? "",
|
|
91
|
+
images: await Promise.all(input.userImages.map(image => this.imageUrlToBase64(image.url))),
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return {
|
|
96
|
+
role: "user",
|
|
97
|
+
content: input.userInput ?? "",
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
protected async mapContentPartsToMessage(content: IPromptContentPart[], role: "user" | "assistant" | "system" = "user"): Promise<OllamaMessage> {
|
|
102
|
+
const text: string[] = []
|
|
103
|
+
const images: string[] = []
|
|
104
|
+
|
|
105
|
+
for(const part of content){
|
|
106
|
+
if(part.type === "text"){
|
|
107
|
+
text.push(part.text)
|
|
108
|
+
continue
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
images.push(await this.imageUrlToBase64(part.imageUrl))
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
return {
|
|
115
|
+
role,
|
|
116
|
+
content: text.join("\n"),
|
|
117
|
+
...(images.length > 0 ? {images} : {}),
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
protected async imageUrlToBase64(url: string): Promise<string> {
|
|
122
|
+
const dataUrlMatch = url.match(/^data:[^;,]+;base64,(.+)$/)
|
|
123
|
+
|
|
124
|
+
if(dataUrlMatch){
|
|
125
|
+
return dataUrlMatch[1]
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
const response = await fetch(url)
|
|
129
|
+
|
|
130
|
+
if(!response.ok){
|
|
131
|
+
throw new Error(`Ollama AI image request failed (${response.status}): ${url}`)
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
const buffer = Buffer.from(await response.arrayBuffer())
|
|
135
|
+
return buffer.toString("base64")
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
protected async mapHistory(history: IPromptMessage[] = []): Promise<OllamaMessage[]>{
|
|
139
|
+
const messages: OllamaMessage[] = []
|
|
140
|
+
|
|
141
|
+
for(const message of history){
|
|
142
|
+
if(typeof message.content === "string"){
|
|
143
|
+
messages.push({
|
|
144
|
+
role: message.role,
|
|
145
|
+
content: message.content,
|
|
146
|
+
})
|
|
147
|
+
continue
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
messages.push(await this.mapContentPartsToMessage(message.content, message.role))
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
return messages
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
protected hasImageInput(input: IPromptParams){
|
|
157
|
+
if(input.userImages && input.userImages.length > 0){
|
|
158
|
+
return true
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
if(input.userContent?.some(part => part.type === 'image')){
|
|
162
|
+
return true
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
return input.history?.some(message =>
|
|
166
|
+
Array.isArray(message.content) && message.content.some(part => part.type === 'image')
|
|
167
|
+
) ?? false
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
protected serializePromptInput(input: IPromptParams, systemPrompt: string){
|
|
171
|
+
return JSON.stringify({
|
|
172
|
+
systemPrompt,
|
|
173
|
+
history: input.history,
|
|
174
|
+
userInput: input.userInput,
|
|
175
|
+
userContent: input.userContent,
|
|
176
|
+
memory: input.memory,
|
|
177
|
+
knowledgeBase: input.knowledgeBase,
|
|
178
|
+
tools: input.tools?.map(tool => ({
|
|
179
|
+
name: tool.name,
|
|
180
|
+
description: tool.description,
|
|
181
|
+
parameters: tool.parameters,
|
|
182
|
+
})),
|
|
183
|
+
})
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
protected serializePromptOutput(output: unknown){
|
|
187
|
+
if (typeof output === "string") {
|
|
188
|
+
return output
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
if (output === null || output === undefined) {
|
|
192
|
+
return undefined
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
return JSON.stringify(output)
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
protected buildLogPayload(input: IPromptParams, params: {
|
|
199
|
+
model: string,
|
|
200
|
+
systemPrompt: string,
|
|
201
|
+
startedAt: Date,
|
|
202
|
+
endedAt?: Date,
|
|
203
|
+
inputTokens?: number,
|
|
204
|
+
outputTokens?: number,
|
|
205
|
+
tokens?: number,
|
|
206
|
+
output?: unknown,
|
|
207
|
+
success: boolean,
|
|
208
|
+
errorMessage?: string,
|
|
209
|
+
}): IAILogBase {
|
|
210
|
+
return {
|
|
211
|
+
provider: "ollamaai",
|
|
212
|
+
model: params.model,
|
|
213
|
+
operationTitle: input.operationTitle,
|
|
214
|
+
operationGroup: input.operationGroup,
|
|
215
|
+
ip: input.ip,
|
|
216
|
+
userAgent: input.userAgent,
|
|
217
|
+
input: this.serializePromptInput(input, params.systemPrompt),
|
|
218
|
+
inputImages: input.userImages?.map(image => ({
|
|
219
|
+
url: image.url,
|
|
220
|
+
})) ?? input.userContent
|
|
221
|
+
?.filter(part => part.type === "image")
|
|
222
|
+
.map(part => ({
|
|
223
|
+
url: part.imageUrl,
|
|
224
|
+
})),
|
|
225
|
+
inputFiles: input.inputFiles,
|
|
226
|
+
inputTokens: params.inputTokens,
|
|
227
|
+
outputTokens: params.outputTokens,
|
|
228
|
+
tokens: params.tokens,
|
|
229
|
+
startedAt: params.startedAt,
|
|
230
|
+
endedAt: params.endedAt,
|
|
231
|
+
responseTime: params.endedAt ? `${params.endedAt.getTime() - params.startedAt.getTime()}ms` : undefined,
|
|
232
|
+
output: this.serializePromptOutput(params.output),
|
|
233
|
+
success: params.success,
|
|
234
|
+
errorMessage: params.errorMessage,
|
|
235
|
+
tenant: input.tenant,
|
|
236
|
+
user: input.user,
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
protected async registerPromptLog(input: IPromptParams, params: {
|
|
241
|
+
model: string,
|
|
242
|
+
systemPrompt: string,
|
|
243
|
+
startedAt: Date,
|
|
244
|
+
endedAt?: Date,
|
|
245
|
+
inputTokens?: number,
|
|
246
|
+
outputTokens?: number,
|
|
247
|
+
tokens?: number,
|
|
248
|
+
output?: unknown,
|
|
249
|
+
success: boolean,
|
|
250
|
+
errorMessage?: string,
|
|
251
|
+
}){
|
|
252
|
+
if(!this._aiLogService){
|
|
253
|
+
return
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
try{
|
|
257
|
+
await this._aiLogService.create(this.buildLogPayload(input, params))
|
|
258
|
+
}catch(e: any){
|
|
259
|
+
console.error("Error registerPromptLog", {
|
|
260
|
+
name: e?.name,
|
|
261
|
+
message: e?.message,
|
|
262
|
+
stack: e?.stack,
|
|
263
|
+
})
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
async generateEmbedding({text, model}: {text:string, model?:string }): Promise<number[]> {
|
|
268
|
+
const response = await this.post<any>("/api/embed", {
|
|
269
|
+
model: model ?? this.embeddingModel,
|
|
270
|
+
input: text,
|
|
271
|
+
});
|
|
272
|
+
|
|
273
|
+
return response.embeddings?.[0] ?? response.embedding ?? [];
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
protected mapTools(tools: IPromptTool[] = []){
|
|
277
|
+
return tools.map(tool => ({
|
|
278
|
+
type: "function",
|
|
279
|
+
function: {
|
|
280
|
+
name: tool.name,
|
|
281
|
+
description: tool.description,
|
|
282
|
+
parameters: tool.parameters ?? {
|
|
283
|
+
type: "object",
|
|
284
|
+
properties: {},
|
|
285
|
+
additionalProperties: false,
|
|
286
|
+
},
|
|
287
|
+
},
|
|
288
|
+
}))
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
protected normalizeResponseFormat(input: IPromptParams){
|
|
292
|
+
if(input.zodSchema){
|
|
293
|
+
return toJSONSchema(input.zodSchema, {
|
|
294
|
+
target: "draft-7",
|
|
295
|
+
})
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
if(!input.jsonSchema){
|
|
299
|
+
return undefined
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
const jsonSchema: any = input.jsonSchema
|
|
303
|
+
|
|
304
|
+
if(jsonSchema.type === "json_schema" && jsonSchema.json_schema?.schema){
|
|
305
|
+
return jsonSchema.json_schema.schema
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
return jsonSchema
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
protected parseToolArguments(args: string | object | undefined){
|
|
312
|
+
if(!args){
|
|
313
|
+
return {}
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
if(typeof args === "object"){
|
|
317
|
+
return args
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
try{
|
|
321
|
+
return JSON.parse(args)
|
|
322
|
+
}catch(e){
|
|
323
|
+
throw new Error(`Invalid tool arguments: ${args}`)
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
protected serializeToolOutput(output: unknown){
|
|
328
|
+
if(typeof output === "string"){
|
|
329
|
+
return output
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
if(output === undefined){
|
|
333
|
+
return ""
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
return JSON.stringify(output)
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
protected async buildToolMessages(toolCalls: OllamaToolCall[] = [], tools: IPromptTool[] = []){
|
|
340
|
+
const toolMessages: OllamaMessage[] = []
|
|
341
|
+
|
|
342
|
+
for(const toolCall of toolCalls){
|
|
343
|
+
const toolName = toolCall.function?.name
|
|
344
|
+
const tool = tools.find(t => t.name === toolName)
|
|
345
|
+
|
|
346
|
+
if(!tool){
|
|
347
|
+
throw new Error(`Tool not found: ${toolName}`)
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
const args = this.parseToolArguments(toolCall.function?.arguments)
|
|
351
|
+
const output = await tool.execute(args)
|
|
352
|
+
|
|
353
|
+
toolMessages.push({
|
|
354
|
+
role: "tool",
|
|
355
|
+
name: toolName,
|
|
356
|
+
content: this.serializeToolOutput(output),
|
|
357
|
+
})
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
return toolMessages
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
async prompt(input: IPromptParams): Promise<IPromptResponse> {
|
|
364
|
+
|
|
365
|
+
if(!input.systemPrompt){
|
|
366
|
+
throw new Error("systemPrompt required")
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
let systemPrompt = input.systemPrompt
|
|
370
|
+
|
|
371
|
+
if(input.memory && input.memory.length > 0){
|
|
372
|
+
systemPrompt += `\n\n ${input.memoryHeader ?? '[MEMORIA]'}\n ${input.memory.map(m => `${m.key}: ${m.value}`).join('\n')}`
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
if(input.knowledgeBase && input.knowledgeBase.length > 0){
|
|
376
|
+
systemPrompt += `\n\n${input.knowledgeBaseHeader ?? '[BASE DE CONOCIMIENTO]'}\n ${input.knowledgeBase.join('\n')}`
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
const model = input.model ?? (this.hasImageInput(input) ? this.visionModel ?? this.model : this.model)
|
|
380
|
+
const startedAt = new Date()
|
|
381
|
+
const startTime = performance.now()
|
|
382
|
+
let tokens = 0
|
|
383
|
+
let inputTokens = 0
|
|
384
|
+
let outputTokens = 0
|
|
385
|
+
|
|
386
|
+
try {
|
|
387
|
+
const messages: OllamaMessage[] = [
|
|
388
|
+
{role: 'system', content: systemPrompt},
|
|
389
|
+
...await this.mapHistory(input.history),
|
|
390
|
+
await this.buildUserMessage(input),
|
|
391
|
+
]
|
|
392
|
+
const tools = input.tools ?? []
|
|
393
|
+
const maxIterations = input.toolMaxIterations ?? 5
|
|
394
|
+
const responseFormat = this.normalizeResponseFormat(input)
|
|
395
|
+
let output: any
|
|
396
|
+
|
|
397
|
+
for(let iteration = 0; iteration < maxIterations; iteration++){
|
|
398
|
+
const response = await this.post<any>("/api/chat", {
|
|
399
|
+
model,
|
|
400
|
+
messages,
|
|
401
|
+
stream: false,
|
|
402
|
+
...(responseFormat ? {format: responseFormat} : {}),
|
|
403
|
+
...(tools.length > 0 ? {tools: this.mapTools(tools)} : {}),
|
|
404
|
+
});
|
|
405
|
+
|
|
406
|
+
inputTokens += response.prompt_eval_count ?? 0
|
|
407
|
+
outputTokens += response.eval_count ?? 0
|
|
408
|
+
tokens += (response.prompt_eval_count ?? 0) + (response.eval_count ?? 0)
|
|
409
|
+
|
|
410
|
+
const message = response.message ?? {}
|
|
411
|
+
const toolCalls = message.tool_calls ?? []
|
|
412
|
+
|
|
413
|
+
if(toolCalls.length === 0){
|
|
414
|
+
output = message.content ?? response.response ?? ""
|
|
415
|
+
break
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
messages.push(message)
|
|
419
|
+
messages.push(...await this.buildToolMessages(toolCalls, tools))
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
if(output === undefined){
|
|
423
|
+
throw new Error(`Tool max iterations reached: ${maxIterations}`)
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
const endTime = performance.now()
|
|
427
|
+
const time = endTime - startTime
|
|
428
|
+
const endedAt = new Date()
|
|
429
|
+
|
|
430
|
+
await this.registerPromptLog(input, {
|
|
431
|
+
model,
|
|
432
|
+
systemPrompt,
|
|
433
|
+
startedAt,
|
|
434
|
+
endedAt,
|
|
435
|
+
inputTokens,
|
|
436
|
+
outputTokens,
|
|
437
|
+
tokens,
|
|
438
|
+
output,
|
|
439
|
+
success: true,
|
|
440
|
+
})
|
|
441
|
+
|
|
442
|
+
return {
|
|
443
|
+
output,
|
|
444
|
+
tokens,
|
|
445
|
+
inputTokens,
|
|
446
|
+
outputTokens,
|
|
447
|
+
time
|
|
448
|
+
}
|
|
449
|
+
} catch (e: any) {
|
|
450
|
+
const endedAt = new Date()
|
|
451
|
+
|
|
452
|
+
await this.registerPromptLog(input, {
|
|
453
|
+
model,
|
|
454
|
+
systemPrompt,
|
|
455
|
+
startedAt,
|
|
456
|
+
endedAt,
|
|
457
|
+
success: false,
|
|
458
|
+
errorMessage: e?.message,
|
|
459
|
+
})
|
|
460
|
+
|
|
461
|
+
throw e
|
|
462
|
+
}
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
export default OllamaAiProvider
|
|
469
|
+
export {OllamaAiProvider}
|
package/src/tools/BuilderTool.ts
CHANGED
|
@@ -258,6 +258,18 @@ class BuilderTool<T = any, C = any, U = any> {
|
|
|
258
258
|
return this.fieldAdapter(f.unwrap()).nullable();
|
|
259
259
|
}
|
|
260
260
|
|
|
261
|
+
if (typeof f?.unwrap === "function" && typeName === "ZodDefault") {
|
|
262
|
+
return this.fieldAdapter(f.unwrap()).default(f.def.defaultValue);
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if (typeof f?.unwrap === "function" && typeName === "ZodCatch") {
|
|
266
|
+
return this.fieldAdapter(f.unwrap()).catch(f.def.catchValue);
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
if (typeof f?.unwrap === "function" && typeName === "ZodReadonly") {
|
|
270
|
+
return this.fieldAdapter(f.unwrap()).readonly();
|
|
271
|
+
}
|
|
272
|
+
|
|
261
273
|
if (typeName === "ZodArray" && f?.element) {
|
|
262
274
|
return z.array(this.fieldAdapter(f.element));
|
|
263
275
|
}
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
import {describe, expect, test} from "vitest";
|
|
2
|
+
import {AiProviderFactory, GoogleAiProvider} from "../src";
|
|
3
|
+
import {IPromptTool} from "../src/interfaces/IAIProvider";
|
|
4
|
+
|
|
5
|
+
describe("GoogleAiProvider Test", () => {
|
|
6
|
+
|
|
7
|
+
test("GoogleAi prompt supports image inputs and vision model fallback", async () => {
|
|
8
|
+
let request: any
|
|
9
|
+
|
|
10
|
+
class MockedGoogleAiProvider extends GoogleAiProvider {
|
|
11
|
+
constructor() {
|
|
12
|
+
super("test-key", "gemini-2.5-flash", "gemini-2.5-flash")
|
|
13
|
+
this._client = {
|
|
14
|
+
models: {
|
|
15
|
+
generateContent: async (payload: any) => {
|
|
16
|
+
request = payload
|
|
17
|
+
return {
|
|
18
|
+
text: "{\"name\":\"invoice\"}",
|
|
19
|
+
usageMetadata: {
|
|
20
|
+
totalTokenCount: 10,
|
|
21
|
+
promptTokenCount: 7,
|
|
22
|
+
candidatesTokenCount: 3,
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
} as any
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
const googleAi = new MockedGoogleAiProvider()
|
|
32
|
+
|
|
33
|
+
const r = await googleAi.prompt({
|
|
34
|
+
systemPrompt: "Extract invoice data",
|
|
35
|
+
userInput: "Read this invoice",
|
|
36
|
+
userImages: [{url: "data:image/png;base64,abc123", detail: "high"}]
|
|
37
|
+
})
|
|
38
|
+
|
|
39
|
+
expect(r.output).toBe("{\"name\":\"invoice\"}")
|
|
40
|
+
expect(request.model).toBe("gemini-2.5-flash")
|
|
41
|
+
expect(request.contents[0]).toEqual({
|
|
42
|
+
role: "user",
|
|
43
|
+
parts: [
|
|
44
|
+
{text: "Read this invoice"},
|
|
45
|
+
{
|
|
46
|
+
inlineData: {
|
|
47
|
+
mimeType: "image/png",
|
|
48
|
+
data: "abc123",
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
]
|
|
52
|
+
})
|
|
53
|
+
})
|
|
54
|
+
|
|
55
|
+
test("GoogleAi prompt maps OpenAI jsonSchema format to Gemini responseJsonSchema", async () => {
|
|
56
|
+
let request: any
|
|
57
|
+
|
|
58
|
+
class MockedGoogleAiProvider extends GoogleAiProvider {
|
|
59
|
+
constructor() {
|
|
60
|
+
super("test-key", "gemini-2.5-flash")
|
|
61
|
+
this._client = {
|
|
62
|
+
models: {
|
|
63
|
+
generateContent: async (payload: any) => {
|
|
64
|
+
request = payload
|
|
65
|
+
return {
|
|
66
|
+
text: "{\"name\":\"Pikachu\"}",
|
|
67
|
+
usageMetadata: {
|
|
68
|
+
totalTokenCount: 8,
|
|
69
|
+
promptTokenCount: 6,
|
|
70
|
+
candidatesTokenCount: 2,
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
} as any
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
const jsonSchema = {
|
|
80
|
+
type: "json_schema",
|
|
81
|
+
json_schema: {
|
|
82
|
+
name: "element_description",
|
|
83
|
+
schema: {
|
|
84
|
+
type: "object",
|
|
85
|
+
properties: {
|
|
86
|
+
name: {type: "string"}
|
|
87
|
+
},
|
|
88
|
+
required: ["name"]
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
const googleAi = new MockedGoogleAiProvider()
|
|
94
|
+
await googleAi.prompt({
|
|
95
|
+
systemPrompt: "You are an AI assistant.",
|
|
96
|
+
userInput: "What is the most famous pokemon",
|
|
97
|
+
jsonSchema,
|
|
98
|
+
})
|
|
99
|
+
|
|
100
|
+
expect(request.config.responseMimeType).toBe("application/json")
|
|
101
|
+
expect(request.config.responseJsonSchema).toEqual(jsonSchema.json_schema.schema)
|
|
102
|
+
})
|
|
103
|
+
|
|
104
|
+
test("GoogleAi prompt executes tools and sends function response back to model", async () => {
|
|
105
|
+
const requests: any[] = []
|
|
106
|
+
const weatherTool: IPromptTool = {
|
|
107
|
+
name: "get_weather",
|
|
108
|
+
description: "Get weather for a city",
|
|
109
|
+
parameters: {
|
|
110
|
+
type: "object",
|
|
111
|
+
properties: {
|
|
112
|
+
city: {type: "string"}
|
|
113
|
+
},
|
|
114
|
+
required: ["city"],
|
|
115
|
+
additionalProperties: false
|
|
116
|
+
},
|
|
117
|
+
execute: async ({city}) => ({city, temperature: 21})
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
class MockedGoogleAiProvider extends GoogleAiProvider {
|
|
121
|
+
constructor() {
|
|
122
|
+
super("test-key", "gemini-2.5-flash")
|
|
123
|
+
this._client = {
|
|
124
|
+
models: {
|
|
125
|
+
generateContent: async (payload: any) => {
|
|
126
|
+
requests.push(payload)
|
|
127
|
+
|
|
128
|
+
if(requests.length === 1){
|
|
129
|
+
return {
|
|
130
|
+
functionCalls: [{
|
|
131
|
+
id: "call_123",
|
|
132
|
+
name: "get_weather",
|
|
133
|
+
args: {city: "Buenos Aires"}
|
|
134
|
+
}],
|
|
135
|
+
candidates: [{
|
|
136
|
+
content: {
|
|
137
|
+
role: "model",
|
|
138
|
+
parts: [{
|
|
139
|
+
functionCall: {
|
|
140
|
+
id: "call_123",
|
|
141
|
+
name: "get_weather",
|
|
142
|
+
args: {city: "Buenos Aires"}
|
|
143
|
+
}
|
|
144
|
+
}]
|
|
145
|
+
}
|
|
146
|
+
}],
|
|
147
|
+
usageMetadata: {
|
|
148
|
+
totalTokenCount: 15,
|
|
149
|
+
promptTokenCount: 10,
|
|
150
|
+
candidatesTokenCount: 5,
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
return {
|
|
156
|
+
text: "21 grados",
|
|
157
|
+
usageMetadata: {
|
|
158
|
+
totalTokenCount: 9,
|
|
159
|
+
promptTokenCount: 7,
|
|
160
|
+
candidatesTokenCount: 2,
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
} as any
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
const googleAi = new MockedGoogleAiProvider()
|
|
170
|
+
const r = await googleAi.prompt({
|
|
171
|
+
systemPrompt: "You are an AI assistant.",
|
|
172
|
+
userInput: "How is the weather in Buenos Aires?",
|
|
173
|
+
tools: [weatherTool]
|
|
174
|
+
})
|
|
175
|
+
|
|
176
|
+
expect(r.output).toBe("21 grados")
|
|
177
|
+
expect(r.tokens).toBe(24)
|
|
178
|
+
expect(requests[0].config.tools).toEqual([{
|
|
179
|
+
functionDeclarations: [{
|
|
180
|
+
name: "get_weather",
|
|
181
|
+
description: "Get weather for a city",
|
|
182
|
+
parametersJsonSchema: weatherTool.parameters,
|
|
183
|
+
}]
|
|
184
|
+
}])
|
|
185
|
+
expect(requests[1].contents[2]).toEqual({
|
|
186
|
+
role: "user",
|
|
187
|
+
parts: [{
|
|
188
|
+
functionResponse: {
|
|
189
|
+
id: "call_123",
|
|
190
|
+
name: "get_weather",
|
|
191
|
+
response: {
|
|
192
|
+
output: {
|
|
193
|
+
city: "Buenos Aires",
|
|
194
|
+
temperature: 21,
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
}]
|
|
199
|
+
})
|
|
200
|
+
})
|
|
201
|
+
|
|
202
|
+
test("AiProviderFactory supports GoogleAi option", () => {
|
|
203
|
+
process.env.GOOGLE_AI_API_KEY = "test-key"
|
|
204
|
+
process.env.GOOGLE_AI_MODEL = "gemini-2.5-flash"
|
|
205
|
+
process.env.GOOGLE_AI_VISION_MODEL = "gemini-2.5-flash"
|
|
206
|
+
process.env.DRAX_DB_ENGINE = "mongo"
|
|
207
|
+
|
|
208
|
+
const googleAi = AiProviderFactory.instance("GoogleAi")
|
|
209
|
+
expect(googleAi).toBeInstanceOf(GoogleAiProvider)
|
|
210
|
+
})
|
|
211
|
+
})
|