@hazeljs/ai 0.2.0-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +192 -0
- package/README.md +497 -0
- package/dist/ai-enhanced.service.d.ts +108 -0
- package/dist/ai-enhanced.service.d.ts.map +1 -0
- package/dist/ai-enhanced.service.js +345 -0
- package/dist/ai-enhanced.service.test.d.ts +2 -0
- package/dist/ai-enhanced.service.test.d.ts.map +1 -0
- package/dist/ai-enhanced.service.test.js +501 -0
- package/dist/ai-enhanced.test.d.ts +2 -0
- package/dist/ai-enhanced.test.d.ts.map +1 -0
- package/dist/ai-enhanced.test.js +587 -0
- package/dist/ai-enhanced.types.d.ts +277 -0
- package/dist/ai-enhanced.types.d.ts.map +1 -0
- package/dist/ai-enhanced.types.js +2 -0
- package/dist/ai.decorator.d.ts +4 -0
- package/dist/ai.decorator.d.ts.map +1 -0
- package/dist/ai.decorator.js +57 -0
- package/dist/ai.decorator.test.d.ts +2 -0
- package/dist/ai.decorator.test.d.ts.map +1 -0
- package/dist/ai.decorator.test.js +189 -0
- package/dist/ai.module.d.ts +12 -0
- package/dist/ai.module.d.ts.map +1 -0
- package/dist/ai.module.js +44 -0
- package/dist/ai.module.test.d.ts +2 -0
- package/dist/ai.module.test.d.ts.map +1 -0
- package/dist/ai.module.test.js +23 -0
- package/dist/ai.service.d.ts +11 -0
- package/dist/ai.service.d.ts.map +1 -0
- package/dist/ai.service.js +266 -0
- package/dist/ai.service.test.d.ts +2 -0
- package/dist/ai.service.test.d.ts.map +1 -0
- package/dist/ai.service.test.js +222 -0
- package/dist/ai.types.d.ts +30 -0
- package/dist/ai.types.d.ts.map +1 -0
- package/dist/ai.types.js +2 -0
- package/dist/context/context.manager.d.ts +69 -0
- package/dist/context/context.manager.d.ts.map +1 -0
- package/dist/context/context.manager.js +168 -0
- package/dist/context/context.manager.test.d.ts +2 -0
- package/dist/context/context.manager.test.d.ts.map +1 -0
- package/dist/context/context.manager.test.js +180 -0
- package/dist/decorators/ai-function.decorator.d.ts +42 -0
- package/dist/decorators/ai-function.decorator.d.ts.map +1 -0
- package/dist/decorators/ai-function.decorator.js +80 -0
- package/dist/decorators/ai-validate.decorator.d.ts +46 -0
- package/dist/decorators/ai-validate.decorator.d.ts.map +1 -0
- package/dist/decorators/ai-validate.decorator.js +83 -0
- package/dist/index.d.ts +18 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/prompts/task.prompt.d.ts +12 -0
- package/dist/prompts/task.prompt.d.ts.map +1 -0
- package/dist/prompts/task.prompt.js +12 -0
- package/dist/providers/anthropic.provider.d.ts +48 -0
- package/dist/providers/anthropic.provider.d.ts.map +1 -0
- package/dist/providers/anthropic.provider.js +194 -0
- package/dist/providers/anthropic.provider.test.d.ts +2 -0
- package/dist/providers/anthropic.provider.test.d.ts.map +1 -0
- package/dist/providers/anthropic.provider.test.js +222 -0
- package/dist/providers/cohere.provider.d.ts +57 -0
- package/dist/providers/cohere.provider.d.ts.map +1 -0
- package/dist/providers/cohere.provider.js +230 -0
- package/dist/providers/cohere.provider.test.d.ts +2 -0
- package/dist/providers/cohere.provider.test.d.ts.map +1 -0
- package/dist/providers/cohere.provider.test.js +267 -0
- package/dist/providers/gemini.provider.d.ts +45 -0
- package/dist/providers/gemini.provider.d.ts.map +1 -0
- package/dist/providers/gemini.provider.js +180 -0
- package/dist/providers/gemini.provider.test.d.ts +2 -0
- package/dist/providers/gemini.provider.test.d.ts.map +1 -0
- package/dist/providers/gemini.provider.test.js +219 -0
- package/dist/providers/ollama.provider.d.ts +45 -0
- package/dist/providers/ollama.provider.d.ts.map +1 -0
- package/dist/providers/ollama.provider.js +232 -0
- package/dist/providers/ollama.provider.test.d.ts +2 -0
- package/dist/providers/ollama.provider.test.d.ts.map +1 -0
- package/dist/providers/ollama.provider.test.js +267 -0
- package/dist/providers/openai.provider.d.ts +57 -0
- package/dist/providers/openai.provider.d.ts.map +1 -0
- package/dist/providers/openai.provider.js +320 -0
- package/dist/providers/openai.provider.test.d.ts +2 -0
- package/dist/providers/openai.provider.test.d.ts.map +1 -0
- package/dist/providers/openai.provider.test.js +364 -0
- package/dist/tracking/token.tracker.d.ts +72 -0
- package/dist/tracking/token.tracker.d.ts.map +1 -0
- package/dist/tracking/token.tracker.js +222 -0
- package/dist/tracking/token.tracker.test.d.ts +2 -0
- package/dist/tracking/token.tracker.test.d.ts.map +1 -0
- package/dist/tracking/token.tracker.test.js +272 -0
- package/dist/vector/vector.service.d.ts +50 -0
- package/dist/vector/vector.service.d.ts.map +1 -0
- package/dist/vector/vector.service.js +163 -0
- package/package.json +60 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __importDefault = (this && this.__importDefault) || function (mod) {
|
|
3
|
+
return (mod && mod.__esModule) ? mod : { "default": mod };
|
|
4
|
+
};
|
|
5
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
6
|
+
exports.CohereProvider = void 0;
|
|
7
|
+
const core_1 = __importDefault(require("@hazeljs/core"));
|
|
8
|
+
const cohere_ai_1 = require("cohere-ai");
|
|
9
|
+
/**
|
|
10
|
+
* Cohere AI Provider
|
|
11
|
+
*
|
|
12
|
+
* Production-ready implementation using Cohere AI SDK.
|
|
13
|
+
*
|
|
14
|
+
* Setup:
|
|
15
|
+
* 1. Install the SDK: `npm install cohere-ai`
|
|
16
|
+
* 2. Set COHERE_API_KEY environment variable
|
|
17
|
+
* 3. Use the provider in your application
|
|
18
|
+
*
|
|
19
|
+
* Supported models:
|
|
20
|
+
* - command-r-plus: Most powerful model for complex tasks
|
|
21
|
+
* - command-r: Balanced performance and cost
|
|
22
|
+
* - command: Standard text generation
|
|
23
|
+
* - command-light: Fast, cost-effective model
|
|
24
|
+
* - embed-english-v3.0: English text embeddings
|
|
25
|
+
* - embed-multilingual-v3.0: Multilingual embeddings
|
|
26
|
+
* - rerank-english-v3.0: Document reranking
|
|
27
|
+
*/
|
|
28
|
+
class CohereProvider {
|
|
29
|
+
constructor(apiKey, endpoint) {
|
|
30
|
+
this.name = 'cohere';
|
|
31
|
+
this.apiKey = apiKey || process.env.COHERE_API_KEY || '';
|
|
32
|
+
this.endpoint = endpoint || 'https://api.cohere.ai/v1';
|
|
33
|
+
if (!this.apiKey) {
|
|
34
|
+
core_1.default.warn('Cohere API key not provided. Set COHERE_API_KEY environment variable.');
|
|
35
|
+
}
|
|
36
|
+
this.cohere = new cohere_ai_1.CohereClient({ token: this.apiKey });
|
|
37
|
+
core_1.default.info('Cohere provider initialized');
|
|
38
|
+
}
|
|
39
|
+
/**
|
|
40
|
+
* Generate completion
|
|
41
|
+
*/
|
|
42
|
+
async complete(request) {
|
|
43
|
+
const modelName = request.model || 'command';
|
|
44
|
+
core_1.default.debug(`Cohere completion request for model: ${modelName}`);
|
|
45
|
+
try {
|
|
46
|
+
// Convert messages to prompt format
|
|
47
|
+
const prompt = request.messages.map((m) => `${m.role}: ${m.content}`).join('\n\n');
|
|
48
|
+
// Generate completion
|
|
49
|
+
const response = await this.cohere.generate({
|
|
50
|
+
model: modelName,
|
|
51
|
+
prompt,
|
|
52
|
+
temperature: request.temperature,
|
|
53
|
+
maxTokens: request.maxTokens,
|
|
54
|
+
p: request.topP,
|
|
55
|
+
});
|
|
56
|
+
return {
|
|
57
|
+
id: response.id || `cohere-${Date.now()}`,
|
|
58
|
+
content: response.generations[0].text,
|
|
59
|
+
role: 'assistant',
|
|
60
|
+
model: modelName,
|
|
61
|
+
usage: {
|
|
62
|
+
promptTokens: response.meta?.billedUnits?.inputTokens || 0,
|
|
63
|
+
completionTokens: response.meta?.billedUnits?.outputTokens || 0,
|
|
64
|
+
totalTokens: (response.meta?.billedUnits?.inputTokens || 0) +
|
|
65
|
+
(response.meta?.billedUnits?.outputTokens || 0),
|
|
66
|
+
},
|
|
67
|
+
finishReason: 'COMPLETE',
|
|
68
|
+
};
|
|
69
|
+
}
|
|
70
|
+
catch (error) {
|
|
71
|
+
core_1.default.error('Cohere completion error:', error);
|
|
72
|
+
throw new Error(`Cohere API error: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
/**
|
|
76
|
+
* Generate streaming completion
|
|
77
|
+
*/
|
|
78
|
+
async *streamComplete(request) {
|
|
79
|
+
const modelName = request.model || 'command';
|
|
80
|
+
core_1.default.debug('Cohere streaming completion started');
|
|
81
|
+
try {
|
|
82
|
+
// Convert messages to prompt format
|
|
83
|
+
const prompt = request.messages.map((m) => `${m.role}: ${m.content}`).join('\n\n');
|
|
84
|
+
// Generate streaming completion
|
|
85
|
+
const stream = await this.cohere.generateStream({
|
|
86
|
+
model: modelName,
|
|
87
|
+
prompt,
|
|
88
|
+
temperature: request.temperature,
|
|
89
|
+
maxTokens: request.maxTokens,
|
|
90
|
+
p: request.topP,
|
|
91
|
+
});
|
|
92
|
+
let fullContent = '';
|
|
93
|
+
const streamId = `cohere-stream-${Date.now()}`;
|
|
94
|
+
for await (const chunk of stream) {
|
|
95
|
+
if (chunk.eventType === 'text-generation') {
|
|
96
|
+
const text = chunk.text || '';
|
|
97
|
+
fullContent += text;
|
|
98
|
+
yield {
|
|
99
|
+
id: streamId,
|
|
100
|
+
content: fullContent,
|
|
101
|
+
delta: text,
|
|
102
|
+
done: false,
|
|
103
|
+
};
|
|
104
|
+
}
|
|
105
|
+
else if (chunk.eventType === 'stream-end') {
|
|
106
|
+
const response = chunk;
|
|
107
|
+
yield {
|
|
108
|
+
id: streamId,
|
|
109
|
+
content: fullContent,
|
|
110
|
+
delta: '',
|
|
111
|
+
done: true,
|
|
112
|
+
usage: response.response?.meta?.billedUnits
|
|
113
|
+
? {
|
|
114
|
+
promptTokens: response.response.meta.billedUnits.inputTokens || 0,
|
|
115
|
+
completionTokens: response.response.meta.billedUnits.outputTokens || 0,
|
|
116
|
+
totalTokens: (response.response.meta.billedUnits.inputTokens || 0) +
|
|
117
|
+
(response.response.meta.billedUnits.outputTokens || 0),
|
|
118
|
+
}
|
|
119
|
+
: undefined,
|
|
120
|
+
};
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
core_1.default.debug('Cohere streaming completed');
|
|
124
|
+
}
|
|
125
|
+
catch (error) {
|
|
126
|
+
core_1.default.error('Cohere streaming error:', error);
|
|
127
|
+
throw new Error(`Cohere streaming error: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
/**
|
|
131
|
+
* Generate embeddings
|
|
132
|
+
*/
|
|
133
|
+
async embed(request) {
|
|
134
|
+
const modelName = request.model || 'embed-english-v3.0';
|
|
135
|
+
core_1.default.debug(`Cohere embedding request for model: ${modelName}`);
|
|
136
|
+
try {
|
|
137
|
+
const inputs = Array.isArray(request.input) ? request.input : [request.input];
|
|
138
|
+
// Generate embeddings
|
|
139
|
+
const response = await this.cohere.embed({
|
|
140
|
+
texts: inputs,
|
|
141
|
+
model: modelName,
|
|
142
|
+
inputType: 'search_document',
|
|
143
|
+
});
|
|
144
|
+
// Estimate token usage (Cohere doesn't provide exact counts for embeddings)
|
|
145
|
+
const estimatedTokens = inputs.reduce((sum, text) => sum + Math.ceil(text.length / 4), 0);
|
|
146
|
+
// Handle different response formats
|
|
147
|
+
const embeddings = Array.isArray(response.embeddings)
|
|
148
|
+
? response.embeddings
|
|
149
|
+
: response.embeddings.float || [];
|
|
150
|
+
return {
|
|
151
|
+
embeddings,
|
|
152
|
+
model: modelName,
|
|
153
|
+
usage: {
|
|
154
|
+
promptTokens: estimatedTokens,
|
|
155
|
+
totalTokens: estimatedTokens,
|
|
156
|
+
},
|
|
157
|
+
};
|
|
158
|
+
}
|
|
159
|
+
catch (error) {
|
|
160
|
+
core_1.default.error('Cohere embedding error:', error);
|
|
161
|
+
throw new Error(`Cohere embedding error: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
/**
|
|
165
|
+
* Check if provider is available
|
|
166
|
+
*/
|
|
167
|
+
async isAvailable() {
|
|
168
|
+
if (!this.apiKey) {
|
|
169
|
+
core_1.default.warn('Cohere API key not configured');
|
|
170
|
+
return false;
|
|
171
|
+
}
|
|
172
|
+
try {
|
|
173
|
+
// Test with a minimal request
|
|
174
|
+
await this.cohere.generate({
|
|
175
|
+
model: 'command-light',
|
|
176
|
+
prompt: 'test',
|
|
177
|
+
maxTokens: 10,
|
|
178
|
+
});
|
|
179
|
+
return true;
|
|
180
|
+
}
|
|
181
|
+
catch (error) {
|
|
182
|
+
core_1.default.error('Cohere availability check failed:', error);
|
|
183
|
+
return false;
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
/**
|
|
187
|
+
* Get supported models
|
|
188
|
+
*/
|
|
189
|
+
getSupportedModels() {
|
|
190
|
+
return [
|
|
191
|
+
'command-r-plus',
|
|
192
|
+
'command-r',
|
|
193
|
+
'command',
|
|
194
|
+
'command-light',
|
|
195
|
+
'command-nightly',
|
|
196
|
+
'embed-english-v3.0',
|
|
197
|
+
'embed-multilingual-v3.0',
|
|
198
|
+
'embed-english-light-v3.0',
|
|
199
|
+
'embed-multilingual-light-v3.0',
|
|
200
|
+
'rerank-english-v3.0',
|
|
201
|
+
'rerank-multilingual-v3.0',
|
|
202
|
+
];
|
|
203
|
+
}
|
|
204
|
+
/**
|
|
205
|
+
* Rerank documents (Cohere-specific feature)
|
|
206
|
+
* Useful for RAG applications to improve retrieval quality
|
|
207
|
+
*/
|
|
208
|
+
async rerank(query, documents, topN, model) {
|
|
209
|
+
const modelName = model || 'rerank-english-v3.0';
|
|
210
|
+
core_1.default.debug(`Cohere rerank request for model: ${modelName}`);
|
|
211
|
+
try {
|
|
212
|
+
const response = await this.cohere.rerank({
|
|
213
|
+
query,
|
|
214
|
+
documents,
|
|
215
|
+
topN,
|
|
216
|
+
model: modelName,
|
|
217
|
+
});
|
|
218
|
+
return response.results.map((r) => ({
|
|
219
|
+
index: r.index,
|
|
220
|
+
score: r.relevanceScore,
|
|
221
|
+
document: documents[r.index],
|
|
222
|
+
}));
|
|
223
|
+
}
|
|
224
|
+
catch (error) {
|
|
225
|
+
core_1.default.error('Cohere rerank error:', error);
|
|
226
|
+
throw new Error(`Cohere rerank error: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
exports.CohereProvider = CohereProvider;
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"cohere.provider.test.d.ts","sourceRoot":"","sources":["../../src/providers/cohere.provider.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
jest.mock('@hazeljs/core', () => ({
|
|
4
|
+
__esModule: true,
|
|
5
|
+
default: { info: jest.fn(), debug: jest.fn(), warn: jest.fn(), error: jest.fn() },
|
|
6
|
+
}));
|
|
7
|
+
const mockGenerate = jest.fn();
|
|
8
|
+
const mockGenerateStream = jest.fn();
|
|
9
|
+
const mockEmbed = jest.fn();
|
|
10
|
+
const mockRerank = jest.fn();
|
|
11
|
+
jest.mock('cohere-ai', () => ({
|
|
12
|
+
CohereClient: jest.fn().mockImplementation(() => ({
|
|
13
|
+
generate: mockGenerate,
|
|
14
|
+
generateStream: mockGenerateStream,
|
|
15
|
+
embed: mockEmbed,
|
|
16
|
+
rerank: mockRerank,
|
|
17
|
+
})),
|
|
18
|
+
}));
|
|
19
|
+
const cohere_provider_1 = require("./cohere.provider");
|
|
20
|
+
const BASE_REQUEST = {
|
|
21
|
+
messages: [{ role: 'user', content: 'Hello Cohere' }],
|
|
22
|
+
model: 'command',
|
|
23
|
+
};
|
|
24
|
+
const MOCK_GENERATE_RESPONSE = {
|
|
25
|
+
id: 'cohere-123',
|
|
26
|
+
generations: [{ text: 'Cohere response' }],
|
|
27
|
+
meta: { billedUnits: { inputTokens: 10, outputTokens: 20 } },
|
|
28
|
+
};
|
|
29
|
+
describe('CohereProvider', () => {
|
|
30
|
+
let provider;
|
|
31
|
+
beforeEach(() => {
|
|
32
|
+
jest.clearAllMocks();
|
|
33
|
+
provider = new cohere_provider_1.CohereProvider('test-api-key');
|
|
34
|
+
});
|
|
35
|
+
describe('constructor', () => {
|
|
36
|
+
it('sets name to cohere', () => {
|
|
37
|
+
expect(provider.name).toBe('cohere');
|
|
38
|
+
});
|
|
39
|
+
it('warns when no API key', () => {
|
|
40
|
+
new cohere_provider_1.CohereProvider(); // Should not throw
|
|
41
|
+
});
|
|
42
|
+
it('uses COHERE_API_KEY env var', () => {
|
|
43
|
+
process.env.COHERE_API_KEY = 'env-key';
|
|
44
|
+
const p = new cohere_provider_1.CohereProvider();
|
|
45
|
+
expect(p).toBeDefined();
|
|
46
|
+
delete process.env.COHERE_API_KEY;
|
|
47
|
+
});
|
|
48
|
+
});
|
|
49
|
+
describe('getSupportedModels()', () => {
|
|
50
|
+
it('returns list of cohere models', () => {
|
|
51
|
+
const models = provider.getSupportedModels();
|
|
52
|
+
expect(models).toContain('command');
|
|
53
|
+
expect(models.length).toBeGreaterThan(0);
|
|
54
|
+
});
|
|
55
|
+
});
|
|
56
|
+
describe('complete()', () => {
|
|
57
|
+
it('returns a completion response', async () => {
|
|
58
|
+
mockGenerate.mockResolvedValue(MOCK_GENERATE_RESPONSE);
|
|
59
|
+
const result = await provider.complete(BASE_REQUEST);
|
|
60
|
+
expect(result.content).toBe('Cohere response');
|
|
61
|
+
expect(result.role).toBe('assistant');
|
|
62
|
+
expect(result.usage?.promptTokens).toBe(10);
|
|
63
|
+
expect(result.usage?.completionTokens).toBe(20);
|
|
64
|
+
expect(result.usage?.totalTokens).toBe(30);
|
|
65
|
+
expect(result.finishReason).toBe('COMPLETE');
|
|
66
|
+
});
|
|
67
|
+
it('uses default model when not specified', async () => {
|
|
68
|
+
mockGenerate.mockResolvedValue(MOCK_GENERATE_RESPONSE);
|
|
69
|
+
await provider.complete({ messages: [{ role: 'user', content: 'hi' }] });
|
|
70
|
+
expect(mockGenerate).toHaveBeenCalledWith(expect.objectContaining({ model: 'command' }));
|
|
71
|
+
});
|
|
72
|
+
it('passes temperature and maxTokens to API', async () => {
|
|
73
|
+
mockGenerate.mockResolvedValue(MOCK_GENERATE_RESPONSE);
|
|
74
|
+
await provider.complete({ ...BASE_REQUEST, temperature: 0.5, maxTokens: 200, topP: 0.9 });
|
|
75
|
+
expect(mockGenerate).toHaveBeenCalledWith(expect.objectContaining({ temperature: 0.5, maxTokens: 200, p: 0.9 }));
|
|
76
|
+
});
|
|
77
|
+
it('handles missing meta/billedUnits gracefully', async () => {
|
|
78
|
+
mockGenerate.mockResolvedValue({
|
|
79
|
+
generations: [{ text: 'ok' }],
|
|
80
|
+
meta: undefined,
|
|
81
|
+
});
|
|
82
|
+
const result = await provider.complete(BASE_REQUEST);
|
|
83
|
+
expect(result.usage?.totalTokens).toBe(0);
|
|
84
|
+
});
|
|
85
|
+
it('uses generated id when response has one', async () => {
|
|
86
|
+
mockGenerate.mockResolvedValue(MOCK_GENERATE_RESPONSE);
|
|
87
|
+
const result = await provider.complete(BASE_REQUEST);
|
|
88
|
+
expect(result.id).toBe('cohere-123');
|
|
89
|
+
});
|
|
90
|
+
it('generates a fallback id when response has no id', async () => {
|
|
91
|
+
mockGenerate.mockResolvedValue({ ...MOCK_GENERATE_RESPONSE, id: undefined });
|
|
92
|
+
const result = await provider.complete(BASE_REQUEST);
|
|
93
|
+
expect(result.id).toMatch(/^cohere-\d+$/);
|
|
94
|
+
});
|
|
95
|
+
it('converts messages to prompt format', async () => {
|
|
96
|
+
mockGenerate.mockResolvedValue(MOCK_GENERATE_RESPONSE);
|
|
97
|
+
await provider.complete({
|
|
98
|
+
messages: [
|
|
99
|
+
{ role: 'user', content: 'User msg' },
|
|
100
|
+
{ role: 'assistant', content: 'Asst msg' },
|
|
101
|
+
],
|
|
102
|
+
});
|
|
103
|
+
const callArg = mockGenerate.mock.calls[0][0];
|
|
104
|
+
expect(callArg.prompt).toContain('user: User msg');
|
|
105
|
+
expect(callArg.prompt).toContain('assistant: Asst msg');
|
|
106
|
+
});
|
|
107
|
+
it('throws wrapped error on API failure', async () => {
|
|
108
|
+
mockGenerate.mockRejectedValue(new Error('Quota exceeded'));
|
|
109
|
+
await expect(provider.complete(BASE_REQUEST)).rejects.toThrow('Cohere API error: Quota exceeded');
|
|
110
|
+
});
|
|
111
|
+
it('wraps non-Error thrown values', async () => {
|
|
112
|
+
mockGenerate.mockRejectedValue('string error');
|
|
113
|
+
await expect(provider.complete(BASE_REQUEST)).rejects.toThrow('Cohere API error: Unknown error');
|
|
114
|
+
});
|
|
115
|
+
});
|
|
116
|
+
describe('streamComplete()', () => {
|
|
117
|
+
it('yields text-generation chunks', async () => {
|
|
118
|
+
async function* mockStream() {
|
|
119
|
+
yield { eventType: 'text-generation', text: 'Hello ' };
|
|
120
|
+
yield { eventType: 'text-generation', text: 'world' };
|
|
121
|
+
yield {
|
|
122
|
+
eventType: 'stream-end',
|
|
123
|
+
response: {
|
|
124
|
+
meta: { billedUnits: { inputTokens: 5, outputTokens: 10 } },
|
|
125
|
+
},
|
|
126
|
+
};
|
|
127
|
+
}
|
|
128
|
+
mockGenerateStream.mockReturnValue(mockStream());
|
|
129
|
+
const results = [];
|
|
130
|
+
for await (const chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
131
|
+
results.push(chunk);
|
|
132
|
+
}
|
|
133
|
+
// 2 text chunks + 1 stream-end
|
|
134
|
+
expect(results.length).toBe(3);
|
|
135
|
+
});
|
|
136
|
+
it('yields done=true chunk on stream-end with usage', async () => {
|
|
137
|
+
async function* mockStream() {
|
|
138
|
+
yield {
|
|
139
|
+
eventType: 'stream-end',
|
|
140
|
+
response: { meta: { billedUnits: { inputTokens: 5, outputTokens: 3 } } },
|
|
141
|
+
};
|
|
142
|
+
}
|
|
143
|
+
mockGenerateStream.mockReturnValue(mockStream());
|
|
144
|
+
const results = [];
|
|
145
|
+
for await (const chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
146
|
+
results.push(chunk);
|
|
147
|
+
}
|
|
148
|
+
const last = results[results.length - 1];
|
|
149
|
+
expect(last.done).toBe(true);
|
|
150
|
+
expect(last.usage).toBeDefined();
|
|
151
|
+
});
|
|
152
|
+
it('yields stream-end with undefined usage when no billedUnits', async () => {
|
|
153
|
+
async function* mockStream() {
|
|
154
|
+
yield { eventType: 'stream-end', response: { meta: {} } };
|
|
155
|
+
}
|
|
156
|
+
mockGenerateStream.mockReturnValue(mockStream());
|
|
157
|
+
const results = [];
|
|
158
|
+
for await (const chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
159
|
+
results.push(chunk);
|
|
160
|
+
}
|
|
161
|
+
expect(results[0].usage).toBeUndefined();
|
|
162
|
+
});
|
|
163
|
+
it('ignores unknown event types', async () => {
|
|
164
|
+
async function* mockStream() {
|
|
165
|
+
yield { eventType: 'unknown-event' };
|
|
166
|
+
yield { eventType: 'stream-end', response: {} };
|
|
167
|
+
}
|
|
168
|
+
mockGenerateStream.mockReturnValue(mockStream());
|
|
169
|
+
const results = [];
|
|
170
|
+
for await (const chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
171
|
+
results.push(chunk);
|
|
172
|
+
}
|
|
173
|
+
expect(results.length).toBe(1); // only stream-end
|
|
174
|
+
});
|
|
175
|
+
it('throws wrapped error on streaming failure', async () => {
|
|
176
|
+
mockGenerateStream.mockImplementation(async function* () {
|
|
177
|
+
throw new Error('Stream crashed');
|
|
178
|
+
yield { eventType: 'stream-end' };
|
|
179
|
+
});
|
|
180
|
+
await expect(async () => {
|
|
181
|
+
for await (const _chunk of provider.streamComplete(BASE_REQUEST)) {
|
|
182
|
+
// consume
|
|
183
|
+
}
|
|
184
|
+
}).rejects.toThrow('Cohere streaming error: Stream crashed');
|
|
185
|
+
});
|
|
186
|
+
});
|
|
187
|
+
describe('embed()', () => {
|
|
188
|
+
it('returns embeddings for string array', async () => {
|
|
189
|
+
mockEmbed.mockResolvedValue({
|
|
190
|
+
embeddings: [
|
|
191
|
+
[0.1, 0.2],
|
|
192
|
+
[0.3, 0.4],
|
|
193
|
+
],
|
|
194
|
+
});
|
|
195
|
+
const result = await provider.embed({ input: ['first', 'second'] });
|
|
196
|
+
expect(result.embeddings).toHaveLength(2);
|
|
197
|
+
expect(result.model).toBe('embed-english-v3.0');
|
|
198
|
+
});
|
|
199
|
+
it('handles single string input', async () => {
|
|
200
|
+
mockEmbed.mockResolvedValue({ embeddings: [[0.5, 0.6]] });
|
|
201
|
+
const result = await provider.embed({ input: 'single' });
|
|
202
|
+
expect(result.embeddings).toHaveLength(1);
|
|
203
|
+
});
|
|
204
|
+
it('handles { float: number[][] } response format', async () => {
|
|
205
|
+
mockEmbed.mockResolvedValue({ embeddings: { float: [[0.7, 0.8]] } });
|
|
206
|
+
const result = await provider.embed({ input: 'test' });
|
|
207
|
+
expect(result.embeddings).toEqual([[0.7, 0.8]]);
|
|
208
|
+
});
|
|
209
|
+
it('uses custom model when specified', async () => {
|
|
210
|
+
mockEmbed.mockResolvedValue({ embeddings: [[0.1]] });
|
|
211
|
+
const result = await provider.embed({ input: 'test', model: 'embed-multilingual-v3.0' });
|
|
212
|
+
expect(result.model).toBe('embed-multilingual-v3.0');
|
|
213
|
+
});
|
|
214
|
+
it('estimates token usage from input length', async () => {
|
|
215
|
+
mockEmbed.mockResolvedValue({ embeddings: [[0.1]] });
|
|
216
|
+
const result = await provider.embed({ input: 'hello world' }); // 11 chars → ~3 tokens
|
|
217
|
+
expect(result.usage?.promptTokens).toBeGreaterThan(0);
|
|
218
|
+
});
|
|
219
|
+
it('throws wrapped error on failure', async () => {
|
|
220
|
+
mockEmbed.mockRejectedValue(new Error('Embedding failed'));
|
|
221
|
+
await expect(provider.embed({ input: 'test' })).rejects.toThrow('Cohere embedding error: Embedding failed');
|
|
222
|
+
});
|
|
223
|
+
});
|
|
224
|
+
describe('isAvailable()', () => {
|
|
225
|
+
it('returns false when no API key', async () => {
|
|
226
|
+
const p = new cohere_provider_1.CohereProvider('');
|
|
227
|
+
expect(await p.isAvailable()).toBe(false);
|
|
228
|
+
});
|
|
229
|
+
it('returns true when API responds', async () => {
|
|
230
|
+
mockGenerate.mockResolvedValue(MOCK_GENERATE_RESPONSE);
|
|
231
|
+
expect(await provider.isAvailable()).toBe(true);
|
|
232
|
+
});
|
|
233
|
+
it('returns false on API error', async () => {
|
|
234
|
+
mockGenerate.mockRejectedValue(new Error('Unauthorized'));
|
|
235
|
+
expect(await provider.isAvailable()).toBe(false);
|
|
236
|
+
});
|
|
237
|
+
});
|
|
238
|
+
describe('rerank()', () => {
|
|
239
|
+
it('returns ranked documents', async () => {
|
|
240
|
+
mockRerank.mockResolvedValue({
|
|
241
|
+
results: [
|
|
242
|
+
{ index: 1, relevanceScore: 0.9 },
|
|
243
|
+
{ index: 0, relevanceScore: 0.7 },
|
|
244
|
+
],
|
|
245
|
+
});
|
|
246
|
+
const result = await provider.rerank('query', ['doc0', 'doc1'], 2);
|
|
247
|
+
expect(result).toHaveLength(2);
|
|
248
|
+
expect(result[0].score).toBe(0.9);
|
|
249
|
+
expect(result[0].index).toBe(1);
|
|
250
|
+
expect(result[0].document).toBe('doc1');
|
|
251
|
+
});
|
|
252
|
+
it('uses default rerank model', async () => {
|
|
253
|
+
mockRerank.mockResolvedValue({ results: [] });
|
|
254
|
+
await provider.rerank('query', ['doc1']);
|
|
255
|
+
expect(mockRerank).toHaveBeenCalledWith(expect.objectContaining({ model: 'rerank-english-v3.0' }));
|
|
256
|
+
});
|
|
257
|
+
it('uses custom model when specified', async () => {
|
|
258
|
+
mockRerank.mockResolvedValue({ results: [] });
|
|
259
|
+
await provider.rerank('query', ['doc'], undefined, 'rerank-multilingual-v3.0');
|
|
260
|
+
expect(mockRerank).toHaveBeenCalledWith(expect.objectContaining({ model: 'rerank-multilingual-v3.0' }));
|
|
261
|
+
});
|
|
262
|
+
it('throws wrapped error on rerank failure', async () => {
|
|
263
|
+
mockRerank.mockRejectedValue(new Error('Rerank failed'));
|
|
264
|
+
await expect(provider.rerank('q', ['d'])).rejects.toThrow('Cohere rerank error: Rerank failed');
|
|
265
|
+
});
|
|
266
|
+
});
|
|
267
|
+
});
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import { IAIProvider, AIProvider, AICompletionRequest, AICompletionResponse, AIStreamChunk, AIEmbeddingRequest, AIEmbeddingResponse } from '../ai-enhanced.types';
|
|
2
|
+
/**
|
|
3
|
+
* Google Gemini AI Provider
|
|
4
|
+
*
|
|
5
|
+
* Production-ready implementation using Google Generative AI SDK.
|
|
6
|
+
*
|
|
7
|
+
* Setup:
|
|
8
|
+
* 1. Install the SDK: `npm install @google/generative-ai`
|
|
9
|
+
* 2. Set GEMINI_API_KEY environment variable
|
|
10
|
+
* 3. Use the provider in your application
|
|
11
|
+
*
|
|
12
|
+
* Supported models:
|
|
13
|
+
* - gemini-pro: Text generation
|
|
14
|
+
* - gemini-pro-vision: Multimodal (text + images)
|
|
15
|
+
* - gemini-1.5-pro: Latest model with extended context
|
|
16
|
+
* - text-embedding-004: Text embeddings
|
|
17
|
+
*/
|
|
18
|
+
export declare class GeminiProvider implements IAIProvider {
|
|
19
|
+
readonly name: AIProvider;
|
|
20
|
+
private apiKey;
|
|
21
|
+
private genAI;
|
|
22
|
+
private endpoint;
|
|
23
|
+
constructor(apiKey?: string, endpoint?: string);
|
|
24
|
+
/**
|
|
25
|
+
* Generate completion
|
|
26
|
+
*/
|
|
27
|
+
complete(request: AICompletionRequest): Promise<AICompletionResponse>;
|
|
28
|
+
/**
|
|
29
|
+
* Generate streaming completion
|
|
30
|
+
*/
|
|
31
|
+
streamComplete(request: AICompletionRequest): AsyncGenerator<AIStreamChunk>;
|
|
32
|
+
/**
|
|
33
|
+
* Generate embeddings
|
|
34
|
+
*/
|
|
35
|
+
embed(request: AIEmbeddingRequest): Promise<AIEmbeddingResponse>;
|
|
36
|
+
/**
|
|
37
|
+
* Check if provider is available
|
|
38
|
+
*/
|
|
39
|
+
isAvailable(): Promise<boolean>;
|
|
40
|
+
/**
|
|
41
|
+
* Get supported models
|
|
42
|
+
*/
|
|
43
|
+
getSupportedModels(): string[];
|
|
44
|
+
}
|
|
45
|
+
//# sourceMappingURL=gemini.provider.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"gemini.provider.d.ts","sourceRoot":"","sources":["../../src/providers/gemini.provider.ts"],"names":[],"mappings":"AAAA,OAAO,EACL,WAAW,EACX,UAAU,EACV,mBAAmB,EACnB,oBAAoB,EACpB,aAAa,EACb,kBAAkB,EAClB,mBAAmB,EACpB,MAAM,sBAAsB,CAAC;AAI9B;;;;;;;;;;;;;;;GAeG;AACH,qBAAa,cAAe,YAAW,WAAW;IAChD,QAAQ,CAAC,IAAI,EAAE,UAAU,CAAY;IACrC,OAAO,CAAC,MAAM,CAAS;IACvB,OAAO,CAAC,KAAK,CAAqB;IAClC,OAAO,CAAC,QAAQ,CAAS;gBAEb,MAAM,CAAC,EAAE,MAAM,EAAE,QAAQ,CAAC,EAAE,MAAM;IAY9C;;OAEG;IACG,QAAQ,CAAC,OAAO,EAAE,mBAAmB,GAAG,OAAO,CAAC,oBAAoB,CAAC;IAwC3E;;OAEG;IACI,cAAc,CAAC,OAAO,EAAE,mBAAmB,GAAG,cAAc,CAAC,aAAa,CAAC;IAoDlF;;OAEG;IACG,KAAK,CAAC,OAAO,EAAE,kBAAkB,GAAG,OAAO,CAAC,mBAAmB,CAAC;IAmCtE;;OAEG;IACG,WAAW,IAAI,OAAO,CAAC,OAAO,CAAC;IAiBrC;;OAEG;IACH,kBAAkB,IAAI,MAAM,EAAE;CAS/B"}
|