neural-ai-sdk 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +5 -8
- package/dist/index.js +3 -0
- package/dist/models/deepseek-model.d.ts +4 -0
- package/dist/models/deepseek-model.js +58 -4
- package/dist/models/google-model.d.ts +4 -0
- package/dist/models/google-model.js +138 -4
- package/dist/models/huggingface-model.d.ts +7 -2
- package/dist/models/huggingface-model.js +192 -40
- package/dist/models/ollama-model.d.ts +4 -0
- package/dist/models/ollama-model.js +163 -17
- package/dist/models/openai-model.d.ts +8 -0
- package/dist/models/openai-model.js +64 -1
- package/dist/types.d.ts +18 -0
- package/dist/utils.d.ts +5 -0
- package/dist/utils.js +22 -0
- package/package.json +2 -2
package/README.md
CHANGED
@@ -42,12 +42,13 @@ async function generateResponse() {
|
|
42
42
|
generateResponse();
|
43
43
|
```
|
44
44
|
|
45
|
-
### Environment Variables Support
|
45
|
+
### Automatic Environment Variables Support
|
46
46
|
|
47
|
-
|
47
|
+
The SDK automatically loads environment variables from `.env` files when imported, so you don't need to manually configure dotenv. Simply create a `.env` file in your project root, and the API keys will be automatically detected:
|
48
48
|
|
49
49
|
```typescript
|
50
|
-
// No need to provide API keys in code if they're set
|
50
|
+
// No need to provide API keys in code if they're set in .env files
|
51
|
+
// No need to manually call require('dotenv').config()
|
51
52
|
const openaiModel = NeuralAI.createModel(AIProvider.OPENAI, {
|
52
53
|
model: "gpt-4",
|
53
54
|
});
|
@@ -229,11 +230,7 @@ HUGGINGFACE_API_KEY=your_huggingface_key_here
|
|
229
230
|
OLLAMA_BASE_URL=http://localhost:11434/api
|
230
231
|
```
|
231
232
|
|
232
|
-
|
233
|
-
|
234
|
-
```javascript
|
235
|
-
require("dotenv").config();
|
236
|
-
```
|
233
|
+
The SDK automatically loads environment variables from `.env` files when imported, so you don't need to manually configure dotenv.
|
237
234
|
|
238
235
|
## Configuration Options
|
239
236
|
|
package/dist/index.js
CHANGED
@@ -22,6 +22,9 @@ const google_model_2 = require("./models/google-model");
|
|
22
22
|
const deepseek_model_2 = require("./models/deepseek-model");
|
23
23
|
const ollama_model_2 = require("./models/ollama-model");
|
24
24
|
const huggingface_model_2 = require("./models/huggingface-model");
|
25
|
+
const utils_1 = require("./utils");
|
26
|
+
// Automatically load environment variables when the module is imported
|
27
|
+
(0, utils_1.loadEnvVariables)();
|
25
28
|
class NeuralAI {
|
26
29
|
/**
|
27
30
|
* Create an AI model instance based on the provider and configuration
|
@@ -6,4 +6,8 @@ export declare class DeepSeekModel extends BaseModel {
|
|
6
6
|
constructor(config: AIModelConfig);
|
7
7
|
generate(request: AIModelRequest): Promise<AIModelResponse>;
|
8
8
|
stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
|
9
|
+
/**
|
10
|
+
* Process function calls from DeepSeek API response
|
11
|
+
*/
|
12
|
+
private processFunctionCalls;
|
9
13
|
}
|
@@ -28,19 +28,39 @@ class DeepSeekModel extends base_model_1.BaseModel {
|
|
28
28
|
role: "user",
|
29
29
|
content: request.prompt,
|
30
30
|
});
|
31
|
-
|
31
|
+
// Prepare request payload
|
32
|
+
const payload = {
|
32
33
|
model: config.model || "deepseek-chat",
|
33
34
|
messages,
|
34
35
|
temperature: config.temperature,
|
35
36
|
max_tokens: config.maxTokens,
|
36
37
|
top_p: config.topP,
|
37
|
-
}
|
38
|
+
};
|
39
|
+
// Add function calling support if functions are provided
|
40
|
+
if (request.functions && request.functions.length > 0) {
|
41
|
+
payload.functions = request.functions;
|
42
|
+
// Handle function call configuration
|
43
|
+
if (request.functionCall) {
|
44
|
+
if (request.functionCall === "auto") {
|
45
|
+
payload.function_call = "auto";
|
46
|
+
}
|
47
|
+
else if (request.functionCall === "none") {
|
48
|
+
payload.function_call = "none";
|
49
|
+
}
|
50
|
+
else if (typeof request.functionCall === "object") {
|
51
|
+
payload.function_call = { name: request.functionCall.name };
|
52
|
+
}
|
53
|
+
}
|
54
|
+
}
|
55
|
+
const response = await axios_1.default.post(`${this.baseURL}/chat/completions`, payload, {
|
38
56
|
headers: {
|
39
57
|
"Content-Type": "application/json",
|
40
58
|
Authorization: `Bearer ${config.apiKey ||
|
41
59
|
(0, utils_1.getApiKey)(config.apiKey, "DEEPSEEK_API_KEY", "DeepSeek")}`,
|
42
60
|
},
|
43
61
|
});
|
62
|
+
// Process function calls if any
|
63
|
+
const functionCalls = this.processFunctionCalls(response.data);
|
44
64
|
return {
|
45
65
|
text: response.data.choices[0].message.content,
|
46
66
|
usage: {
|
@@ -48,6 +68,7 @@ class DeepSeekModel extends base_model_1.BaseModel {
|
|
48
68
|
completionTokens: response.data.usage?.completion_tokens,
|
49
69
|
totalTokens: response.data.usage?.total_tokens,
|
50
70
|
},
|
71
|
+
functionCalls,
|
51
72
|
raw: response.data,
|
52
73
|
};
|
53
74
|
}
|
@@ -64,14 +85,32 @@ class DeepSeekModel extends base_model_1.BaseModel {
|
|
64
85
|
role: "user",
|
65
86
|
content: request.prompt,
|
66
87
|
});
|
67
|
-
|
88
|
+
// Prepare request payload
|
89
|
+
const payload = {
|
68
90
|
model: config.model || "deepseek-chat",
|
69
91
|
messages,
|
70
92
|
temperature: config.temperature,
|
71
93
|
max_tokens: config.maxTokens,
|
72
94
|
top_p: config.topP,
|
73
95
|
stream: true,
|
74
|
-
}
|
96
|
+
};
|
97
|
+
// Add function calling support if functions are provided
|
98
|
+
if (request.functions && request.functions.length > 0) {
|
99
|
+
payload.functions = request.functions;
|
100
|
+
// Handle function call configuration
|
101
|
+
if (request.functionCall) {
|
102
|
+
if (request.functionCall === "auto") {
|
103
|
+
payload.function_call = "auto";
|
104
|
+
}
|
105
|
+
else if (request.functionCall === "none") {
|
106
|
+
payload.function_call = "none";
|
107
|
+
}
|
108
|
+
else if (typeof request.functionCall === "object") {
|
109
|
+
payload.function_call = { name: request.functionCall.name };
|
110
|
+
}
|
111
|
+
}
|
112
|
+
}
|
113
|
+
const response = await axios_1.default.post(`${this.baseURL}/chat/completions`, payload, {
|
75
114
|
headers: {
|
76
115
|
"Content-Type": "application/json",
|
77
116
|
Authorization: `Bearer ${config.apiKey ||
|
@@ -101,5 +140,20 @@ class DeepSeekModel extends base_model_1.BaseModel {
|
|
101
140
|
}
|
102
141
|
}
|
103
142
|
}
|
143
|
+
/**
|
144
|
+
* Process function calls from DeepSeek API response
|
145
|
+
*/
|
146
|
+
processFunctionCalls(response) {
|
147
|
+
if (!response.choices?.[0]?.message?.function_call) {
|
148
|
+
return undefined;
|
149
|
+
}
|
150
|
+
const functionCall = response.choices[0].message.function_call;
|
151
|
+
return [
|
152
|
+
{
|
153
|
+
name: functionCall.name,
|
154
|
+
arguments: functionCall.arguments,
|
155
|
+
},
|
156
|
+
];
|
157
|
+
}
|
104
158
|
}
|
105
159
|
exports.DeepSeekModel = DeepSeekModel;
|
@@ -6,6 +6,10 @@ export declare class GoogleModel extends BaseModel {
|
|
6
6
|
constructor(config: AIModelConfig);
|
7
7
|
generate(request: AIModelRequest): Promise<AIModelResponse>;
|
8
8
|
stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
|
9
|
+
/**
|
10
|
+
* Extract function calls from text using regex patterns since we're using prompt engineering
|
11
|
+
*/
|
12
|
+
private extractFunctionCallsFromText;
|
9
13
|
/**
|
10
14
|
* Format content for Google's Gemini API, handling both text and images
|
11
15
|
*/
|
@@ -15,14 +15,56 @@ class GoogleModel extends base_model_1.BaseModel {
|
|
15
15
|
}
|
16
16
|
async generate(request) {
|
17
17
|
const config = this.mergeConfig(request.options);
|
18
|
-
|
19
|
-
|
18
|
+
// Create base model configuration with Gemini 2.0 models
|
19
|
+
const modelConfig = {
|
20
|
+
model: config.model || "gemini-2.0-flash", // Using 2.0 models as default
|
20
21
|
generationConfig: {
|
21
22
|
temperature: config.temperature,
|
22
23
|
maxOutputTokens: config.maxTokens,
|
23
24
|
topP: config.topP,
|
24
25
|
},
|
25
|
-
}
|
26
|
+
};
|
27
|
+
const model = this.client.getGenerativeModel(modelConfig);
|
28
|
+
// Handle function calling through prompt engineering for Gemini 2.0 models
|
29
|
+
// as native function calling may not be fully supported in the same way
|
30
|
+
if (request.functions && request.functions.length > 0) {
|
31
|
+
try {
|
32
|
+
// Create an enhanced prompt with function definitions
|
33
|
+
let enhancedPrompt = request.prompt || "";
|
34
|
+
if (request.systemPrompt) {
|
35
|
+
enhancedPrompt = `${request.systemPrompt}\n\n${enhancedPrompt}`;
|
36
|
+
}
|
37
|
+
// Add function definitions to the prompt
|
38
|
+
enhancedPrompt += `\n\nYou have access to the following functions:\n\`\`\`json\n${JSON.stringify(request.functions, null, 2)}\n\`\`\`\n\n`;
|
39
|
+
// Add specific instructions based on function call mode
|
40
|
+
if (typeof request.functionCall === "object") {
|
41
|
+
enhancedPrompt += `You MUST use the function: ${request.functionCall.name}\n`;
|
42
|
+
enhancedPrompt += `Format your response as a function call using JSON in this exact format:\n`;
|
43
|
+
enhancedPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
|
44
|
+
enhancedPrompt += `Don't include any explanations, just output the function call.\n`;
|
45
|
+
}
|
46
|
+
else if (request.functionCall === "auto") {
|
47
|
+
enhancedPrompt += `If appropriate for the request, call one of these functions.\n`;
|
48
|
+
enhancedPrompt += `Format your response as a function call using JSON in this exact format:\n`;
|
49
|
+
enhancedPrompt += `{"name": "functionName", "arguments": {...}}\n`;
|
50
|
+
}
|
51
|
+
const result = await model.generateContent(enhancedPrompt);
|
52
|
+
const response = result.response;
|
53
|
+
const text = response.text();
|
54
|
+
// Extract function calls from the response text
|
55
|
+
const functionCalls = this.extractFunctionCallsFromText(text);
|
56
|
+
return {
|
57
|
+
text,
|
58
|
+
functionCalls,
|
59
|
+
raw: response,
|
60
|
+
};
|
61
|
+
}
|
62
|
+
catch (error) {
|
63
|
+
console.warn("Function calling with prompt engineering failed for Gemini, falling back to text-only", error);
|
64
|
+
// Fall back to regular text generation if prompt engineering fails
|
65
|
+
}
|
66
|
+
}
|
67
|
+
// Regular text generation without function calling
|
26
68
|
const content = await this.formatMultiModalContent(request);
|
27
69
|
const result = await model.generateContent(content);
|
28
70
|
const response = result.response;
|
@@ -32,9 +74,11 @@ class GoogleModel extends base_model_1.BaseModel {
|
|
32
74
|
};
|
33
75
|
}
|
34
76
|
async *stream(request) {
|
77
|
+
// For streaming, we'll keep it simpler as function calling with streaming
|
78
|
+
// has additional complexities
|
35
79
|
const config = this.mergeConfig(request.options);
|
36
80
|
const model = this.client.getGenerativeModel({
|
37
|
-
model: config.model || "gemini-2.0-flash", //
|
81
|
+
model: config.model || "gemini-2.0-flash", // Using 2.0 models as default
|
38
82
|
generationConfig: {
|
39
83
|
temperature: config.temperature,
|
40
84
|
maxOutputTokens: config.maxTokens,
|
@@ -50,6 +94,96 @@ class GoogleModel extends base_model_1.BaseModel {
|
|
50
94
|
}
|
51
95
|
}
|
52
96
|
}
|
97
|
+
/**
|
98
|
+
* Extract function calls from text using regex patterns since we're using prompt engineering
|
99
|
+
*/
|
100
|
+
extractFunctionCallsFromText(text) {
|
101
|
+
try {
|
102
|
+
if (!text)
|
103
|
+
return undefined;
|
104
|
+
const functionCalls = [];
|
105
|
+
// Pattern 1: JSON object with name and arguments
|
106
|
+
// Try full regex match first
|
107
|
+
const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*(\{.*?\})\s*\}/gs;
|
108
|
+
let match;
|
109
|
+
while ((match = jsonRegex.exec(text)) !== null) {
|
110
|
+
try {
|
111
|
+
const name = match[1];
|
112
|
+
const args = match[2];
|
113
|
+
functionCalls.push({
|
114
|
+
name,
|
115
|
+
arguments: args,
|
116
|
+
});
|
117
|
+
}
|
118
|
+
catch (e) {
|
119
|
+
console.warn("Error parsing function call:", e);
|
120
|
+
}
|
121
|
+
}
|
122
|
+
// Pattern 2: Markdown code blocks that might contain JSON
|
123
|
+
const markdownRegex = /```(?:json)?\s*\n\s*(\{[\s\S]*?\})\s*\n```/gs;
|
124
|
+
while ((match = markdownRegex.exec(text)) !== null) {
|
125
|
+
try {
|
126
|
+
const jsonBlock = match[1].trim();
|
127
|
+
// Try to parse the JSON
|
128
|
+
const jsonObj = JSON.parse(jsonBlock);
|
129
|
+
if (jsonObj.name && (jsonObj.arguments || jsonObj.args)) {
|
130
|
+
functionCalls.push({
|
131
|
+
name: jsonObj.name,
|
132
|
+
arguments: JSON.stringify(jsonObj.arguments || jsonObj.args),
|
133
|
+
});
|
134
|
+
}
|
135
|
+
}
|
136
|
+
catch (e) {
|
137
|
+
// JSON might be malformed, try more aggressive parsing
|
138
|
+
const nameMatch = match[1].match(/"name"\s*:\s*"([^"]+)"/);
|
139
|
+
const argsMatch = match[1].match(/"arguments"\s*:\s*(\{[^}]*\})/);
|
140
|
+
if (nameMatch && argsMatch) {
|
141
|
+
try {
|
142
|
+
// Try to fix and parse the arguments
|
143
|
+
const argumentsStr = argsMatch[1].replace(/,\s*$/, "");
|
144
|
+
const fixedArgs = argumentsStr + (argumentsStr.endsWith("}") ? "" : "}");
|
145
|
+
functionCalls.push({
|
146
|
+
name: nameMatch[1],
|
147
|
+
arguments: fixedArgs,
|
148
|
+
});
|
149
|
+
}
|
150
|
+
catch (e) {
|
151
|
+
console.warn("Failed to fix JSON format:", e);
|
152
|
+
}
|
153
|
+
}
|
154
|
+
}
|
155
|
+
}
|
156
|
+
// Pattern 3: If still no matches, try looser regex
|
157
|
+
if (functionCalls.length === 0) {
|
158
|
+
// Extract function name and args separately with more permissive patterns
|
159
|
+
const nameMatch = text.match(/"name"\s*:\s*"([^"]+)"/);
|
160
|
+
if (nameMatch) {
|
161
|
+
const name = nameMatch[1];
|
162
|
+
// Find arguments block, accounting for potential closing bracket issues
|
163
|
+
const argsRegex = /"arguments"\s*:\s*(\{[^{]*?(?:\}|$))/;
|
164
|
+
const argsMatch = text.match(argsRegex);
|
165
|
+
if (argsMatch) {
|
166
|
+
let args = argsMatch[1].trim();
|
167
|
+
// Fix common JSON formatting issues
|
168
|
+
if (!args.endsWith("}")) {
|
169
|
+
args += "}";
|
170
|
+
}
|
171
|
+
// Clean up trailing commas which cause JSON.parse to fail
|
172
|
+
args = args.replace(/,\s*}/g, "}");
|
173
|
+
functionCalls.push({
|
174
|
+
name,
|
175
|
+
arguments: args,
|
176
|
+
});
|
177
|
+
}
|
178
|
+
}
|
179
|
+
}
|
180
|
+
return functionCalls.length > 0 ? functionCalls : undefined;
|
181
|
+
}
|
182
|
+
catch (error) {
|
183
|
+
console.warn("Error extracting function calls from text:", error);
|
184
|
+
return undefined;
|
185
|
+
}
|
186
|
+
}
|
53
187
|
/**
|
54
188
|
* Format content for Google's Gemini API, handling both text and images
|
55
189
|
*/
|
@@ -13,6 +13,11 @@ export declare class HuggingFaceModel extends BaseModel {
|
|
13
13
|
* Generate a response using multimodal inputs (text + images)
|
14
14
|
*/
|
15
15
|
private generateWithImages;
|
16
|
+
/**
|
17
|
+
* Extract function calls from text using pattern matching
|
18
|
+
* This is more robust than the previous implementation and handles various formats
|
19
|
+
*/
|
20
|
+
private extractFunctionCallsFromText;
|
16
21
|
/**
|
17
22
|
* Try generating with nested inputs format (common in newer models)
|
18
23
|
*/
|
@@ -22,9 +27,9 @@ export declare class HuggingFaceModel extends BaseModel {
|
|
22
27
|
*/
|
23
28
|
private generateWithFlatFormat;
|
24
29
|
/**
|
25
|
-
* Helper to parse HuggingFace response
|
30
|
+
* Helper to parse HuggingFace response with function call extraction
|
26
31
|
*/
|
27
|
-
private
|
32
|
+
private parseResponseWithFunctionCalls;
|
28
33
|
/**
|
29
34
|
* Fallback method that uses multipart/form-data for older HuggingFace models
|
30
35
|
*/
|
@@ -18,7 +18,8 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
18
18
|
}
|
19
19
|
async generate(request) {
|
20
20
|
const config = this.mergeConfig(request.options);
|
21
|
-
|
21
|
+
// Use a more accessible default model that doesn't require special permissions
|
22
|
+
const model = config.model || "mistralai/Mistral-7B-Instruct-v0.2";
|
22
23
|
try {
|
23
24
|
// Try multimodal approach if images are present
|
24
25
|
if (request.image ||
|
@@ -52,6 +53,10 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
52
53
|
" Try a different vision-capable model like 'llava-hf/llava-1.5-7b-hf' or check HuggingFace's documentation for this specific model.";
|
53
54
|
throw new Error(errorMessage);
|
54
55
|
}
|
56
|
+
else if (error.response?.status === 403) {
|
57
|
+
// Handle permission errors more specifically
|
58
|
+
throw new Error(`Permission denied for model "${model}". Try using a different model with public access. Error: ${error.response?.data || error.message}`);
|
59
|
+
}
|
55
60
|
throw error;
|
56
61
|
}
|
57
62
|
}
|
@@ -63,37 +68,62 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
63
68
|
if (request.systemPrompt) {
|
64
69
|
fullPrompt = `${request.systemPrompt}\n\n${fullPrompt}`;
|
65
70
|
}
|
71
|
+
// If functions are provided, enhance the prompt to handle function calling
|
72
|
+
if (request.functions && request.functions.length > 0) {
|
73
|
+
// Format the functions in a way the model can understand
|
74
|
+
fullPrompt += `\n\nAVAILABLE FUNCTIONS:\n${JSON.stringify(request.functions, null, 2)}\n\n`;
|
75
|
+
// Add guidance based on function call setting
|
76
|
+
if (typeof request.functionCall === "object") {
|
77
|
+
fullPrompt += `You must call the function: ${request.functionCall.name}.\n`;
|
78
|
+
fullPrompt += `Format your answer as a function call using JSON, like this:\n`;
|
79
|
+
fullPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
|
80
|
+
fullPrompt += `Don't include any explanations, just output the function call.\n`;
|
81
|
+
}
|
82
|
+
else if (request.functionCall === "auto") {
|
83
|
+
fullPrompt += `Call one of the available functions if appropriate. Format the function call as JSON, like this:\n`;
|
84
|
+
fullPrompt += `{"name": "functionName", "arguments": {...}}\n`;
|
85
|
+
}
|
86
|
+
}
|
66
87
|
const payload = {
|
67
88
|
inputs: fullPrompt,
|
68
89
|
parameters: {
|
69
|
-
temperature: config.temperature,
|
70
|
-
max_new_tokens: config.maxTokens,
|
71
|
-
top_p: config.topP,
|
90
|
+
temperature: config.temperature || 0.1,
|
91
|
+
max_new_tokens: config.maxTokens || 500,
|
92
|
+
top_p: config.topP || 0.9,
|
72
93
|
return_full_text: false,
|
73
94
|
},
|
74
95
|
};
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
96
|
+
try {
|
97
|
+
const response = await axios_1.default.post(`${this.baseURL}/${model}`, payload, {
|
98
|
+
headers: {
|
99
|
+
Authorization: `Bearer ${config.apiKey ||
|
100
|
+
(0, utils_1.getApiKey)(config.apiKey, "HUGGINGFACE_API_KEY", "HuggingFace")}`,
|
101
|
+
"Content-Type": "application/json",
|
102
|
+
},
|
103
|
+
});
|
104
|
+
// Parse the response
|
105
|
+
let text = "";
|
106
|
+
if (Array.isArray(response.data)) {
|
107
|
+
text = response.data[0]?.generated_text || "";
|
108
|
+
}
|
109
|
+
else if (response.data.generated_text) {
|
110
|
+
text = response.data.generated_text;
|
111
|
+
}
|
112
|
+
else {
|
113
|
+
text = JSON.stringify(response.data);
|
114
|
+
}
|
115
|
+
// Extract function calls from the response
|
116
|
+
const functionCalls = this.extractFunctionCallsFromText(text);
|
117
|
+
return {
|
118
|
+
text,
|
119
|
+
functionCalls,
|
120
|
+
raw: response.data,
|
121
|
+
};
|
89
122
|
}
|
90
|
-
|
91
|
-
|
123
|
+
catch (error) {
|
124
|
+
console.error("Error generating with HuggingFace model:", error);
|
125
|
+
throw error;
|
92
126
|
}
|
93
|
-
return {
|
94
|
-
text,
|
95
|
-
raw: response.data,
|
96
|
-
};
|
97
127
|
}
|
98
128
|
/**
|
99
129
|
* Generate a response using multimodal inputs (text + images)
|
@@ -128,6 +158,86 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
128
158
|
.join("; ")}`;
|
129
159
|
throw new Error(errorMessage);
|
130
160
|
}
|
161
|
+
/**
|
162
|
+
* Extract function calls from text using pattern matching
|
163
|
+
* This is more robust than the previous implementation and handles various formats
|
164
|
+
*/
|
165
|
+
extractFunctionCallsFromText(text) {
|
166
|
+
if (!text)
|
167
|
+
return undefined;
|
168
|
+
try {
|
169
|
+
const functionCalls = [];
|
170
|
+
// Pattern 1: Clean JSON function call format
|
171
|
+
// Example: {"name": "functionName", "arguments": {...}}
|
172
|
+
const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*([\s\S]*?)\}/g;
|
173
|
+
let match;
|
174
|
+
while ((match = jsonRegex.exec(text)) !== null) {
|
175
|
+
try {
|
176
|
+
// Try to parse the arguments part as JSON
|
177
|
+
const name = match[1];
|
178
|
+
let args = match[2].trim();
|
179
|
+
// Check if args is already a valid JSON string
|
180
|
+
try {
|
181
|
+
JSON.parse(args);
|
182
|
+
functionCalls.push({
|
183
|
+
name,
|
184
|
+
arguments: args,
|
185
|
+
});
|
186
|
+
}
|
187
|
+
catch (e) {
|
188
|
+
// If not valid JSON, try to extract the JSON object
|
189
|
+
const argsMatch = args.match(/\{[\s\S]*\}/);
|
190
|
+
if (argsMatch) {
|
191
|
+
functionCalls.push({
|
192
|
+
name,
|
193
|
+
arguments: argsMatch[0],
|
194
|
+
});
|
195
|
+
}
|
196
|
+
else {
|
197
|
+
functionCalls.push({
|
198
|
+
name,
|
199
|
+
arguments: "{}",
|
200
|
+
});
|
201
|
+
}
|
202
|
+
}
|
203
|
+
}
|
204
|
+
catch (e) {
|
205
|
+
console.warn("Error parsing function call:", e);
|
206
|
+
}
|
207
|
+
}
|
208
|
+
// Pattern 2: Function-like syntax
|
209
|
+
// Example: functionName({param1: "value", param2: 123})
|
210
|
+
const functionRegex = /([a-zA-Z0-9_]+)\s*\(\s*(\{[\s\S]*?\})\s*\)/g;
|
211
|
+
while ((match = functionRegex.exec(text)) !== null) {
|
212
|
+
functionCalls.push({
|
213
|
+
name: match[1],
|
214
|
+
arguments: match[2],
|
215
|
+
});
|
216
|
+
}
|
217
|
+
// Pattern 3: Markdown code block with JSON
|
218
|
+
// Example: ```json\n{"name": "functionName", "arguments": {...}}\n```
|
219
|
+
const markdownRegex = /```(?:json)?\s*\n\s*(\{[\s\S]*?\})\s*\n```/g;
|
220
|
+
while ((match = markdownRegex.exec(text)) !== null) {
|
221
|
+
try {
|
222
|
+
const jsonObj = JSON.parse(match[1]);
|
223
|
+
if (jsonObj.name && (jsonObj.arguments || jsonObj.args)) {
|
224
|
+
functionCalls.push({
|
225
|
+
name: jsonObj.name,
|
226
|
+
arguments: JSON.stringify(jsonObj.arguments || jsonObj.args),
|
227
|
+
});
|
228
|
+
}
|
229
|
+
}
|
230
|
+
catch (e) {
|
231
|
+
// Ignore parse errors for markdown blocks
|
232
|
+
}
|
233
|
+
}
|
234
|
+
return functionCalls.length > 0 ? functionCalls : undefined;
|
235
|
+
}
|
236
|
+
catch (e) {
|
237
|
+
console.warn("Error in extractFunctionCallsFromText:", e);
|
238
|
+
return undefined;
|
239
|
+
}
|
240
|
+
}
|
131
241
|
/**
|
132
242
|
* Try generating with nested inputs format (common in newer models)
|
133
243
|
*/
|
@@ -135,14 +245,26 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
135
245
|
const prompt = request.systemPrompt
|
136
246
|
? `${request.systemPrompt}\n\n${request.prompt}`
|
137
247
|
: request.prompt;
|
248
|
+
// Handle function calling by adding function definitions to the prompt
|
249
|
+
let enhancedPrompt = prompt;
|
250
|
+
if (request.functions && request.functions.length > 0) {
|
251
|
+
const functionText = JSON.stringify({ functions: request.functions }, null, 2);
|
252
|
+
enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
|
253
|
+
if (typeof request.functionCall === "object") {
|
254
|
+
enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
|
255
|
+
}
|
256
|
+
else if (request.functionCall === "auto") {
|
257
|
+
enhancedPrompt += "Call the appropriate function if needed.\n\n";
|
258
|
+
}
|
259
|
+
}
|
138
260
|
let payload = {
|
139
261
|
inputs: {
|
140
|
-
text:
|
262
|
+
text: enhancedPrompt,
|
141
263
|
},
|
142
264
|
parameters: {
|
143
|
-
temperature: config.temperature,
|
144
|
-
max_new_tokens: config.maxTokens,
|
145
|
-
top_p: config.topP,
|
265
|
+
temperature: config.temperature || 0.1,
|
266
|
+
max_new_tokens: config.maxTokens || 500,
|
267
|
+
top_p: config.topP || 0.9,
|
146
268
|
return_full_text: false,
|
147
269
|
},
|
148
270
|
};
|
@@ -178,8 +300,8 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
178
300
|
"Content-Type": "application/json",
|
179
301
|
},
|
180
302
|
});
|
181
|
-
// Parse response
|
182
|
-
return this.
|
303
|
+
// Parse response with function call extraction
|
304
|
+
return this.parseResponseWithFunctionCalls(response);
|
183
305
|
}
|
184
306
|
/**
|
185
307
|
* Try generating with flat inputs format (common in some models)
|
@@ -188,13 +310,25 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
188
310
|
const prompt = request.systemPrompt
|
189
311
|
? `${request.systemPrompt}\n\n${request.prompt}`
|
190
312
|
: request.prompt;
|
313
|
+
// Handle function calling by adding function definitions to the prompt
|
314
|
+
let enhancedPrompt = prompt;
|
315
|
+
if (request.functions && request.functions.length > 0) {
|
316
|
+
const functionText = JSON.stringify({ functions: request.functions }, null, 2);
|
317
|
+
enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
|
318
|
+
if (typeof request.functionCall === "object") {
|
319
|
+
enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
|
320
|
+
}
|
321
|
+
else if (request.functionCall === "auto") {
|
322
|
+
enhancedPrompt += "Call the appropriate function if needed.\n\n";
|
323
|
+
}
|
324
|
+
}
|
191
325
|
// Some models expect a flat structure with inputs as a string
|
192
326
|
let payload = {
|
193
|
-
inputs:
|
327
|
+
inputs: enhancedPrompt,
|
194
328
|
parameters: {
|
195
|
-
temperature: config.temperature,
|
196
|
-
max_new_tokens: config.maxTokens,
|
197
|
-
top_p: config.topP,
|
329
|
+
temperature: config.temperature || 0.1,
|
330
|
+
max_new_tokens: config.maxTokens || 500,
|
331
|
+
top_p: config.topP || 0.9,
|
198
332
|
return_full_text: false,
|
199
333
|
},
|
200
334
|
};
|
@@ -218,13 +352,13 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
218
352
|
"Content-Type": "application/json",
|
219
353
|
},
|
220
354
|
});
|
221
|
-
// Parse response
|
222
|
-
return this.
|
355
|
+
// Parse response with function call extraction
|
356
|
+
return this.parseResponseWithFunctionCalls(response);
|
223
357
|
}
|
224
358
|
/**
|
225
|
-
* Helper to parse HuggingFace response
|
359
|
+
* Helper to parse HuggingFace response with function call extraction
|
226
360
|
*/
|
227
|
-
|
361
|
+
parseResponseWithFunctionCalls(response) {
|
228
362
|
let text = "";
|
229
363
|
if (Array.isArray(response.data)) {
|
230
364
|
text = response.data[0]?.generated_text || "";
|
@@ -238,8 +372,11 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
238
372
|
else {
|
239
373
|
text = JSON.stringify(response.data);
|
240
374
|
}
|
375
|
+
// Extract function calls from the response text
|
376
|
+
const functionCalls = this.extractFunctionCallsFromText(text);
|
241
377
|
return {
|
242
378
|
text,
|
379
|
+
functionCalls,
|
243
380
|
raw: response.data,
|
244
381
|
};
|
245
382
|
}
|
@@ -249,11 +386,23 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
249
386
|
async generateWithMultipartForm(request, config, model) {
|
250
387
|
// Create a multipart form-data payload for multimodal models
|
251
388
|
const formData = new FormData();
|
252
|
-
// Add text prompt
|
389
|
+
// Add text prompt with function definitions
|
253
390
|
const prompt = request.systemPrompt
|
254
391
|
? `${request.systemPrompt}\n\n${request.prompt}`
|
255
392
|
: request.prompt;
|
256
|
-
|
393
|
+
// Handle function calling by adding function definitions to the prompt
|
394
|
+
let enhancedPrompt = prompt;
|
395
|
+
if (request.functions && request.functions.length > 0) {
|
396
|
+
const functionText = JSON.stringify({ functions: request.functions }, null, 2);
|
397
|
+
enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
|
398
|
+
if (typeof request.functionCall === "object") {
|
399
|
+
enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
|
400
|
+
}
|
401
|
+
else if (request.functionCall === "auto") {
|
402
|
+
enhancedPrompt += "Call the appropriate function if needed.\n\n";
|
403
|
+
}
|
404
|
+
}
|
405
|
+
formData.append("text", enhancedPrompt);
|
257
406
|
// Process the convenience 'image' property
|
258
407
|
if (request.image) {
|
259
408
|
const { base64 } = await (0, image_utils_1.processImage)(request.image);
|
@@ -304,8 +453,11 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
304
453
|
else {
|
305
454
|
text = JSON.stringify(response.data);
|
306
455
|
}
|
456
|
+
// Extract function calls from the response text
|
457
|
+
const functionCalls = this.extractFunctionCallsFromText(text);
|
307
458
|
return {
|
308
459
|
text,
|
460
|
+
functionCalls,
|
309
461
|
raw: response.data,
|
310
462
|
};
|
311
463
|
}
|
@@ -6,6 +6,10 @@ export declare class OllamaModel extends BaseModel {
|
|
6
6
|
constructor(config: AIModelConfig);
|
7
7
|
generate(request: AIModelRequest): Promise<AIModelResponse>;
|
8
8
|
stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
|
9
|
+
/**
|
10
|
+
* Extract function calls from text using various patterns
|
11
|
+
*/
|
12
|
+
private extractFunctionCallsFromText;
|
9
13
|
/**
|
10
14
|
* Creates the request payload for Ollama, handling multimodal content if provided
|
11
15
|
*/
|
@@ -18,8 +18,18 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
18
18
|
async generate(request) {
|
19
19
|
const config = this.mergeConfig(request.options);
|
20
20
|
try {
|
21
|
+
// Create and modify the payload for function calling
|
21
22
|
const payload = await this.createRequestPayload(request, config);
|
22
|
-
const response = await axios_1.default.post(`${this.baseURL}/generate`, payload
|
23
|
+
const response = await axios_1.default.post(`${this.baseURL}/generate`, payload, {
|
24
|
+
headers: {
|
25
|
+
"Content-Type": "application/json",
|
26
|
+
},
|
27
|
+
});
|
28
|
+
if (!response.data || !response.data.response) {
|
29
|
+
throw new Error("Invalid response from Ollama API");
|
30
|
+
}
|
31
|
+
// Try to extract function calls from the response
|
32
|
+
const functionCalls = this.extractFunctionCallsFromText(response.data.response, request);
|
23
33
|
return {
|
24
34
|
text: response.data.response,
|
25
35
|
usage: {
|
@@ -27,6 +37,7 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
27
37
|
completionTokens: response.data.eval_count,
|
28
38
|
totalTokens: response.data.prompt_eval_count + response.data.eval_count,
|
29
39
|
},
|
40
|
+
functionCalls,
|
30
41
|
raw: response.data,
|
31
42
|
};
|
32
43
|
}
|
@@ -48,6 +59,9 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
48
59
|
const payload = await this.createRequestPayload(request, config, true);
|
49
60
|
const response = await axios_1.default.post(`${this.baseURL}/generate`, payload, {
|
50
61
|
responseType: "stream",
|
62
|
+
headers: {
|
63
|
+
"Content-Type": "application/json",
|
64
|
+
},
|
51
65
|
});
|
52
66
|
const reader = response.data;
|
53
67
|
for await (const chunk of reader) {
|
@@ -77,13 +91,105 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
77
91
|
throw error;
|
78
92
|
}
|
79
93
|
}
|
94
|
+
/**
|
95
|
+
* Extract function calls from text using various patterns
|
96
|
+
*/
|
97
|
+
extractFunctionCallsFromText(text, currentRequest) {
|
98
|
+
if (!text)
|
99
|
+
return undefined;
|
100
|
+
try {
|
101
|
+
// Try multiple patterns for function calls
|
102
|
+
// Pattern 1: JSON format with name and arguments
|
103
|
+
// E.g., {"name": "getWeather", "arguments": {"location": "Tokyo"}}
|
104
|
+
const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*([\s\S]*?)\}/g;
|
105
|
+
const jsonMatches = [...text.matchAll(jsonRegex)];
|
106
|
+
if (jsonMatches.length > 0) {
|
107
|
+
return jsonMatches.map((match) => {
|
108
|
+
try {
|
109
|
+
// Try to parse the arguments as JSON
|
110
|
+
const argsText = match[2];
|
111
|
+
let args;
|
112
|
+
try {
|
113
|
+
args = JSON.parse(argsText);
|
114
|
+
return {
|
115
|
+
name: match[1],
|
116
|
+
arguments: JSON.stringify(args),
|
117
|
+
};
|
118
|
+
}
|
119
|
+
catch (e) {
|
120
|
+
// If parsing fails, use the raw text
|
121
|
+
return {
|
122
|
+
name: match[1],
|
123
|
+
arguments: argsText,
|
124
|
+
};
|
125
|
+
}
|
126
|
+
}
|
127
|
+
catch (e) {
|
128
|
+
console.warn("Error parsing function call:", e);
|
129
|
+
return {
|
130
|
+
name: match[1],
|
131
|
+
arguments: "{}",
|
132
|
+
};
|
133
|
+
}
|
134
|
+
});
|
135
|
+
}
|
136
|
+
// Pattern 2: Function call pattern: functionName({"key": "value"})
|
137
|
+
const functionRegex = /([a-zA-Z0-9_]+)\s*\(\s*(\{[\s\S]*?\})\s*\)/g;
|
138
|
+
const functionMatches = [...text.matchAll(functionRegex)];
|
139
|
+
if (functionMatches.length > 0) {
|
140
|
+
return functionMatches.map((match) => ({
|
141
|
+
name: match[1],
|
142
|
+
arguments: match[2],
|
143
|
+
}));
|
144
|
+
}
|
145
|
+
// Pattern 3: Look for more specific calculator patterns
|
146
|
+
if (currentRequest.functionCall &&
|
147
|
+
typeof currentRequest.functionCall === "object" &&
|
148
|
+
currentRequest.functionCall.name === "calculator") {
|
149
|
+
const calculatorRegex = /"?operation"?\s*:\s*"?([^",\s]+)"?,\s*"?a"?\s*:\s*(\d+),\s*"?b"?\s*:\s*(\d+)/;
|
150
|
+
const calculatorMatch = text.match(calculatorRegex);
|
151
|
+
if (calculatorMatch) {
|
152
|
+
const operation = calculatorMatch[1];
|
153
|
+
const a = parseInt(calculatorMatch[2]);
|
154
|
+
const b = parseInt(calculatorMatch[3]);
|
155
|
+
return [
|
156
|
+
{
|
157
|
+
name: "calculator",
|
158
|
+
arguments: JSON.stringify({ operation, a, b }),
|
159
|
+
},
|
160
|
+
];
|
161
|
+
}
|
162
|
+
}
|
163
|
+
// Pattern 4: Look for more specific weather patterns
|
164
|
+
if (currentRequest.functionCall &&
|
165
|
+
typeof currentRequest.functionCall === "object" &&
|
166
|
+
currentRequest.functionCall.name === "getWeather") {
|
167
|
+
const weatherRegex = /"?location"?\s*:\s*"([^"]+)"(?:,\s*"?unit"?\s*:\s*"([^"]+)")?/;
|
168
|
+
const weatherMatch = text.match(weatherRegex);
|
169
|
+
if (weatherMatch) {
|
170
|
+
const location = weatherMatch[1];
|
171
|
+
const unit = weatherMatch[2] || "celsius";
|
172
|
+
return [
|
173
|
+
{
|
174
|
+
name: "getWeather",
|
175
|
+
arguments: JSON.stringify({ location, unit }),
|
176
|
+
},
|
177
|
+
];
|
178
|
+
}
|
179
|
+
}
|
180
|
+
}
|
181
|
+
catch (e) {
|
182
|
+
console.warn("Error in extractFunctionCallsFromText:", e);
|
183
|
+
}
|
184
|
+
return undefined;
|
185
|
+
}
|
80
186
|
/**
|
81
187
|
* Creates the request payload for Ollama, handling multimodal content if provided
|
82
188
|
*/
|
83
189
|
async createRequestPayload(request, config, isStream = false) {
|
84
190
|
// Base payload
|
85
191
|
const payload = {
|
86
|
-
model: config.model || "
|
192
|
+
model: config.model || "llama3", // Updated default to a model that better supports function calling
|
87
193
|
temperature: config.temperature,
|
88
194
|
num_predict: config.maxTokens,
|
89
195
|
top_p: config.topP,
|
@@ -92,10 +198,13 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
92
198
|
if (isStream) {
|
93
199
|
payload.stream = true;
|
94
200
|
}
|
95
|
-
//
|
96
|
-
|
97
|
-
(request.content &&
|
98
|
-
|
201
|
+
// Check if we should use chat format (messages array) or text format
|
202
|
+
const useMessagesFormat = request.image ||
|
203
|
+
(request.content &&
|
204
|
+
request.content.some((item) => item.type === "image")) ||
|
205
|
+
(request.functions && request.functions.length > 0);
|
206
|
+
if (useMessagesFormat) {
|
207
|
+
// Modern message-based format for Ollama
|
99
208
|
const messages = [];
|
100
209
|
// Add system prompt if provided
|
101
210
|
if (request.systemPrompt) {
|
@@ -104,27 +213,27 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
104
213
|
content: request.systemPrompt,
|
105
214
|
});
|
106
215
|
}
|
107
|
-
// Create
|
108
|
-
|
109
|
-
// Add
|
216
|
+
// Create user message content parts
|
217
|
+
let userContent = [];
|
218
|
+
// Add main text content
|
110
219
|
if (request.prompt) {
|
111
|
-
|
220
|
+
userContent.push({
|
112
221
|
type: "text",
|
113
222
|
text: request.prompt,
|
114
223
|
});
|
115
224
|
}
|
116
|
-
//
|
225
|
+
// Add structured content
|
117
226
|
if (request.content) {
|
118
227
|
for (const item of request.content) {
|
119
228
|
if (item.type === "text") {
|
120
|
-
|
229
|
+
userContent.push({
|
121
230
|
type: "text",
|
122
231
|
text: item.text,
|
123
232
|
});
|
124
233
|
}
|
125
234
|
else if (item.type === "image") {
|
126
235
|
const { base64, mimeType } = await (0, image_utils_1.processImage)(item.source);
|
127
|
-
|
236
|
+
userContent.push({
|
128
237
|
type: "image",
|
129
238
|
image: {
|
130
239
|
data: base64,
|
@@ -134,10 +243,10 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
134
243
|
}
|
135
244
|
}
|
136
245
|
}
|
137
|
-
//
|
246
|
+
// Add simple image if provided
|
138
247
|
if (request.image) {
|
139
248
|
const { base64, mimeType } = await (0, image_utils_1.processImage)(request.image);
|
140
|
-
|
249
|
+
userContent.push({
|
141
250
|
type: "image",
|
142
251
|
image: {
|
143
252
|
data: base64,
|
@@ -145,9 +254,46 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
145
254
|
},
|
146
255
|
});
|
147
256
|
}
|
148
|
-
//
|
257
|
+
// Create the user message
|
258
|
+
let userMessage = {
|
259
|
+
role: "user",
|
260
|
+
};
|
261
|
+
// If we have a single text content, use string format
|
262
|
+
if (userContent.length === 1 && userContent[0].type === "text") {
|
263
|
+
userMessage.content = userContent[0].text;
|
264
|
+
}
|
265
|
+
else {
|
266
|
+
userMessage.content = userContent;
|
267
|
+
}
|
149
268
|
messages.push(userMessage);
|
150
|
-
//
|
269
|
+
// Add function calling data to system prompt
|
270
|
+
if (request.functions && request.functions.length > 0) {
|
271
|
+
// Create a system prompt for function calling
|
272
|
+
let functionSystemPrompt = request.systemPrompt || "";
|
273
|
+
// Add function definitions as JSON
|
274
|
+
functionSystemPrompt += `\n\nAvailable functions:\n\`\`\`json\n${JSON.stringify(request.functions, null, 2)}\n\`\`\`\n\n`;
|
275
|
+
// Add instruction based on functionCall setting
|
276
|
+
if (typeof request.functionCall === "object") {
|
277
|
+
functionSystemPrompt += `You must call the function: ${request.functionCall.name}.\n`;
|
278
|
+
functionSystemPrompt += `Format your response as a function call using this exact format:\n`;
|
279
|
+
functionSystemPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
|
280
|
+
}
|
281
|
+
else if (request.functionCall === "auto") {
|
282
|
+
functionSystemPrompt += `Call one of these functions if appropriate for the user's request.\n`;
|
283
|
+
functionSystemPrompt += `Format your response as a function call using this exact format:\n`;
|
284
|
+
functionSystemPrompt += `{"name": "functionName", "arguments": {...}}\n`;
|
285
|
+
}
|
286
|
+
// Replace or add the system message
|
287
|
+
if (messages.length > 0 && messages[0].role === "system") {
|
288
|
+
messages[0].content = functionSystemPrompt;
|
289
|
+
}
|
290
|
+
else {
|
291
|
+
messages.unshift({
|
292
|
+
role: "system",
|
293
|
+
content: functionSystemPrompt,
|
294
|
+
});
|
295
|
+
}
|
296
|
+
}
|
151
297
|
payload.messages = messages;
|
152
298
|
}
|
153
299
|
else {
|
@@ -6,6 +6,14 @@ export declare class OpenAIModel extends BaseModel {
|
|
6
6
|
constructor(config: AIModelConfig);
|
7
7
|
generate(request: AIModelRequest): Promise<AIModelResponse>;
|
8
8
|
stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
|
9
|
+
/**
|
10
|
+
* Prepare OpenAI function calling options
|
11
|
+
*/
|
12
|
+
private prepareFunctionCalling;
|
13
|
+
/**
|
14
|
+
* Process function calls from OpenAI response
|
15
|
+
*/
|
16
|
+
private processFunctionCalls;
|
9
17
|
/**
|
10
18
|
* Format messages for OpenAI API, including handling multimodal content
|
11
19
|
*/
|
@@ -20,13 +20,18 @@ class OpenAIModel extends base_model_1.BaseModel {
|
|
20
20
|
const config = this.mergeConfig(request.options);
|
21
21
|
// Process messages for OpenAI API
|
22
22
|
const messages = await this.formatMessages(request);
|
23
|
+
// Prepare function calling if requested
|
24
|
+
const functionOptions = this.prepareFunctionCalling(request);
|
23
25
|
const response = await this.client.chat.completions.create({
|
24
26
|
model: config.model || "gpt-3.5-turbo",
|
25
27
|
messages,
|
26
28
|
temperature: config.temperature,
|
27
29
|
max_tokens: config.maxTokens,
|
28
30
|
top_p: config.topP,
|
31
|
+
...functionOptions,
|
29
32
|
});
|
33
|
+
// Process function calls if any are present
|
34
|
+
const functionCalls = this.processFunctionCalls(response);
|
30
35
|
return {
|
31
36
|
text: response.choices[0].message.content || "",
|
32
37
|
usage: {
|
@@ -34,6 +39,7 @@ class OpenAIModel extends base_model_1.BaseModel {
|
|
34
39
|
completionTokens: response.usage?.completion_tokens,
|
35
40
|
totalTokens: response.usage?.total_tokens,
|
36
41
|
},
|
42
|
+
functionCalls,
|
37
43
|
raw: response,
|
38
44
|
};
|
39
45
|
}
|
@@ -41,14 +47,19 @@ class OpenAIModel extends base_model_1.BaseModel {
|
|
41
47
|
const config = this.mergeConfig(request.options);
|
42
48
|
// Process messages for OpenAI API
|
43
49
|
const messages = await this.formatMessages(request);
|
44
|
-
|
50
|
+
// Prepare function calling if requested
|
51
|
+
const functionOptions = this.prepareFunctionCalling(request);
|
52
|
+
const response = await this.client.chat.completions.create({
|
45
53
|
model: config.model || "gpt-3.5-turbo",
|
46
54
|
messages,
|
47
55
|
temperature: config.temperature,
|
48
56
|
max_tokens: config.maxTokens,
|
49
57
|
top_p: config.topP,
|
50
58
|
stream: true,
|
59
|
+
...functionOptions,
|
51
60
|
});
|
61
|
+
// Using a more compatible approach with Stream API
|
62
|
+
const stream = response;
|
52
63
|
for await (const chunk of stream) {
|
53
64
|
const content = chunk.choices[0]?.delta?.content || "";
|
54
65
|
if (content) {
|
@@ -56,6 +67,58 @@ class OpenAIModel extends base_model_1.BaseModel {
|
|
56
67
|
}
|
57
68
|
}
|
58
69
|
}
|
70
|
+
/**
|
71
|
+
* Prepare OpenAI function calling options
|
72
|
+
*/
|
73
|
+
prepareFunctionCalling(request) {
|
74
|
+
if (!request.functions || request.functions.length === 0) {
|
75
|
+
return {};
|
76
|
+
}
|
77
|
+
// Transform our function definitions to OpenAI's format
|
78
|
+
const tools = request.functions.map((func) => ({
|
79
|
+
type: "function",
|
80
|
+
function: {
|
81
|
+
name: func.name,
|
82
|
+
description: func.description,
|
83
|
+
parameters: func.parameters,
|
84
|
+
},
|
85
|
+
}));
|
86
|
+
// Handle tool_choice (functionCall in our API)
|
87
|
+
let tool_choice = undefined;
|
88
|
+
if (request.functionCall) {
|
89
|
+
if (request.functionCall === "auto") {
|
90
|
+
tool_choice = "auto";
|
91
|
+
}
|
92
|
+
else if (request.functionCall === "none") {
|
93
|
+
tool_choice = "none";
|
94
|
+
}
|
95
|
+
else if (typeof request.functionCall === "object") {
|
96
|
+
tool_choice = {
|
97
|
+
type: "function",
|
98
|
+
function: { name: request.functionCall.name },
|
99
|
+
};
|
100
|
+
}
|
101
|
+
}
|
102
|
+
return {
|
103
|
+
tools,
|
104
|
+
tool_choice,
|
105
|
+
};
|
106
|
+
}
|
107
|
+
/**
|
108
|
+
* Process function calls from OpenAI response
|
109
|
+
*/
|
110
|
+
processFunctionCalls(response) {
|
111
|
+
const toolCalls = response.choices[0]?.message?.tool_calls;
|
112
|
+
if (!toolCalls || toolCalls.length === 0) {
|
113
|
+
return undefined;
|
114
|
+
}
|
115
|
+
return toolCalls
|
116
|
+
.filter((call) => call.type === "function")
|
117
|
+
.map((call) => ({
|
118
|
+
name: call.function.name,
|
119
|
+
arguments: call.function.arguments,
|
120
|
+
}));
|
121
|
+
}
|
59
122
|
/**
|
60
123
|
* Format messages for OpenAI API, including handling multimodal content
|
61
124
|
*/
|
package/dist/types.d.ts
CHANGED
@@ -23,6 +23,19 @@ export interface ImageContent {
|
|
23
23
|
source: string | Buffer;
|
24
24
|
}
|
25
25
|
export type Content = TextContent | ImageContent;
|
26
|
+
export interface FunctionDefinition {
|
27
|
+
name: string;
|
28
|
+
description: string;
|
29
|
+
parameters: {
|
30
|
+
type: string;
|
31
|
+
properties?: Record<string, any>;
|
32
|
+
required?: string[];
|
33
|
+
};
|
34
|
+
}
|
35
|
+
export interface FunctionCall {
|
36
|
+
name: string;
|
37
|
+
arguments: string;
|
38
|
+
}
|
26
39
|
export interface AIModelResponse {
|
27
40
|
text: string;
|
28
41
|
usage?: {
|
@@ -30,6 +43,7 @@ export interface AIModelResponse {
|
|
30
43
|
completionTokens?: number;
|
31
44
|
totalTokens?: number;
|
32
45
|
};
|
46
|
+
functionCalls?: FunctionCall[];
|
33
47
|
raw?: any;
|
34
48
|
}
|
35
49
|
export interface AIModelRequest {
|
@@ -38,6 +52,10 @@ export interface AIModelRequest {
|
|
38
52
|
options?: Partial<AIModelConfig>;
|
39
53
|
content?: Content[];
|
40
54
|
image?: string | Buffer;
|
55
|
+
functions?: FunctionDefinition[];
|
56
|
+
functionCall?: "auto" | "none" | {
|
57
|
+
name: string;
|
58
|
+
};
|
41
59
|
}
|
42
60
|
export interface AIModel {
|
43
61
|
provider: AIProvider;
|
package/dist/utils.d.ts
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
/**
|
2
2
|
* Utilities for the Neural AI SDK
|
3
3
|
*/
|
4
|
+
/**
|
5
|
+
* Try to load environment variables from .env files
|
6
|
+
* This is done automatically when the module is imported
|
7
|
+
*/
|
8
|
+
export declare function loadEnvVariables(): void;
|
4
9
|
/**
|
5
10
|
* Get an API key from config or environment variables
|
6
11
|
* @param configKey The API key from the config object
|
package/dist/utils.js
CHANGED
@@ -3,8 +3,30 @@
|
|
3
3
|
* Utilities for the Neural AI SDK
|
4
4
|
*/
|
5
5
|
Object.defineProperty(exports, "__esModule", { value: true });
|
6
|
+
exports.loadEnvVariables = loadEnvVariables;
|
6
7
|
exports.getApiKey = getApiKey;
|
7
8
|
exports.getBaseUrl = getBaseUrl;
|
9
|
+
/**
|
10
|
+
* Try to load environment variables from .env files
|
11
|
+
* This is done automatically when the module is imported
|
12
|
+
*/
|
13
|
+
function loadEnvVariables() {
|
14
|
+
try {
|
15
|
+
// Only require dotenv if it's available
|
16
|
+
// This avoids errors if the user hasn't installed dotenv
|
17
|
+
const dotenv = require('dotenv');
|
18
|
+
// Load from .env file in the project root by default
|
19
|
+
dotenv.config();
|
20
|
+
// Also try to load from any parent directories to support monorepos
|
21
|
+
// and projects where the .env file might be in a different location
|
22
|
+
dotenv.config({ path: '../../.env' });
|
23
|
+
dotenv.config({ path: '../.env' });
|
24
|
+
}
|
25
|
+
catch (error) {
|
26
|
+
// Silent fail if dotenv is not available
|
27
|
+
// This is intentional to not break the module if dotenv is not installed
|
28
|
+
}
|
29
|
+
}
|
8
30
|
/**
|
9
31
|
* Get an API key from config or environment variables
|
10
32
|
* @param configKey The API key from the config object
|
package/package.json
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
{
|
2
2
|
"name": "neural-ai-sdk",
|
3
|
-
"version": "0.1.
|
3
|
+
"version": "0.1.4",
|
4
4
|
"description": "Unified SDK for interacting with various AI LLM providers",
|
5
5
|
"main": "dist/index.js",
|
6
6
|
"types": "dist/index.d.ts",
|
@@ -45,6 +45,7 @@
|
|
45
45
|
"dependencies": {
|
46
46
|
"@google/generative-ai": "^0.2.0",
|
47
47
|
"axios": "^1.6.0",
|
48
|
+
"dotenv": "^16.3.1",
|
48
49
|
"openai": "^4.28.0"
|
49
50
|
},
|
50
51
|
"devDependencies": {
|
@@ -52,7 +53,6 @@
|
|
52
53
|
"@types/node": "^20.8.3",
|
53
54
|
"@typescript-eslint/eslint-plugin": "^6.7.4",
|
54
55
|
"@typescript-eslint/parser": "^6.7.4",
|
55
|
-
"dotenv": "^16.3.1",
|
56
56
|
"eslint": "^8.51.0",
|
57
57
|
"jest": "^29.7.0",
|
58
58
|
"rimraf": "^5.0.10",
|