neural-ai-sdk 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,4 +6,8 @@ export declare class DeepSeekModel extends BaseModel {
6
6
  constructor(config: AIModelConfig);
7
7
  generate(request: AIModelRequest): Promise<AIModelResponse>;
8
8
  stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
9
+ /**
10
+ * Process function calls from DeepSeek API response
11
+ */
12
+ private processFunctionCalls;
9
13
  }
@@ -28,19 +28,39 @@ class DeepSeekModel extends base_model_1.BaseModel {
28
28
  role: "user",
29
29
  content: request.prompt,
30
30
  });
31
- const response = await axios_1.default.post(`${this.baseURL}/chat/completions`, {
31
+ // Prepare request payload
32
+ const payload = {
32
33
  model: config.model || "deepseek-chat",
33
34
  messages,
34
35
  temperature: config.temperature,
35
36
  max_tokens: config.maxTokens,
36
37
  top_p: config.topP,
37
- }, {
38
+ };
39
+ // Add function calling support if functions are provided
40
+ if (request.functions && request.functions.length > 0) {
41
+ payload.functions = request.functions;
42
+ // Handle function call configuration
43
+ if (request.functionCall) {
44
+ if (request.functionCall === "auto") {
45
+ payload.function_call = "auto";
46
+ }
47
+ else if (request.functionCall === "none") {
48
+ payload.function_call = "none";
49
+ }
50
+ else if (typeof request.functionCall === "object") {
51
+ payload.function_call = { name: request.functionCall.name };
52
+ }
53
+ }
54
+ }
55
+ const response = await axios_1.default.post(`${this.baseURL}/chat/completions`, payload, {
38
56
  headers: {
39
57
  "Content-Type": "application/json",
40
58
  Authorization: `Bearer ${config.apiKey ||
41
59
  (0, utils_1.getApiKey)(config.apiKey, "DEEPSEEK_API_KEY", "DeepSeek")}`,
42
60
  },
43
61
  });
62
+ // Process function calls if any
63
+ const functionCalls = this.processFunctionCalls(response.data);
44
64
  return {
45
65
  text: response.data.choices[0].message.content,
46
66
  usage: {
@@ -48,6 +68,7 @@ class DeepSeekModel extends base_model_1.BaseModel {
48
68
  completionTokens: response.data.usage?.completion_tokens,
49
69
  totalTokens: response.data.usage?.total_tokens,
50
70
  },
71
+ functionCalls,
51
72
  raw: response.data,
52
73
  };
53
74
  }
@@ -64,14 +85,32 @@ class DeepSeekModel extends base_model_1.BaseModel {
64
85
  role: "user",
65
86
  content: request.prompt,
66
87
  });
67
- const response = await axios_1.default.post(`${this.baseURL}/chat/completions`, {
88
+ // Prepare request payload
89
+ const payload = {
68
90
  model: config.model || "deepseek-chat",
69
91
  messages,
70
92
  temperature: config.temperature,
71
93
  max_tokens: config.maxTokens,
72
94
  top_p: config.topP,
73
95
  stream: true,
74
- }, {
96
+ };
97
+ // Add function calling support if functions are provided
98
+ if (request.functions && request.functions.length > 0) {
99
+ payload.functions = request.functions;
100
+ // Handle function call configuration
101
+ if (request.functionCall) {
102
+ if (request.functionCall === "auto") {
103
+ payload.function_call = "auto";
104
+ }
105
+ else if (request.functionCall === "none") {
106
+ payload.function_call = "none";
107
+ }
108
+ else if (typeof request.functionCall === "object") {
109
+ payload.function_call = { name: request.functionCall.name };
110
+ }
111
+ }
112
+ }
113
+ const response = await axios_1.default.post(`${this.baseURL}/chat/completions`, payload, {
75
114
  headers: {
76
115
  "Content-Type": "application/json",
77
116
  Authorization: `Bearer ${config.apiKey ||
@@ -101,5 +140,20 @@ class DeepSeekModel extends base_model_1.BaseModel {
101
140
  }
102
141
  }
103
142
  }
143
+ /**
144
+ * Process function calls from DeepSeek API response
145
+ */
146
+ processFunctionCalls(response) {
147
+ if (!response.choices?.[0]?.message?.function_call) {
148
+ return undefined;
149
+ }
150
+ const functionCall = response.choices[0].message.function_call;
151
+ return [
152
+ {
153
+ name: functionCall.name,
154
+ arguments: functionCall.arguments,
155
+ },
156
+ ];
157
+ }
104
158
  }
105
159
  exports.DeepSeekModel = DeepSeekModel;
@@ -6,6 +6,10 @@ export declare class GoogleModel extends BaseModel {
6
6
  constructor(config: AIModelConfig);
7
7
  generate(request: AIModelRequest): Promise<AIModelResponse>;
8
8
  stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
9
+ /**
10
+ * Extract function calls from text using regex patterns since we're using prompt engineering
11
+ */
12
+ private extractFunctionCallsFromText;
9
13
  /**
10
14
  * Format content for Google's Gemini API, handling both text and images
11
15
  */
@@ -15,14 +15,56 @@ class GoogleModel extends base_model_1.BaseModel {
15
15
  }
16
16
  async generate(request) {
17
17
  const config = this.mergeConfig(request.options);
18
- const model = this.client.getGenerativeModel({
19
- model: config.model || "gemini-2.0-flash", // Updated default model
18
+ // Create base model configuration with Gemini 2.0 models
19
+ const modelConfig = {
20
+ model: config.model || "gemini-2.0-flash", // Using 2.0 models as default
20
21
  generationConfig: {
21
22
  temperature: config.temperature,
22
23
  maxOutputTokens: config.maxTokens,
23
24
  topP: config.topP,
24
25
  },
25
- });
26
+ };
27
+ const model = this.client.getGenerativeModel(modelConfig);
28
+ // Handle function calling through prompt engineering for Gemini 2.0 models
29
+ // as native function calling may not be fully supported in the same way
30
+ if (request.functions && request.functions.length > 0) {
31
+ try {
32
+ // Create an enhanced prompt with function definitions
33
+ let enhancedPrompt = request.prompt || "";
34
+ if (request.systemPrompt) {
35
+ enhancedPrompt = `${request.systemPrompt}\n\n${enhancedPrompt}`;
36
+ }
37
+ // Add function definitions to the prompt
38
+ enhancedPrompt += `\n\nYou have access to the following functions:\n\`\`\`json\n${JSON.stringify(request.functions, null, 2)}\n\`\`\`\n\n`;
39
+ // Add specific instructions based on function call mode
40
+ if (typeof request.functionCall === "object") {
41
+ enhancedPrompt += `You MUST use the function: ${request.functionCall.name}\n`;
42
+ enhancedPrompt += `Format your response as a function call using JSON in this exact format:\n`;
43
+ enhancedPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
44
+ enhancedPrompt += `Don't include any explanations, just output the function call.\n`;
45
+ }
46
+ else if (request.functionCall === "auto") {
47
+ enhancedPrompt += `If appropriate for the request, call one of these functions.\n`;
48
+ enhancedPrompt += `Format your response as a function call using JSON in this exact format:\n`;
49
+ enhancedPrompt += `{"name": "functionName", "arguments": {...}}\n`;
50
+ }
51
+ const result = await model.generateContent(enhancedPrompt);
52
+ const response = result.response;
53
+ const text = response.text();
54
+ // Extract function calls from the response text
55
+ const functionCalls = this.extractFunctionCallsFromText(text);
56
+ return {
57
+ text,
58
+ functionCalls,
59
+ raw: response,
60
+ };
61
+ }
62
+ catch (error) {
63
+ console.warn("Function calling with prompt engineering failed for Gemini, falling back to text-only", error);
64
+ // Fall back to regular text generation if prompt engineering fails
65
+ }
66
+ }
67
+ // Regular text generation without function calling
26
68
  const content = await this.formatMultiModalContent(request);
27
69
  const result = await model.generateContent(content);
28
70
  const response = result.response;
@@ -32,9 +74,11 @@ class GoogleModel extends base_model_1.BaseModel {
32
74
  };
33
75
  }
34
76
  async *stream(request) {
77
+ // For streaming, we'll keep it simpler as function calling with streaming
78
+ // has additional complexities
35
79
  const config = this.mergeConfig(request.options);
36
80
  const model = this.client.getGenerativeModel({
37
- model: config.model || "gemini-2.0-flash", // Updated default model
81
+ model: config.model || "gemini-2.0-flash", // Using 2.0 models as default
38
82
  generationConfig: {
39
83
  temperature: config.temperature,
40
84
  maxOutputTokens: config.maxTokens,
@@ -50,6 +94,96 @@ class GoogleModel extends base_model_1.BaseModel {
50
94
  }
51
95
  }
52
96
  }
97
+ /**
98
+ * Extract function calls from text using regex patterns since we're using prompt engineering
99
+ */
100
+ extractFunctionCallsFromText(text) {
101
+ try {
102
+ if (!text)
103
+ return undefined;
104
+ const functionCalls = [];
105
+ // Pattern 1: JSON object with name and arguments
106
+ // Try full regex match first
107
+ const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*(\{.*?\})\s*\}/gs;
108
+ let match;
109
+ while ((match = jsonRegex.exec(text)) !== null) {
110
+ try {
111
+ const name = match[1];
112
+ const args = match[2];
113
+ functionCalls.push({
114
+ name,
115
+ arguments: args,
116
+ });
117
+ }
118
+ catch (e) {
119
+ console.warn("Error parsing function call:", e);
120
+ }
121
+ }
122
+ // Pattern 2: Markdown code blocks that might contain JSON
123
+ const markdownRegex = /```(?:json)?\s*\n\s*(\{[\s\S]*?\})\s*\n```/gs;
124
+ while ((match = markdownRegex.exec(text)) !== null) {
125
+ try {
126
+ const jsonBlock = match[1].trim();
127
+ // Try to parse the JSON
128
+ const jsonObj = JSON.parse(jsonBlock);
129
+ if (jsonObj.name && (jsonObj.arguments || jsonObj.args)) {
130
+ functionCalls.push({
131
+ name: jsonObj.name,
132
+ arguments: JSON.stringify(jsonObj.arguments || jsonObj.args),
133
+ });
134
+ }
135
+ }
136
+ catch (e) {
137
+ // JSON might be malformed, try more aggressive parsing
138
+ const nameMatch = match[1].match(/"name"\s*:\s*"([^"]+)"/);
139
+ const argsMatch = match[1].match(/"arguments"\s*:\s*(\{[^}]*\})/);
140
+ if (nameMatch && argsMatch) {
141
+ try {
142
+ // Try to fix and parse the arguments
143
+ const argumentsStr = argsMatch[1].replace(/,\s*$/, "");
144
+ const fixedArgs = argumentsStr + (argumentsStr.endsWith("}") ? "" : "}");
145
+ functionCalls.push({
146
+ name: nameMatch[1],
147
+ arguments: fixedArgs,
148
+ });
149
+ }
150
+ catch (e) {
151
+ console.warn("Failed to fix JSON format:", e);
152
+ }
153
+ }
154
+ }
155
+ }
156
+ // Pattern 3: If still no matches, try looser regex
157
+ if (functionCalls.length === 0) {
158
+ // Extract function name and args separately with more permissive patterns
159
+ const nameMatch = text.match(/"name"\s*:\s*"([^"]+)"/);
160
+ if (nameMatch) {
161
+ const name = nameMatch[1];
162
+ // Find arguments block, accounting for potential closing bracket issues
163
+ const argsRegex = /"arguments"\s*:\s*(\{[^{]*?(?:\}|$))/;
164
+ const argsMatch = text.match(argsRegex);
165
+ if (argsMatch) {
166
+ let args = argsMatch[1].trim();
167
+ // Fix common JSON formatting issues
168
+ if (!args.endsWith("}")) {
169
+ args += "}";
170
+ }
171
+ // Clean up trailing commas which cause JSON.parse to fail
172
+ args = args.replace(/,\s*}/g, "}");
173
+ functionCalls.push({
174
+ name,
175
+ arguments: args,
176
+ });
177
+ }
178
+ }
179
+ }
180
+ return functionCalls.length > 0 ? functionCalls : undefined;
181
+ }
182
+ catch (error) {
183
+ console.warn("Error extracting function calls from text:", error);
184
+ return undefined;
185
+ }
186
+ }
53
187
  /**
54
188
  * Format content for Google's Gemini API, handling both text and images
55
189
  */
@@ -13,6 +13,11 @@ export declare class HuggingFaceModel extends BaseModel {
13
13
  * Generate a response using multimodal inputs (text + images)
14
14
  */
15
15
  private generateWithImages;
16
+ /**
17
+ * Extract function calls from text using pattern matching
18
+ * This is more robust than the previous implementation and handles various formats
19
+ */
20
+ private extractFunctionCallsFromText;
16
21
  /**
17
22
  * Try generating with nested inputs format (common in newer models)
18
23
  */
@@ -22,9 +27,9 @@ export declare class HuggingFaceModel extends BaseModel {
22
27
  */
23
28
  private generateWithFlatFormat;
24
29
  /**
25
- * Helper to parse HuggingFace response in various formats
30
+ * Helper to parse HuggingFace response with function call extraction
26
31
  */
27
- private parseResponse;
32
+ private parseResponseWithFunctionCalls;
28
33
  /**
29
34
  * Fallback method that uses multipart/form-data for older HuggingFace models
30
35
  */
@@ -18,7 +18,8 @@ class HuggingFaceModel extends base_model_1.BaseModel {
18
18
  }
19
19
  async generate(request) {
20
20
  const config = this.mergeConfig(request.options);
21
- const model = config.model || "meta-llama/Llama-2-7b-chat-hf";
21
+ // Use a more accessible default model that doesn't require special permissions
22
+ const model = config.model || "mistralai/Mistral-7B-Instruct-v0.2";
22
23
  try {
23
24
  // Try multimodal approach if images are present
24
25
  if (request.image ||
@@ -52,6 +53,10 @@ class HuggingFaceModel extends base_model_1.BaseModel {
52
53
  " Try a different vision-capable model like 'llava-hf/llava-1.5-7b-hf' or check HuggingFace's documentation for this specific model.";
53
54
  throw new Error(errorMessage);
54
55
  }
56
+ else if (error.response?.status === 403) {
57
+ // Handle permission errors more specifically
58
+ throw new Error(`Permission denied for model "${model}". Try using a different model with public access. Error: ${error.response?.data || error.message}`);
59
+ }
55
60
  throw error;
56
61
  }
57
62
  }
@@ -63,37 +68,62 @@ class HuggingFaceModel extends base_model_1.BaseModel {
63
68
  if (request.systemPrompt) {
64
69
  fullPrompt = `${request.systemPrompt}\n\n${fullPrompt}`;
65
70
  }
71
+ // If functions are provided, enhance the prompt to handle function calling
72
+ if (request.functions && request.functions.length > 0) {
73
+ // Format the functions in a way the model can understand
74
+ fullPrompt += `\n\nAVAILABLE FUNCTIONS:\n${JSON.stringify(request.functions, null, 2)}\n\n`;
75
+ // Add guidance based on function call setting
76
+ if (typeof request.functionCall === "object") {
77
+ fullPrompt += `You must call the function: ${request.functionCall.name}.\n`;
78
+ fullPrompt += `Format your answer as a function call using JSON, like this:\n`;
79
+ fullPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
80
+ fullPrompt += `Don't include any explanations, just output the function call.\n`;
81
+ }
82
+ else if (request.functionCall === "auto") {
83
+ fullPrompt += `Call one of the available functions if appropriate. Format the function call as JSON, like this:\n`;
84
+ fullPrompt += `{"name": "functionName", "arguments": {...}}\n`;
85
+ }
86
+ }
66
87
  const payload = {
67
88
  inputs: fullPrompt,
68
89
  parameters: {
69
- temperature: config.temperature,
70
- max_new_tokens: config.maxTokens,
71
- top_p: config.topP,
90
+ temperature: config.temperature || 0.1,
91
+ max_new_tokens: config.maxTokens || 500,
92
+ top_p: config.topP || 0.9,
72
93
  return_full_text: false,
73
94
  },
74
95
  };
75
- const response = await axios_1.default.post(`${this.baseURL}/${model}`, payload, {
76
- headers: {
77
- Authorization: `Bearer ${config.apiKey ||
78
- (0, utils_1.getApiKey)(config.apiKey, "HUGGINGFACE_API_KEY", "HuggingFace")}`,
79
- "Content-Type": "application/json",
80
- },
81
- });
82
- // HuggingFace can return different formats depending on the model
83
- let text = "";
84
- if (Array.isArray(response.data)) {
85
- text = response.data[0]?.generated_text || "";
86
- }
87
- else if (response.data.generated_text) {
88
- text = response.data.generated_text;
96
+ try {
97
+ const response = await axios_1.default.post(`${this.baseURL}/${model}`, payload, {
98
+ headers: {
99
+ Authorization: `Bearer ${config.apiKey ||
100
+ (0, utils_1.getApiKey)(config.apiKey, "HUGGINGFACE_API_KEY", "HuggingFace")}`,
101
+ "Content-Type": "application/json",
102
+ },
103
+ });
104
+ // Parse the response
105
+ let text = "";
106
+ if (Array.isArray(response.data)) {
107
+ text = response.data[0]?.generated_text || "";
108
+ }
109
+ else if (response.data.generated_text) {
110
+ text = response.data.generated_text;
111
+ }
112
+ else {
113
+ text = JSON.stringify(response.data);
114
+ }
115
+ // Extract function calls from the response
116
+ const functionCalls = this.extractFunctionCallsFromText(text);
117
+ return {
118
+ text,
119
+ functionCalls,
120
+ raw: response.data,
121
+ };
89
122
  }
90
- else {
91
- text = JSON.stringify(response.data);
123
+ catch (error) {
124
+ console.error("Error generating with HuggingFace model:", error);
125
+ throw error;
92
126
  }
93
- return {
94
- text,
95
- raw: response.data,
96
- };
97
127
  }
98
128
  /**
99
129
  * Generate a response using multimodal inputs (text + images)
@@ -128,6 +158,86 @@ class HuggingFaceModel extends base_model_1.BaseModel {
128
158
  .join("; ")}`;
129
159
  throw new Error(errorMessage);
130
160
  }
161
+ /**
162
+ * Extract function calls from text using pattern matching
163
+ * This is more robust than the previous implementation and handles various formats
164
+ */
165
+ extractFunctionCallsFromText(text) {
166
+ if (!text)
167
+ return undefined;
168
+ try {
169
+ const functionCalls = [];
170
+ // Pattern 1: Clean JSON function call format
171
+ // Example: {"name": "functionName", "arguments": {...}}
172
+ const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*([\s\S]*?)\}/g;
173
+ let match;
174
+ while ((match = jsonRegex.exec(text)) !== null) {
175
+ try {
176
+ // Try to parse the arguments part as JSON
177
+ const name = match[1];
178
+ let args = match[2].trim();
179
+ // Check if args is already a valid JSON string
180
+ try {
181
+ JSON.parse(args);
182
+ functionCalls.push({
183
+ name,
184
+ arguments: args,
185
+ });
186
+ }
187
+ catch (e) {
188
+ // If not valid JSON, try to extract the JSON object
189
+ const argsMatch = args.match(/\{[\s\S]*\}/);
190
+ if (argsMatch) {
191
+ functionCalls.push({
192
+ name,
193
+ arguments: argsMatch[0],
194
+ });
195
+ }
196
+ else {
197
+ functionCalls.push({
198
+ name,
199
+ arguments: "{}",
200
+ });
201
+ }
202
+ }
203
+ }
204
+ catch (e) {
205
+ console.warn("Error parsing function call:", e);
206
+ }
207
+ }
208
+ // Pattern 2: Function-like syntax
209
+ // Example: functionName({param1: "value", param2: 123})
210
+ const functionRegex = /([a-zA-Z0-9_]+)\s*\(\s*(\{[\s\S]*?\})\s*\)/g;
211
+ while ((match = functionRegex.exec(text)) !== null) {
212
+ functionCalls.push({
213
+ name: match[1],
214
+ arguments: match[2],
215
+ });
216
+ }
217
+ // Pattern 3: Markdown code block with JSON
218
+ // Example: ```json\n{"name": "functionName", "arguments": {...}}\n```
219
+ const markdownRegex = /```(?:json)?\s*\n\s*(\{[\s\S]*?\})\s*\n```/g;
220
+ while ((match = markdownRegex.exec(text)) !== null) {
221
+ try {
222
+ const jsonObj = JSON.parse(match[1]);
223
+ if (jsonObj.name && (jsonObj.arguments || jsonObj.args)) {
224
+ functionCalls.push({
225
+ name: jsonObj.name,
226
+ arguments: JSON.stringify(jsonObj.arguments || jsonObj.args),
227
+ });
228
+ }
229
+ }
230
+ catch (e) {
231
+ // Ignore parse errors for markdown blocks
232
+ }
233
+ }
234
+ return functionCalls.length > 0 ? functionCalls : undefined;
235
+ }
236
+ catch (e) {
237
+ console.warn("Error in extractFunctionCallsFromText:", e);
238
+ return undefined;
239
+ }
240
+ }
131
241
  /**
132
242
  * Try generating with nested inputs format (common in newer models)
133
243
  */
@@ -135,14 +245,26 @@ class HuggingFaceModel extends base_model_1.BaseModel {
135
245
  const prompt = request.systemPrompt
136
246
  ? `${request.systemPrompt}\n\n${request.prompt}`
137
247
  : request.prompt;
248
+ // Handle function calling by adding function definitions to the prompt
249
+ let enhancedPrompt = prompt;
250
+ if (request.functions && request.functions.length > 0) {
251
+ const functionText = JSON.stringify({ functions: request.functions }, null, 2);
252
+ enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
253
+ if (typeof request.functionCall === "object") {
254
+ enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
255
+ }
256
+ else if (request.functionCall === "auto") {
257
+ enhancedPrompt += "Call the appropriate function if needed.\n\n";
258
+ }
259
+ }
138
260
  let payload = {
139
261
  inputs: {
140
- text: prompt,
262
+ text: enhancedPrompt,
141
263
  },
142
264
  parameters: {
143
- temperature: config.temperature,
144
- max_new_tokens: config.maxTokens,
145
- top_p: config.topP,
265
+ temperature: config.temperature || 0.1,
266
+ max_new_tokens: config.maxTokens || 500,
267
+ top_p: config.topP || 0.9,
146
268
  return_full_text: false,
147
269
  },
148
270
  };
@@ -178,8 +300,8 @@ class HuggingFaceModel extends base_model_1.BaseModel {
178
300
  "Content-Type": "application/json",
179
301
  },
180
302
  });
181
- // Parse response
182
- return this.parseResponse(response);
303
+ // Parse response with function call extraction
304
+ return this.parseResponseWithFunctionCalls(response);
183
305
  }
184
306
  /**
185
307
  * Try generating with flat inputs format (common in some models)
@@ -188,13 +310,25 @@ class HuggingFaceModel extends base_model_1.BaseModel {
188
310
  const prompt = request.systemPrompt
189
311
  ? `${request.systemPrompt}\n\n${request.prompt}`
190
312
  : request.prompt;
313
+ // Handle function calling by adding function definitions to the prompt
314
+ let enhancedPrompt = prompt;
315
+ if (request.functions && request.functions.length > 0) {
316
+ const functionText = JSON.stringify({ functions: request.functions }, null, 2);
317
+ enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
318
+ if (typeof request.functionCall === "object") {
319
+ enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
320
+ }
321
+ else if (request.functionCall === "auto") {
322
+ enhancedPrompt += "Call the appropriate function if needed.\n\n";
323
+ }
324
+ }
191
325
  // Some models expect a flat structure with inputs as a string
192
326
  let payload = {
193
- inputs: prompt,
327
+ inputs: enhancedPrompt,
194
328
  parameters: {
195
- temperature: config.temperature,
196
- max_new_tokens: config.maxTokens,
197
- top_p: config.topP,
329
+ temperature: config.temperature || 0.1,
330
+ max_new_tokens: config.maxTokens || 500,
331
+ top_p: config.topP || 0.9,
198
332
  return_full_text: false,
199
333
  },
200
334
  };
@@ -218,13 +352,13 @@ class HuggingFaceModel extends base_model_1.BaseModel {
218
352
  "Content-Type": "application/json",
219
353
  },
220
354
  });
221
- // Parse response
222
- return this.parseResponse(response);
355
+ // Parse response with function call extraction
356
+ return this.parseResponseWithFunctionCalls(response);
223
357
  }
224
358
  /**
225
- * Helper to parse HuggingFace response in various formats
359
+ * Helper to parse HuggingFace response with function call extraction
226
360
  */
227
- parseResponse(response) {
361
+ parseResponseWithFunctionCalls(response) {
228
362
  let text = "";
229
363
  if (Array.isArray(response.data)) {
230
364
  text = response.data[0]?.generated_text || "";
@@ -238,8 +372,11 @@ class HuggingFaceModel extends base_model_1.BaseModel {
238
372
  else {
239
373
  text = JSON.stringify(response.data);
240
374
  }
375
+ // Extract function calls from the response text
376
+ const functionCalls = this.extractFunctionCallsFromText(text);
241
377
  return {
242
378
  text,
379
+ functionCalls,
243
380
  raw: response.data,
244
381
  };
245
382
  }
@@ -249,11 +386,23 @@ class HuggingFaceModel extends base_model_1.BaseModel {
249
386
  async generateWithMultipartForm(request, config, model) {
250
387
  // Create a multipart form-data payload for multimodal models
251
388
  const formData = new FormData();
252
- // Add text prompt
389
+ // Add text prompt with function definitions
253
390
  const prompt = request.systemPrompt
254
391
  ? `${request.systemPrompt}\n\n${request.prompt}`
255
392
  : request.prompt;
256
- formData.append("text", prompt);
393
+ // Handle function calling by adding function definitions to the prompt
394
+ let enhancedPrompt = prompt;
395
+ if (request.functions && request.functions.length > 0) {
396
+ const functionText = JSON.stringify({ functions: request.functions }, null, 2);
397
+ enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
398
+ if (typeof request.functionCall === "object") {
399
+ enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
400
+ }
401
+ else if (request.functionCall === "auto") {
402
+ enhancedPrompt += "Call the appropriate function if needed.\n\n";
403
+ }
404
+ }
405
+ formData.append("text", enhancedPrompt);
257
406
  // Process the convenience 'image' property
258
407
  if (request.image) {
259
408
  const { base64 } = await (0, image_utils_1.processImage)(request.image);
@@ -304,8 +453,11 @@ class HuggingFaceModel extends base_model_1.BaseModel {
304
453
  else {
305
454
  text = JSON.stringify(response.data);
306
455
  }
456
+ // Extract function calls from the response text
457
+ const functionCalls = this.extractFunctionCallsFromText(text);
307
458
  return {
308
459
  text,
460
+ functionCalls,
309
461
  raw: response.data,
310
462
  };
311
463
  }
@@ -6,6 +6,10 @@ export declare class OllamaModel extends BaseModel {
6
6
  constructor(config: AIModelConfig);
7
7
  generate(request: AIModelRequest): Promise<AIModelResponse>;
8
8
  stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
9
+ /**
10
+ * Extract function calls from text using various patterns
11
+ */
12
+ private extractFunctionCallsFromText;
9
13
  /**
10
14
  * Creates the request payload for Ollama, handling multimodal content if provided
11
15
  */
@@ -18,8 +18,18 @@ class OllamaModel extends base_model_1.BaseModel {
18
18
  async generate(request) {
19
19
  const config = this.mergeConfig(request.options);
20
20
  try {
21
+ // Create and modify the payload for function calling
21
22
  const payload = await this.createRequestPayload(request, config);
22
- const response = await axios_1.default.post(`${this.baseURL}/generate`, payload);
23
+ const response = await axios_1.default.post(`${this.baseURL}/generate`, payload, {
24
+ headers: {
25
+ "Content-Type": "application/json",
26
+ },
27
+ });
28
+ if (!response.data || !response.data.response) {
29
+ throw new Error("Invalid response from Ollama API");
30
+ }
31
+ // Try to extract function calls from the response
32
+ const functionCalls = this.extractFunctionCallsFromText(response.data.response, request);
23
33
  return {
24
34
  text: response.data.response,
25
35
  usage: {
@@ -27,6 +37,7 @@ class OllamaModel extends base_model_1.BaseModel {
27
37
  completionTokens: response.data.eval_count,
28
38
  totalTokens: response.data.prompt_eval_count + response.data.eval_count,
29
39
  },
40
+ functionCalls,
30
41
  raw: response.data,
31
42
  };
32
43
  }
@@ -48,6 +59,9 @@ class OllamaModel extends base_model_1.BaseModel {
48
59
  const payload = await this.createRequestPayload(request, config, true);
49
60
  const response = await axios_1.default.post(`${this.baseURL}/generate`, payload, {
50
61
  responseType: "stream",
62
+ headers: {
63
+ "Content-Type": "application/json",
64
+ },
51
65
  });
52
66
  const reader = response.data;
53
67
  for await (const chunk of reader) {
@@ -77,13 +91,105 @@ class OllamaModel extends base_model_1.BaseModel {
77
91
  throw error;
78
92
  }
79
93
  }
94
+ /**
95
+ * Extract function calls from text using various patterns
96
+ */
97
+ extractFunctionCallsFromText(text, currentRequest) {
98
+ if (!text)
99
+ return undefined;
100
+ try {
101
+ // Try multiple patterns for function calls
102
+ // Pattern 1: JSON format with name and arguments
103
+ // E.g., {"name": "getWeather", "arguments": {"location": "Tokyo"}}
104
+ const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*([\s\S]*?)\}/g;
105
+ const jsonMatches = [...text.matchAll(jsonRegex)];
106
+ if (jsonMatches.length > 0) {
107
+ return jsonMatches.map((match) => {
108
+ try {
109
+ // Try to parse the arguments as JSON
110
+ const argsText = match[2];
111
+ let args;
112
+ try {
113
+ args = JSON.parse(argsText);
114
+ return {
115
+ name: match[1],
116
+ arguments: JSON.stringify(args),
117
+ };
118
+ }
119
+ catch (e) {
120
+ // If parsing fails, use the raw text
121
+ return {
122
+ name: match[1],
123
+ arguments: argsText,
124
+ };
125
+ }
126
+ }
127
+ catch (e) {
128
+ console.warn("Error parsing function call:", e);
129
+ return {
130
+ name: match[1],
131
+ arguments: "{}",
132
+ };
133
+ }
134
+ });
135
+ }
136
+ // Pattern 2: Function call pattern: functionName({"key": "value"})
137
+ const functionRegex = /([a-zA-Z0-9_]+)\s*\(\s*(\{[\s\S]*?\})\s*\)/g;
138
+ const functionMatches = [...text.matchAll(functionRegex)];
139
+ if (functionMatches.length > 0) {
140
+ return functionMatches.map((match) => ({
141
+ name: match[1],
142
+ arguments: match[2],
143
+ }));
144
+ }
145
+ // Pattern 3: Look for more specific calculator patterns
146
+ if (currentRequest.functionCall &&
147
+ typeof currentRequest.functionCall === "object" &&
148
+ currentRequest.functionCall.name === "calculator") {
149
+ const calculatorRegex = /"?operation"?\s*:\s*"?([^",\s]+)"?,\s*"?a"?\s*:\s*(\d+),\s*"?b"?\s*:\s*(\d+)/;
150
+ const calculatorMatch = text.match(calculatorRegex);
151
+ if (calculatorMatch) {
152
+ const operation = calculatorMatch[1];
153
+ const a = parseInt(calculatorMatch[2]);
154
+ const b = parseInt(calculatorMatch[3]);
155
+ return [
156
+ {
157
+ name: "calculator",
158
+ arguments: JSON.stringify({ operation, a, b }),
159
+ },
160
+ ];
161
+ }
162
+ }
163
+ // Pattern 4: Look for more specific weather patterns
164
+ if (currentRequest.functionCall &&
165
+ typeof currentRequest.functionCall === "object" &&
166
+ currentRequest.functionCall.name === "getWeather") {
167
+ const weatherRegex = /"?location"?\s*:\s*"([^"]+)"(?:,\s*"?unit"?\s*:\s*"([^"]+)")?/;
168
+ const weatherMatch = text.match(weatherRegex);
169
+ if (weatherMatch) {
170
+ const location = weatherMatch[1];
171
+ const unit = weatherMatch[2] || "celsius";
172
+ return [
173
+ {
174
+ name: "getWeather",
175
+ arguments: JSON.stringify({ location, unit }),
176
+ },
177
+ ];
178
+ }
179
+ }
180
+ }
181
+ catch (e) {
182
+ console.warn("Error in extractFunctionCallsFromText:", e);
183
+ }
184
+ return undefined;
185
+ }
80
186
  /**
81
187
  * Creates the request payload for Ollama, handling multimodal content if provided
82
188
  */
83
189
  async createRequestPayload(request, config, isStream = false) {
84
190
  // Base payload
85
191
  const payload = {
86
- model: config.model || "llama2",
192
+ model: config.model || "llama3", // Updated default to a model that better supports function calling
87
193
  temperature: config.temperature,
88
194
  num_predict: config.maxTokens,
89
195
  top_p: config.topP,
@@ -92,10 +198,13 @@ class OllamaModel extends base_model_1.BaseModel {
92
198
  if (isStream) {
93
199
  payload.stream = true;
94
200
  }
95
- // If there are any image inputs, use the messages format
96
- if (request.image ||
97
- (request.content && request.content.some((item) => item.type === "image"))) {
98
- // Create a messages array for multimodal models (similar to OpenAI format)
201
+ // Check if we should use chat format (messages array) or text format
202
+ const useMessagesFormat = request.image ||
203
+ (request.content &&
204
+ request.content.some((item) => item.type === "image")) ||
205
+ (request.functions && request.functions.length > 0);
206
+ if (useMessagesFormat) {
207
+ // Modern message-based format for Ollama
99
208
  const messages = [];
100
209
  // Add system prompt if provided
101
210
  if (request.systemPrompt) {
@@ -104,27 +213,27 @@ class OllamaModel extends base_model_1.BaseModel {
104
213
  content: request.systemPrompt,
105
214
  });
106
215
  }
107
- // Create a user message with potentially multiple content parts
108
- const userMessage = { role: "user", content: [] };
109
- // Add the main prompt as text content
216
+ // Create user message content parts
217
+ let userContent = [];
218
+ // Add main text content
110
219
  if (request.prompt) {
111
- userMessage.content.push({
220
+ userContent.push({
112
221
  type: "text",
113
222
  text: request.prompt,
114
223
  });
115
224
  }
116
- // Process structured content if available
225
+ // Add structured content
117
226
  if (request.content) {
118
227
  for (const item of request.content) {
119
228
  if (item.type === "text") {
120
- userMessage.content.push({
229
+ userContent.push({
121
230
  type: "text",
122
231
  text: item.text,
123
232
  });
124
233
  }
125
234
  else if (item.type === "image") {
126
235
  const { base64, mimeType } = await (0, image_utils_1.processImage)(item.source);
127
- userMessage.content.push({
236
+ userContent.push({
128
237
  type: "image",
129
238
  image: {
130
239
  data: base64,
@@ -134,10 +243,10 @@ class OllamaModel extends base_model_1.BaseModel {
134
243
  }
135
244
  }
136
245
  }
137
- // Handle the convenience image property
246
+ // Add simple image if provided
138
247
  if (request.image) {
139
248
  const { base64, mimeType } = await (0, image_utils_1.processImage)(request.image);
140
- userMessage.content.push({
249
+ userContent.push({
141
250
  type: "image",
142
251
  image: {
143
252
  data: base64,
@@ -145,9 +254,46 @@ class OllamaModel extends base_model_1.BaseModel {
145
254
  },
146
255
  });
147
256
  }
148
- // Add the user message
257
+ // Create the user message
258
+ let userMessage = {
259
+ role: "user",
260
+ };
261
+ // If we have a single text content, use string format
262
+ if (userContent.length === 1 && userContent[0].type === "text") {
263
+ userMessage.content = userContent[0].text;
264
+ }
265
+ else {
266
+ userMessage.content = userContent;
267
+ }
149
268
  messages.push(userMessage);
150
- // Set the messages in the payload
269
+ // Add function calling data to system prompt
270
+ if (request.functions && request.functions.length > 0) {
271
+ // Create a system prompt for function calling
272
+ let functionSystemPrompt = request.systemPrompt || "";
273
+ // Add function definitions as JSON
274
+ functionSystemPrompt += `\n\nAvailable functions:\n\`\`\`json\n${JSON.stringify(request.functions, null, 2)}\n\`\`\`\n\n`;
275
+ // Add instruction based on functionCall setting
276
+ if (typeof request.functionCall === "object") {
277
+ functionSystemPrompt += `You must call the function: ${request.functionCall.name}.\n`;
278
+ functionSystemPrompt += `Format your response as a function call using this exact format:\n`;
279
+ functionSystemPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
280
+ }
281
+ else if (request.functionCall === "auto") {
282
+ functionSystemPrompt += `Call one of these functions if appropriate for the user's request.\n`;
283
+ functionSystemPrompt += `Format your response as a function call using this exact format:\n`;
284
+ functionSystemPrompt += `{"name": "functionName", "arguments": {...}}\n`;
285
+ }
286
+ // Replace or add the system message
287
+ if (messages.length > 0 && messages[0].role === "system") {
288
+ messages[0].content = functionSystemPrompt;
289
+ }
290
+ else {
291
+ messages.unshift({
292
+ role: "system",
293
+ content: functionSystemPrompt,
294
+ });
295
+ }
296
+ }
151
297
  payload.messages = messages;
152
298
  }
153
299
  else {
@@ -6,6 +6,14 @@ export declare class OpenAIModel extends BaseModel {
6
6
  constructor(config: AIModelConfig);
7
7
  generate(request: AIModelRequest): Promise<AIModelResponse>;
8
8
  stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
9
+ /**
10
+ * Prepare OpenAI function calling options
11
+ */
12
+ private prepareFunctionCalling;
13
+ /**
14
+ * Process function calls from OpenAI response
15
+ */
16
+ private processFunctionCalls;
9
17
  /**
10
18
  * Format messages for OpenAI API, including handling multimodal content
11
19
  */
@@ -20,13 +20,18 @@ class OpenAIModel extends base_model_1.BaseModel {
20
20
  const config = this.mergeConfig(request.options);
21
21
  // Process messages for OpenAI API
22
22
  const messages = await this.formatMessages(request);
23
+ // Prepare function calling if requested
24
+ const functionOptions = this.prepareFunctionCalling(request);
23
25
  const response = await this.client.chat.completions.create({
24
26
  model: config.model || "gpt-3.5-turbo",
25
27
  messages,
26
28
  temperature: config.temperature,
27
29
  max_tokens: config.maxTokens,
28
30
  top_p: config.topP,
31
+ ...functionOptions,
29
32
  });
33
+ // Process function calls if any are present
34
+ const functionCalls = this.processFunctionCalls(response);
30
35
  return {
31
36
  text: response.choices[0].message.content || "",
32
37
  usage: {
@@ -34,6 +39,7 @@ class OpenAIModel extends base_model_1.BaseModel {
34
39
  completionTokens: response.usage?.completion_tokens,
35
40
  totalTokens: response.usage?.total_tokens,
36
41
  },
42
+ functionCalls,
37
43
  raw: response,
38
44
  };
39
45
  }
@@ -41,14 +47,19 @@ class OpenAIModel extends base_model_1.BaseModel {
41
47
  const config = this.mergeConfig(request.options);
42
48
  // Process messages for OpenAI API
43
49
  const messages = await this.formatMessages(request);
44
- const stream = await this.client.chat.completions.create({
50
+ // Prepare function calling if requested
51
+ const functionOptions = this.prepareFunctionCalling(request);
52
+ const response = await this.client.chat.completions.create({
45
53
  model: config.model || "gpt-3.5-turbo",
46
54
  messages,
47
55
  temperature: config.temperature,
48
56
  max_tokens: config.maxTokens,
49
57
  top_p: config.topP,
50
58
  stream: true,
59
+ ...functionOptions,
51
60
  });
61
+ // Using a more compatible approach with Stream API
62
+ const stream = response;
52
63
  for await (const chunk of stream) {
53
64
  const content = chunk.choices[0]?.delta?.content || "";
54
65
  if (content) {
@@ -56,6 +67,58 @@ class OpenAIModel extends base_model_1.BaseModel {
56
67
  }
57
68
  }
58
69
  }
70
+ /**
71
+ * Prepare OpenAI function calling options
72
+ */
73
+ prepareFunctionCalling(request) {
74
+ if (!request.functions || request.functions.length === 0) {
75
+ return {};
76
+ }
77
+ // Transform our function definitions to OpenAI's format
78
+ const tools = request.functions.map((func) => ({
79
+ type: "function",
80
+ function: {
81
+ name: func.name,
82
+ description: func.description,
83
+ parameters: func.parameters,
84
+ },
85
+ }));
86
+ // Handle tool_choice (functionCall in our API)
87
+ let tool_choice = undefined;
88
+ if (request.functionCall) {
89
+ if (request.functionCall === "auto") {
90
+ tool_choice = "auto";
91
+ }
92
+ else if (request.functionCall === "none") {
93
+ tool_choice = "none";
94
+ }
95
+ else if (typeof request.functionCall === "object") {
96
+ tool_choice = {
97
+ type: "function",
98
+ function: { name: request.functionCall.name },
99
+ };
100
+ }
101
+ }
102
+ return {
103
+ tools,
104
+ tool_choice,
105
+ };
106
+ }
107
+ /**
108
+ * Process function calls from OpenAI response
109
+ */
110
+ processFunctionCalls(response) {
111
+ const toolCalls = response.choices[0]?.message?.tool_calls;
112
+ if (!toolCalls || toolCalls.length === 0) {
113
+ return undefined;
114
+ }
115
+ return toolCalls
116
+ .filter((call) => call.type === "function")
117
+ .map((call) => ({
118
+ name: call.function.name,
119
+ arguments: call.function.arguments,
120
+ }));
121
+ }
59
122
  /**
60
123
  * Format messages for OpenAI API, including handling multimodal content
61
124
  */
package/dist/types.d.ts CHANGED
@@ -23,6 +23,19 @@ export interface ImageContent {
23
23
  source: string | Buffer;
24
24
  }
25
25
  export type Content = TextContent | ImageContent;
26
+ export interface FunctionDefinition {
27
+ name: string;
28
+ description: string;
29
+ parameters: {
30
+ type: string;
31
+ properties?: Record<string, any>;
32
+ required?: string[];
33
+ };
34
+ }
35
+ export interface FunctionCall {
36
+ name: string;
37
+ arguments: string;
38
+ }
26
39
  export interface AIModelResponse {
27
40
  text: string;
28
41
  usage?: {
@@ -30,6 +43,7 @@ export interface AIModelResponse {
30
43
  completionTokens?: number;
31
44
  totalTokens?: number;
32
45
  };
46
+ functionCalls?: FunctionCall[];
33
47
  raw?: any;
34
48
  }
35
49
  export interface AIModelRequest {
@@ -38,6 +52,10 @@ export interface AIModelRequest {
38
52
  options?: Partial<AIModelConfig>;
39
53
  content?: Content[];
40
54
  image?: string | Buffer;
55
+ functions?: FunctionDefinition[];
56
+ functionCall?: "auto" | "none" | {
57
+ name: string;
58
+ };
41
59
  }
42
60
  export interface AIModel {
43
61
  provider: AIProvider;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "neural-ai-sdk",
3
- "version": "0.1.3",
3
+ "version": "0.1.4",
4
4
  "description": "Unified SDK for interacting with various AI LLM providers",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",