haroo 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +58 -0
- package/dist/index.js +84883 -0
- package/package.json +73 -0
- package/src/__tests__/e2e/EventService.test.ts +211 -0
- package/src/__tests__/unit/Event.test.ts +89 -0
- package/src/__tests__/unit/Memory.test.ts +130 -0
- package/src/application/graph/builder.ts +106 -0
- package/src/application/graph/edges.ts +37 -0
- package/src/application/graph/nodes/addEvent.ts +113 -0
- package/src/application/graph/nodes/chat.ts +128 -0
- package/src/application/graph/nodes/extractMemory.ts +135 -0
- package/src/application/graph/nodes/index.ts +8 -0
- package/src/application/graph/nodes/query.ts +194 -0
- package/src/application/graph/nodes/respond.ts +26 -0
- package/src/application/graph/nodes/router.ts +82 -0
- package/src/application/graph/nodes/toolExecutor.ts +79 -0
- package/src/application/graph/nodes/types.ts +2 -0
- package/src/application/index.ts +4 -0
- package/src/application/services/DiaryService.ts +188 -0
- package/src/application/services/EventService.ts +61 -0
- package/src/application/services/index.ts +2 -0
- package/src/application/tools/calendarTool.ts +179 -0
- package/src/application/tools/diaryTool.ts +182 -0
- package/src/application/tools/index.ts +68 -0
- package/src/config/env.ts +33 -0
- package/src/config/index.ts +1 -0
- package/src/domain/entities/DiaryEntry.ts +16 -0
- package/src/domain/entities/Event.ts +13 -0
- package/src/domain/entities/Memory.ts +20 -0
- package/src/domain/index.ts +5 -0
- package/src/domain/interfaces/IDiaryRepository.ts +21 -0
- package/src/domain/interfaces/IEventsRepository.ts +12 -0
- package/src/domain/interfaces/ILanguageModel.ts +23 -0
- package/src/domain/interfaces/IMemoriesRepository.ts +15 -0
- package/src/domain/interfaces/IMemory.ts +19 -0
- package/src/domain/interfaces/index.ts +4 -0
- package/src/domain/state/AgentState.ts +30 -0
- package/src/index.ts +5 -0
- package/src/infrastructure/database/factory.ts +52 -0
- package/src/infrastructure/database/index.ts +21 -0
- package/src/infrastructure/database/sqlite-checkpointer.ts +179 -0
- package/src/infrastructure/database/sqlite-client.ts +69 -0
- package/src/infrastructure/database/sqlite-diary-repository.ts +209 -0
- package/src/infrastructure/database/sqlite-events-repository.ts +167 -0
- package/src/infrastructure/database/sqlite-memories-repository.ts +284 -0
- package/src/infrastructure/database/sqlite-schema.ts +98 -0
- package/src/infrastructure/index.ts +3 -0
- package/src/infrastructure/llm/base.ts +14 -0
- package/src/infrastructure/llm/gemini.ts +139 -0
- package/src/infrastructure/llm/index.ts +22 -0
- package/src/infrastructure/llm/ollama.ts +126 -0
- package/src/infrastructure/llm/openai.ts +148 -0
- package/src/infrastructure/memory/checkpointer.ts +19 -0
- package/src/infrastructure/memory/index.ts +2 -0
- package/src/infrastructure/settings/index.ts +96 -0
- package/src/interface/cli/calendar.ts +120 -0
- package/src/interface/cli/chat.ts +185 -0
- package/src/interface/cli/commands.ts +337 -0
- package/src/interface/cli/printer.ts +65 -0
- package/src/interface/index.ts +1 -0
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import type { Database } from "bun:sqlite";
|
|
2
|
+
import { z } from "zod";
|
|
3
|
+
import { type Memory, MemorySchema } from "../../domain/entities/Memory";
|
|
4
|
+
import type {
|
|
5
|
+
EmbeddingFunction,
|
|
6
|
+
IMemoriesRepository,
|
|
7
|
+
} from "../../domain/interfaces/IMemoriesRepository";
|
|
8
|
+
import type { IVectorStore } from "../../domain/interfaces/IMemory";
|
|
9
|
+
|
|
10
|
+
const MemoryRowSchema = z.object({
|
|
11
|
+
id: z.string().uuid(),
|
|
12
|
+
type: z.enum([
|
|
13
|
+
"fact",
|
|
14
|
+
"preference",
|
|
15
|
+
"routine",
|
|
16
|
+
"relationship",
|
|
17
|
+
"communication_style",
|
|
18
|
+
"interest",
|
|
19
|
+
]),
|
|
20
|
+
content: z.string(),
|
|
21
|
+
source: z.string().nullable(),
|
|
22
|
+
importance: z.number(),
|
|
23
|
+
last_accessed: z.string(),
|
|
24
|
+
created_at: z.string(),
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
type MemoryRow = z.infer<typeof MemoryRowSchema>;
|
|
28
|
+
|
|
29
|
+
function rowToMemory(row: MemoryRow): Memory {
|
|
30
|
+
return MemorySchema.parse({
|
|
31
|
+
id: row.id,
|
|
32
|
+
type: row.type,
|
|
33
|
+
content: row.content,
|
|
34
|
+
source: row.source ?? undefined,
|
|
35
|
+
importance: row.importance,
|
|
36
|
+
lastAccessed: new Date(row.last_accessed),
|
|
37
|
+
createdAt: new Date(row.created_at),
|
|
38
|
+
});
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
export class SqliteMemoriesRepository implements IMemoriesRepository, IVectorStore {
|
|
42
|
+
private embedFn: EmbeddingFunction | null = null;
|
|
43
|
+
|
|
44
|
+
constructor(private readonly db: Database) {}
|
|
45
|
+
|
|
46
|
+
setEmbeddingFunction(fn: EmbeddingFunction): void {
|
|
47
|
+
this.embedFn = fn;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
async create(memory: Memory, embedding?: number[]): Promise<Memory> {
|
|
51
|
+
// Insert metadata
|
|
52
|
+
this.db
|
|
53
|
+
.prepare(
|
|
54
|
+
`
|
|
55
|
+
INSERT INTO memories (id, type, content, source, importance, last_accessed, created_at)
|
|
56
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
57
|
+
`
|
|
58
|
+
)
|
|
59
|
+
.run(
|
|
60
|
+
memory.id,
|
|
61
|
+
memory.type,
|
|
62
|
+
memory.content,
|
|
63
|
+
memory.source ?? null,
|
|
64
|
+
memory.importance,
|
|
65
|
+
memory.lastAccessed.toISOString(),
|
|
66
|
+
memory.createdAt.toISOString()
|
|
67
|
+
);
|
|
68
|
+
|
|
69
|
+
// Insert embedding if provided
|
|
70
|
+
if (embedding) {
|
|
71
|
+
this.db
|
|
72
|
+
.prepare("INSERT INTO vec_memories (memory_id, embedding) VALUES (?, ?)")
|
|
73
|
+
.run(memory.id, new Float32Array(embedding));
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
return memory;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
async getById(id: string): Promise<Memory | null> {
|
|
80
|
+
const row = this.db
|
|
81
|
+
.prepare(
|
|
82
|
+
`
|
|
83
|
+
SELECT id, type, content, source, importance, last_accessed, created_at
|
|
84
|
+
FROM memories WHERE id = ?
|
|
85
|
+
`
|
|
86
|
+
)
|
|
87
|
+
.get(id) as MemoryRow | undefined;
|
|
88
|
+
|
|
89
|
+
if (!row) {
|
|
90
|
+
return null;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
return rowToMemory(MemoryRowSchema.parse(row));
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
async getByType(type: Memory["type"]): Promise<Memory[]> {
|
|
97
|
+
const rows = this.db
|
|
98
|
+
.prepare(
|
|
99
|
+
`
|
|
100
|
+
SELECT id, type, content, source, importance, last_accessed, created_at
|
|
101
|
+
FROM memories
|
|
102
|
+
WHERE type = ?
|
|
103
|
+
ORDER BY importance DESC
|
|
104
|
+
`
|
|
105
|
+
)
|
|
106
|
+
.all(type) as MemoryRow[];
|
|
107
|
+
|
|
108
|
+
return rows.map((row) => rowToMemory(MemoryRowSchema.parse(row)));
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
async getAll(limit = 20): Promise<Memory[]> {
|
|
112
|
+
const rows = this.db
|
|
113
|
+
.prepare(
|
|
114
|
+
`
|
|
115
|
+
SELECT id, type, content, source, importance, last_accessed, created_at
|
|
116
|
+
FROM memories
|
|
117
|
+
ORDER BY importance DESC, created_at DESC
|
|
118
|
+
LIMIT ?
|
|
119
|
+
`
|
|
120
|
+
)
|
|
121
|
+
.all(limit) as MemoryRow[];
|
|
122
|
+
|
|
123
|
+
return rows.map((row) => rowToMemory(MemoryRowSchema.parse(row)));
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
async semanticSearch(embedding: number[], limit = 5, threshold = 0.7): Promise<Memory[]> {
|
|
127
|
+
// sqlite-vec uses cosine distance, so convert similarity threshold to distance
|
|
128
|
+
// similarity = 1 - distance, so distance = 1 - similarity
|
|
129
|
+
const maxDistance = 1 - threshold;
|
|
130
|
+
|
|
131
|
+
const rows = this.db
|
|
132
|
+
.prepare(
|
|
133
|
+
`
|
|
134
|
+
SELECT
|
|
135
|
+
m.id, m.type, m.content, m.source, m.importance, m.last_accessed, m.created_at,
|
|
136
|
+
v.distance
|
|
137
|
+
FROM vec_memories v
|
|
138
|
+
JOIN memories m ON m.id = v.memory_id
|
|
139
|
+
WHERE v.embedding MATCH ?
|
|
140
|
+
AND k = ?
|
|
141
|
+
ORDER BY v.distance
|
|
142
|
+
`
|
|
143
|
+
)
|
|
144
|
+
.all(new Float32Array(embedding), limit) as (MemoryRow & {
|
|
145
|
+
distance: number;
|
|
146
|
+
})[];
|
|
147
|
+
|
|
148
|
+
// Filter by threshold after query (sqlite-vec doesn't support WHERE on distance in MATCH)
|
|
149
|
+
return rows
|
|
150
|
+
.filter((row) => row.distance <= maxDistance)
|
|
151
|
+
.map((row) => rowToMemory(MemoryRowSchema.parse(row)));
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
async update(
|
|
155
|
+
id: string,
|
|
156
|
+
updates: Partial<Omit<Memory, "id" | "createdAt">>
|
|
157
|
+
): Promise<Memory | null> {
|
|
158
|
+
const setters: string[] = [];
|
|
159
|
+
const values: (string | number | null)[] = [];
|
|
160
|
+
|
|
161
|
+
if (updates.type !== undefined) {
|
|
162
|
+
setters.push("type = ?");
|
|
163
|
+
values.push(updates.type);
|
|
164
|
+
}
|
|
165
|
+
if (updates.content !== undefined) {
|
|
166
|
+
setters.push("content = ?");
|
|
167
|
+
values.push(updates.content);
|
|
168
|
+
}
|
|
169
|
+
if (updates.source !== undefined) {
|
|
170
|
+
setters.push("source = ?");
|
|
171
|
+
values.push(updates.source ?? null);
|
|
172
|
+
}
|
|
173
|
+
if (updates.importance !== undefined) {
|
|
174
|
+
setters.push("importance = ?");
|
|
175
|
+
values.push(updates.importance);
|
|
176
|
+
}
|
|
177
|
+
if (updates.lastAccessed !== undefined) {
|
|
178
|
+
setters.push("last_accessed = ?");
|
|
179
|
+
values.push(updates.lastAccessed.toISOString());
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
if (setters.length === 0) {
|
|
183
|
+
return this.getById(id);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
values.push(id);
|
|
187
|
+
|
|
188
|
+
const result = this.db
|
|
189
|
+
.prepare(`UPDATE memories SET ${setters.join(", ")} WHERE id = ?`)
|
|
190
|
+
.run(...values);
|
|
191
|
+
|
|
192
|
+
if (result.changes === 0) {
|
|
193
|
+
return null;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
return this.getById(id);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
async delete(id: string): Promise<boolean> {
|
|
200
|
+
// Delete from both tables (vector table first due to foreign key logic)
|
|
201
|
+
this.db.prepare("DELETE FROM vec_memories WHERE memory_id = ?").run(id);
|
|
202
|
+
const result = this.db.prepare("DELETE FROM memories WHERE id = ?").run(id);
|
|
203
|
+
|
|
204
|
+
return result.changes > 0;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
async updateLastAccessed(id: string): Promise<void> {
|
|
208
|
+
this.db
|
|
209
|
+
.prepare("UPDATE memories SET last_accessed = ? WHERE id = ?")
|
|
210
|
+
.run(new Date().toISOString(), id);
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
async getRecent(limit = 10): Promise<Memory[]> {
|
|
214
|
+
const rows = this.db
|
|
215
|
+
.prepare(
|
|
216
|
+
`
|
|
217
|
+
SELECT id, type, content, source, importance, last_accessed, created_at
|
|
218
|
+
FROM memories
|
|
219
|
+
ORDER BY last_accessed DESC
|
|
220
|
+
LIMIT ?
|
|
221
|
+
`
|
|
222
|
+
)
|
|
223
|
+
.all(limit) as MemoryRow[];
|
|
224
|
+
|
|
225
|
+
return rows.map((row) => rowToMemory(MemoryRowSchema.parse(row)));
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
// IVectorStore implementation
|
|
229
|
+
async addDocuments(
|
|
230
|
+
docs: { content: string; metadata?: Record<string, unknown> }[]
|
|
231
|
+
): Promise<void> {
|
|
232
|
+
if (!this.embedFn) {
|
|
233
|
+
throw new Error("Embedding function not set. Call setEmbeddingFunction first.");
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
for (const doc of docs) {
|
|
237
|
+
const embedding = await this.embedFn(doc.content);
|
|
238
|
+
const id = crypto.randomUUID();
|
|
239
|
+
const now = new Date();
|
|
240
|
+
|
|
241
|
+
const memory: Memory = {
|
|
242
|
+
id,
|
|
243
|
+
type: (doc.metadata?.type as Memory["type"]) ?? "fact",
|
|
244
|
+
content: doc.content,
|
|
245
|
+
source: doc.metadata?.source as string | undefined,
|
|
246
|
+
importance: (doc.metadata?.importance as number) ?? 5,
|
|
247
|
+
lastAccessed: now,
|
|
248
|
+
createdAt: now,
|
|
249
|
+
};
|
|
250
|
+
|
|
251
|
+
await this.create(memory, embedding);
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
async search(query: string, limit = 5): Promise<{ content: string; score: number }[]> {
|
|
256
|
+
if (!this.embedFn) {
|
|
257
|
+
throw new Error("Embedding function not set. Call setEmbeddingFunction first.");
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
const queryEmbedding = await this.embedFn(query);
|
|
261
|
+
|
|
262
|
+
const rows = this.db
|
|
263
|
+
.prepare(
|
|
264
|
+
`
|
|
265
|
+
SELECT m.content, v.distance
|
|
266
|
+
FROM vec_memories v
|
|
267
|
+
JOIN memories m ON m.id = v.memory_id
|
|
268
|
+
WHERE v.embedding MATCH ?
|
|
269
|
+
AND k = ?
|
|
270
|
+
ORDER BY v.distance
|
|
271
|
+
`
|
|
272
|
+
)
|
|
273
|
+
.all(new Float32Array(queryEmbedding), limit) as {
|
|
274
|
+
content: string;
|
|
275
|
+
distance: number;
|
|
276
|
+
}[];
|
|
277
|
+
|
|
278
|
+
// Convert distance to similarity score (1 - distance for cosine)
|
|
279
|
+
return rows.map((row) => ({
|
|
280
|
+
content: row.content,
|
|
281
|
+
score: 1 - row.distance,
|
|
282
|
+
}));
|
|
283
|
+
}
|
|
284
|
+
}
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import type { Database } from "bun:sqlite";
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* Initialize SQLite schema with all required tables.
|
|
5
|
+
* Safe to call multiple times (uses IF NOT EXISTS).
|
|
6
|
+
*/
|
|
7
|
+
export function initializeSqliteSchema(db: Database): void {
|
|
8
|
+
// Events table
|
|
9
|
+
db.exec(`
|
|
10
|
+
CREATE TABLE IF NOT EXISTS events (
|
|
11
|
+
id TEXT PRIMARY KEY,
|
|
12
|
+
title TEXT NOT NULL,
|
|
13
|
+
datetime TEXT NOT NULL,
|
|
14
|
+
end_time TEXT,
|
|
15
|
+
notes TEXT,
|
|
16
|
+
tags TEXT,
|
|
17
|
+
created_at TEXT DEFAULT (datetime('now'))
|
|
18
|
+
)
|
|
19
|
+
`);
|
|
20
|
+
|
|
21
|
+
db.exec(`
|
|
22
|
+
CREATE INDEX IF NOT EXISTS idx_events_datetime ON events(datetime)
|
|
23
|
+
`);
|
|
24
|
+
|
|
25
|
+
// Memories metadata table
|
|
26
|
+
db.exec(`
|
|
27
|
+
CREATE TABLE IF NOT EXISTS memories (
|
|
28
|
+
id TEXT PRIMARY KEY,
|
|
29
|
+
type TEXT CHECK (type IN ('fact', 'preference', 'routine', 'relationship', 'communication_style', 'interest')),
|
|
30
|
+
content TEXT NOT NULL,
|
|
31
|
+
source TEXT,
|
|
32
|
+
importance INTEGER CHECK (importance BETWEEN 1 AND 10),
|
|
33
|
+
last_accessed TEXT,
|
|
34
|
+
created_at TEXT DEFAULT (datetime('now'))
|
|
35
|
+
)
|
|
36
|
+
`);
|
|
37
|
+
|
|
38
|
+
db.exec(`
|
|
39
|
+
CREATE INDEX IF NOT EXISTS idx_memories_type ON memories(type)
|
|
40
|
+
`);
|
|
41
|
+
|
|
42
|
+
db.exec(`
|
|
43
|
+
CREATE INDEX IF NOT EXISTS idx_memories_importance ON memories(importance DESC)
|
|
44
|
+
`);
|
|
45
|
+
|
|
46
|
+
// Vector table for memory embeddings (sqlite-vec)
|
|
47
|
+
// Note: vec0 tables don't support IF NOT EXISTS, so we check manually
|
|
48
|
+
const vecTableExists = db
|
|
49
|
+
.prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='vec_memories'")
|
|
50
|
+
.get();
|
|
51
|
+
|
|
52
|
+
if (!vecTableExists) {
|
|
53
|
+
db.exec(`
|
|
54
|
+
CREATE VIRTUAL TABLE vec_memories USING vec0(
|
|
55
|
+
memory_id TEXT PRIMARY KEY,
|
|
56
|
+
embedding float[1536] distance_metric=cosine
|
|
57
|
+
)
|
|
58
|
+
`);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// Checkpoints table
|
|
62
|
+
db.exec(`
|
|
63
|
+
CREATE TABLE IF NOT EXISTS checkpoints (
|
|
64
|
+
session_id TEXT PRIMARY KEY,
|
|
65
|
+
state TEXT NOT NULL,
|
|
66
|
+
created_at TEXT DEFAULT (datetime('now')),
|
|
67
|
+
updated_at TEXT DEFAULT (datetime('now'))
|
|
68
|
+
)
|
|
69
|
+
`);
|
|
70
|
+
|
|
71
|
+
db.exec(`
|
|
72
|
+
CREATE INDEX IF NOT EXISTS idx_checkpoints_updated ON checkpoints(updated_at DESC)
|
|
73
|
+
`);
|
|
74
|
+
|
|
75
|
+
// Diary entries table
|
|
76
|
+
db.exec(`
|
|
77
|
+
CREATE TABLE IF NOT EXISTS diary_entries (
|
|
78
|
+
id TEXT PRIMARY KEY,
|
|
79
|
+
entry_date TEXT NOT NULL UNIQUE,
|
|
80
|
+
summary TEXT NOT NULL,
|
|
81
|
+
mood TEXT NOT NULL,
|
|
82
|
+
mood_score INTEGER CHECK (mood_score BETWEEN 1 AND 10),
|
|
83
|
+
therapeutic_advice TEXT NOT NULL,
|
|
84
|
+
session_ids TEXT,
|
|
85
|
+
message_count INTEGER,
|
|
86
|
+
created_at TEXT DEFAULT (datetime('now')),
|
|
87
|
+
updated_at TEXT DEFAULT (datetime('now'))
|
|
88
|
+
)
|
|
89
|
+
`);
|
|
90
|
+
|
|
91
|
+
db.exec(`
|
|
92
|
+
CREATE INDEX IF NOT EXISTS idx_diary_entry_date ON diary_entries(entry_date DESC)
|
|
93
|
+
`);
|
|
94
|
+
|
|
95
|
+
db.exec(`
|
|
96
|
+
CREATE INDEX IF NOT EXISTS idx_diary_mood ON diary_entries(mood)
|
|
97
|
+
`);
|
|
98
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import type { ILanguageModel, LLMResponse, ToolDefinition } from "../../domain";
|
|
2
|
+
import type { BaseMessage } from "@langchain/core/messages";
|
|
3
|
+
|
|
4
|
+
export abstract class BaseLLMAdapter implements ILanguageModel {
|
|
5
|
+
protected model: string;
|
|
6
|
+
|
|
7
|
+
constructor(model: string) {
|
|
8
|
+
this.model = model;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
abstract generate(messages: BaseMessage[], tools?: ToolDefinition[]): Promise<LLMResponse>;
|
|
12
|
+
|
|
13
|
+
abstract withStructuredOutput<T>(schema: unknown): ILanguageModel;
|
|
14
|
+
}
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import type { LLMResponse, ToolDefinition, ToolCall, ILanguageModel } from "../../domain";
|
|
2
|
+
import type { BaseMessage } from "@langchain/core/messages";
|
|
3
|
+
import { zodToJsonSchema } from "zod-to-json-schema";
|
|
4
|
+
import { BaseLLMAdapter } from "./base";
|
|
5
|
+
|
|
6
|
+
interface GeminiFunctionCall {
|
|
7
|
+
name: string;
|
|
8
|
+
args: Record<string, unknown>;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
interface GeminiPart {
|
|
12
|
+
text?: string;
|
|
13
|
+
functionCall?: GeminiFunctionCall;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
interface GeminiCandidate {
|
|
17
|
+
content: {
|
|
18
|
+
parts: GeminiPart[];
|
|
19
|
+
};
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
interface GeminiResponse {
|
|
23
|
+
candidates: GeminiCandidate[];
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
export class GeminiAdapter extends BaseLLMAdapter {
|
|
27
|
+
private apiKey: string;
|
|
28
|
+
private structuredSchema: unknown = null;
|
|
29
|
+
|
|
30
|
+
constructor(model: string = "gemini-pro", apiKey?: string) {
|
|
31
|
+
super(model);
|
|
32
|
+
this.apiKey = apiKey ?? process.env.GEMINI_API_KEY ?? "";
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
async generate(messages: BaseMessage[], tools?: ToolDefinition[]): Promise<LLMResponse> {
|
|
36
|
+
if (!this.apiKey) {
|
|
37
|
+
throw new Error("GEMINI_API_KEY is required");
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
const formattedMessages = messages.map((m) => {
|
|
41
|
+
const type = m._getType();
|
|
42
|
+
return {
|
|
43
|
+
role: type === "human" ? "user" : "model",
|
|
44
|
+
parts: [{ text: typeof m.content === "string" ? m.content : JSON.stringify(m.content) }],
|
|
45
|
+
};
|
|
46
|
+
});
|
|
47
|
+
|
|
48
|
+
const requestBody: Record<string, unknown> = {
|
|
49
|
+
contents: formattedMessages,
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
// Use function calling for structured output
|
|
53
|
+
if (this.structuredSchema && (!tools || tools.length === 0)) {
|
|
54
|
+
const jsonSchema = zodToJsonSchema(this.structuredSchema as any);
|
|
55
|
+
requestBody.tools = [
|
|
56
|
+
{
|
|
57
|
+
functionDeclarations: [
|
|
58
|
+
{
|
|
59
|
+
name: "structured_response",
|
|
60
|
+
description: "Return structured data matching the schema",
|
|
61
|
+
parameters: jsonSchema,
|
|
62
|
+
},
|
|
63
|
+
],
|
|
64
|
+
},
|
|
65
|
+
];
|
|
66
|
+
requestBody.toolConfig = {
|
|
67
|
+
functionCallingConfig: {
|
|
68
|
+
mode: "ANY",
|
|
69
|
+
allowedFunctionNames: ["structured_response"],
|
|
70
|
+
},
|
|
71
|
+
};
|
|
72
|
+
} else if (tools && tools.length > 0) {
|
|
73
|
+
// Add tools if provided (regular tool use, not structured output)
|
|
74
|
+
requestBody.tools = [
|
|
75
|
+
{
|
|
76
|
+
functionDeclarations: tools.map((t) => ({
|
|
77
|
+
name: t.name,
|
|
78
|
+
description: t.description,
|
|
79
|
+
parameters: t.parameters,
|
|
80
|
+
})),
|
|
81
|
+
},
|
|
82
|
+
];
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
const response = await fetch(
|
|
86
|
+
`https://generativelanguage.googleapis.com/v1beta/models/${this.model}:generateContent?key=${this.apiKey}`,
|
|
87
|
+
{
|
|
88
|
+
method: "POST",
|
|
89
|
+
headers: { "Content-Type": "application/json" },
|
|
90
|
+
body: JSON.stringify(requestBody),
|
|
91
|
+
}
|
|
92
|
+
);
|
|
93
|
+
|
|
94
|
+
if (!response.ok) {
|
|
95
|
+
const errorBody = await response.text();
|
|
96
|
+
throw new Error(`Gemini API error: ${response.statusText} - ${errorBody}`);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
const data = (await response.json()) as GeminiResponse;
|
|
100
|
+
const parts = data.candidates?.[0]?.content?.parts ?? [];
|
|
101
|
+
|
|
102
|
+
// Check for function calls
|
|
103
|
+
const functionCalls = parts.filter((p) => p.functionCall);
|
|
104
|
+
|
|
105
|
+
// Handle structured output via function calling
|
|
106
|
+
if (this.structuredSchema && functionCalls.length > 0) {
|
|
107
|
+
// Return the function arguments as content
|
|
108
|
+
return {
|
|
109
|
+
content: JSON.stringify(functionCalls[0].functionCall!.args),
|
|
110
|
+
toolCalls: undefined,
|
|
111
|
+
};
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
// Parse tool calls if present (regular tool use)
|
|
115
|
+
let toolCalls: ToolCall[] | undefined;
|
|
116
|
+
if (functionCalls.length > 0) {
|
|
117
|
+
toolCalls = functionCalls.map((p, idx) => ({
|
|
118
|
+
id: `call_${idx}`,
|
|
119
|
+
name: p.functionCall!.name,
|
|
120
|
+
arguments: p.functionCall!.args,
|
|
121
|
+
}));
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
// Get text content
|
|
125
|
+
const textParts = parts.filter((p) => p.text);
|
|
126
|
+
const content = textParts.map((p) => p.text).join("");
|
|
127
|
+
|
|
128
|
+
return {
|
|
129
|
+
content,
|
|
130
|
+
toolCalls,
|
|
131
|
+
};
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
withStructuredOutput<T>(schema: unknown): ILanguageModel {
|
|
135
|
+
const adapter = new GeminiAdapter(this.model, this.apiKey);
|
|
136
|
+
adapter.structuredSchema = schema;
|
|
137
|
+
return adapter;
|
|
138
|
+
}
|
|
139
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import type { ILanguageModel } from "../../domain";
|
|
2
|
+
import type { LLMProvider } from "../../interface/cli/commands";
|
|
3
|
+
import { GeminiAdapter } from "./gemini";
|
|
4
|
+
import { OllamaAdapter } from "./ollama";
|
|
5
|
+
import { OpenAIAdapter } from "./openai";
|
|
6
|
+
|
|
7
|
+
export { GeminiAdapter } from "./gemini";
|
|
8
|
+
export { OllamaAdapter } from "./ollama";
|
|
9
|
+
export { OpenAIAdapter } from "./openai";
|
|
10
|
+
|
|
11
|
+
export function createLLM(provider: LLMProvider, model?: string): ILanguageModel {
|
|
12
|
+
switch (provider) {
|
|
13
|
+
case "ollama":
|
|
14
|
+
return new OllamaAdapter(model);
|
|
15
|
+
case "openai":
|
|
16
|
+
return new OpenAIAdapter(model);
|
|
17
|
+
case "gemini":
|
|
18
|
+
return new GeminiAdapter(model);
|
|
19
|
+
default:
|
|
20
|
+
throw new Error(`Unknown provider: ${provider}`);
|
|
21
|
+
}
|
|
22
|
+
}
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import type { LLMResponse, ToolDefinition, ToolCall, ILanguageModel } from "../../domain";
|
|
2
|
+
import type { BaseMessage } from "@langchain/core/messages";
|
|
3
|
+
import { zodToJsonSchema } from "zod-to-json-schema";
|
|
4
|
+
import { BaseLLMAdapter } from "./base";
|
|
5
|
+
|
|
6
|
+
interface OllamaToolCall {
|
|
7
|
+
function: {
|
|
8
|
+
name: string;
|
|
9
|
+
arguments: Record<string, unknown>;
|
|
10
|
+
};
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
interface OllamaMessage {
|
|
14
|
+
role: string;
|
|
15
|
+
content: string;
|
|
16
|
+
tool_calls?: OllamaToolCall[];
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
interface OllamaResponse {
|
|
20
|
+
message: OllamaMessage;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export class OllamaAdapter extends BaseLLMAdapter {
|
|
24
|
+
private baseUrl: string;
|
|
25
|
+
private structuredSchema: unknown = null;
|
|
26
|
+
|
|
27
|
+
constructor(model: string = "llama3.1:latest", baseUrl: string = "http://localhost:11434") {
|
|
28
|
+
super(model);
|
|
29
|
+
this.baseUrl = baseUrl;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
async generate(messages: BaseMessage[], tools?: ToolDefinition[]): Promise<LLMResponse> {
|
|
33
|
+
const formattedMessages = messages.map((m) => {
|
|
34
|
+
const type = m._getType();
|
|
35
|
+
return {
|
|
36
|
+
role: type === "human" ? "user" : type === "ai" ? "assistant" : type,
|
|
37
|
+
content: typeof m.content === "string" ? m.content : JSON.stringify(m.content),
|
|
38
|
+
};
|
|
39
|
+
});
|
|
40
|
+
|
|
41
|
+
const requestBody: Record<string, unknown> = {
|
|
42
|
+
model: this.model,
|
|
43
|
+
messages: formattedMessages,
|
|
44
|
+
stream: false,
|
|
45
|
+
};
|
|
46
|
+
|
|
47
|
+
// Use function calling for structured output
|
|
48
|
+
if (this.structuredSchema && (!tools || tools.length === 0)) {
|
|
49
|
+
const jsonSchema = zodToJsonSchema(this.structuredSchema as any);
|
|
50
|
+
// Ollama accepts JSON schema directly in format parameter
|
|
51
|
+
requestBody.format = jsonSchema;
|
|
52
|
+
// Also add as tool for models that support function calling
|
|
53
|
+
requestBody.tools = [
|
|
54
|
+
{
|
|
55
|
+
type: "function",
|
|
56
|
+
function: {
|
|
57
|
+
name: "structured_response",
|
|
58
|
+
description: "Return structured data matching the schema",
|
|
59
|
+
parameters: jsonSchema,
|
|
60
|
+
},
|
|
61
|
+
},
|
|
62
|
+
];
|
|
63
|
+
} else if (tools && tools.length > 0) {
|
|
64
|
+
// Add tools if provided (regular tool use, not structured output)
|
|
65
|
+
requestBody.tools = tools.map((t) => ({
|
|
66
|
+
type: "function",
|
|
67
|
+
function: {
|
|
68
|
+
name: t.name,
|
|
69
|
+
description: t.description,
|
|
70
|
+
parameters: t.parameters,
|
|
71
|
+
},
|
|
72
|
+
}));
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
const response = await fetch(`${this.baseUrl}/api/chat`, {
|
|
76
|
+
method: "POST",
|
|
77
|
+
headers: { "Content-Type": "application/json" },
|
|
78
|
+
body: JSON.stringify(requestBody),
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
if (!response.ok) {
|
|
82
|
+
const errorBody = await response.text();
|
|
83
|
+
throw new Error(`Ollama API error: ${response.statusText} - ${errorBody}`);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
const data = (await response.json()) as OllamaResponse;
|
|
87
|
+
const message = data.message;
|
|
88
|
+
|
|
89
|
+
// Handle structured output - Ollama returns content directly with format schema
|
|
90
|
+
if (this.structuredSchema) {
|
|
91
|
+
// For structured output, the response comes in content (format parameter enforces schema)
|
|
92
|
+
// If model supports function calling, it might also come via tool_calls
|
|
93
|
+
if (message?.tool_calls?.[0]) {
|
|
94
|
+
return {
|
|
95
|
+
content: JSON.stringify(message.tool_calls[0].function.arguments),
|
|
96
|
+
toolCalls: undefined,
|
|
97
|
+
};
|
|
98
|
+
}
|
|
99
|
+
return {
|
|
100
|
+
content: message?.content ?? "",
|
|
101
|
+
toolCalls: undefined,
|
|
102
|
+
};
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
// Parse tool calls if present (regular tool use)
|
|
106
|
+
let toolCalls: ToolCall[] | undefined;
|
|
107
|
+
if (message?.tool_calls && message.tool_calls.length > 0) {
|
|
108
|
+
toolCalls = message.tool_calls.map((tc, idx) => ({
|
|
109
|
+
id: `call_${idx}`,
|
|
110
|
+
name: tc.function.name,
|
|
111
|
+
arguments: tc.function.arguments,
|
|
112
|
+
}));
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
return {
|
|
116
|
+
content: message?.content ?? "",
|
|
117
|
+
toolCalls,
|
|
118
|
+
};
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
withStructuredOutput<T>(schema: unknown): ILanguageModel {
|
|
122
|
+
const adapter = new OllamaAdapter(this.model, this.baseUrl);
|
|
123
|
+
adapter.structuredSchema = schema;
|
|
124
|
+
return adapter;
|
|
125
|
+
}
|
|
126
|
+
}
|