neural-ai-sdk 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/models/deepseek-model.d.ts +4 -0
- package/dist/models/deepseek-model.js +58 -4
- package/dist/models/google-model.d.ts +4 -0
- package/dist/models/google-model.js +138 -4
- package/dist/models/huggingface-model.d.ts +7 -2
- package/dist/models/huggingface-model.js +192 -40
- package/dist/models/ollama-model.d.ts +8 -0
- package/dist/models/ollama-model.js +318 -27
- package/dist/models/openai-model.d.ts +8 -0
- package/dist/models/openai-model.js +64 -1
- package/dist/types.d.ts +18 -0
- package/package.json +1 -1
@@ -6,4 +6,8 @@ export declare class DeepSeekModel extends BaseModel {
|
|
6
6
|
constructor(config: AIModelConfig);
|
7
7
|
generate(request: AIModelRequest): Promise<AIModelResponse>;
|
8
8
|
stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
|
9
|
+
/**
|
10
|
+
* Process function calls from DeepSeek API response
|
11
|
+
*/
|
12
|
+
private processFunctionCalls;
|
9
13
|
}
|
@@ -28,19 +28,39 @@ class DeepSeekModel extends base_model_1.BaseModel {
|
|
28
28
|
role: "user",
|
29
29
|
content: request.prompt,
|
30
30
|
});
|
31
|
-
|
31
|
+
// Prepare request payload
|
32
|
+
const payload = {
|
32
33
|
model: config.model || "deepseek-chat",
|
33
34
|
messages,
|
34
35
|
temperature: config.temperature,
|
35
36
|
max_tokens: config.maxTokens,
|
36
37
|
top_p: config.topP,
|
37
|
-
}
|
38
|
+
};
|
39
|
+
// Add function calling support if functions are provided
|
40
|
+
if (request.functions && request.functions.length > 0) {
|
41
|
+
payload.functions = request.functions;
|
42
|
+
// Handle function call configuration
|
43
|
+
if (request.functionCall) {
|
44
|
+
if (request.functionCall === "auto") {
|
45
|
+
payload.function_call = "auto";
|
46
|
+
}
|
47
|
+
else if (request.functionCall === "none") {
|
48
|
+
payload.function_call = "none";
|
49
|
+
}
|
50
|
+
else if (typeof request.functionCall === "object") {
|
51
|
+
payload.function_call = { name: request.functionCall.name };
|
52
|
+
}
|
53
|
+
}
|
54
|
+
}
|
55
|
+
const response = await axios_1.default.post(`${this.baseURL}/chat/completions`, payload, {
|
38
56
|
headers: {
|
39
57
|
"Content-Type": "application/json",
|
40
58
|
Authorization: `Bearer ${config.apiKey ||
|
41
59
|
(0, utils_1.getApiKey)(config.apiKey, "DEEPSEEK_API_KEY", "DeepSeek")}`,
|
42
60
|
},
|
43
61
|
});
|
62
|
+
// Process function calls if any
|
63
|
+
const functionCalls = this.processFunctionCalls(response.data);
|
44
64
|
return {
|
45
65
|
text: response.data.choices[0].message.content,
|
46
66
|
usage: {
|
@@ -48,6 +68,7 @@ class DeepSeekModel extends base_model_1.BaseModel {
|
|
48
68
|
completionTokens: response.data.usage?.completion_tokens,
|
49
69
|
totalTokens: response.data.usage?.total_tokens,
|
50
70
|
},
|
71
|
+
functionCalls,
|
51
72
|
raw: response.data,
|
52
73
|
};
|
53
74
|
}
|
@@ -64,14 +85,32 @@ class DeepSeekModel extends base_model_1.BaseModel {
|
|
64
85
|
role: "user",
|
65
86
|
content: request.prompt,
|
66
87
|
});
|
67
|
-
|
88
|
+
// Prepare request payload
|
89
|
+
const payload = {
|
68
90
|
model: config.model || "deepseek-chat",
|
69
91
|
messages,
|
70
92
|
temperature: config.temperature,
|
71
93
|
max_tokens: config.maxTokens,
|
72
94
|
top_p: config.topP,
|
73
95
|
stream: true,
|
74
|
-
}
|
96
|
+
};
|
97
|
+
// Add function calling support if functions are provided
|
98
|
+
if (request.functions && request.functions.length > 0) {
|
99
|
+
payload.functions = request.functions;
|
100
|
+
// Handle function call configuration
|
101
|
+
if (request.functionCall) {
|
102
|
+
if (request.functionCall === "auto") {
|
103
|
+
payload.function_call = "auto";
|
104
|
+
}
|
105
|
+
else if (request.functionCall === "none") {
|
106
|
+
payload.function_call = "none";
|
107
|
+
}
|
108
|
+
else if (typeof request.functionCall === "object") {
|
109
|
+
payload.function_call = { name: request.functionCall.name };
|
110
|
+
}
|
111
|
+
}
|
112
|
+
}
|
113
|
+
const response = await axios_1.default.post(`${this.baseURL}/chat/completions`, payload, {
|
75
114
|
headers: {
|
76
115
|
"Content-Type": "application/json",
|
77
116
|
Authorization: `Bearer ${config.apiKey ||
|
@@ -101,5 +140,20 @@ class DeepSeekModel extends base_model_1.BaseModel {
|
|
101
140
|
}
|
102
141
|
}
|
103
142
|
}
|
143
|
+
/**
|
144
|
+
* Process function calls from DeepSeek API response
|
145
|
+
*/
|
146
|
+
processFunctionCalls(response) {
|
147
|
+
if (!response.choices?.[0]?.message?.function_call) {
|
148
|
+
return undefined;
|
149
|
+
}
|
150
|
+
const functionCall = response.choices[0].message.function_call;
|
151
|
+
return [
|
152
|
+
{
|
153
|
+
name: functionCall.name,
|
154
|
+
arguments: functionCall.arguments,
|
155
|
+
},
|
156
|
+
];
|
157
|
+
}
|
104
158
|
}
|
105
159
|
exports.DeepSeekModel = DeepSeekModel;
|
@@ -6,6 +6,10 @@ export declare class GoogleModel extends BaseModel {
|
|
6
6
|
constructor(config: AIModelConfig);
|
7
7
|
generate(request: AIModelRequest): Promise<AIModelResponse>;
|
8
8
|
stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
|
9
|
+
/**
|
10
|
+
* Extract function calls from text using regex patterns since we're using prompt engineering
|
11
|
+
*/
|
12
|
+
private extractFunctionCallsFromText;
|
9
13
|
/**
|
10
14
|
* Format content for Google's Gemini API, handling both text and images
|
11
15
|
*/
|
@@ -15,14 +15,56 @@ class GoogleModel extends base_model_1.BaseModel {
|
|
15
15
|
}
|
16
16
|
async generate(request) {
|
17
17
|
const config = this.mergeConfig(request.options);
|
18
|
-
|
19
|
-
|
18
|
+
// Create base model configuration with Gemini 2.0 models
|
19
|
+
const modelConfig = {
|
20
|
+
model: config.model || "gemini-2.0-flash", // Using 2.0 models as default
|
20
21
|
generationConfig: {
|
21
22
|
temperature: config.temperature,
|
22
23
|
maxOutputTokens: config.maxTokens,
|
23
24
|
topP: config.topP,
|
24
25
|
},
|
25
|
-
}
|
26
|
+
};
|
27
|
+
const model = this.client.getGenerativeModel(modelConfig);
|
28
|
+
// Handle function calling through prompt engineering for Gemini 2.0 models
|
29
|
+
// as native function calling may not be fully supported in the same way
|
30
|
+
if (request.functions && request.functions.length > 0) {
|
31
|
+
try {
|
32
|
+
// Create an enhanced prompt with function definitions
|
33
|
+
let enhancedPrompt = request.prompt || "";
|
34
|
+
if (request.systemPrompt) {
|
35
|
+
enhancedPrompt = `${request.systemPrompt}\n\n${enhancedPrompt}`;
|
36
|
+
}
|
37
|
+
// Add function definitions to the prompt
|
38
|
+
enhancedPrompt += `\n\nYou have access to the following functions:\n\`\`\`json\n${JSON.stringify(request.functions, null, 2)}\n\`\`\`\n\n`;
|
39
|
+
// Add specific instructions based on function call mode
|
40
|
+
if (typeof request.functionCall === "object") {
|
41
|
+
enhancedPrompt += `You MUST use the function: ${request.functionCall.name}\n`;
|
42
|
+
enhancedPrompt += `Format your response as a function call using JSON in this exact format:\n`;
|
43
|
+
enhancedPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
|
44
|
+
enhancedPrompt += `Don't include any explanations, just output the function call.\n`;
|
45
|
+
}
|
46
|
+
else if (request.functionCall === "auto") {
|
47
|
+
enhancedPrompt += `If appropriate for the request, call one of these functions.\n`;
|
48
|
+
enhancedPrompt += `Format your response as a function call using JSON in this exact format:\n`;
|
49
|
+
enhancedPrompt += `{"name": "functionName", "arguments": {...}}\n`;
|
50
|
+
}
|
51
|
+
const result = await model.generateContent(enhancedPrompt);
|
52
|
+
const response = result.response;
|
53
|
+
const text = response.text();
|
54
|
+
// Extract function calls from the response text
|
55
|
+
const functionCalls = this.extractFunctionCallsFromText(text);
|
56
|
+
return {
|
57
|
+
text,
|
58
|
+
functionCalls,
|
59
|
+
raw: response,
|
60
|
+
};
|
61
|
+
}
|
62
|
+
catch (error) {
|
63
|
+
console.warn("Function calling with prompt engineering failed for Gemini, falling back to text-only", error);
|
64
|
+
// Fall back to regular text generation if prompt engineering fails
|
65
|
+
}
|
66
|
+
}
|
67
|
+
// Regular text generation without function calling
|
26
68
|
const content = await this.formatMultiModalContent(request);
|
27
69
|
const result = await model.generateContent(content);
|
28
70
|
const response = result.response;
|
@@ -32,9 +74,11 @@ class GoogleModel extends base_model_1.BaseModel {
|
|
32
74
|
};
|
33
75
|
}
|
34
76
|
async *stream(request) {
|
77
|
+
// For streaming, we'll keep it simpler as function calling with streaming
|
78
|
+
// has additional complexities
|
35
79
|
const config = this.mergeConfig(request.options);
|
36
80
|
const model = this.client.getGenerativeModel({
|
37
|
-
model: config.model || "gemini-2.0-flash", //
|
81
|
+
model: config.model || "gemini-2.0-flash", // Using 2.0 models as default
|
38
82
|
generationConfig: {
|
39
83
|
temperature: config.temperature,
|
40
84
|
maxOutputTokens: config.maxTokens,
|
@@ -50,6 +94,96 @@ class GoogleModel extends base_model_1.BaseModel {
|
|
50
94
|
}
|
51
95
|
}
|
52
96
|
}
|
97
|
+
/**
|
98
|
+
* Extract function calls from text using regex patterns since we're using prompt engineering
|
99
|
+
*/
|
100
|
+
extractFunctionCallsFromText(text) {
|
101
|
+
try {
|
102
|
+
if (!text)
|
103
|
+
return undefined;
|
104
|
+
const functionCalls = [];
|
105
|
+
// Pattern 1: JSON object with name and arguments
|
106
|
+
// Try full regex match first
|
107
|
+
const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*(\{.*?\})\s*\}/gs;
|
108
|
+
let match;
|
109
|
+
while ((match = jsonRegex.exec(text)) !== null) {
|
110
|
+
try {
|
111
|
+
const name = match[1];
|
112
|
+
const args = match[2];
|
113
|
+
functionCalls.push({
|
114
|
+
name,
|
115
|
+
arguments: args,
|
116
|
+
});
|
117
|
+
}
|
118
|
+
catch (e) {
|
119
|
+
console.warn("Error parsing function call:", e);
|
120
|
+
}
|
121
|
+
}
|
122
|
+
// Pattern 2: Markdown code blocks that might contain JSON
|
123
|
+
const markdownRegex = /```(?:json)?\s*\n\s*(\{[\s\S]*?\})\s*\n```/gs;
|
124
|
+
while ((match = markdownRegex.exec(text)) !== null) {
|
125
|
+
try {
|
126
|
+
const jsonBlock = match[1].trim();
|
127
|
+
// Try to parse the JSON
|
128
|
+
const jsonObj = JSON.parse(jsonBlock);
|
129
|
+
if (jsonObj.name && (jsonObj.arguments || jsonObj.args)) {
|
130
|
+
functionCalls.push({
|
131
|
+
name: jsonObj.name,
|
132
|
+
arguments: JSON.stringify(jsonObj.arguments || jsonObj.args),
|
133
|
+
});
|
134
|
+
}
|
135
|
+
}
|
136
|
+
catch (e) {
|
137
|
+
// JSON might be malformed, try more aggressive parsing
|
138
|
+
const nameMatch = match[1].match(/"name"\s*:\s*"([^"]+)"/);
|
139
|
+
const argsMatch = match[1].match(/"arguments"\s*:\s*(\{[^}]*\})/);
|
140
|
+
if (nameMatch && argsMatch) {
|
141
|
+
try {
|
142
|
+
// Try to fix and parse the arguments
|
143
|
+
const argumentsStr = argsMatch[1].replace(/,\s*$/, "");
|
144
|
+
const fixedArgs = argumentsStr + (argumentsStr.endsWith("}") ? "" : "}");
|
145
|
+
functionCalls.push({
|
146
|
+
name: nameMatch[1],
|
147
|
+
arguments: fixedArgs,
|
148
|
+
});
|
149
|
+
}
|
150
|
+
catch (e) {
|
151
|
+
console.warn("Failed to fix JSON format:", e);
|
152
|
+
}
|
153
|
+
}
|
154
|
+
}
|
155
|
+
}
|
156
|
+
// Pattern 3: If still no matches, try looser regex
|
157
|
+
if (functionCalls.length === 0) {
|
158
|
+
// Extract function name and args separately with more permissive patterns
|
159
|
+
const nameMatch = text.match(/"name"\s*:\s*"([^"]+)"/);
|
160
|
+
if (nameMatch) {
|
161
|
+
const name = nameMatch[1];
|
162
|
+
// Find arguments block, accounting for potential closing bracket issues
|
163
|
+
const argsRegex = /"arguments"\s*:\s*(\{[^{]*?(?:\}|$))/;
|
164
|
+
const argsMatch = text.match(argsRegex);
|
165
|
+
if (argsMatch) {
|
166
|
+
let args = argsMatch[1].trim();
|
167
|
+
// Fix common JSON formatting issues
|
168
|
+
if (!args.endsWith("}")) {
|
169
|
+
args += "}";
|
170
|
+
}
|
171
|
+
// Clean up trailing commas which cause JSON.parse to fail
|
172
|
+
args = args.replace(/,\s*}/g, "}");
|
173
|
+
functionCalls.push({
|
174
|
+
name,
|
175
|
+
arguments: args,
|
176
|
+
});
|
177
|
+
}
|
178
|
+
}
|
179
|
+
}
|
180
|
+
return functionCalls.length > 0 ? functionCalls : undefined;
|
181
|
+
}
|
182
|
+
catch (error) {
|
183
|
+
console.warn("Error extracting function calls from text:", error);
|
184
|
+
return undefined;
|
185
|
+
}
|
186
|
+
}
|
53
187
|
/**
|
54
188
|
* Format content for Google's Gemini API, handling both text and images
|
55
189
|
*/
|
@@ -13,6 +13,11 @@ export declare class HuggingFaceModel extends BaseModel {
|
|
13
13
|
* Generate a response using multimodal inputs (text + images)
|
14
14
|
*/
|
15
15
|
private generateWithImages;
|
16
|
+
/**
|
17
|
+
* Extract function calls from text using pattern matching
|
18
|
+
* This is more robust than the previous implementation and handles various formats
|
19
|
+
*/
|
20
|
+
private extractFunctionCallsFromText;
|
16
21
|
/**
|
17
22
|
* Try generating with nested inputs format (common in newer models)
|
18
23
|
*/
|
@@ -22,9 +27,9 @@ export declare class HuggingFaceModel extends BaseModel {
|
|
22
27
|
*/
|
23
28
|
private generateWithFlatFormat;
|
24
29
|
/**
|
25
|
-
* Helper to parse HuggingFace response
|
30
|
+
* Helper to parse HuggingFace response with function call extraction
|
26
31
|
*/
|
27
|
-
private
|
32
|
+
private parseResponseWithFunctionCalls;
|
28
33
|
/**
|
29
34
|
* Fallback method that uses multipart/form-data for older HuggingFace models
|
30
35
|
*/
|
@@ -18,7 +18,8 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
18
18
|
}
|
19
19
|
async generate(request) {
|
20
20
|
const config = this.mergeConfig(request.options);
|
21
|
-
|
21
|
+
// Use a more accessible default model that doesn't require special permissions
|
22
|
+
const model = config.model || "mistralai/Mistral-7B-Instruct-v0.2";
|
22
23
|
try {
|
23
24
|
// Try multimodal approach if images are present
|
24
25
|
if (request.image ||
|
@@ -52,6 +53,10 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
52
53
|
" Try a different vision-capable model like 'llava-hf/llava-1.5-7b-hf' or check HuggingFace's documentation for this specific model.";
|
53
54
|
throw new Error(errorMessage);
|
54
55
|
}
|
56
|
+
else if (error.response?.status === 403) {
|
57
|
+
// Handle permission errors more specifically
|
58
|
+
throw new Error(`Permission denied for model "${model}". Try using a different model with public access. Error: ${error.response?.data || error.message}`);
|
59
|
+
}
|
55
60
|
throw error;
|
56
61
|
}
|
57
62
|
}
|
@@ -63,37 +68,62 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
63
68
|
if (request.systemPrompt) {
|
64
69
|
fullPrompt = `${request.systemPrompt}\n\n${fullPrompt}`;
|
65
70
|
}
|
71
|
+
// If functions are provided, enhance the prompt to handle function calling
|
72
|
+
if (request.functions && request.functions.length > 0) {
|
73
|
+
// Format the functions in a way the model can understand
|
74
|
+
fullPrompt += `\n\nAVAILABLE FUNCTIONS:\n${JSON.stringify(request.functions, null, 2)}\n\n`;
|
75
|
+
// Add guidance based on function call setting
|
76
|
+
if (typeof request.functionCall === "object") {
|
77
|
+
fullPrompt += `You must call the function: ${request.functionCall.name}.\n`;
|
78
|
+
fullPrompt += `Format your answer as a function call using JSON, like this:\n`;
|
79
|
+
fullPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
|
80
|
+
fullPrompt += `Don't include any explanations, just output the function call.\n`;
|
81
|
+
}
|
82
|
+
else if (request.functionCall === "auto") {
|
83
|
+
fullPrompt += `Call one of the available functions if appropriate. Format the function call as JSON, like this:\n`;
|
84
|
+
fullPrompt += `{"name": "functionName", "arguments": {...}}\n`;
|
85
|
+
}
|
86
|
+
}
|
66
87
|
const payload = {
|
67
88
|
inputs: fullPrompt,
|
68
89
|
parameters: {
|
69
|
-
temperature: config.temperature,
|
70
|
-
max_new_tokens: config.maxTokens,
|
71
|
-
top_p: config.topP,
|
90
|
+
temperature: config.temperature || 0.1,
|
91
|
+
max_new_tokens: config.maxTokens || 500,
|
92
|
+
top_p: config.topP || 0.9,
|
72
93
|
return_full_text: false,
|
73
94
|
},
|
74
95
|
};
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
96
|
+
try {
|
97
|
+
const response = await axios_1.default.post(`${this.baseURL}/${model}`, payload, {
|
98
|
+
headers: {
|
99
|
+
Authorization: `Bearer ${config.apiKey ||
|
100
|
+
(0, utils_1.getApiKey)(config.apiKey, "HUGGINGFACE_API_KEY", "HuggingFace")}`,
|
101
|
+
"Content-Type": "application/json",
|
102
|
+
},
|
103
|
+
});
|
104
|
+
// Parse the response
|
105
|
+
let text = "";
|
106
|
+
if (Array.isArray(response.data)) {
|
107
|
+
text = response.data[0]?.generated_text || "";
|
108
|
+
}
|
109
|
+
else if (response.data.generated_text) {
|
110
|
+
text = response.data.generated_text;
|
111
|
+
}
|
112
|
+
else {
|
113
|
+
text = JSON.stringify(response.data);
|
114
|
+
}
|
115
|
+
// Extract function calls from the response
|
116
|
+
const functionCalls = this.extractFunctionCallsFromText(text);
|
117
|
+
return {
|
118
|
+
text,
|
119
|
+
functionCalls,
|
120
|
+
raw: response.data,
|
121
|
+
};
|
89
122
|
}
|
90
|
-
|
91
|
-
|
123
|
+
catch (error) {
|
124
|
+
console.error("Error generating with HuggingFace model:", error);
|
125
|
+
throw error;
|
92
126
|
}
|
93
|
-
return {
|
94
|
-
text,
|
95
|
-
raw: response.data,
|
96
|
-
};
|
97
127
|
}
|
98
128
|
/**
|
99
129
|
* Generate a response using multimodal inputs (text + images)
|
@@ -128,6 +158,86 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
128
158
|
.join("; ")}`;
|
129
159
|
throw new Error(errorMessage);
|
130
160
|
}
|
161
|
+
/**
|
162
|
+
* Extract function calls from text using pattern matching
|
163
|
+
* This is more robust than the previous implementation and handles various formats
|
164
|
+
*/
|
165
|
+
extractFunctionCallsFromText(text) {
|
166
|
+
if (!text)
|
167
|
+
return undefined;
|
168
|
+
try {
|
169
|
+
const functionCalls = [];
|
170
|
+
// Pattern 1: Clean JSON function call format
|
171
|
+
// Example: {"name": "functionName", "arguments": {...}}
|
172
|
+
const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*([\s\S]*?)\}/g;
|
173
|
+
let match;
|
174
|
+
while ((match = jsonRegex.exec(text)) !== null) {
|
175
|
+
try {
|
176
|
+
// Try to parse the arguments part as JSON
|
177
|
+
const name = match[1];
|
178
|
+
let args = match[2].trim();
|
179
|
+
// Check if args is already a valid JSON string
|
180
|
+
try {
|
181
|
+
JSON.parse(args);
|
182
|
+
functionCalls.push({
|
183
|
+
name,
|
184
|
+
arguments: args,
|
185
|
+
});
|
186
|
+
}
|
187
|
+
catch (e) {
|
188
|
+
// If not valid JSON, try to extract the JSON object
|
189
|
+
const argsMatch = args.match(/\{[\s\S]*\}/);
|
190
|
+
if (argsMatch) {
|
191
|
+
functionCalls.push({
|
192
|
+
name,
|
193
|
+
arguments: argsMatch[0],
|
194
|
+
});
|
195
|
+
}
|
196
|
+
else {
|
197
|
+
functionCalls.push({
|
198
|
+
name,
|
199
|
+
arguments: "{}",
|
200
|
+
});
|
201
|
+
}
|
202
|
+
}
|
203
|
+
}
|
204
|
+
catch (e) {
|
205
|
+
console.warn("Error parsing function call:", e);
|
206
|
+
}
|
207
|
+
}
|
208
|
+
// Pattern 2: Function-like syntax
|
209
|
+
// Example: functionName({param1: "value", param2: 123})
|
210
|
+
const functionRegex = /([a-zA-Z0-9_]+)\s*\(\s*(\{[\s\S]*?\})\s*\)/g;
|
211
|
+
while ((match = functionRegex.exec(text)) !== null) {
|
212
|
+
functionCalls.push({
|
213
|
+
name: match[1],
|
214
|
+
arguments: match[2],
|
215
|
+
});
|
216
|
+
}
|
217
|
+
// Pattern 3: Markdown code block with JSON
|
218
|
+
// Example: ```json\n{"name": "functionName", "arguments": {...}}\n```
|
219
|
+
const markdownRegex = /```(?:json)?\s*\n\s*(\{[\s\S]*?\})\s*\n```/g;
|
220
|
+
while ((match = markdownRegex.exec(text)) !== null) {
|
221
|
+
try {
|
222
|
+
const jsonObj = JSON.parse(match[1]);
|
223
|
+
if (jsonObj.name && (jsonObj.arguments || jsonObj.args)) {
|
224
|
+
functionCalls.push({
|
225
|
+
name: jsonObj.name,
|
226
|
+
arguments: JSON.stringify(jsonObj.arguments || jsonObj.args),
|
227
|
+
});
|
228
|
+
}
|
229
|
+
}
|
230
|
+
catch (e) {
|
231
|
+
// Ignore parse errors for markdown blocks
|
232
|
+
}
|
233
|
+
}
|
234
|
+
return functionCalls.length > 0 ? functionCalls : undefined;
|
235
|
+
}
|
236
|
+
catch (e) {
|
237
|
+
console.warn("Error in extractFunctionCallsFromText:", e);
|
238
|
+
return undefined;
|
239
|
+
}
|
240
|
+
}
|
131
241
|
/**
|
132
242
|
* Try generating with nested inputs format (common in newer models)
|
133
243
|
*/
|
@@ -135,14 +245,26 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
135
245
|
const prompt = request.systemPrompt
|
136
246
|
? `${request.systemPrompt}\n\n${request.prompt}`
|
137
247
|
: request.prompt;
|
248
|
+
// Handle function calling by adding function definitions to the prompt
|
249
|
+
let enhancedPrompt = prompt;
|
250
|
+
if (request.functions && request.functions.length > 0) {
|
251
|
+
const functionText = JSON.stringify({ functions: request.functions }, null, 2);
|
252
|
+
enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
|
253
|
+
if (typeof request.functionCall === "object") {
|
254
|
+
enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
|
255
|
+
}
|
256
|
+
else if (request.functionCall === "auto") {
|
257
|
+
enhancedPrompt += "Call the appropriate function if needed.\n\n";
|
258
|
+
}
|
259
|
+
}
|
138
260
|
let payload = {
|
139
261
|
inputs: {
|
140
|
-
text:
|
262
|
+
text: enhancedPrompt,
|
141
263
|
},
|
142
264
|
parameters: {
|
143
|
-
temperature: config.temperature,
|
144
|
-
max_new_tokens: config.maxTokens,
|
145
|
-
top_p: config.topP,
|
265
|
+
temperature: config.temperature || 0.1,
|
266
|
+
max_new_tokens: config.maxTokens || 500,
|
267
|
+
top_p: config.topP || 0.9,
|
146
268
|
return_full_text: false,
|
147
269
|
},
|
148
270
|
};
|
@@ -178,8 +300,8 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
178
300
|
"Content-Type": "application/json",
|
179
301
|
},
|
180
302
|
});
|
181
|
-
// Parse response
|
182
|
-
return this.
|
303
|
+
// Parse response with function call extraction
|
304
|
+
return this.parseResponseWithFunctionCalls(response);
|
183
305
|
}
|
184
306
|
/**
|
185
307
|
* Try generating with flat inputs format (common in some models)
|
@@ -188,13 +310,25 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
188
310
|
const prompt = request.systemPrompt
|
189
311
|
? `${request.systemPrompt}\n\n${request.prompt}`
|
190
312
|
: request.prompt;
|
313
|
+
// Handle function calling by adding function definitions to the prompt
|
314
|
+
let enhancedPrompt = prompt;
|
315
|
+
if (request.functions && request.functions.length > 0) {
|
316
|
+
const functionText = JSON.stringify({ functions: request.functions }, null, 2);
|
317
|
+
enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
|
318
|
+
if (typeof request.functionCall === "object") {
|
319
|
+
enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
|
320
|
+
}
|
321
|
+
else if (request.functionCall === "auto") {
|
322
|
+
enhancedPrompt += "Call the appropriate function if needed.\n\n";
|
323
|
+
}
|
324
|
+
}
|
191
325
|
// Some models expect a flat structure with inputs as a string
|
192
326
|
let payload = {
|
193
|
-
inputs:
|
327
|
+
inputs: enhancedPrompt,
|
194
328
|
parameters: {
|
195
|
-
temperature: config.temperature,
|
196
|
-
max_new_tokens: config.maxTokens,
|
197
|
-
top_p: config.topP,
|
329
|
+
temperature: config.temperature || 0.1,
|
330
|
+
max_new_tokens: config.maxTokens || 500,
|
331
|
+
top_p: config.topP || 0.9,
|
198
332
|
return_full_text: false,
|
199
333
|
},
|
200
334
|
};
|
@@ -218,13 +352,13 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
218
352
|
"Content-Type": "application/json",
|
219
353
|
},
|
220
354
|
});
|
221
|
-
// Parse response
|
222
|
-
return this.
|
355
|
+
// Parse response with function call extraction
|
356
|
+
return this.parseResponseWithFunctionCalls(response);
|
223
357
|
}
|
224
358
|
/**
|
225
|
-
* Helper to parse HuggingFace response
|
359
|
+
* Helper to parse HuggingFace response with function call extraction
|
226
360
|
*/
|
227
|
-
|
361
|
+
parseResponseWithFunctionCalls(response) {
|
228
362
|
let text = "";
|
229
363
|
if (Array.isArray(response.data)) {
|
230
364
|
text = response.data[0]?.generated_text || "";
|
@@ -238,8 +372,11 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
238
372
|
else {
|
239
373
|
text = JSON.stringify(response.data);
|
240
374
|
}
|
375
|
+
// Extract function calls from the response text
|
376
|
+
const functionCalls = this.extractFunctionCallsFromText(text);
|
241
377
|
return {
|
242
378
|
text,
|
379
|
+
functionCalls,
|
243
380
|
raw: response.data,
|
244
381
|
};
|
245
382
|
}
|
@@ -249,11 +386,23 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
249
386
|
async generateWithMultipartForm(request, config, model) {
|
250
387
|
// Create a multipart form-data payload for multimodal models
|
251
388
|
const formData = new FormData();
|
252
|
-
// Add text prompt
|
389
|
+
// Add text prompt with function definitions
|
253
390
|
const prompt = request.systemPrompt
|
254
391
|
? `${request.systemPrompt}\n\n${request.prompt}`
|
255
392
|
: request.prompt;
|
256
|
-
|
393
|
+
// Handle function calling by adding function definitions to the prompt
|
394
|
+
let enhancedPrompt = prompt;
|
395
|
+
if (request.functions && request.functions.length > 0) {
|
396
|
+
const functionText = JSON.stringify({ functions: request.functions }, null, 2);
|
397
|
+
enhancedPrompt = `${enhancedPrompt}\n\nAvailable functions:\n\`\`\`json\n${functionText}\n\`\`\`\n\n`;
|
398
|
+
if (typeof request.functionCall === "object") {
|
399
|
+
enhancedPrompt += `Please call the function: ${request.functionCall.name}\n\n`;
|
400
|
+
}
|
401
|
+
else if (request.functionCall === "auto") {
|
402
|
+
enhancedPrompt += "Call the appropriate function if needed.\n\n";
|
403
|
+
}
|
404
|
+
}
|
405
|
+
formData.append("text", enhancedPrompt);
|
257
406
|
// Process the convenience 'image' property
|
258
407
|
if (request.image) {
|
259
408
|
const { base64 } = await (0, image_utils_1.processImage)(request.image);
|
@@ -304,8 +453,11 @@ class HuggingFaceModel extends base_model_1.BaseModel {
|
|
304
453
|
else {
|
305
454
|
text = JSON.stringify(response.data);
|
306
455
|
}
|
456
|
+
// Extract function calls from the response text
|
457
|
+
const functionCalls = this.extractFunctionCallsFromText(text);
|
307
458
|
return {
|
308
459
|
text,
|
460
|
+
functionCalls,
|
309
461
|
raw: response.data,
|
310
462
|
};
|
311
463
|
}
|
@@ -6,6 +6,14 @@ export declare class OllamaModel extends BaseModel {
|
|
6
6
|
constructor(config: AIModelConfig);
|
7
7
|
generate(request: AIModelRequest): Promise<AIModelResponse>;
|
8
8
|
stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
|
9
|
+
/**
|
10
|
+
* Extract function calls from text using various patterns
|
11
|
+
*/
|
12
|
+
private extractFunctionCallsFromText;
|
13
|
+
/**
|
14
|
+
* Tries to fix incomplete JSON strings by adding missing closing braces
|
15
|
+
*/
|
16
|
+
private tryFixIncompleteJSON;
|
9
17
|
/**
|
10
18
|
* Creates the request payload for Ollama, handling multimodal content if provided
|
11
19
|
*/
|
@@ -18,19 +18,70 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
18
18
|
async generate(request) {
|
19
19
|
const config = this.mergeConfig(request.options);
|
20
20
|
try {
|
21
|
+
// Create the payload for the request
|
21
22
|
const payload = await this.createRequestPayload(request, config);
|
22
|
-
|
23
|
+
// Determine which endpoint to use based on the payload format
|
24
|
+
const endpoint = payload.messages ? "chat" : "generate";
|
25
|
+
console.log(`Using Ollama ${endpoint} endpoint with model: ${payload.model || "default"}`);
|
26
|
+
// Set stream to true to handle responses as a stream
|
27
|
+
payload.stream = true;
|
28
|
+
const response = await axios_1.default.post(`${this.baseURL}/${endpoint}`, payload, {
|
29
|
+
responseType: "stream",
|
30
|
+
headers: {
|
31
|
+
"Content-Type": "application/json",
|
32
|
+
},
|
33
|
+
});
|
34
|
+
// Accumulate the complete response
|
35
|
+
let responseText = "";
|
36
|
+
let promptTokens = 0;
|
37
|
+
let completionTokens = 0;
|
38
|
+
// Process the stream
|
39
|
+
const reader = response.data;
|
40
|
+
for await (const chunk of reader) {
|
41
|
+
const lines = chunk.toString().split("\n").filter(Boolean);
|
42
|
+
for (const line of lines) {
|
43
|
+
try {
|
44
|
+
const parsed = JSON.parse(line);
|
45
|
+
// Handle different response formats
|
46
|
+
if (endpoint === "chat") {
|
47
|
+
if (parsed.message && parsed.message.content) {
|
48
|
+
responseText += parsed.message.content;
|
49
|
+
}
|
50
|
+
}
|
51
|
+
else if (parsed.response) {
|
52
|
+
responseText += parsed.response;
|
53
|
+
}
|
54
|
+
// Extract token usage from the final message
|
55
|
+
if (parsed.done) {
|
56
|
+
promptTokens = parsed.prompt_eval_count || 0;
|
57
|
+
completionTokens = parsed.eval_count || 0;
|
58
|
+
}
|
59
|
+
}
|
60
|
+
catch (error) {
|
61
|
+
console.error("Error parsing Ollama stream data:", line, error);
|
62
|
+
}
|
63
|
+
}
|
64
|
+
}
|
65
|
+
console.log(`Extracted response text: "${responseText}"`);
|
66
|
+
// Try to extract function calls from the response
|
67
|
+
const functionCalls = this.extractFunctionCallsFromText(responseText, request);
|
23
68
|
return {
|
24
|
-
text:
|
69
|
+
text: responseText,
|
25
70
|
usage: {
|
26
|
-
promptTokens:
|
27
|
-
completionTokens:
|
28
|
-
totalTokens:
|
71
|
+
promptTokens: promptTokens,
|
72
|
+
completionTokens: completionTokens,
|
73
|
+
totalTokens: promptTokens + completionTokens,
|
29
74
|
},
|
30
|
-
|
75
|
+
functionCalls,
|
76
|
+
raw: { response: responseText }, // We don't have the original raw response, so create one
|
31
77
|
};
|
32
78
|
}
|
33
79
|
catch (error) {
|
80
|
+
console.error("Ollama API error details:", error);
|
81
|
+
if (error.response) {
|
82
|
+
console.error("Response status:", error.response.status);
|
83
|
+
console.error("Response data:", error.response.data);
|
84
|
+
}
|
34
85
|
// Enhance error message if it appears to be related to multimodal support
|
35
86
|
if (error.response?.status === 400 &&
|
36
87
|
(request.image || request.content) &&
|
@@ -39,6 +90,16 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
39
90
|
error.response?.data?.error?.includes("vision"))) {
|
40
91
|
throw new Error(`The model "${config.model || "default"}" doesn't support multimodal inputs. Try a vision-capable model like "llama-3.2-vision" or "llava". Original error: ${error.message}`);
|
41
92
|
}
|
93
|
+
// Check if the error is about the model not being found or loaded
|
94
|
+
if (error.response?.status === 404 ||
|
95
|
+
(error.response?.data &&
|
96
|
+
typeof error.response.data.error === "string" &&
|
97
|
+
error.response.data.error.toLowerCase().includes("model") &&
|
98
|
+
error.response.data.error.toLowerCase().includes("not"))) {
|
99
|
+
throw new Error(`Model "${config.model || "default"}" not found or not loaded in Ollama. ` +
|
100
|
+
`Make sure the model is installed with 'ollama pull ${config.model || "llama2"}' ` +
|
101
|
+
`Original error: ${error.response?.data?.error || error.message}`);
|
102
|
+
}
|
42
103
|
throw error;
|
43
104
|
}
|
44
105
|
}
|
@@ -46,8 +107,13 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
46
107
|
const config = this.mergeConfig(request.options);
|
47
108
|
try {
|
48
109
|
const payload = await this.createRequestPayload(request, config, true);
|
49
|
-
|
110
|
+
// Determine which endpoint to use based on the payload format
|
111
|
+
const endpoint = payload.messages ? "chat" : "generate";
|
112
|
+
const response = await axios_1.default.post(`${this.baseURL}/${endpoint}`, payload, {
|
50
113
|
responseType: "stream",
|
114
|
+
headers: {
|
115
|
+
"Content-Type": "application/json",
|
116
|
+
},
|
51
117
|
});
|
52
118
|
const reader = response.data;
|
53
119
|
for await (const chunk of reader) {
|
@@ -55,7 +121,13 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
55
121
|
for (const line of lines) {
|
56
122
|
try {
|
57
123
|
const parsed = JSON.parse(line);
|
58
|
-
|
124
|
+
// Handle different response formats
|
125
|
+
if (endpoint === "chat") {
|
126
|
+
if (parsed.message && parsed.message.content) {
|
127
|
+
yield parsed.message.content;
|
128
|
+
}
|
129
|
+
}
|
130
|
+
else if (parsed.response) {
|
59
131
|
yield parsed.response;
|
60
132
|
}
|
61
133
|
}
|
@@ -77,25 +149,205 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
77
149
|
throw error;
|
78
150
|
}
|
79
151
|
}
|
152
|
+
/**
|
153
|
+
* Extract function calls from text using various patterns
|
154
|
+
*/
|
155
|
+
extractFunctionCallsFromText(text, currentRequest) {
|
156
|
+
if (!text)
|
157
|
+
return undefined;
|
158
|
+
try {
|
159
|
+
// Fix incomplete JSON - look for patterns where JSON might be incomplete
|
160
|
+
// First, let's try to fix a common issue where the closing brace is missing
|
161
|
+
const fixedText = this.tryFixIncompleteJSON(text);
|
162
|
+
// Pattern 1: JSON format with name and arguments
|
163
|
+
// E.g., {"name": "getWeather", "arguments": {"location": "Tokyo"}}
|
164
|
+
const jsonRegex = /\{[\s\n]*"name"[\s\n]*:[\s\n]*"([^"]+)"[\s\n]*,[\s\n]*"arguments"[\s\n]*:[\s\n]*([\s\S]*?)\}/g;
|
165
|
+
const jsonMatches = [...fixedText.matchAll(jsonRegex)];
|
166
|
+
if (jsonMatches.length > 0) {
|
167
|
+
return jsonMatches.map((match) => {
|
168
|
+
try {
|
169
|
+
// Try to parse the arguments as JSON
|
170
|
+
let argsText = match[2];
|
171
|
+
// Fix potential incomplete JSON in arguments
|
172
|
+
argsText = this.tryFixIncompleteJSON(argsText);
|
173
|
+
let args;
|
174
|
+
try {
|
175
|
+
args = JSON.parse(argsText);
|
176
|
+
return {
|
177
|
+
name: match[1],
|
178
|
+
arguments: JSON.stringify(args),
|
179
|
+
};
|
180
|
+
}
|
181
|
+
catch (e) {
|
182
|
+
// If parsing fails, try to fix the JSON before returning
|
183
|
+
console.warn("Error parsing function arguments, trying to fix:", e);
|
184
|
+
return {
|
185
|
+
name: match[1],
|
186
|
+
arguments: this.tryFixIncompleteJSON(argsText, true),
|
187
|
+
};
|
188
|
+
}
|
189
|
+
}
|
190
|
+
catch (e) {
|
191
|
+
console.warn("Error parsing function call:", e);
|
192
|
+
return {
|
193
|
+
name: match[1],
|
194
|
+
arguments: "{}",
|
195
|
+
};
|
196
|
+
}
|
197
|
+
});
|
198
|
+
}
|
199
|
+
// Pattern 2: Function call pattern: functionName({"key": "value"})
|
200
|
+
const functionRegex = /([a-zA-Z0-9_]+)\s*\(\s*(\{[\s\S]*?\})\s*\)/g;
|
201
|
+
const functionMatches = [...fixedText.matchAll(functionRegex)];
|
202
|
+
if (functionMatches.length > 0) {
|
203
|
+
return functionMatches.map((match) => {
|
204
|
+
const argsText = this.tryFixIncompleteJSON(match[2]);
|
205
|
+
return {
|
206
|
+
name: match[1],
|
207
|
+
arguments: argsText,
|
208
|
+
};
|
209
|
+
});
|
210
|
+
}
|
211
|
+
// Pattern 3: Looking for direct JSON objects - for function specific forced calls
|
212
|
+
if (currentRequest.functionCall &&
|
213
|
+
typeof currentRequest.functionCall === "object") {
|
214
|
+
const forcedFunctionName = currentRequest.functionCall.name;
|
215
|
+
// For getWeather function
|
216
|
+
if (forcedFunctionName === "getWeather") {
|
217
|
+
const weatherMatch = fixedText.match(/\{[\s\n]*"location"[\s\n]*:[\s\n]*"([^"]*)"(?:[\s\n]*,[\s\n]*"unit"[\s\n]*:[\s\n]*"([^"]*)"|)(.*?)\}/s);
|
218
|
+
if (weatherMatch) {
|
219
|
+
const location = weatherMatch[1];
|
220
|
+
const unit = weatherMatch[2] || "celsius";
|
221
|
+
return [
|
222
|
+
{
|
223
|
+
name: "getWeather",
|
224
|
+
arguments: JSON.stringify({ location, unit }),
|
225
|
+
},
|
226
|
+
];
|
227
|
+
}
|
228
|
+
}
|
229
|
+
// For calculator function
|
230
|
+
if (forcedFunctionName === "calculator") {
|
231
|
+
const calculatorMatch = fixedText.match(/\{[\s\n]*"operation"[\s\n]*:[\s\n]*"([^"]*)"[\s\n]*,[\s\n]*"a"[\s\n]*:[\s\n]*(\d+)[\s\n]*,[\s\n]*"b"[\s\n]*:[\s\n]*(\d+)[\s\S]*?\}/s);
|
232
|
+
if (calculatorMatch) {
|
233
|
+
const operation = calculatorMatch[1];
|
234
|
+
const a = parseInt(calculatorMatch[2]);
|
235
|
+
const b = parseInt(calculatorMatch[3]);
|
236
|
+
return [
|
237
|
+
{
|
238
|
+
name: "calculator",
|
239
|
+
arguments: JSON.stringify({ operation, a, b }),
|
240
|
+
},
|
241
|
+
];
|
242
|
+
}
|
243
|
+
}
|
244
|
+
}
|
245
|
+
// If no matches found and we have a functionCall request, try one more pattern matching approach
|
246
|
+
if (currentRequest.functionCall) {
|
247
|
+
// Try to extract JSON-like structures even if they're not complete
|
248
|
+
const namedFunction = typeof currentRequest.functionCall === "object"
|
249
|
+
? currentRequest.functionCall.name
|
250
|
+
: null;
|
251
|
+
// Look for the function name in the text followed by arguments
|
252
|
+
const functionNamePattern = namedFunction
|
253
|
+
? new RegExp(`"name"\\s*:\\s*"${namedFunction}"\\s*,\\s*"arguments"\\s*:\\s*(\\{[\\s\\S]*?)(?:\\}|$)`, "s")
|
254
|
+
: null;
|
255
|
+
if (functionNamePattern) {
|
256
|
+
const extractedMatch = fixedText.match(functionNamePattern);
|
257
|
+
if (extractedMatch && extractedMatch[1]) {
|
258
|
+
let argsText = extractedMatch[1];
|
259
|
+
// Make sure the JSON is complete
|
260
|
+
if (!argsText.endsWith("}")) {
|
261
|
+
argsText += "}";
|
262
|
+
}
|
263
|
+
try {
|
264
|
+
// Try to parse the fixed arguments
|
265
|
+
const args = JSON.parse(argsText);
|
266
|
+
return [
|
267
|
+
{
|
268
|
+
name: namedFunction,
|
269
|
+
arguments: JSON.stringify(args),
|
270
|
+
},
|
271
|
+
];
|
272
|
+
}
|
273
|
+
catch (e) {
|
274
|
+
console.warn("Failed to parse extracted arguments:", e);
|
275
|
+
return [
|
276
|
+
{
|
277
|
+
name: namedFunction,
|
278
|
+
arguments: this.tryFixIncompleteJSON(argsText, true),
|
279
|
+
},
|
280
|
+
];
|
281
|
+
}
|
282
|
+
}
|
283
|
+
}
|
284
|
+
}
|
285
|
+
}
|
286
|
+
catch (e) {
|
287
|
+
console.warn("Error in extractFunctionCallsFromText:", e);
|
288
|
+
}
|
289
|
+
return undefined;
|
290
|
+
}
|
291
|
+
/**
|
292
|
+
* Tries to fix incomplete JSON strings by adding missing closing braces
|
293
|
+
*/
|
294
|
+
tryFixIncompleteJSON(text, returnAsString = false) {
|
295
|
+
// Skip if the string is already valid JSON
|
296
|
+
try {
|
297
|
+
JSON.parse(text);
|
298
|
+
return text; // Already valid
|
299
|
+
}
|
300
|
+
catch (e) {
|
301
|
+
// Not valid JSON, try to fix
|
302
|
+
}
|
303
|
+
// Count opening and closing braces
|
304
|
+
const openBraces = (text.match(/\{/g) || []).length;
|
305
|
+
const closeBraces = (text.match(/\}/g) || []).length;
|
306
|
+
// If we have more opening braces than closing, add the missing closing braces
|
307
|
+
if (openBraces > closeBraces) {
|
308
|
+
const missingBraces = openBraces - closeBraces;
|
309
|
+
let fixedText = text + "}".repeat(missingBraces);
|
310
|
+
// Try to parse it to see if it's valid now
|
311
|
+
try {
|
312
|
+
if (!returnAsString) {
|
313
|
+
JSON.parse(fixedText);
|
314
|
+
}
|
315
|
+
return fixedText;
|
316
|
+
}
|
317
|
+
catch (e) {
|
318
|
+
// Still not valid, return the original
|
319
|
+
console.warn("Failed to fix JSON even after adding braces", e);
|
320
|
+
return text;
|
321
|
+
}
|
322
|
+
}
|
323
|
+
return text;
|
324
|
+
}
|
80
325
|
/**
|
81
326
|
* Creates the request payload for Ollama, handling multimodal content if provided
|
82
327
|
*/
|
83
328
|
async createRequestPayload(request, config, isStream = false) {
|
84
329
|
// Base payload
|
85
330
|
const payload = {
|
86
|
-
model: config.model || "
|
331
|
+
model: config.model || "llama3", // Updated default to a model that better supports function calling
|
87
332
|
temperature: config.temperature,
|
88
|
-
num_predict: config.maxTokens,
|
89
333
|
top_p: config.topP,
|
90
334
|
};
|
335
|
+
// Add max tokens for the generate endpoint
|
336
|
+
if (config.maxTokens) {
|
337
|
+
payload.num_predict = config.maxTokens;
|
338
|
+
}
|
91
339
|
// Handle streaming
|
92
340
|
if (isStream) {
|
93
341
|
payload.stream = true;
|
94
342
|
}
|
95
|
-
//
|
96
|
-
|
97
|
-
(request.content &&
|
98
|
-
|
343
|
+
// Check if we should use chat format (messages array) or text format
|
344
|
+
const useMessagesFormat = request.image ||
|
345
|
+
(request.content &&
|
346
|
+
request.content.some((item) => item.type === "image")) ||
|
347
|
+
(request.functions && request.functions.length > 0) ||
|
348
|
+
request.systemPrompt; // Always use messages format when system prompt is provided
|
349
|
+
if (useMessagesFormat) {
|
350
|
+
// Modern message-based format for Ollama (chat endpoint)
|
99
351
|
const messages = [];
|
100
352
|
// Add system prompt if provided
|
101
353
|
if (request.systemPrompt) {
|
@@ -104,27 +356,48 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
104
356
|
content: request.systemPrompt,
|
105
357
|
});
|
106
358
|
}
|
107
|
-
|
108
|
-
|
109
|
-
|
359
|
+
else if (request.functions && request.functions.length > 0) {
|
360
|
+
// Add function calling guidance in system prompt if none provided
|
361
|
+
let functionSystemPrompt = "You are a helpful AI assistant with access to functions.";
|
362
|
+
// Add function definitions as JSON
|
363
|
+
functionSystemPrompt += `\n\nAvailable functions:\n\`\`\`json\n${JSON.stringify(request.functions, null, 2)}\n\`\`\`\n\n`;
|
364
|
+
// Add instruction based on functionCall setting
|
365
|
+
if (typeof request.functionCall === "object") {
|
366
|
+
functionSystemPrompt += `You must call the function: ${request.functionCall.name}.\n`;
|
367
|
+
functionSystemPrompt += `Format your response as a function call using this exact format:\n`;
|
368
|
+
functionSystemPrompt += `{"name": "${request.functionCall.name}", "arguments": {...}}\n`;
|
369
|
+
}
|
370
|
+
else if (request.functionCall === "auto") {
|
371
|
+
functionSystemPrompt += `Call one of these functions if appropriate for the user's request.\n`;
|
372
|
+
functionSystemPrompt += `Format your response as a function call using this exact format:\n`;
|
373
|
+
functionSystemPrompt += `{"name": "functionName", "arguments": {...}}\n`;
|
374
|
+
}
|
375
|
+
messages.push({
|
376
|
+
role: "system",
|
377
|
+
content: functionSystemPrompt,
|
378
|
+
});
|
379
|
+
}
|
380
|
+
// Create user message content parts
|
381
|
+
let userContent = [];
|
382
|
+
// Add main text content
|
110
383
|
if (request.prompt) {
|
111
|
-
|
384
|
+
userContent.push({
|
112
385
|
type: "text",
|
113
386
|
text: request.prompt,
|
114
387
|
});
|
115
388
|
}
|
116
|
-
//
|
389
|
+
// Add structured content
|
117
390
|
if (request.content) {
|
118
391
|
for (const item of request.content) {
|
119
392
|
if (item.type === "text") {
|
120
|
-
|
393
|
+
userContent.push({
|
121
394
|
type: "text",
|
122
395
|
text: item.text,
|
123
396
|
});
|
124
397
|
}
|
125
398
|
else if (item.type === "image") {
|
126
399
|
const { base64, mimeType } = await (0, image_utils_1.processImage)(item.source);
|
127
|
-
|
400
|
+
userContent.push({
|
128
401
|
type: "image",
|
129
402
|
image: {
|
130
403
|
data: base64,
|
@@ -134,10 +407,10 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
134
407
|
}
|
135
408
|
}
|
136
409
|
}
|
137
|
-
//
|
410
|
+
// Add simple image if provided
|
138
411
|
if (request.image) {
|
139
412
|
const { base64, mimeType } = await (0, image_utils_1.processImage)(request.image);
|
140
|
-
|
413
|
+
userContent.push({
|
141
414
|
type: "image",
|
142
415
|
image: {
|
143
416
|
data: base64,
|
@@ -145,14 +418,32 @@ class OllamaModel extends base_model_1.BaseModel {
|
|
145
418
|
},
|
146
419
|
});
|
147
420
|
}
|
148
|
-
//
|
421
|
+
// Create the user message
|
422
|
+
let userMessage = {
|
423
|
+
role: "user",
|
424
|
+
};
|
425
|
+
// If we have a single text content, use string format
|
426
|
+
if (userContent.length === 1 && userContent[0].type === "text") {
|
427
|
+
userMessage.content = userContent[0].text;
|
428
|
+
}
|
429
|
+
else if (userContent.length > 0) {
|
430
|
+
userMessage.content = userContent;
|
431
|
+
}
|
432
|
+
else {
|
433
|
+
// Add empty string if no content provided to avoid invalid request
|
434
|
+
userMessage.content = "";
|
435
|
+
}
|
149
436
|
messages.push(userMessage);
|
150
|
-
// Set the messages in the payload
|
151
437
|
payload.messages = messages;
|
438
|
+
// Remove any fields specific to the generate endpoint
|
439
|
+
// that might cause issues with the chat endpoint
|
440
|
+
if (payload.hasOwnProperty("num_predict")) {
|
441
|
+
delete payload.num_predict;
|
442
|
+
}
|
152
443
|
}
|
153
444
|
else {
|
154
|
-
// Traditional text-only format
|
155
|
-
let prompt = request.prompt;
|
445
|
+
// Traditional text-only format (generate endpoint)
|
446
|
+
let prompt = request.prompt || "";
|
156
447
|
// Add system prompt if provided
|
157
448
|
if (request.systemPrompt) {
|
158
449
|
prompt = `${request.systemPrompt}\n\n${prompt}`;
|
@@ -6,6 +6,14 @@ export declare class OpenAIModel extends BaseModel {
|
|
6
6
|
constructor(config: AIModelConfig);
|
7
7
|
generate(request: AIModelRequest): Promise<AIModelResponse>;
|
8
8
|
stream(request: AIModelRequest): AsyncGenerator<string, void, unknown>;
|
9
|
+
/**
|
10
|
+
* Prepare OpenAI function calling options
|
11
|
+
*/
|
12
|
+
private prepareFunctionCalling;
|
13
|
+
/**
|
14
|
+
* Process function calls from OpenAI response
|
15
|
+
*/
|
16
|
+
private processFunctionCalls;
|
9
17
|
/**
|
10
18
|
* Format messages for OpenAI API, including handling multimodal content
|
11
19
|
*/
|
@@ -20,13 +20,18 @@ class OpenAIModel extends base_model_1.BaseModel {
|
|
20
20
|
const config = this.mergeConfig(request.options);
|
21
21
|
// Process messages for OpenAI API
|
22
22
|
const messages = await this.formatMessages(request);
|
23
|
+
// Prepare function calling if requested
|
24
|
+
const functionOptions = this.prepareFunctionCalling(request);
|
23
25
|
const response = await this.client.chat.completions.create({
|
24
26
|
model: config.model || "gpt-3.5-turbo",
|
25
27
|
messages,
|
26
28
|
temperature: config.temperature,
|
27
29
|
max_tokens: config.maxTokens,
|
28
30
|
top_p: config.topP,
|
31
|
+
...functionOptions,
|
29
32
|
});
|
33
|
+
// Process function calls if any are present
|
34
|
+
const functionCalls = this.processFunctionCalls(response);
|
30
35
|
return {
|
31
36
|
text: response.choices[0].message.content || "",
|
32
37
|
usage: {
|
@@ -34,6 +39,7 @@ class OpenAIModel extends base_model_1.BaseModel {
|
|
34
39
|
completionTokens: response.usage?.completion_tokens,
|
35
40
|
totalTokens: response.usage?.total_tokens,
|
36
41
|
},
|
42
|
+
functionCalls,
|
37
43
|
raw: response,
|
38
44
|
};
|
39
45
|
}
|
@@ -41,14 +47,19 @@ class OpenAIModel extends base_model_1.BaseModel {
|
|
41
47
|
const config = this.mergeConfig(request.options);
|
42
48
|
// Process messages for OpenAI API
|
43
49
|
const messages = await this.formatMessages(request);
|
44
|
-
|
50
|
+
// Prepare function calling if requested
|
51
|
+
const functionOptions = this.prepareFunctionCalling(request);
|
52
|
+
const response = await this.client.chat.completions.create({
|
45
53
|
model: config.model || "gpt-3.5-turbo",
|
46
54
|
messages,
|
47
55
|
temperature: config.temperature,
|
48
56
|
max_tokens: config.maxTokens,
|
49
57
|
top_p: config.topP,
|
50
58
|
stream: true,
|
59
|
+
...functionOptions,
|
51
60
|
});
|
61
|
+
// Using a more compatible approach with Stream API
|
62
|
+
const stream = response;
|
52
63
|
for await (const chunk of stream) {
|
53
64
|
const content = chunk.choices[0]?.delta?.content || "";
|
54
65
|
if (content) {
|
@@ -56,6 +67,58 @@ class OpenAIModel extends base_model_1.BaseModel {
|
|
56
67
|
}
|
57
68
|
}
|
58
69
|
}
|
70
|
+
/**
|
71
|
+
* Prepare OpenAI function calling options
|
72
|
+
*/
|
73
|
+
prepareFunctionCalling(request) {
|
74
|
+
if (!request.functions || request.functions.length === 0) {
|
75
|
+
return {};
|
76
|
+
}
|
77
|
+
// Transform our function definitions to OpenAI's format
|
78
|
+
const tools = request.functions.map((func) => ({
|
79
|
+
type: "function",
|
80
|
+
function: {
|
81
|
+
name: func.name,
|
82
|
+
description: func.description,
|
83
|
+
parameters: func.parameters,
|
84
|
+
},
|
85
|
+
}));
|
86
|
+
// Handle tool_choice (functionCall in our API)
|
87
|
+
let tool_choice = undefined;
|
88
|
+
if (request.functionCall) {
|
89
|
+
if (request.functionCall === "auto") {
|
90
|
+
tool_choice = "auto";
|
91
|
+
}
|
92
|
+
else if (request.functionCall === "none") {
|
93
|
+
tool_choice = "none";
|
94
|
+
}
|
95
|
+
else if (typeof request.functionCall === "object") {
|
96
|
+
tool_choice = {
|
97
|
+
type: "function",
|
98
|
+
function: { name: request.functionCall.name },
|
99
|
+
};
|
100
|
+
}
|
101
|
+
}
|
102
|
+
return {
|
103
|
+
tools,
|
104
|
+
tool_choice,
|
105
|
+
};
|
106
|
+
}
|
107
|
+
/**
|
108
|
+
* Process function calls from OpenAI response
|
109
|
+
*/
|
110
|
+
processFunctionCalls(response) {
|
111
|
+
const toolCalls = response.choices[0]?.message?.tool_calls;
|
112
|
+
if (!toolCalls || toolCalls.length === 0) {
|
113
|
+
return undefined;
|
114
|
+
}
|
115
|
+
return toolCalls
|
116
|
+
.filter((call) => call.type === "function")
|
117
|
+
.map((call) => ({
|
118
|
+
name: call.function.name,
|
119
|
+
arguments: call.function.arguments,
|
120
|
+
}));
|
121
|
+
}
|
59
122
|
/**
|
60
123
|
* Format messages for OpenAI API, including handling multimodal content
|
61
124
|
*/
|
package/dist/types.d.ts
CHANGED
@@ -23,6 +23,19 @@ export interface ImageContent {
|
|
23
23
|
source: string | Buffer;
|
24
24
|
}
|
25
25
|
export type Content = TextContent | ImageContent;
|
26
|
+
export interface FunctionDefinition {
|
27
|
+
name: string;
|
28
|
+
description: string;
|
29
|
+
parameters: {
|
30
|
+
type: string;
|
31
|
+
properties?: Record<string, any>;
|
32
|
+
required?: string[];
|
33
|
+
};
|
34
|
+
}
|
35
|
+
export interface FunctionCall {
|
36
|
+
name: string;
|
37
|
+
arguments: string;
|
38
|
+
}
|
26
39
|
export interface AIModelResponse {
|
27
40
|
text: string;
|
28
41
|
usage?: {
|
@@ -30,6 +43,7 @@ export interface AIModelResponse {
|
|
30
43
|
completionTokens?: number;
|
31
44
|
totalTokens?: number;
|
32
45
|
};
|
46
|
+
functionCalls?: FunctionCall[];
|
33
47
|
raw?: any;
|
34
48
|
}
|
35
49
|
export interface AIModelRequest {
|
@@ -38,6 +52,10 @@ export interface AIModelRequest {
|
|
38
52
|
options?: Partial<AIModelConfig>;
|
39
53
|
content?: Content[];
|
40
54
|
image?: string | Buffer;
|
55
|
+
functions?: FunctionDefinition[];
|
56
|
+
functionCall?: "auto" | "none" | {
|
57
|
+
name: string;
|
58
|
+
};
|
41
59
|
}
|
42
60
|
export interface AIModel {
|
43
61
|
provider: AIProvider;
|