@hazeljs/ai 0.2.0-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +192 -0
- package/README.md +497 -0
- package/dist/ai-enhanced.service.d.ts +108 -0
- package/dist/ai-enhanced.service.d.ts.map +1 -0
- package/dist/ai-enhanced.service.js +345 -0
- package/dist/ai-enhanced.service.test.d.ts +2 -0
- package/dist/ai-enhanced.service.test.d.ts.map +1 -0
- package/dist/ai-enhanced.service.test.js +501 -0
- package/dist/ai-enhanced.test.d.ts +2 -0
- package/dist/ai-enhanced.test.d.ts.map +1 -0
- package/dist/ai-enhanced.test.js +587 -0
- package/dist/ai-enhanced.types.d.ts +277 -0
- package/dist/ai-enhanced.types.d.ts.map +1 -0
- package/dist/ai-enhanced.types.js +2 -0
- package/dist/ai.decorator.d.ts +4 -0
- package/dist/ai.decorator.d.ts.map +1 -0
- package/dist/ai.decorator.js +57 -0
- package/dist/ai.decorator.test.d.ts +2 -0
- package/dist/ai.decorator.test.d.ts.map +1 -0
- package/dist/ai.decorator.test.js +189 -0
- package/dist/ai.module.d.ts +12 -0
- package/dist/ai.module.d.ts.map +1 -0
- package/dist/ai.module.js +44 -0
- package/dist/ai.module.test.d.ts +2 -0
- package/dist/ai.module.test.d.ts.map +1 -0
- package/dist/ai.module.test.js +23 -0
- package/dist/ai.service.d.ts +11 -0
- package/dist/ai.service.d.ts.map +1 -0
- package/dist/ai.service.js +266 -0
- package/dist/ai.service.test.d.ts +2 -0
- package/dist/ai.service.test.d.ts.map +1 -0
- package/dist/ai.service.test.js +222 -0
- package/dist/ai.types.d.ts +30 -0
- package/dist/ai.types.d.ts.map +1 -0
- package/dist/ai.types.js +2 -0
- package/dist/context/context.manager.d.ts +69 -0
- package/dist/context/context.manager.d.ts.map +1 -0
- package/dist/context/context.manager.js +168 -0
- package/dist/context/context.manager.test.d.ts +2 -0
- package/dist/context/context.manager.test.d.ts.map +1 -0
- package/dist/context/context.manager.test.js +180 -0
- package/dist/decorators/ai-function.decorator.d.ts +42 -0
- package/dist/decorators/ai-function.decorator.d.ts.map +1 -0
- package/dist/decorators/ai-function.decorator.js +80 -0
- package/dist/decorators/ai-validate.decorator.d.ts +46 -0
- package/dist/decorators/ai-validate.decorator.d.ts.map +1 -0
- package/dist/decorators/ai-validate.decorator.js +83 -0
- package/dist/index.d.ts +18 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/prompts/task.prompt.d.ts +12 -0
- package/dist/prompts/task.prompt.d.ts.map +1 -0
- package/dist/prompts/task.prompt.js +12 -0
- package/dist/providers/anthropic.provider.d.ts +48 -0
- package/dist/providers/anthropic.provider.d.ts.map +1 -0
- package/dist/providers/anthropic.provider.js +194 -0
- package/dist/providers/anthropic.provider.test.d.ts +2 -0
- package/dist/providers/anthropic.provider.test.d.ts.map +1 -0
- package/dist/providers/anthropic.provider.test.js +222 -0
- package/dist/providers/cohere.provider.d.ts +57 -0
- package/dist/providers/cohere.provider.d.ts.map +1 -0
- package/dist/providers/cohere.provider.js +230 -0
- package/dist/providers/cohere.provider.test.d.ts +2 -0
- package/dist/providers/cohere.provider.test.d.ts.map +1 -0
- package/dist/providers/cohere.provider.test.js +267 -0
- package/dist/providers/gemini.provider.d.ts +45 -0
- package/dist/providers/gemini.provider.d.ts.map +1 -0
- package/dist/providers/gemini.provider.js +180 -0
- package/dist/providers/gemini.provider.test.d.ts +2 -0
- package/dist/providers/gemini.provider.test.d.ts.map +1 -0
- package/dist/providers/gemini.provider.test.js +219 -0
- package/dist/providers/ollama.provider.d.ts +45 -0
- package/dist/providers/ollama.provider.d.ts.map +1 -0
- package/dist/providers/ollama.provider.js +232 -0
- package/dist/providers/ollama.provider.test.d.ts +2 -0
- package/dist/providers/ollama.provider.test.d.ts.map +1 -0
- package/dist/providers/ollama.provider.test.js +267 -0
- package/dist/providers/openai.provider.d.ts +57 -0
- package/dist/providers/openai.provider.d.ts.map +1 -0
- package/dist/providers/openai.provider.js +320 -0
- package/dist/providers/openai.provider.test.d.ts +2 -0
- package/dist/providers/openai.provider.test.d.ts.map +1 -0
- package/dist/providers/openai.provider.test.js +364 -0
- package/dist/tracking/token.tracker.d.ts +72 -0
- package/dist/tracking/token.tracker.d.ts.map +1 -0
- package/dist/tracking/token.tracker.js +222 -0
- package/dist/tracking/token.tracker.test.d.ts +2 -0
- package/dist/tracking/token.tracker.test.d.ts.map +1 -0
- package/dist/tracking/token.tracker.test.js +272 -0
- package/dist/vector/vector.service.d.ts +50 -0
- package/dist/vector/vector.service.d.ts.map +1 -0
- package/dist/vector/vector.service.js +163 -0
- package/package.json +60 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __importDefault = (this && this.__importDefault) || function (mod) {
|
|
3
|
+
return (mod && mod.__esModule) ? mod : { "default": mod };
|
|
4
|
+
};
|
|
5
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
6
|
+
exports.GeminiProvider = void 0;
|
|
7
|
+
const core_1 = __importDefault(require("@hazeljs/core"));
|
|
8
|
+
const generative_ai_1 = require("@google/generative-ai");
|
|
9
|
+
/**
|
|
10
|
+
* Google Gemini AI Provider
|
|
11
|
+
*
|
|
12
|
+
* Production-ready implementation using Google Generative AI SDK.
|
|
13
|
+
*
|
|
14
|
+
* Setup:
|
|
15
|
+
* 1. Install the SDK: `npm install @google/generative-ai`
|
|
16
|
+
* 2. Set GEMINI_API_KEY environment variable
|
|
17
|
+
* 3. Use the provider in your application
|
|
18
|
+
*
|
|
19
|
+
* Supported models:
|
|
20
|
+
* - gemini-pro: Text generation
|
|
21
|
+
* - gemini-pro-vision: Multimodal (text + images)
|
|
22
|
+
* - gemini-1.5-pro: Latest model with extended context
|
|
23
|
+
* - text-embedding-004: Text embeddings
|
|
24
|
+
*/
|
|
25
|
+
class GeminiProvider {
|
|
26
|
+
constructor(apiKey, endpoint) {
|
|
27
|
+
this.name = 'gemini';
|
|
28
|
+
this.apiKey = apiKey || process.env.GEMINI_API_KEY || '';
|
|
29
|
+
this.endpoint = endpoint || 'https://generativelanguage.googleapis.com/v1';
|
|
30
|
+
if (!this.apiKey) {
|
|
31
|
+
core_1.default.warn('Gemini API key not provided. Set GEMINI_API_KEY environment variable.');
|
|
32
|
+
}
|
|
33
|
+
this.genAI = new generative_ai_1.GoogleGenerativeAI(this.apiKey);
|
|
34
|
+
core_1.default.info('Gemini provider initialized');
|
|
35
|
+
}
|
|
36
|
+
/**
|
|
37
|
+
* Generate completion
|
|
38
|
+
*/
|
|
39
|
+
async complete(request) {
|
|
40
|
+
const modelName = request.model || 'gemini-pro';
|
|
41
|
+
core_1.default.debug(`Gemini completion request for model: ${modelName}`);
|
|
42
|
+
try {
|
|
43
|
+
const model = this.genAI.getGenerativeModel({ model: modelName });
|
|
44
|
+
// Convert messages to Gemini format
|
|
45
|
+
const prompt = request.messages
|
|
46
|
+
.map((m) => {
|
|
47
|
+
const role = m.role === 'assistant' ? 'model' : m.role;
|
|
48
|
+
return `${role}: ${m.content}`;
|
|
49
|
+
})
|
|
50
|
+
.join('\n\n');
|
|
51
|
+
// Generate content
|
|
52
|
+
const result = await model.generateContent(prompt);
|
|
53
|
+
const response = result.response;
|
|
54
|
+
const text = response.text();
|
|
55
|
+
return {
|
|
56
|
+
id: `gemini-${Date.now()}`,
|
|
57
|
+
content: text,
|
|
58
|
+
role: 'assistant',
|
|
59
|
+
model: modelName,
|
|
60
|
+
usage: {
|
|
61
|
+
promptTokens: response.usageMetadata?.promptTokenCount || 0,
|
|
62
|
+
completionTokens: response.usageMetadata?.candidatesTokenCount || 0,
|
|
63
|
+
totalTokens: response.usageMetadata?.totalTokenCount || 0,
|
|
64
|
+
},
|
|
65
|
+
finishReason: response.candidates?.[0]?.finishReason || 'STOP',
|
|
66
|
+
};
|
|
67
|
+
}
|
|
68
|
+
catch (error) {
|
|
69
|
+
core_1.default.error('Gemini completion error:', error);
|
|
70
|
+
throw new Error(`Gemini API error: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
/**
|
|
74
|
+
* Generate streaming completion
|
|
75
|
+
*/
|
|
76
|
+
async *streamComplete(request) {
|
|
77
|
+
const modelName = request.model || 'gemini-pro';
|
|
78
|
+
core_1.default.debug('Gemini streaming completion started');
|
|
79
|
+
try {
|
|
80
|
+
const model = this.genAI.getGenerativeModel({ model: modelName });
|
|
81
|
+
// Convert messages to Gemini format
|
|
82
|
+
const prompt = request.messages
|
|
83
|
+
.map((m) => {
|
|
84
|
+
const role = m.role === 'assistant' ? 'model' : m.role;
|
|
85
|
+
return `${role}: ${m.content}`;
|
|
86
|
+
})
|
|
87
|
+
.join('\n\n');
|
|
88
|
+
// Generate streaming content
|
|
89
|
+
const result = await model.generateContentStream(prompt);
|
|
90
|
+
let fullContent = '';
|
|
91
|
+
let chunkCount = 0;
|
|
92
|
+
for await (const chunk of result.stream) {
|
|
93
|
+
const text = chunk.text();
|
|
94
|
+
fullContent += text;
|
|
95
|
+
chunkCount++;
|
|
96
|
+
const isLast = chunk.candidates?.[0]?.finishReason !== undefined;
|
|
97
|
+
yield {
|
|
98
|
+
id: `gemini-stream-${Date.now()}-${chunkCount}`,
|
|
99
|
+
content: fullContent,
|
|
100
|
+
delta: text,
|
|
101
|
+
done: isLast,
|
|
102
|
+
usage: isLast && chunk.usageMetadata
|
|
103
|
+
? {
|
|
104
|
+
promptTokens: chunk.usageMetadata.promptTokenCount || 0,
|
|
105
|
+
completionTokens: chunk.usageMetadata.candidatesTokenCount || 0,
|
|
106
|
+
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
|
107
|
+
}
|
|
108
|
+
: undefined,
|
|
109
|
+
};
|
|
110
|
+
}
|
|
111
|
+
core_1.default.debug('Gemini streaming completed');
|
|
112
|
+
}
|
|
113
|
+
catch (error) {
|
|
114
|
+
core_1.default.error('Gemini streaming error:', error);
|
|
115
|
+
throw new Error(`Gemini streaming error: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
/**
|
|
119
|
+
* Generate embeddings
|
|
120
|
+
*/
|
|
121
|
+
async embed(request) {
|
|
122
|
+
const modelName = request.model || 'text-embedding-004';
|
|
123
|
+
core_1.default.debug(`Gemini embedding request for model: ${modelName}`);
|
|
124
|
+
try {
|
|
125
|
+
const model = this.genAI.getGenerativeModel({ model: modelName });
|
|
126
|
+
const inputs = Array.isArray(request.input) ? request.input : [request.input];
|
|
127
|
+
// Generate embeddings for each input
|
|
128
|
+
const embeddings = await Promise.all(inputs.map(async (text) => {
|
|
129
|
+
const result = await model.embedContent(text);
|
|
130
|
+
return result.embedding.values;
|
|
131
|
+
}));
|
|
132
|
+
// Estimate token usage (Gemini doesn't provide exact counts for embeddings)
|
|
133
|
+
const estimatedTokens = inputs.reduce((sum, text) => sum + Math.ceil(text.length / 4), 0);
|
|
134
|
+
return {
|
|
135
|
+
embeddings,
|
|
136
|
+
model: modelName,
|
|
137
|
+
usage: {
|
|
138
|
+
promptTokens: estimatedTokens,
|
|
139
|
+
totalTokens: estimatedTokens,
|
|
140
|
+
},
|
|
141
|
+
};
|
|
142
|
+
}
|
|
143
|
+
catch (error) {
|
|
144
|
+
core_1.default.error('Gemini embedding error:', error);
|
|
145
|
+
throw new Error(`Gemini embedding error: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
/**
|
|
149
|
+
* Check if provider is available
|
|
150
|
+
*/
|
|
151
|
+
async isAvailable() {
|
|
152
|
+
if (!this.apiKey) {
|
|
153
|
+
core_1.default.warn('Gemini API key not configured');
|
|
154
|
+
return false;
|
|
155
|
+
}
|
|
156
|
+
try {
|
|
157
|
+
// Test with a minimal request
|
|
158
|
+
const model = this.genAI.getGenerativeModel({ model: 'gemini-pro' });
|
|
159
|
+
await model.generateContent('test');
|
|
160
|
+
return true;
|
|
161
|
+
}
|
|
162
|
+
catch (error) {
|
|
163
|
+
core_1.default.error('Gemini availability check failed:', error);
|
|
164
|
+
return false;
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
/**
|
|
168
|
+
* Get supported models
|
|
169
|
+
*/
|
|
170
|
+
getSupportedModels() {
|
|
171
|
+
return [
|
|
172
|
+
'gemini-pro',
|
|
173
|
+
'gemini-pro-vision',
|
|
174
|
+
'gemini-1.5-pro',
|
|
175
|
+
'gemini-1.5-flash',
|
|
176
|
+
'text-embedding-004',
|
|
177
|
+
];
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
exports.GeminiProvider = GeminiProvider;
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"gemini.provider.test.d.ts","sourceRoot":"","sources":["../../src/providers/gemini.provider.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
jest.mock('@hazeljs/core', () => ({
|
|
4
|
+
__esModule: true,
|
|
5
|
+
default: { info: jest.fn(), debug: jest.fn(), warn: jest.fn(), error: jest.fn() },
|
|
6
|
+
}));
|
|
7
|
+
const mockGenerateContent = jest.fn();
|
|
8
|
+
const mockGenerateContentStream = jest.fn();
|
|
9
|
+
const mockEmbedContent = jest.fn();
|
|
10
|
+
const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
|
11
|
+
generateContent: mockGenerateContent,
|
|
12
|
+
generateContentStream: mockGenerateContentStream,
|
|
13
|
+
embedContent: mockEmbedContent,
|
|
14
|
+
});
|
|
15
|
+
jest.mock('@google/generative-ai', () => ({
|
|
16
|
+
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
|
17
|
+
getGenerativeModel: mockGetGenerativeModel,
|
|
18
|
+
})),
|
|
19
|
+
}));
|
|
20
|
+
const gemini_provider_1 = require("./gemini.provider");
|
|
21
|
+
const BASE_REQUEST = {
|
|
22
|
+
messages: [{ role: 'user', content: 'Hello Gemini' }],
|
|
23
|
+
model: 'gemini-pro',
|
|
24
|
+
};
|
|
25
|
+
const MOCK_GENERATE_RESULT = {
|
|
26
|
+
response: {
|
|
27
|
+
text: jest.fn().mockReturnValue('Gemini response'),
|
|
28
|
+
usageMetadata: {
|
|
29
|
+
promptTokenCount: 8,
|
|
30
|
+
candidatesTokenCount: 12,
|
|
31
|
+
totalTokenCount: 20,
|
|
32
|
+
},
|
|
33
|
+
candidates: [{ finishReason: 'STOP' }],
|
|
34
|
+
},
|
|
35
|
+
};
|
|
36
|
+
describe('GeminiProvider', () => {
|
|
37
|
+
let provider;
|
|
38
|
+
beforeEach(() => {
|
|
39
|
+
jest.clearAllMocks();
|
|
40
|
+
MOCK_GENERATE_RESULT.response.text.mockReturnValue('Gemini response');
|
|
41
|
+
provider = new gemini_provider_1.GeminiProvider('test-api-key');
|
|
42
|
+
});
|
|
43
|
+
describe('constructor', () => {
|
|
44
|
+
it('sets name to gemini', () => {
|
|
45
|
+
expect(provider.name).toBe('gemini');
|
|
46
|
+
});
|
|
47
|
+
it('warns when no API key', () => {
|
|
48
|
+
new gemini_provider_1.GeminiProvider(); // Should not throw
|
|
49
|
+
});
|
|
50
|
+
it('uses GEMINI_API_KEY env var', () => {
|
|
51
|
+
process.env.GEMINI_API_KEY = 'env-key';
|
|
52
|
+
const p = new gemini_provider_1.GeminiProvider();
|
|
53
|
+
expect(p).toBeDefined();
|
|
54
|
+
delete process.env.GEMINI_API_KEY;
|
|
55
|
+
});
|
|
56
|
+
});
|
|
57
|
+
describe('getSupportedModels()', () => {
|
|
58
|
+
it('returns list of gemini models', () => {
|
|
59
|
+
const models = provider.getSupportedModels();
|
|
60
|
+
expect(models).toContain('gemini-pro');
|
|
61
|
+
expect(models.length).toBeGreaterThan(0);
|
|
62
|
+
});
|
|
63
|
+
});
|
|
64
|
+
describe('complete()', () => {
|
|
65
|
+
it('returns a completion response', async () => {
|
|
66
|
+
mockGenerateContent.mockResolvedValue(MOCK_GENERATE_RESULT);
|
|
67
|
+
const result = await provider.complete(BASE_REQUEST);
|
|
68
|
+
expect(result.content).toBe('Gemini response');
|
|
69
|
+
expect(result.role).toBe('assistant');
|
|
70
|
+
expect(result.usage?.promptTokens).toBe(8);
|
|
71
|
+
expect(result.usage?.completionTokens).toBe(12);
|
|
72
|
+
expect(result.usage?.totalTokens).toBe(20);
|
|
73
|
+
});
|
|
74
|
+
it('uses default model when not specified', async () => {
|
|
75
|
+
mockGenerateContent.mockResolvedValue(MOCK_GENERATE_RESULT);
|
|
76
|
+
await provider.complete({ messages: [{ role: 'user', content: 'hi' }] });
|
|
77
|
+
expect(mockGetGenerativeModel).toHaveBeenCalledWith({ model: 'gemini-pro' });
|
|
78
|
+
});
|
|
79
|
+
it('converts messages to prompt format', async () => {
|
|
80
|
+
mockGenerateContent.mockResolvedValue(MOCK_GENERATE_RESULT);
|
|
81
|
+
await provider.complete({
|
|
82
|
+
messages: [
|
|
83
|
+
{ role: 'user', content: 'User msg' },
|
|
84
|
+
{ role: 'assistant', content: 'Asst msg' },
|
|
85
|
+
{ role: 'system', content: 'Sys msg' },
|
|
86
|
+
],
|
|
87
|
+
});
|
|
88
|
+
const callArg = mockGenerateContent.mock.calls[0][0];
|
|
89
|
+
expect(callArg).toContain('user: User msg');
|
|
90
|
+
expect(callArg).toContain('model: Asst msg');
|
|
91
|
+
expect(callArg).toContain('system: Sys msg');
|
|
92
|
+
});
|
|
93
|
+
it('handles missing usageMetadata gracefully', async () => {
|
|
94
|
+
mockGenerateContent.mockResolvedValue({
|
|
95
|
+
response: {
|
|
96
|
+
text: jest.fn().mockReturnValue('ok'),
|
|
97
|
+
usageMetadata: undefined,
|
|
98
|
+
candidates: undefined,
|
|
99
|
+
},
|
|
100
|
+
});
|
|
101
|
+
const result = await provider.complete(BASE_REQUEST);
|
|
102
|
+
expect(result.usage?.totalTokens).toBe(0);
|
|
103
|
+
expect(result.finishReason).toBe('STOP');
|
|
104
|
+
});
|
|
105
|
+
it('throws wrapped error on API failure', async () => {
|
|
106
|
+
mockGenerateContent.mockRejectedValue(new Error('Quota exceeded'));
|
|
107
|
+
await expect(provider.complete(BASE_REQUEST)).rejects.toThrow('Gemini API error: Quota exceeded');
|
|
108
|
+
});
|
|
109
|
+
it('wraps non-Error thrown values', async () => {
|
|
110
|
+
mockGenerateContent.mockRejectedValue('string error');
|
|
111
|
+
await expect(provider.complete(BASE_REQUEST)).rejects.toThrow('Gemini API error: Unknown error');
|
|
112
|
+
});
|
|
113
|
+
});
|
|
114
|
+
describe('streamComplete()', () => {
|
|
115
|
+
it('yields chunks from stream', async () => {
|
|
116
|
+
async function* mockStream() {
|
|
117
|
+
yield {
|
|
118
|
+
text: jest.fn().mockReturnValue('Hello '),
|
|
119
|
+
candidates: undefined,
|
|
120
|
+
usageMetadata: undefined,
|
|
121
|
+
};
|
|
122
|
+
yield {
|
|
123
|
+
text: jest.fn().mockReturnValue('world'),
|
|
124
|
+
candidates: [{ finishReason: 'STOP' }],
|
|
125
|
+
usageMetadata: { promptTokenCount: 5, candidatesTokenCount: 3, totalTokenCount: 8 },
|
|
126
|
+
};
|
|
127
|
+
}
|
|
128
|
+
mockGenerateContentStream.mockResolvedValue({ stream: mockStream() });
|
|
129
|
+
const results = [];
|
|
130
|
+
for await (const chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
131
|
+
results.push(chunk);
|
|
132
|
+
}
|
|
133
|
+
expect(results.length).toBe(2);
|
|
134
|
+
});
|
|
135
|
+
it('marks last chunk as done when finishReason is set', async () => {
|
|
136
|
+
async function* mockStream() {
|
|
137
|
+
yield {
|
|
138
|
+
text: jest.fn().mockReturnValue('end'),
|
|
139
|
+
candidates: [{ finishReason: 'STOP' }],
|
|
140
|
+
usageMetadata: { promptTokenCount: 3, candidatesTokenCount: 2, totalTokenCount: 5 },
|
|
141
|
+
};
|
|
142
|
+
}
|
|
143
|
+
mockGenerateContentStream.mockResolvedValue({ stream: mockStream() });
|
|
144
|
+
const results = [];
|
|
145
|
+
for await (const chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
146
|
+
results.push(chunk);
|
|
147
|
+
}
|
|
148
|
+
expect(results[0].done).toBe(true);
|
|
149
|
+
expect(results[0].usage).toBeDefined();
|
|
150
|
+
});
|
|
151
|
+
it('yields chunks without usage when not last', async () => {
|
|
152
|
+
async function* mockStream() {
|
|
153
|
+
yield {
|
|
154
|
+
text: jest.fn().mockReturnValue('partial'),
|
|
155
|
+
candidates: undefined,
|
|
156
|
+
usageMetadata: undefined,
|
|
157
|
+
};
|
|
158
|
+
}
|
|
159
|
+
mockGenerateContentStream.mockResolvedValue({ stream: mockStream() });
|
|
160
|
+
const results = [];
|
|
161
|
+
for await (const chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
162
|
+
results.push(chunk);
|
|
163
|
+
}
|
|
164
|
+
expect(results[0].done).toBe(false);
|
|
165
|
+
expect(results[0].usage).toBeUndefined();
|
|
166
|
+
});
|
|
167
|
+
it('throws wrapped error on stream failure', async () => {
|
|
168
|
+
mockGenerateContentStream.mockRejectedValue(new Error('Stream failed'));
|
|
169
|
+
await expect(async () => {
|
|
170
|
+
for await (const _chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
171
|
+
// consume
|
|
172
|
+
}
|
|
173
|
+
}).rejects.toThrow('Gemini streaming error: Stream failed');
|
|
174
|
+
});
|
|
175
|
+
});
|
|
176
|
+
describe('embed()', () => {
|
|
177
|
+
it('returns embeddings for single string input', async () => {
|
|
178
|
+
mockEmbedContent.mockResolvedValue({ embedding: { values: [0.1, 0.2, 0.3] } });
|
|
179
|
+
const result = await provider.embed({ input: 'hello world', model: 'text-embedding-004' });
|
|
180
|
+
expect(result.embeddings).toHaveLength(1);
|
|
181
|
+
expect(result.embeddings[0]).toEqual([0.1, 0.2, 0.3]);
|
|
182
|
+
});
|
|
183
|
+
it('returns embeddings for array input', async () => {
|
|
184
|
+
mockEmbedContent.mockResolvedValue({ embedding: { values: [0.5, 0.6] } });
|
|
185
|
+
const result = await provider.embed({ input: ['first', 'second'] });
|
|
186
|
+
expect(result.embeddings).toHaveLength(2);
|
|
187
|
+
expect(mockEmbedContent).toHaveBeenCalledTimes(2);
|
|
188
|
+
});
|
|
189
|
+
it('uses default model text-embedding-004', async () => {
|
|
190
|
+
mockEmbedContent.mockResolvedValue({ embedding: { values: [0.1] } });
|
|
191
|
+
const result = await provider.embed({ input: 'test' });
|
|
192
|
+
expect(mockGetGenerativeModel).toHaveBeenCalledWith({ model: 'text-embedding-004' });
|
|
193
|
+
expect(result.model).toBe('text-embedding-004');
|
|
194
|
+
});
|
|
195
|
+
it('estimates token usage based on input length', async () => {
|
|
196
|
+
mockEmbedContent.mockResolvedValue({ embedding: { values: [0.1] } });
|
|
197
|
+
const result = await provider.embed({ input: 'test input' }); // 10 chars → 3 tokens
|
|
198
|
+
expect(result.usage?.promptTokens).toBeGreaterThan(0);
|
|
199
|
+
});
|
|
200
|
+
it('throws wrapped error on API failure', async () => {
|
|
201
|
+
mockEmbedContent.mockRejectedValue(new Error('Embedding failed'));
|
|
202
|
+
await expect(provider.embed({ input: 'test' })).rejects.toThrow('Gemini embedding error: Embedding failed');
|
|
203
|
+
});
|
|
204
|
+
});
|
|
205
|
+
describe('isAvailable()', () => {
|
|
206
|
+
it('returns false when no API key', async () => {
|
|
207
|
+
const p = new gemini_provider_1.GeminiProvider('');
|
|
208
|
+
expect(await p.isAvailable()).toBe(false);
|
|
209
|
+
});
|
|
210
|
+
it('returns true when API responds', async () => {
|
|
211
|
+
mockGenerateContent.mockResolvedValue(MOCK_GENERATE_RESULT);
|
|
212
|
+
expect(await provider.isAvailable()).toBe(true);
|
|
213
|
+
});
|
|
214
|
+
it('returns false on API error', async () => {
|
|
215
|
+
mockGenerateContent.mockRejectedValue(new Error('API error'));
|
|
216
|
+
expect(await provider.isAvailable()).toBe(false);
|
|
217
|
+
});
|
|
218
|
+
});
|
|
219
|
+
});
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import { IAIProvider, AIProvider, AICompletionRequest, AICompletionResponse, AIStreamChunk, AIEmbeddingRequest, AIEmbeddingResponse } from '../ai-enhanced.types';
|
|
2
|
+
/**
|
|
3
|
+
* Ollama Provider
|
|
4
|
+
* Production-ready implementation for local LLM support via Ollama
|
|
5
|
+
* Supports models like Llama 2, Mistral, CodeLlama, and other open-source models
|
|
6
|
+
*/
|
|
7
|
+
export declare class OllamaProvider implements IAIProvider {
|
|
8
|
+
readonly name: AIProvider;
|
|
9
|
+
private baseURL;
|
|
10
|
+
private defaultModel;
|
|
11
|
+
constructor(config?: {
|
|
12
|
+
baseURL?: string;
|
|
13
|
+
defaultModel?: string;
|
|
14
|
+
});
|
|
15
|
+
/**
|
|
16
|
+
* Transform messages to Ollama prompt format
|
|
17
|
+
*/
|
|
18
|
+
private transformMessages;
|
|
19
|
+
/**
|
|
20
|
+
* Generate completion
|
|
21
|
+
*/
|
|
22
|
+
complete(request: AICompletionRequest): Promise<AICompletionResponse>;
|
|
23
|
+
/**
|
|
24
|
+
* Generate streaming completion
|
|
25
|
+
*/
|
|
26
|
+
streamComplete(request: AICompletionRequest): AsyncGenerator<AIStreamChunk>;
|
|
27
|
+
/**
|
|
28
|
+
* Generate embeddings
|
|
29
|
+
*/
|
|
30
|
+
embed(request: AIEmbeddingRequest): Promise<AIEmbeddingResponse>;
|
|
31
|
+
/**
|
|
32
|
+
* Check if provider is available
|
|
33
|
+
*/
|
|
34
|
+
isAvailable(): Promise<boolean>;
|
|
35
|
+
/**
|
|
36
|
+
* Get supported models
|
|
37
|
+
* Note: This returns common models, but Ollama supports any model you pull
|
|
38
|
+
*/
|
|
39
|
+
getSupportedModels(): string[];
|
|
40
|
+
/**
|
|
41
|
+
* Get supported embedding models
|
|
42
|
+
*/
|
|
43
|
+
getSupportedEmbeddingModels(): string[];
|
|
44
|
+
}
|
|
45
|
+
//# sourceMappingURL=ollama.provider.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"ollama.provider.d.ts","sourceRoot":"","sources":["../../src/providers/ollama.provider.ts"],"names":[],"mappings":"AAAA,OAAO,EACL,WAAW,EACX,UAAU,EACV,mBAAmB,EACnB,oBAAoB,EACpB,aAAa,EACb,kBAAkB,EAClB,mBAAmB,EAEpB,MAAM,sBAAsB,CAAC;AAiC9B;;;;GAIG;AACH,qBAAa,cAAe,YAAW,WAAW;IAChD,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAY;IACrC,OAAO,CAAC,OAAO,CAAS;IACxB,OAAO,CAAC,YAAY,CAAS;gBAEjB,MAAM,CAAC,EAAE;QAAE,OAAO,CAAC,EAAE,MAAM,CAAC;QAAC,YAAY,CAAC,EAAE,MAAM,CAAA;KAAE;IAMhE;;OAEG;IACH,OAAO,CAAC,iBAAiB;IAUzB;;OAEG;IACG,QAAQ,CAAC,OAAO,EAAE,mBAAmB,GAAG,OAAO,CAAC,oBAAoB,CAAC;IA+C3E;;OAEG;IACI,cAAc,CAAC,OAAO,EAAE,mBAAmB,GAAG,cAAc,CAAC,aAAa,CAAC;IAqFlF;;OAEG;IACG,KAAK,CAAC,OAAO,EAAE,kBAAkB,GAAG,OAAO,CAAC,mBAAmB,CAAC;IAqCtE;;OAEG;IACG,WAAW,IAAI,OAAO,CAAC,OAAO,CAAC;IAWrC;;;OAGG;IACH,kBAAkB,IAAI,MAAM,EAAE;IAkB9B;;OAEG;IACH,2BAA2B,IAAI,MAAM,EAAE;CAGxC"}
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __importDefault = (this && this.__importDefault) || function (mod) {
|
|
3
|
+
return (mod && mod.__esModule) ? mod : { "default": mod };
|
|
4
|
+
};
|
|
5
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
6
|
+
exports.OllamaProvider = void 0;
|
|
7
|
+
const core_1 = __importDefault(require("@hazeljs/core"));
|
|
8
|
+
/**
|
|
9
|
+
* Ollama Provider
|
|
10
|
+
* Production-ready implementation for local LLM support via Ollama
|
|
11
|
+
* Supports models like Llama 2, Mistral, CodeLlama, and other open-source models
|
|
12
|
+
*/
|
|
13
|
+
class OllamaProvider {
|
|
14
|
+
constructor(config) {
|
|
15
|
+
this.name = 'ollama';
|
|
16
|
+
this.baseURL = config?.baseURL || process.env.OLLAMA_BASE_URL || 'http://localhost:11434';
|
|
17
|
+
this.defaultModel = config?.defaultModel || 'llama2';
|
|
18
|
+
core_1.default.info(`Ollama provider initialized with base URL: ${this.baseURL}`);
|
|
19
|
+
}
|
|
20
|
+
/**
|
|
21
|
+
* Transform messages to Ollama prompt format
|
|
22
|
+
*/
|
|
23
|
+
transformMessages(messages) {
|
|
24
|
+
return messages
|
|
25
|
+
.map((msg) => {
|
|
26
|
+
const role = msg.role === 'assistant' ? 'Assistant' : msg.role === 'system' ? 'System' : 'User';
|
|
27
|
+
return `${role}: ${msg.content}`;
|
|
28
|
+
})
|
|
29
|
+
.join('\n\n');
|
|
30
|
+
}
|
|
31
|
+
/**
|
|
32
|
+
* Generate completion
|
|
33
|
+
*/
|
|
34
|
+
async complete(request) {
|
|
35
|
+
try {
|
|
36
|
+
const model = request.model || this.defaultModel;
|
|
37
|
+
core_1.default.debug(`Ollama completion request for model: ${model}`);
|
|
38
|
+
const prompt = this.transformMessages(request.messages);
|
|
39
|
+
const ollamaRequest = {
|
|
40
|
+
model,
|
|
41
|
+
prompt,
|
|
42
|
+
stream: false,
|
|
43
|
+
temperature: request.temperature ?? 0.7,
|
|
44
|
+
num_predict: request.maxTokens,
|
|
45
|
+
top_p: request.topP,
|
|
46
|
+
};
|
|
47
|
+
const response = await fetch(`${this.baseURL}/api/generate`, {
|
|
48
|
+
method: 'POST',
|
|
49
|
+
headers: { 'Content-Type': 'application/json' },
|
|
50
|
+
body: JSON.stringify(ollamaRequest),
|
|
51
|
+
});
|
|
52
|
+
if (!response.ok) {
|
|
53
|
+
const errorText = await response.text();
|
|
54
|
+
throw new Error(`Ollama API error: ${response.status} ${errorText}`);
|
|
55
|
+
}
|
|
56
|
+
const data = (await response.json());
|
|
57
|
+
return {
|
|
58
|
+
id: `ollama-${Date.now()}`,
|
|
59
|
+
content: data.response,
|
|
60
|
+
role: 'assistant',
|
|
61
|
+
model: data.model,
|
|
62
|
+
usage: {
|
|
63
|
+
promptTokens: data.prompt_eval_count || 0,
|
|
64
|
+
completionTokens: data.eval_count || 0,
|
|
65
|
+
totalTokens: (data.prompt_eval_count || 0) + (data.eval_count || 0),
|
|
66
|
+
},
|
|
67
|
+
finishReason: data.done ? 'stop' : 'length',
|
|
68
|
+
};
|
|
69
|
+
}
|
|
70
|
+
catch (error) {
|
|
71
|
+
core_1.default.error('Ollama completion error:', error);
|
|
72
|
+
throw error;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
/**
|
|
76
|
+
* Generate streaming completion
|
|
77
|
+
*/
|
|
78
|
+
async *streamComplete(request) {
|
|
79
|
+
try {
|
|
80
|
+
const model = request.model || this.defaultModel;
|
|
81
|
+
core_1.default.debug(`Ollama streaming completion request for model: ${model}`);
|
|
82
|
+
const prompt = this.transformMessages(request.messages);
|
|
83
|
+
const ollamaRequest = {
|
|
84
|
+
model,
|
|
85
|
+
prompt,
|
|
86
|
+
stream: true,
|
|
87
|
+
temperature: request.temperature ?? 0.7,
|
|
88
|
+
num_predict: request.maxTokens,
|
|
89
|
+
top_p: request.topP,
|
|
90
|
+
};
|
|
91
|
+
const response = await fetch(`${this.baseURL}/api/generate`, {
|
|
92
|
+
method: 'POST',
|
|
93
|
+
headers: { 'Content-Type': 'application/json' },
|
|
94
|
+
body: JSON.stringify(ollamaRequest),
|
|
95
|
+
});
|
|
96
|
+
if (!response.ok) {
|
|
97
|
+
const errorText = await response.text();
|
|
98
|
+
throw new Error(`Ollama API error: ${response.status} ${errorText}`);
|
|
99
|
+
}
|
|
100
|
+
if (!response.body) {
|
|
101
|
+
throw new Error('No response body available for streaming');
|
|
102
|
+
}
|
|
103
|
+
const reader = response.body.getReader();
|
|
104
|
+
const decoder = new TextDecoder();
|
|
105
|
+
let fullContent = '';
|
|
106
|
+
let totalPromptTokens = 0;
|
|
107
|
+
let totalCompletionTokens = 0;
|
|
108
|
+
const chunkId = `ollama-${Date.now()}`;
|
|
109
|
+
try {
|
|
110
|
+
while (true) {
|
|
111
|
+
const { done, value } = await reader.read();
|
|
112
|
+
if (done)
|
|
113
|
+
break;
|
|
114
|
+
const chunk = decoder.decode(value, { stream: true });
|
|
115
|
+
const lines = chunk.split('\n').filter(Boolean);
|
|
116
|
+
for (const line of lines) {
|
|
117
|
+
try {
|
|
118
|
+
const data = JSON.parse(line);
|
|
119
|
+
if (data.response) {
|
|
120
|
+
fullContent += data.response;
|
|
121
|
+
totalPromptTokens = data.prompt_eval_count || totalPromptTokens;
|
|
122
|
+
totalCompletionTokens = data.eval_count || totalCompletionTokens;
|
|
123
|
+
yield {
|
|
124
|
+
id: chunkId,
|
|
125
|
+
content: fullContent,
|
|
126
|
+
delta: data.response,
|
|
127
|
+
done: data.done || false,
|
|
128
|
+
usage: {
|
|
129
|
+
promptTokens: totalPromptTokens,
|
|
130
|
+
completionTokens: totalCompletionTokens,
|
|
131
|
+
totalTokens: totalPromptTokens + totalCompletionTokens,
|
|
132
|
+
},
|
|
133
|
+
};
|
|
134
|
+
}
|
|
135
|
+
if (data.done) {
|
|
136
|
+
return;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
catch {
|
|
140
|
+
// Skip invalid JSON lines
|
|
141
|
+
continue;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
finally {
|
|
147
|
+
reader.releaseLock();
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
catch (error) {
|
|
151
|
+
core_1.default.error('Ollama streaming error:', error);
|
|
152
|
+
throw error;
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
/**
|
|
156
|
+
* Generate embeddings
|
|
157
|
+
*/
|
|
158
|
+
async embed(request) {
|
|
159
|
+
try {
|
|
160
|
+
const model = request.model || this.defaultModel;
|
|
161
|
+
core_1.default.debug(`Ollama embedding request for model: ${model}`);
|
|
162
|
+
const input = Array.isArray(request.input) ? request.input[0] : request.input;
|
|
163
|
+
const response = await fetch(`${this.baseURL}/api/embeddings`, {
|
|
164
|
+
method: 'POST',
|
|
165
|
+
headers: { 'Content-Type': 'application/json' },
|
|
166
|
+
body: JSON.stringify({
|
|
167
|
+
model,
|
|
168
|
+
prompt: input,
|
|
169
|
+
}),
|
|
170
|
+
});
|
|
171
|
+
if (!response.ok) {
|
|
172
|
+
const errorText = await response.text();
|
|
173
|
+
throw new Error(`Ollama API error: ${response.status} ${errorText}`);
|
|
174
|
+
}
|
|
175
|
+
const data = (await response.json());
|
|
176
|
+
return {
|
|
177
|
+
embeddings: [data.embedding],
|
|
178
|
+
model,
|
|
179
|
+
usage: {
|
|
180
|
+
promptTokens: 0, // Ollama doesn't provide token usage for embeddings
|
|
181
|
+
totalTokens: 0,
|
|
182
|
+
},
|
|
183
|
+
};
|
|
184
|
+
}
|
|
185
|
+
catch (error) {
|
|
186
|
+
core_1.default.error('Ollama embedding error:', error);
|
|
187
|
+
throw error;
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
/**
|
|
191
|
+
* Check if provider is available
|
|
192
|
+
*/
|
|
193
|
+
async isAvailable() {
|
|
194
|
+
try {
|
|
195
|
+
const response = await fetch(`${this.baseURL}/api/tags`, {
|
|
196
|
+
method: 'GET',
|
|
197
|
+
});
|
|
198
|
+
return response.ok;
|
|
199
|
+
}
|
|
200
|
+
catch {
|
|
201
|
+
return false;
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
/**
|
|
205
|
+
* Get supported models
|
|
206
|
+
* Note: This returns common models, but Ollama supports any model you pull
|
|
207
|
+
*/
|
|
208
|
+
getSupportedModels() {
|
|
209
|
+
return [
|
|
210
|
+
'llama2',
|
|
211
|
+
'llama2:13b',
|
|
212
|
+
'llama2:70b',
|
|
213
|
+
'mistral',
|
|
214
|
+
'mixtral',
|
|
215
|
+
'codellama',
|
|
216
|
+
'neural-chat',
|
|
217
|
+
'starling-lm',
|
|
218
|
+
'phi',
|
|
219
|
+
'orca-mini',
|
|
220
|
+
'vicuna',
|
|
221
|
+
'wizardcoder',
|
|
222
|
+
'wizard-vicuna',
|
|
223
|
+
];
|
|
224
|
+
}
|
|
225
|
+
/**
|
|
226
|
+
* Get supported embedding models
|
|
227
|
+
*/
|
|
228
|
+
getSupportedEmbeddingModels() {
|
|
229
|
+
return ['llama2', 'mistral', 'nomic-embed-text'];
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
exports.OllamaProvider = OllamaProvider;
|