@disco_trooper/apple-notes-mcp 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CLAUDE.md +56 -0
- package/LICENSE +21 -0
- package/README.md +216 -0
- package/package.json +61 -0
- package/src/config/constants.ts +41 -0
- package/src/config/env.test.ts +58 -0
- package/src/config/env.ts +25 -0
- package/src/db/lancedb.test.ts +141 -0
- package/src/db/lancedb.ts +263 -0
- package/src/db/validation.test.ts +76 -0
- package/src/db/validation.ts +57 -0
- package/src/embeddings/index.test.ts +54 -0
- package/src/embeddings/index.ts +111 -0
- package/src/embeddings/local.test.ts +70 -0
- package/src/embeddings/local.ts +191 -0
- package/src/embeddings/openrouter.test.ts +21 -0
- package/src/embeddings/openrouter.ts +285 -0
- package/src/index.ts +387 -0
- package/src/notes/crud.test.ts +199 -0
- package/src/notes/crud.ts +257 -0
- package/src/notes/read.test.ts +131 -0
- package/src/notes/read.ts +504 -0
- package/src/search/index.test.ts +52 -0
- package/src/search/index.ts +283 -0
- package/src/search/indexer.test.ts +42 -0
- package/src/search/indexer.ts +335 -0
- package/src/server.ts +386 -0
- package/src/setup.ts +540 -0
- package/src/types/index.ts +39 -0
- package/src/utils/debug.test.ts +41 -0
- package/src/utils/debug.ts +51 -0
- package/src/utils/errors.test.ts +29 -0
- package/src/utils/errors.ts +46 -0
- package/src/utils/text.ts +23 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import * as lancedb from "@lancedb/lancedb";
|
|
2
|
+
import path from "node:path";
|
|
3
|
+
import os from "node:os";
|
|
4
|
+
import { validateTitle, escapeForFilter } from "./validation.js";
|
|
5
|
+
import type { DBSearchResult as SearchResult } from "../types/index.js";
|
|
6
|
+
import { createDebugLogger } from "../utils/debug.js";
|
|
7
|
+
|
|
8
|
+
// Schema for stored notes
|
|
9
|
+
export interface NoteRecord {
|
|
10
|
+
title: string;
|
|
11
|
+
content: string;
|
|
12
|
+
vector: number[];
|
|
13
|
+
folder: string;
|
|
14
|
+
created: string; // ISO date
|
|
15
|
+
modified: string; // ISO date
|
|
16
|
+
indexed_at: string; // ISO date - when embedding was generated
|
|
17
|
+
[key: string]: unknown; // Index signature for LanceDB compatibility
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
// SearchResult is imported from ../types/index.js as DBSearchResult
|
|
21
|
+
export type { SearchResult };
|
|
22
|
+
|
|
23
|
+
// VectorStore interface for future extensibility
|
|
24
|
+
export interface VectorStore {
|
|
25
|
+
index(records: NoteRecord[]): Promise<void>;
|
|
26
|
+
update(record: NoteRecord): Promise<void>;
|
|
27
|
+
delete(title: string): Promise<void>;
|
|
28
|
+
deleteByFolderAndTitle(folder: string, title: string): Promise<void>;
|
|
29
|
+
search(queryVector: number[], limit: number): Promise<SearchResult[]>;
|
|
30
|
+
searchFTS(query: string, limit: number): Promise<SearchResult[]>;
|
|
31
|
+
getByTitle(title: string): Promise<NoteRecord | null>;
|
|
32
|
+
getAll(): Promise<NoteRecord[]>;
|
|
33
|
+
count(): Promise<number>;
|
|
34
|
+
clear(): Promise<void>;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// Debug logging
|
|
38
|
+
const debug = createDebugLogger("DB");
|
|
39
|
+
|
|
40
|
+
// LanceDB implementation
|
|
41
|
+
export class LanceDBStore implements VectorStore {
|
|
42
|
+
private db: lancedb.Connection | null = null;
|
|
43
|
+
private table: lancedb.Table | null = null;
|
|
44
|
+
private readonly dbPath: string;
|
|
45
|
+
private readonly tableName = "notes";
|
|
46
|
+
|
|
47
|
+
constructor(dataDir?: string) {
|
|
48
|
+
this.dbPath = dataDir || path.join(os.homedir(), ".apple-notes-mcp", "data");
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
private async ensureConnection(): Promise<lancedb.Connection> {
|
|
52
|
+
if (!this.db) {
|
|
53
|
+
debug(`Connecting to LanceDB at ${this.dbPath}`);
|
|
54
|
+
this.db = await lancedb.connect(this.dbPath);
|
|
55
|
+
}
|
|
56
|
+
return this.db;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
private async ensureTable(): Promise<lancedb.Table> {
|
|
60
|
+
if (!this.table) {
|
|
61
|
+
const db = await this.ensureConnection();
|
|
62
|
+
try {
|
|
63
|
+
this.table = await db.openTable(this.tableName);
|
|
64
|
+
debug(`Opened existing table: ${this.tableName}`);
|
|
65
|
+
} catch (error) {
|
|
66
|
+
// Table doesn't exist yet - will be created on first index
|
|
67
|
+
debug(`Table ${this.tableName} not found, will create on first index. Error:`, error);
|
|
68
|
+
throw new Error("Index not found. Run index-notes first.");
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
return this.table;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
async index(records: NoteRecord[]): Promise<void> {
|
|
75
|
+
if (records.length === 0) {
|
|
76
|
+
debug("No records to index");
|
|
77
|
+
return;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
const db = await this.ensureConnection();
|
|
81
|
+
|
|
82
|
+
// Drop existing table if exists
|
|
83
|
+
try {
|
|
84
|
+
await db.dropTable(this.tableName);
|
|
85
|
+
debug(`Dropped existing table: ${this.tableName}`);
|
|
86
|
+
} catch {
|
|
87
|
+
// Table didn't exist, that's fine
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Create new table with records
|
|
91
|
+
debug(`Creating table with ${records.length} records`);
|
|
92
|
+
this.table = await db.createTable(this.tableName, records);
|
|
93
|
+
|
|
94
|
+
// Create FTS index for hybrid search
|
|
95
|
+
debug("Creating FTS index on content");
|
|
96
|
+
await this.table.createIndex("content", {
|
|
97
|
+
config: lancedb.Index.fts(),
|
|
98
|
+
replace: true,
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
debug(`Indexed ${records.length} records`);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
async update(record: NoteRecord): Promise<void> {
|
|
105
|
+
const table = await this.ensureTable();
|
|
106
|
+
|
|
107
|
+
// Add new record first (LanceDB allows duplicates with same title)
|
|
108
|
+
// This ensures we never lose data - if add fails, old record still exists
|
|
109
|
+
try {
|
|
110
|
+
await table.add([record]);
|
|
111
|
+
debug(`Added new version of record: ${record.title}`);
|
|
112
|
+
} catch (addError) {
|
|
113
|
+
// If add fails, old record still exists, throw original error
|
|
114
|
+
throw addError;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// Now delete old record(s) - use indexed_at to identify which is old
|
|
118
|
+
const validTitle = validateTitle(record.title);
|
|
119
|
+
const escapedTitle = escapeForFilter(validTitle);
|
|
120
|
+
|
|
121
|
+
try {
|
|
122
|
+
// Delete records with same title but different indexed_at (older versions)
|
|
123
|
+
const allWithTitle = await table
|
|
124
|
+
.query()
|
|
125
|
+
.where(`title = '${escapedTitle}'`)
|
|
126
|
+
.toArray();
|
|
127
|
+
|
|
128
|
+
for (const existing of allWithTitle) {
|
|
129
|
+
if (existing.indexed_at !== record.indexed_at) {
|
|
130
|
+
const escapedOldIndexedAt = escapeForFilter(existing.indexed_at as string);
|
|
131
|
+
await table.delete(`title = '${escapedTitle}' AND indexed_at = '${escapedOldIndexedAt}'`);
|
|
132
|
+
debug(`Deleted old version: ${record.title} (indexed_at: ${existing.indexed_at})`);
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
} catch (deleteError) {
|
|
136
|
+
// Log but don't fail - we have the new record, old one is just orphaned
|
|
137
|
+
debug(`Warning: Failed to delete old record versions for: ${record.title}`, deleteError);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
debug(`Updated record: ${record.title}`);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
async delete(title: string): Promise<void> {
|
|
144
|
+
const table = await this.ensureTable();
|
|
145
|
+
const validTitle = validateTitle(title);
|
|
146
|
+
const escapedTitle = escapeForFilter(validTitle);
|
|
147
|
+
await table.delete(`title = '${escapedTitle}'`);
|
|
148
|
+
debug(`Deleted record: ${title}`);
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
async deleteByFolderAndTitle(folder: string, title: string): Promise<void> {
|
|
152
|
+
const table = await this.ensureTable();
|
|
153
|
+
const validTitle = validateTitle(title);
|
|
154
|
+
const escapedTitle = escapeForFilter(validTitle);
|
|
155
|
+
const escapedFolder = escapeForFilter(folder);
|
|
156
|
+
await table.delete(`folder = '${escapedFolder}' AND title = '${escapedTitle}'`);
|
|
157
|
+
debug(`Deleted record: ${folder}/${title}`);
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
async search(queryVector: number[], limit: number): Promise<SearchResult[]> {
|
|
161
|
+
const table = await this.ensureTable();
|
|
162
|
+
|
|
163
|
+
const results = await table
|
|
164
|
+
.search(queryVector)
|
|
165
|
+
.limit(limit)
|
|
166
|
+
.toArray();
|
|
167
|
+
|
|
168
|
+
return results.map((row, index) => ({
|
|
169
|
+
title: row.title as string,
|
|
170
|
+
folder: row.folder as string,
|
|
171
|
+
content: row.content as string,
|
|
172
|
+
modified: row.modified as string,
|
|
173
|
+
score: 1 / (1 + index), // Simple rank-based score
|
|
174
|
+
}));
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
async searchFTS(query: string, limit: number): Promise<SearchResult[]> {
|
|
178
|
+
const table = await this.ensureTable();
|
|
179
|
+
|
|
180
|
+
try {
|
|
181
|
+
// LanceDB FTS search - use queryType option
|
|
182
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
183
|
+
const results = await (table as any)
|
|
184
|
+
.search(query, { queryType: "fts" })
|
|
185
|
+
.limit(limit)
|
|
186
|
+
.toArray();
|
|
187
|
+
|
|
188
|
+
return results.map((row: Record<string, unknown>, index: number) => ({
|
|
189
|
+
title: row.title as string,
|
|
190
|
+
folder: row.folder as string,
|
|
191
|
+
content: row.content as string,
|
|
192
|
+
modified: row.modified as string,
|
|
193
|
+
score: 1 / (1 + index),
|
|
194
|
+
}));
|
|
195
|
+
} catch (error) {
|
|
196
|
+
// FTS might fail if no index or no matches
|
|
197
|
+
debug("FTS search failed, returning empty results. Error:", error);
|
|
198
|
+
return [];
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
async getByTitle(title: string): Promise<NoteRecord | null> {
|
|
203
|
+
const table = await this.ensureTable();
|
|
204
|
+
|
|
205
|
+
const validTitle = validateTitle(title);
|
|
206
|
+
const escapedTitle = escapeForFilter(validTitle);
|
|
207
|
+
const results = await table
|
|
208
|
+
.query()
|
|
209
|
+
.where(`title = '${escapedTitle}'`)
|
|
210
|
+
.limit(1)
|
|
211
|
+
.toArray();
|
|
212
|
+
|
|
213
|
+
if (results.length === 0) return null;
|
|
214
|
+
|
|
215
|
+
return results[0] as unknown as NoteRecord;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
async getAll(): Promise<NoteRecord[]> {
|
|
219
|
+
const table = await this.ensureTable();
|
|
220
|
+
|
|
221
|
+
const results = await table.query().toArray();
|
|
222
|
+
|
|
223
|
+
return results.map((row) => ({
|
|
224
|
+
title: row.title as string,
|
|
225
|
+
content: row.content as string,
|
|
226
|
+
vector: row.vector as number[],
|
|
227
|
+
folder: row.folder as string,
|
|
228
|
+
created: row.created as string,
|
|
229
|
+
modified: row.modified as string,
|
|
230
|
+
indexed_at: row.indexed_at as string,
|
|
231
|
+
}));
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
async count(): Promise<number> {
|
|
235
|
+
try {
|
|
236
|
+
const table = await this.ensureTable();
|
|
237
|
+
return await table.countRows();
|
|
238
|
+
} catch {
|
|
239
|
+
return 0;
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
async clear(): Promise<void> {
|
|
244
|
+
const db = await this.ensureConnection();
|
|
245
|
+
try {
|
|
246
|
+
await db.dropTable(this.tableName);
|
|
247
|
+
this.table = null;
|
|
248
|
+
debug("Cleared table");
|
|
249
|
+
} catch {
|
|
250
|
+
// Table didn't exist
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
// Singleton instance
|
|
256
|
+
let storeInstance: LanceDBStore | null = null;
|
|
257
|
+
|
|
258
|
+
export function getVectorStore(): VectorStore {
|
|
259
|
+
if (!storeInstance) {
|
|
260
|
+
storeInstance = new LanceDBStore();
|
|
261
|
+
}
|
|
262
|
+
return storeInstance;
|
|
263
|
+
}
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import { describe, it, expect } from "vitest";
|
|
2
|
+
import { validateTitle, escapeForFilter } from "./validation.js";
|
|
3
|
+
|
|
4
|
+
describe("validateTitle", () => {
|
|
5
|
+
it("accepts valid titles", () => {
|
|
6
|
+
expect(validateTitle("My Note")).toBe("My Note");
|
|
7
|
+
expect(validateTitle(" Trimmed ")).toBe("Trimmed");
|
|
8
|
+
expect(validateTitle("Note with numbers 123")).toBe("Note with numbers 123");
|
|
9
|
+
expect(validateTitle("Unicode: cestina 日本語")).toBe("Unicode: cestina 日本語");
|
|
10
|
+
});
|
|
11
|
+
|
|
12
|
+
it("rejects empty titles", () => {
|
|
13
|
+
expect(() => validateTitle("")).toThrow("non-empty");
|
|
14
|
+
expect(() => validateTitle(" ")).toThrow("empty or whitespace");
|
|
15
|
+
});
|
|
16
|
+
|
|
17
|
+
it("rejects very long titles", () => {
|
|
18
|
+
const longTitle = "a".repeat(501);
|
|
19
|
+
expect(() => validateTitle(longTitle)).toThrow("maximum length");
|
|
20
|
+
});
|
|
21
|
+
|
|
22
|
+
it("rejects invalid characters", () => {
|
|
23
|
+
expect(() => validateTitle("Note\x00with\x00null")).toThrow("invalid characters");
|
|
24
|
+
});
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
describe("validateTitle - allowed characters", () => {
|
|
28
|
+
it("allows common punctuation", () => {
|
|
29
|
+
expect(validateTitle("Note: My Title")).toBe("Note: My Title");
|
|
30
|
+
expect(validateTitle("Meeting (2024-01-08)")).toBe("Meeting (2024-01-08)");
|
|
31
|
+
expect(validateTitle("Q&A Session")).toBe("Q&A Session");
|
|
32
|
+
expect(validateTitle("To-Do List")).toBe("To-Do List");
|
|
33
|
+
expect(validateTitle("Notes #1")).toBe("Notes #1");
|
|
34
|
+
expect(validateTitle("50% Complete!")).toBe("50% Complete!");
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
it("allows special characters that appear in real note titles", () => {
|
|
38
|
+
// Forward slashes (common in paths/dates)
|
|
39
|
+
expect(validateTitle("Airdrops/zisky")).toBe("Airdrops/zisky");
|
|
40
|
+
// Apostrophes
|
|
41
|
+
expect(validateTitle("Disco's Notes")).toBe("Disco's Notes");
|
|
42
|
+
// Unicode symbols
|
|
43
|
+
expect(validateTitle("Task ➜ Done")).toBe("Task ➜ Done");
|
|
44
|
+
// Ellipsis
|
|
45
|
+
expect(validateTitle("Long title…")).toBe("Long title…");
|
|
46
|
+
// Backticks, pipes, angle brackets (now allowed)
|
|
47
|
+
expect(validateTitle("Note `code` here")).toBe("Note `code` here");
|
|
48
|
+
expect(validateTitle("Option A | Option B")).toBe("Option A | Option B");
|
|
49
|
+
expect(validateTitle("A > B < C")).toBe("A > B < C");
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
it("rejects control characters", () => {
|
|
53
|
+
expect(() => validateTitle("Note\x00null")).toThrow("invalid characters");
|
|
54
|
+
expect(() => validateTitle("Note\x1ftab")).toThrow("invalid characters");
|
|
55
|
+
expect(() => validateTitle("Note\x7fdel")).toThrow("invalid characters");
|
|
56
|
+
});
|
|
57
|
+
});
|
|
58
|
+
|
|
59
|
+
describe("escapeForFilter", () => {
|
|
60
|
+
it("escapes single quotes", () => {
|
|
61
|
+
expect(escapeForFilter("O'Brien")).toBe("O''Brien");
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
it("escapes backslashes", () => {
|
|
65
|
+
expect(escapeForFilter("path\\to\\note")).toBe("path\\\\to\\\\note");
|
|
66
|
+
});
|
|
67
|
+
|
|
68
|
+
it("escapes newlines and tabs", () => {
|
|
69
|
+
expect(escapeForFilter("line1\nline2")).toBe("line1\\nline2");
|
|
70
|
+
expect(escapeForFilter("col1\tcol2")).toBe("col1\\tcol2");
|
|
71
|
+
});
|
|
72
|
+
|
|
73
|
+
it("handles combined escapes", () => {
|
|
74
|
+
expect(escapeForFilter("O'Brien's\nnote")).toBe("O''Brien''s\\nnote");
|
|
75
|
+
});
|
|
76
|
+
});
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Input validation for database operations.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
import { MAX_TITLE_LENGTH } from "../config/constants.js";
|
|
6
|
+
|
|
7
|
+
// Pattern for allowed characters in titles
|
|
8
|
+
// Allows: letters, numbers, whitespace, punctuation, symbols, and most printable chars
|
|
9
|
+
// Security is handled by escapeForFilter(), so we can be permissive here
|
|
10
|
+
// Only reject control characters and null bytes
|
|
11
|
+
const SAFE_TITLE_PATTERN = /^[^\x00-\x1f\x7f]+$/u;
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Validate and sanitize a note title for database operations.
|
|
15
|
+
*
|
|
16
|
+
* @param title - The title to validate
|
|
17
|
+
* @throws Error if title is invalid
|
|
18
|
+
* @returns Sanitized title safe for database queries
|
|
19
|
+
*/
|
|
20
|
+
export function validateTitle(title: string): string {
|
|
21
|
+
if (!title || typeof title !== "string") {
|
|
22
|
+
throw new Error("Title must be a non-empty string");
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
const trimmed = title.trim();
|
|
26
|
+
|
|
27
|
+
if (trimmed.length === 0) {
|
|
28
|
+
throw new Error("Title cannot be empty or whitespace only");
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
if (trimmed.length > MAX_TITLE_LENGTH) {
|
|
32
|
+
throw new Error(`Title exceeds maximum length of ${MAX_TITLE_LENGTH} characters`);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
// Check for potentially dangerous characters
|
|
36
|
+
if (!SAFE_TITLE_PATTERN.test(trimmed)) {
|
|
37
|
+
throw new Error("Title contains invalid characters");
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
return trimmed;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
/**
|
|
44
|
+
* Escape a string for use in LanceDB SQL-like filters.
|
|
45
|
+
* Uses comprehensive escaping beyond just single quotes.
|
|
46
|
+
*
|
|
47
|
+
* @param value - The value to escape
|
|
48
|
+
* @returns Escaped string safe for filter expressions
|
|
49
|
+
*/
|
|
50
|
+
export function escapeForFilter(value: string): string {
|
|
51
|
+
return value
|
|
52
|
+
.replace(/\\/g, "\\\\") // Escape backslashes first
|
|
53
|
+
.replace(/'/g, "''") // Escape single quotes
|
|
54
|
+
.replace(/\n/g, "\\n") // Escape newlines
|
|
55
|
+
.replace(/\r/g, "\\r") // Escape carriage returns
|
|
56
|
+
.replace(/\t/g, "\\t"); // Escape tabs
|
|
57
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
|
2
|
+
|
|
3
|
+
describe("embeddings index", () => {
|
|
4
|
+
const originalApiKey = process.env.OPENROUTER_API_KEY;
|
|
5
|
+
|
|
6
|
+
beforeEach(() => {
|
|
7
|
+
vi.resetModules();
|
|
8
|
+
delete process.env.OPENROUTER_API_KEY;
|
|
9
|
+
});
|
|
10
|
+
|
|
11
|
+
afterEach(() => {
|
|
12
|
+
if (originalApiKey !== undefined) {
|
|
13
|
+
process.env.OPENROUTER_API_KEY = originalApiKey;
|
|
14
|
+
} else {
|
|
15
|
+
delete process.env.OPENROUTER_API_KEY;
|
|
16
|
+
}
|
|
17
|
+
});
|
|
18
|
+
|
|
19
|
+
describe("detectProvider", () => {
|
|
20
|
+
it("should detect local provider when no API key", async () => {
|
|
21
|
+
const { detectProvider } = await import("./index.js");
|
|
22
|
+
expect(detectProvider()).toBe("local");
|
|
23
|
+
});
|
|
24
|
+
|
|
25
|
+
it("should detect openrouter provider when API key is set", async () => {
|
|
26
|
+
process.env.OPENROUTER_API_KEY = "test-key";
|
|
27
|
+
const { detectProvider } = await import("./index.js");
|
|
28
|
+
expect(detectProvider()).toBe("openrouter");
|
|
29
|
+
});
|
|
30
|
+
});
|
|
31
|
+
|
|
32
|
+
describe("getProvider", () => {
|
|
33
|
+
it("should return detected provider", async () => {
|
|
34
|
+
const { getProvider } = await import("./index.js");
|
|
35
|
+
expect(getProvider()).toBe("local");
|
|
36
|
+
});
|
|
37
|
+
});
|
|
38
|
+
|
|
39
|
+
describe("getEmbeddingDimensions", () => {
|
|
40
|
+
it("should return dimensions for local provider", async () => {
|
|
41
|
+
const { getEmbeddingDimensions } = await import("./index.js");
|
|
42
|
+
expect(getEmbeddingDimensions()).toBe(384);
|
|
43
|
+
});
|
|
44
|
+
});
|
|
45
|
+
|
|
46
|
+
describe("getProviderDescription", () => {
|
|
47
|
+
it("should return description for local provider", async () => {
|
|
48
|
+
const { getProviderDescription } = await import("./index.js");
|
|
49
|
+
const desc = getProviderDescription();
|
|
50
|
+
expect(desc).toContain("Local");
|
|
51
|
+
expect(desc).toContain("384");
|
|
52
|
+
});
|
|
53
|
+
});
|
|
54
|
+
});
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Embedding provider auto-detection and unified interface.
|
|
3
|
+
*
|
|
4
|
+
* Automatically selects between:
|
|
5
|
+
* - OpenRouter (if OPENROUTER_API_KEY is set)
|
|
6
|
+
* - Local HuggingFace (fallback)
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
import { getOpenRouterEmbedding, getOpenRouterDimensions } from "./openrouter.js";
|
|
10
|
+
import { getLocalEmbedding, getLocalDimensions, getLocalModelName } from "./local.js";
|
|
11
|
+
import { createDebugLogger } from "../utils/debug.js";
|
|
12
|
+
|
|
13
|
+
// Debug logging
|
|
14
|
+
const debug = createDebugLogger("EMBED");
|
|
15
|
+
|
|
16
|
+
// Provider type
|
|
17
|
+
export type EmbeddingProvider = "openrouter" | "local";
|
|
18
|
+
|
|
19
|
+
// Detect which provider to use based on environment
|
|
20
|
+
let detectedProvider: EmbeddingProvider | null = null;
|
|
21
|
+
|
|
22
|
+
/**
|
|
23
|
+
* Detect and cache the embedding provider.
|
|
24
|
+
* Uses OpenRouter if API key is set, otherwise falls back to local.
|
|
25
|
+
*/
|
|
26
|
+
export function detectProvider(): EmbeddingProvider {
|
|
27
|
+
if (detectedProvider) {
|
|
28
|
+
return detectedProvider;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
if (process.env.OPENROUTER_API_KEY) {
|
|
32
|
+
detectedProvider = "openrouter";
|
|
33
|
+
debug("Using OpenRouter embeddings (OPENROUTER_API_KEY found)");
|
|
34
|
+
} else {
|
|
35
|
+
detectedProvider = "local";
|
|
36
|
+
debug(`Using local embeddings (${getLocalModelName()})`);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
return detectedProvider;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
/**
|
|
43
|
+
* Get the current embedding provider.
|
|
44
|
+
* Call detectProvider() first to ensure detection has occurred.
|
|
45
|
+
*/
|
|
46
|
+
export function getProvider(): EmbeddingProvider {
|
|
47
|
+
return detectedProvider || detectProvider();
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Generate embedding for text using the auto-detected provider.
|
|
52
|
+
*
|
|
53
|
+
* @param text - Text to embed
|
|
54
|
+
* @returns Promise resolving to embedding vector
|
|
55
|
+
*/
|
|
56
|
+
export async function getEmbedding(text: string): Promise<number[]> {
|
|
57
|
+
const provider = getProvider();
|
|
58
|
+
|
|
59
|
+
if (provider === "openrouter") {
|
|
60
|
+
return getOpenRouterEmbedding(text);
|
|
61
|
+
} else {
|
|
62
|
+
return getLocalEmbedding(text);
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
/**
|
|
67
|
+
* Get the embedding dimensions for the current provider.
|
|
68
|
+
*
|
|
69
|
+
* @returns Number of dimensions in embedding vectors
|
|
70
|
+
*/
|
|
71
|
+
export function getEmbeddingDimensions(): number {
|
|
72
|
+
const provider = getProvider();
|
|
73
|
+
|
|
74
|
+
if (provider === "openrouter") {
|
|
75
|
+
return getOpenRouterDimensions();
|
|
76
|
+
} else {
|
|
77
|
+
return getLocalDimensions();
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* Get a human-readable description of the current provider.
|
|
83
|
+
*
|
|
84
|
+
* @returns Provider description string
|
|
85
|
+
*/
|
|
86
|
+
export function getProviderDescription(): string {
|
|
87
|
+
const provider = getProvider();
|
|
88
|
+
|
|
89
|
+
if (provider === "openrouter") {
|
|
90
|
+
const model = process.env.EMBEDDING_MODEL || "qwen/qwen3-embedding-8b";
|
|
91
|
+
const dims = getOpenRouterDimensions();
|
|
92
|
+
return `OpenRouter (${model}, ${dims} dims)`;
|
|
93
|
+
} else {
|
|
94
|
+
const model = getLocalModelName();
|
|
95
|
+
const dims = getLocalDimensions();
|
|
96
|
+
return `Local (${model}, ${dims} dims)`;
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
// Re-export individual providers for direct access if needed
|
|
101
|
+
export {
|
|
102
|
+
getOpenRouterEmbedding,
|
|
103
|
+
getOpenRouterDimensions,
|
|
104
|
+
} from "./openrouter.js";
|
|
105
|
+
|
|
106
|
+
export {
|
|
107
|
+
getLocalEmbedding,
|
|
108
|
+
getLocalDimensions,
|
|
109
|
+
getLocalModelName,
|
|
110
|
+
isModelLoaded,
|
|
111
|
+
} from "./local.js";
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
|
2
|
+
|
|
3
|
+
describe("local embeddings", () => {
|
|
4
|
+
const originalEnv = process.env.EMBEDDING_MODEL;
|
|
5
|
+
|
|
6
|
+
beforeEach(() => {
|
|
7
|
+
vi.resetModules();
|
|
8
|
+
delete process.env.EMBEDDING_MODEL;
|
|
9
|
+
});
|
|
10
|
+
|
|
11
|
+
afterEach(() => {
|
|
12
|
+
if (originalEnv !== undefined) {
|
|
13
|
+
process.env.EMBEDDING_MODEL = originalEnv;
|
|
14
|
+
} else {
|
|
15
|
+
delete process.env.EMBEDDING_MODEL;
|
|
16
|
+
}
|
|
17
|
+
});
|
|
18
|
+
|
|
19
|
+
describe("getLocalDimensions", () => {
|
|
20
|
+
it("should return 384 for default model", async () => {
|
|
21
|
+
const { getLocalDimensions } = await import("./local.js");
|
|
22
|
+
expect(getLocalDimensions()).toBe(384);
|
|
23
|
+
});
|
|
24
|
+
|
|
25
|
+
it("should return 1024 for bge-m3 model", async () => {
|
|
26
|
+
process.env.EMBEDDING_MODEL = "Xenova/bge-m3";
|
|
27
|
+
const { getLocalDimensions } = await import("./local.js");
|
|
28
|
+
expect(getLocalDimensions()).toBe(1024);
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
it("should return default for unknown model", async () => {
|
|
32
|
+
process.env.EMBEDDING_MODEL = "unknown/model";
|
|
33
|
+
const { getLocalDimensions } = await import("./local.js");
|
|
34
|
+
expect(getLocalDimensions()).toBe(384); // DEFAULT_LOCAL_EMBEDDING_DIMS
|
|
35
|
+
});
|
|
36
|
+
});
|
|
37
|
+
|
|
38
|
+
describe("getLocalModelName", () => {
|
|
39
|
+
it("should return default model when env not set", async () => {
|
|
40
|
+
const { getLocalModelName } = await import("./local.js");
|
|
41
|
+
expect(getLocalModelName()).toBe("Xenova/multilingual-e5-small");
|
|
42
|
+
});
|
|
43
|
+
|
|
44
|
+
it("should return env model when set", async () => {
|
|
45
|
+
process.env.EMBEDDING_MODEL = "Xenova/all-MiniLM-L6-v2";
|
|
46
|
+
const { getLocalModelName } = await import("./local.js");
|
|
47
|
+
expect(getLocalModelName()).toBe("Xenova/all-MiniLM-L6-v2");
|
|
48
|
+
});
|
|
49
|
+
});
|
|
50
|
+
|
|
51
|
+
describe("isModelLoaded", () => {
|
|
52
|
+
it("should return false before first embedding call", async () => {
|
|
53
|
+
const { isModelLoaded } = await import("./local.js");
|
|
54
|
+
expect(isModelLoaded()).toBe(false);
|
|
55
|
+
});
|
|
56
|
+
});
|
|
57
|
+
|
|
58
|
+
describe("getLocalEmbedding", () => {
|
|
59
|
+
it("should throw on empty text", async () => {
|
|
60
|
+
const { getLocalEmbedding } = await import("./local.js");
|
|
61
|
+
await expect(getLocalEmbedding("")).rejects.toThrow("non-empty string");
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
it("should throw on non-string input", async () => {
|
|
65
|
+
const { getLocalEmbedding } = await import("./local.js");
|
|
66
|
+
// @ts-expect-error - testing runtime validation
|
|
67
|
+
await expect(getLocalEmbedding(null)).rejects.toThrow("non-empty string");
|
|
68
|
+
});
|
|
69
|
+
});
|
|
70
|
+
});
|