@disco_trooper/apple-notes-mcp 1.2.0 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +104 -24
- package/package.json +10 -8
- package/src/config/claude.test.ts +47 -0
- package/src/config/claude.ts +106 -0
- package/src/config/constants.ts +11 -2
- package/src/config/paths.test.ts +40 -0
- package/src/config/paths.ts +86 -0
- package/src/db/arrow-fix.test.ts +101 -0
- package/src/db/lancedb.test.ts +209 -2
- package/src/db/lancedb.ts +345 -7
- package/src/embeddings/cache.test.ts +150 -0
- package/src/embeddings/cache.ts +204 -0
- package/src/embeddings/index.ts +21 -2
- package/src/embeddings/local.ts +61 -10
- package/src/embeddings/openrouter.ts +233 -11
- package/src/graph/export.test.ts +81 -0
- package/src/graph/export.ts +163 -0
- package/src/graph/extract.test.ts +90 -0
- package/src/graph/extract.ts +52 -0
- package/src/graph/queries.test.ts +156 -0
- package/src/graph/queries.ts +224 -0
- package/src/index.ts +249 -9
- package/src/notes/crud.test.ts +26 -2
- package/src/notes/crud.ts +43 -5
- package/src/notes/read.ts +83 -68
- package/src/search/chunk-indexer.test.ts +353 -0
- package/src/search/chunk-indexer.ts +207 -0
- package/src/search/chunk-search.test.ts +327 -0
- package/src/search/chunk-search.ts +298 -0
- package/src/search/indexer.ts +151 -109
- package/src/setup.ts +46 -67
- package/src/utils/chunker.test.ts +182 -0
- package/src/utils/chunker.ts +170 -0
- package/src/utils/content-filter.test.ts +225 -0
- package/src/utils/content-filter.ts +275 -0
- package/src/utils/runtime.test.ts +70 -0
- package/src/utils/runtime.ts +40 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import { describe, it, expect, vi, beforeEach } from "vitest";
|
|
2
|
+
import { getEmbeddingCache, resetEmbeddingCache } from "./cache.js";
|
|
3
|
+
|
|
4
|
+
describe("EmbeddingCache", () => {
|
|
5
|
+
beforeEach(() => {
|
|
6
|
+
resetEmbeddingCache();
|
|
7
|
+
});
|
|
8
|
+
|
|
9
|
+
describe("get/set", () => {
|
|
10
|
+
it("returns undefined for uncached query", () => {
|
|
11
|
+
const cache = getEmbeddingCache();
|
|
12
|
+
expect(cache.get("test query")).toBeUndefined();
|
|
13
|
+
});
|
|
14
|
+
|
|
15
|
+
it("returns cached embedding", () => {
|
|
16
|
+
const cache = getEmbeddingCache();
|
|
17
|
+
const embedding = [0.1, 0.2, 0.3];
|
|
18
|
+
|
|
19
|
+
cache.set("test query", embedding);
|
|
20
|
+
expect(cache.get("test query")).toEqual(embedding);
|
|
21
|
+
});
|
|
22
|
+
|
|
23
|
+
it("normalizes queries for better hit rate", () => {
|
|
24
|
+
const cache = getEmbeddingCache();
|
|
25
|
+
const embedding = [0.1, 0.2, 0.3];
|
|
26
|
+
|
|
27
|
+
cache.set("Test Query", embedding);
|
|
28
|
+
// Should match with different casing/spacing
|
|
29
|
+
expect(cache.get("test query")).toEqual(embedding);
|
|
30
|
+
expect(cache.get(" TEST QUERY ")).toEqual(embedding);
|
|
31
|
+
});
|
|
32
|
+
});
|
|
33
|
+
|
|
34
|
+
describe("getOrCompute", () => {
|
|
35
|
+
it("calls compute function on cache miss", async () => {
|
|
36
|
+
const cache = getEmbeddingCache();
|
|
37
|
+
const computeFn = vi.fn().mockResolvedValue([0.1, 0.2, 0.3]);
|
|
38
|
+
|
|
39
|
+
const result = await cache.getOrCompute("test query", computeFn);
|
|
40
|
+
|
|
41
|
+
expect(computeFn).toHaveBeenCalledWith("test query");
|
|
42
|
+
expect(result).toEqual([0.1, 0.2, 0.3]);
|
|
43
|
+
});
|
|
44
|
+
|
|
45
|
+
it("returns cached value without calling compute", async () => {
|
|
46
|
+
const cache = getEmbeddingCache();
|
|
47
|
+
const embedding = [0.1, 0.2, 0.3];
|
|
48
|
+
cache.set("test query", embedding);
|
|
49
|
+
|
|
50
|
+
const computeFn = vi.fn().mockResolvedValue([0.4, 0.5, 0.6]);
|
|
51
|
+
const result = await cache.getOrCompute("test query", computeFn);
|
|
52
|
+
|
|
53
|
+
expect(computeFn).not.toHaveBeenCalled();
|
|
54
|
+
expect(result).toEqual(embedding);
|
|
55
|
+
});
|
|
56
|
+
|
|
57
|
+
it("caches computed value for subsequent calls", async () => {
|
|
58
|
+
const cache = getEmbeddingCache();
|
|
59
|
+
const computeFn = vi.fn().mockResolvedValue([0.1, 0.2, 0.3]);
|
|
60
|
+
|
|
61
|
+
await cache.getOrCompute("test query", computeFn);
|
|
62
|
+
await cache.getOrCompute("test query", computeFn);
|
|
63
|
+
|
|
64
|
+
expect(computeFn).toHaveBeenCalledTimes(1);
|
|
65
|
+
});
|
|
66
|
+
});
|
|
67
|
+
|
|
68
|
+
describe("LRU eviction", () => {
|
|
69
|
+
it("evicts oldest entry when at capacity", () => {
|
|
70
|
+
// Create cache with small size for testing
|
|
71
|
+
resetEmbeddingCache();
|
|
72
|
+
const cache = getEmbeddingCache();
|
|
73
|
+
// We can't easily change max size, but we can test stats
|
|
74
|
+
|
|
75
|
+
// Fill cache with entries
|
|
76
|
+
for (let i = 0; i < 5; i++) {
|
|
77
|
+
cache.set(`query ${i}`, [i]);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
const stats = cache.getStats();
|
|
81
|
+
expect(stats.size).toBeGreaterThan(0);
|
|
82
|
+
});
|
|
83
|
+
});
|
|
84
|
+
|
|
85
|
+
describe("stats", () => {
|
|
86
|
+
it("tracks hits and misses", () => {
|
|
87
|
+
const cache = getEmbeddingCache();
|
|
88
|
+
cache.set("query1", [0.1]);
|
|
89
|
+
|
|
90
|
+
cache.get("query1"); // hit
|
|
91
|
+
cache.get("query2"); // miss
|
|
92
|
+
cache.get("query1"); // hit
|
|
93
|
+
cache.get("query3"); // miss
|
|
94
|
+
|
|
95
|
+
const stats = cache.getStats();
|
|
96
|
+
expect(stats.hits).toBe(2);
|
|
97
|
+
expect(stats.misses).toBe(2);
|
|
98
|
+
expect(stats.hitRate).toBe(0.5);
|
|
99
|
+
});
|
|
100
|
+
});
|
|
101
|
+
|
|
102
|
+
describe("clear", () => {
|
|
103
|
+
it("clears all cached embeddings", () => {
|
|
104
|
+
const cache = getEmbeddingCache();
|
|
105
|
+
cache.set("query1", [0.1]);
|
|
106
|
+
cache.set("query2", [0.2]);
|
|
107
|
+
|
|
108
|
+
cache.clear();
|
|
109
|
+
|
|
110
|
+
expect(cache.get("query1")).toBeUndefined();
|
|
111
|
+
expect(cache.get("query2")).toBeUndefined();
|
|
112
|
+
expect(cache.getStats().size).toBe(0);
|
|
113
|
+
});
|
|
114
|
+
|
|
115
|
+
it("resets stats on clear", () => {
|
|
116
|
+
const cache = getEmbeddingCache();
|
|
117
|
+
cache.set("query1", [0.1]);
|
|
118
|
+
cache.get("query1");
|
|
119
|
+
cache.get("query2");
|
|
120
|
+
|
|
121
|
+
cache.clear();
|
|
122
|
+
|
|
123
|
+
const stats = cache.getStats();
|
|
124
|
+
expect(stats.hits).toBe(0);
|
|
125
|
+
expect(stats.misses).toBe(0);
|
|
126
|
+
});
|
|
127
|
+
});
|
|
128
|
+
|
|
129
|
+
describe("model version", () => {
|
|
130
|
+
it("invalidates cache when model version changes", () => {
|
|
131
|
+
const cache = getEmbeddingCache();
|
|
132
|
+
cache.set("query1", [0.1]);
|
|
133
|
+
|
|
134
|
+
cache.setModelVersion("new-model-v2");
|
|
135
|
+
|
|
136
|
+
expect(cache.get("query1")).toBeUndefined();
|
|
137
|
+
});
|
|
138
|
+
|
|
139
|
+
it("does not invalidate if version unchanged", () => {
|
|
140
|
+
const cache = getEmbeddingCache();
|
|
141
|
+
cache.set("query1", [0.1]);
|
|
142
|
+
|
|
143
|
+
cache.setModelVersion("default"); // Same as initial
|
|
144
|
+
cache.setModelVersion("default"); // Same again
|
|
145
|
+
|
|
146
|
+
// Cache should still have the value
|
|
147
|
+
expect(cache.get("query1")).toEqual([0.1]);
|
|
148
|
+
});
|
|
149
|
+
});
|
|
150
|
+
});
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* LRU Cache for query embeddings.
|
|
3
|
+
* Dramatically speeds up hybrid search by caching repeated queries.
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
import { createDebugLogger } from "../utils/debug.js";
|
|
7
|
+
|
|
8
|
+
const debug = createDebugLogger("EMBED_CACHE");
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* Simple LRU Cache implementation for embeddings.
|
|
12
|
+
*/
|
|
13
|
+
class LRUCache<K, V> {
|
|
14
|
+
private cache = new Map<K, V>();
|
|
15
|
+
private readonly maxSize: number;
|
|
16
|
+
|
|
17
|
+
constructor(maxSize: number) {
|
|
18
|
+
this.maxSize = maxSize;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
get(key: K): V | undefined {
|
|
22
|
+
const value = this.cache.get(key);
|
|
23
|
+
if (value !== undefined) {
|
|
24
|
+
// Move to end (most recently used)
|
|
25
|
+
this.cache.delete(key);
|
|
26
|
+
this.cache.set(key, value);
|
|
27
|
+
}
|
|
28
|
+
return value;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
set(key: K, value: V): void {
|
|
32
|
+
// Delete if exists (to update position)
|
|
33
|
+
if (this.cache.has(key)) {
|
|
34
|
+
this.cache.delete(key);
|
|
35
|
+
}
|
|
36
|
+
// Evict oldest if at capacity
|
|
37
|
+
else if (this.cache.size >= this.maxSize) {
|
|
38
|
+
const firstKey = this.cache.keys().next().value;
|
|
39
|
+
if (firstKey !== undefined) {
|
|
40
|
+
this.cache.delete(firstKey);
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
this.cache.set(key, value);
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
has(key: K): boolean {
|
|
47
|
+
return this.cache.has(key);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
clear(): void {
|
|
51
|
+
this.cache.clear();
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
get size(): number {
|
|
55
|
+
return this.cache.size;
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
/**
|
|
60
|
+
* Normalize query for better cache hit rate.
|
|
61
|
+
* - Lowercase
|
|
62
|
+
* - Trim whitespace
|
|
63
|
+
* - Collapse multiple spaces
|
|
64
|
+
*/
|
|
65
|
+
function normalizeQuery(query: string): string {
|
|
66
|
+
return query.toLowerCase().trim().replace(/\s+/g, " ");
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
/**
|
|
70
|
+
* Cache statistics for monitoring.
|
|
71
|
+
*/
|
|
72
|
+
export interface CacheStats {
|
|
73
|
+
hits: number;
|
|
74
|
+
misses: number;
|
|
75
|
+
size: number;
|
|
76
|
+
hitRate: number;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/**
|
|
80
|
+
* Embedding cache with LRU eviction.
|
|
81
|
+
*/
|
|
82
|
+
class EmbeddingCache {
|
|
83
|
+
private cache: LRUCache<string, number[]>;
|
|
84
|
+
private modelVersion: string;
|
|
85
|
+
private hits = 0;
|
|
86
|
+
private misses = 0;
|
|
87
|
+
|
|
88
|
+
constructor(maxSize = 1000, modelVersion = "default") {
|
|
89
|
+
this.cache = new LRUCache(maxSize);
|
|
90
|
+
this.modelVersion = modelVersion;
|
|
91
|
+
debug(`Embedding cache initialized (max: ${maxSize})`);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
/**
|
|
95
|
+
* Create cache key from query and model version.
|
|
96
|
+
*/
|
|
97
|
+
private makeKey(query: string): string {
|
|
98
|
+
const normalized = normalizeQuery(query);
|
|
99
|
+
return `${this.modelVersion}:${normalized}`;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
/**
|
|
103
|
+
* Get cached embedding for query.
|
|
104
|
+
* Returns undefined if not cached.
|
|
105
|
+
*/
|
|
106
|
+
get(query: string): number[] | undefined {
|
|
107
|
+
const key = this.makeKey(query);
|
|
108
|
+
const cached = this.cache.get(key);
|
|
109
|
+
|
|
110
|
+
if (cached) {
|
|
111
|
+
this.hits++;
|
|
112
|
+
debug(`Cache HIT for "${query.slice(0, 30)}..." (hits: ${this.hits})`);
|
|
113
|
+
return cached;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
this.misses++;
|
|
117
|
+
return undefined;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/**
|
|
121
|
+
* Store embedding in cache.
|
|
122
|
+
*/
|
|
123
|
+
set(query: string, embedding: number[]): void {
|
|
124
|
+
const key = this.makeKey(query);
|
|
125
|
+
this.cache.set(key, embedding);
|
|
126
|
+
debug(`Cached embedding for "${query.slice(0, 30)}..." (size: ${this.cache.size})`);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
/**
|
|
130
|
+
* Get or compute embedding using provided function.
|
|
131
|
+
* This is the main API for cached embedding retrieval.
|
|
132
|
+
*/
|
|
133
|
+
async getOrCompute(
|
|
134
|
+
query: string,
|
|
135
|
+
computeFn: (q: string) => Promise<number[]>
|
|
136
|
+
): Promise<number[]> {
|
|
137
|
+
const cached = this.get(query);
|
|
138
|
+
if (cached) {
|
|
139
|
+
return cached;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
const embedding = await computeFn(query);
|
|
143
|
+
this.set(query, embedding);
|
|
144
|
+
return embedding;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
/**
|
|
148
|
+
* Invalidate cache (e.g., when model changes).
|
|
149
|
+
*/
|
|
150
|
+
clear(): void {
|
|
151
|
+
this.cache.clear();
|
|
152
|
+
this.hits = 0;
|
|
153
|
+
this.misses = 0;
|
|
154
|
+
debug("Cache cleared");
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
/**
|
|
158
|
+
* Update model version and clear cache.
|
|
159
|
+
*/
|
|
160
|
+
setModelVersion(version: string): void {
|
|
161
|
+
if (version !== this.modelVersion) {
|
|
162
|
+
debug(`Model version changed: ${this.modelVersion} -> ${version}`);
|
|
163
|
+
this.modelVersion = version;
|
|
164
|
+
this.clear();
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
/**
|
|
169
|
+
* Get cache statistics.
|
|
170
|
+
*/
|
|
171
|
+
getStats(): CacheStats {
|
|
172
|
+
const total = this.hits + this.misses;
|
|
173
|
+
return {
|
|
174
|
+
hits: this.hits,
|
|
175
|
+
misses: this.misses,
|
|
176
|
+
size: this.cache.size,
|
|
177
|
+
hitRate: total > 0 ? this.hits / total : 0,
|
|
178
|
+
};
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// Singleton instance
|
|
183
|
+
let cacheInstance: EmbeddingCache | null = null;
|
|
184
|
+
|
|
185
|
+
/**
|
|
186
|
+
* Get the embedding cache singleton.
|
|
187
|
+
*/
|
|
188
|
+
export function getEmbeddingCache(): EmbeddingCache {
|
|
189
|
+
if (!cacheInstance) {
|
|
190
|
+
// Max 1000 queries * ~1.5KB per embedding = ~1.5MB
|
|
191
|
+
cacheInstance = new EmbeddingCache(1000);
|
|
192
|
+
}
|
|
193
|
+
return cacheInstance;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
/**
|
|
197
|
+
* Reset the cache (useful for testing).
|
|
198
|
+
*/
|
|
199
|
+
export function resetEmbeddingCache(): void {
|
|
200
|
+
if (cacheInstance) {
|
|
201
|
+
cacheInstance.clear();
|
|
202
|
+
}
|
|
203
|
+
cacheInstance = null;
|
|
204
|
+
}
|
package/src/embeddings/index.ts
CHANGED
|
@@ -6,8 +6,8 @@
|
|
|
6
6
|
* - Local HuggingFace (fallback)
|
|
7
7
|
*/
|
|
8
8
|
|
|
9
|
-
import { getOpenRouterEmbedding, getOpenRouterDimensions } from "./openrouter.js";
|
|
10
|
-
import { getLocalEmbedding, getLocalDimensions, getLocalModelName } from "./local.js";
|
|
9
|
+
import { getOpenRouterEmbedding, getOpenRouterDimensions, getOpenRouterEmbeddingBatch } from "./openrouter.js";
|
|
10
|
+
import { getLocalEmbedding, getLocalDimensions, getLocalModelName, getLocalEmbeddingBatch } from "./local.js";
|
|
11
11
|
import { createDebugLogger } from "../utils/debug.js";
|
|
12
12
|
|
|
13
13
|
// Debug logging
|
|
@@ -62,6 +62,23 @@ export async function getEmbedding(text: string): Promise<number[]> {
|
|
|
62
62
|
}
|
|
63
63
|
}
|
|
64
64
|
|
|
65
|
+
/**
|
|
66
|
+
* Generate embeddings for multiple texts in batch.
|
|
67
|
+
* Uses native batch API for both OpenRouter and local providers.
|
|
68
|
+
*
|
|
69
|
+
* @param texts - Array of texts to embed
|
|
70
|
+
* @returns Promise resolving to array of embedding vectors
|
|
71
|
+
*/
|
|
72
|
+
export async function getEmbeddingBatch(texts: string[]): Promise<number[][]> {
|
|
73
|
+
const provider = getProvider();
|
|
74
|
+
|
|
75
|
+
if (provider === "openrouter") {
|
|
76
|
+
return getOpenRouterEmbeddingBatch(texts);
|
|
77
|
+
} else {
|
|
78
|
+
return getLocalEmbeddingBatch(texts);
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
65
82
|
/**
|
|
66
83
|
* Get the embedding dimensions for the current provider.
|
|
67
84
|
*
|
|
@@ -100,10 +117,12 @@ export function getProviderDescription(): string {
|
|
|
100
117
|
export {
|
|
101
118
|
getOpenRouterEmbedding,
|
|
102
119
|
getOpenRouterDimensions,
|
|
120
|
+
getOpenRouterEmbeddingBatch,
|
|
103
121
|
} from "./openrouter.js";
|
|
104
122
|
|
|
105
123
|
export {
|
|
106
124
|
getLocalEmbedding,
|
|
125
|
+
getLocalEmbeddingBatch,
|
|
107
126
|
getLocalDimensions,
|
|
108
127
|
getLocalModelName,
|
|
109
128
|
isModelLoaded,
|
package/src/embeddings/local.ts
CHANGED
|
@@ -25,7 +25,7 @@ const debug = createDebugLogger("LOCAL");
|
|
|
25
25
|
|
|
26
26
|
// Lazy-loaded pipeline
|
|
27
27
|
type FeatureExtractionPipeline = (
|
|
28
|
-
text: string,
|
|
28
|
+
text: string | string[],
|
|
29
29
|
options?: { pooling?: string; normalize?: boolean }
|
|
30
30
|
) => Promise<{ tolist: () => number[][] }>;
|
|
31
31
|
|
|
@@ -40,6 +40,27 @@ function getModelName(): string {
|
|
|
40
40
|
return process.env.EMBEDDING_MODEL || DEFAULT_MODEL;
|
|
41
41
|
}
|
|
42
42
|
|
|
43
|
+
/**
|
|
44
|
+
* Check if the model is an E5 model that requires prefixed input.
|
|
45
|
+
*/
|
|
46
|
+
function isE5Model(): boolean {
|
|
47
|
+
return getModelName().toLowerCase().includes("e5");
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Prepare text for embedding by adding E5 prefix if needed.
|
|
52
|
+
*/
|
|
53
|
+
function prepareText(text: string): string {
|
|
54
|
+
return isE5Model() ? `passage: ${text}` : text;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
/**
|
|
58
|
+
* Prepare multiple texts for embedding by adding E5 prefix if needed.
|
|
59
|
+
*/
|
|
60
|
+
function prepareTexts(texts: string[]): string[] {
|
|
61
|
+
return isE5Model() ? texts.map(t => `passage: ${t}`) : texts;
|
|
62
|
+
}
|
|
63
|
+
|
|
43
64
|
/**
|
|
44
65
|
* Lazy-load the HuggingFace transformers pipeline.
|
|
45
66
|
* Only loads once, subsequent calls return the cached instance.
|
|
@@ -116,19 +137,11 @@ export async function getLocalEmbedding(text: string): Promise<number[]> {
|
|
|
116
137
|
const startTime = Date.now();
|
|
117
138
|
|
|
118
139
|
try {
|
|
119
|
-
|
|
120
|
-
// or "query: " for search queries - using passage for general text
|
|
121
|
-
const modelName = getModelName();
|
|
122
|
-
const isE5Model = modelName.toLowerCase().includes("e5");
|
|
123
|
-
const inputText = isE5Model ? `passage: ${text}` : text;
|
|
124
|
-
|
|
125
|
-
// Run inference with mean pooling and normalization
|
|
126
|
-
const output = await pipe(inputText, {
|
|
140
|
+
const output = await pipe(prepareText(text), {
|
|
127
141
|
pooling: "mean",
|
|
128
142
|
normalize: true,
|
|
129
143
|
});
|
|
130
144
|
|
|
131
|
-
// Extract the embedding vector
|
|
132
145
|
const embedding = output.tolist()[0];
|
|
133
146
|
|
|
134
147
|
const inferenceTime = Date.now() - startTime;
|
|
@@ -178,3 +191,41 @@ export function getLocalModelName(): string {
|
|
|
178
191
|
export function isModelLoaded(): boolean {
|
|
179
192
|
return pipelineInstance !== null;
|
|
180
193
|
}
|
|
194
|
+
|
|
195
|
+
/**
|
|
196
|
+
* Generate embeddings for multiple texts in a single batch call.
|
|
197
|
+
* More efficient than calling getLocalEmbedding for each text individually.
|
|
198
|
+
*
|
|
199
|
+
* @param texts - Array of texts to embed
|
|
200
|
+
* @returns Promise resolving to array of embedding vectors
|
|
201
|
+
* @throws Error if model loading or inference fails
|
|
202
|
+
*/
|
|
203
|
+
export async function getLocalEmbeddingBatch(texts: string[]): Promise<number[][]> {
|
|
204
|
+
if (!texts || texts.length === 0) {
|
|
205
|
+
return [];
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
const pipe = await getPipeline();
|
|
209
|
+
|
|
210
|
+
debug(`Generating batch embeddings for ${texts.length} texts`);
|
|
211
|
+
const startTime = Date.now();
|
|
212
|
+
|
|
213
|
+
try {
|
|
214
|
+
const output = await pipe(prepareTexts(texts), {
|
|
215
|
+
pooling: "mean",
|
|
216
|
+
normalize: true,
|
|
217
|
+
});
|
|
218
|
+
|
|
219
|
+
const embeddings = output.tolist() as number[][];
|
|
220
|
+
|
|
221
|
+
const inferenceTime = Date.now() - startTime;
|
|
222
|
+
debug(`Batch embeddings generated in ${inferenceTime}ms (${embeddings.length} vectors, ${embeddings[0]?.length ?? 0} dims)`);
|
|
223
|
+
|
|
224
|
+
return embeddings;
|
|
225
|
+
} catch (error) {
|
|
226
|
+
const message = error instanceof Error ? error.message : String(error);
|
|
227
|
+
debug(`Batch embedding generation failed: ${message}`);
|
|
228
|
+
|
|
229
|
+
throw new Error(`Failed to generate batch embeddings: ${message}`);
|
|
230
|
+
}
|
|
231
|
+
}
|