memory-braid 0.2.0 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,354 @@
1
+ import os from "node:os";
2
+ import path from "node:path";
3
+ import { normalizeWhitespace } from "./chunking.js";
4
+ import type { MemoryBraidConfig } from "./config.js";
5
+ import { MemoryBraidLogger } from "./logger.js";
6
+
7
+ type NerPipeline = (text: string, options?: Record<string, unknown>) => Promise<unknown>;
8
+
9
+ type NerRecord = {
10
+ word?: unknown;
11
+ entity_group?: unknown;
12
+ entity?: unknown;
13
+ score?: unknown;
14
+ };
15
+
16
+ export type ExtractedEntity = {
17
+ text: string;
18
+ type: "person" | "organization" | "location" | "misc";
19
+ score: number;
20
+ canonicalUri: string;
21
+ };
22
+
23
+ function summarizeEntityTypes(entities: ExtractedEntity[]): Record<string, number> {
24
+ const summary: Record<string, number> = {};
25
+ for (const entity of entities) {
26
+ summary[entity.type] = (summary[entity.type] ?? 0) + 1;
27
+ }
28
+ return summary;
29
+ }
30
+
31
+ function resolveStateDir(explicitStateDir?: string): string {
32
+ const resolved =
33
+ explicitStateDir?.trim() ||
34
+ process.env.OPENCLAW_STATE_DIR?.trim() ||
35
+ path.join(os.homedir(), ".openclaw");
36
+ return path.resolve(resolved);
37
+ }
38
+
39
+ export function resolveEntityModelCacheDir(stateDir?: string): string {
40
+ return path.join(resolveStateDir(stateDir), "memory-braid", "models", "entity-extraction");
41
+ }
42
+
43
+ function slugify(value: string): string {
44
+ const ascii = value
45
+ .normalize("NFKD")
46
+ .replace(/[\u0300-\u036f]/g, "");
47
+ const slug = ascii
48
+ .toLowerCase()
49
+ .replace(/[^a-z0-9]+/g, "-")
50
+ .replace(/^-+|-+$/g, "");
51
+ return slug || "unknown";
52
+ }
53
+
54
+ export function buildCanonicalEntityUri(
55
+ type: ExtractedEntity["type"],
56
+ text: string,
57
+ ): string {
58
+ return `entity://${type}/${slugify(text)}`;
59
+ }
60
+
61
+ function normalizeEntityType(raw: unknown): ExtractedEntity["type"] {
62
+ const label = typeof raw === "string" ? raw.toUpperCase() : "";
63
+ if (label.includes("PER")) {
64
+ return "person";
65
+ }
66
+ if (label.includes("ORG")) {
67
+ return "organization";
68
+ }
69
+ if (label.includes("LOC") || label.includes("GPE")) {
70
+ return "location";
71
+ }
72
+ return "misc";
73
+ }
74
+
75
+ function normalizeEntityText(raw: unknown): string {
76
+ if (typeof raw !== "string") {
77
+ return "";
78
+ }
79
+ return normalizeWhitespace(raw.replace(/^##/, "").replace(/^▁/, ""));
80
+ }
81
+
82
+ type EntityExtractionOptions = {
83
+ stateDir?: string;
84
+ };
85
+
86
+ export class EntityExtractionManager {
87
+ private readonly cfg: MemoryBraidConfig["entityExtraction"];
88
+ private readonly log: MemoryBraidLogger;
89
+ private stateDir?: string;
90
+ private pipelinePromise: Promise<NerPipeline | null> | null = null;
91
+
92
+ constructor(
93
+ cfg: MemoryBraidConfig["entityExtraction"],
94
+ log: MemoryBraidLogger,
95
+ options?: EntityExtractionOptions,
96
+ ) {
97
+ this.cfg = cfg;
98
+ this.log = log;
99
+ this.stateDir = options?.stateDir;
100
+ }
101
+
102
+ setStateDir(stateDir?: string): void {
103
+ const next = stateDir?.trim();
104
+ if (!next || next === this.stateDir) {
105
+ return;
106
+ }
107
+ this.stateDir = next;
108
+ this.pipelinePromise = null;
109
+ }
110
+
111
+ getStatus(): {
112
+ enabled: boolean;
113
+ provider: MemoryBraidConfig["entityExtraction"]["provider"];
114
+ model: string;
115
+ minScore: number;
116
+ maxEntitiesPerMemory: number;
117
+ cacheDir: string;
118
+ } {
119
+ return {
120
+ enabled: this.cfg.enabled,
121
+ provider: this.cfg.provider,
122
+ model: this.cfg.model,
123
+ minScore: this.cfg.minScore,
124
+ maxEntitiesPerMemory: this.cfg.maxEntitiesPerMemory,
125
+ cacheDir: resolveEntityModelCacheDir(this.stateDir),
126
+ };
127
+ }
128
+
129
+ async warmup(params?: {
130
+ runId?: string;
131
+ reason?: string;
132
+ forceReload?: boolean;
133
+ text?: string;
134
+ }): Promise<{
135
+ ok: boolean;
136
+ cacheDir: string;
137
+ model: string;
138
+ entities: number;
139
+ durMs: number;
140
+ error?: string;
141
+ }> {
142
+ const startedAt = Date.now();
143
+ if (!this.cfg.enabled) {
144
+ return {
145
+ ok: false,
146
+ cacheDir: resolveEntityModelCacheDir(this.stateDir),
147
+ model: this.cfg.model,
148
+ entities: 0,
149
+ durMs: Date.now() - startedAt,
150
+ error: "entity_extraction_disabled",
151
+ };
152
+ }
153
+
154
+ const pipeline = await this.ensurePipeline(params?.forceReload);
155
+ if (!pipeline) {
156
+ return {
157
+ ok: false,
158
+ cacheDir: resolveEntityModelCacheDir(this.stateDir),
159
+ model: this.cfg.model,
160
+ entities: 0,
161
+ durMs: Date.now() - startedAt,
162
+ error: "model_load_failed",
163
+ };
164
+ }
165
+
166
+ try {
167
+ const entities = await this.extractWithPipeline({
168
+ pipeline,
169
+ text: params?.text ?? this.cfg.startup.warmupText,
170
+ });
171
+ this.log.info("memory_braid.entity.warmup", {
172
+ runId: params?.runId,
173
+ reason: params?.reason ?? "manual",
174
+ provider: this.cfg.provider,
175
+ model: this.cfg.model,
176
+ cacheDir: resolveEntityModelCacheDir(this.stateDir),
177
+ entities: entities.length,
178
+ entityTypes: summarizeEntityTypes(entities),
179
+ sampleEntityUris: entities.slice(0, 5).map((entry) => entry.canonicalUri),
180
+ durMs: Date.now() - startedAt,
181
+ });
182
+ return {
183
+ ok: true,
184
+ cacheDir: resolveEntityModelCacheDir(this.stateDir),
185
+ model: this.cfg.model,
186
+ entities: entities.length,
187
+ durMs: Date.now() - startedAt,
188
+ };
189
+ } catch (err) {
190
+ const message = err instanceof Error ? err.message : String(err);
191
+ this.log.warn("memory_braid.entity.warmup", {
192
+ runId: params?.runId,
193
+ reason: params?.reason ?? "manual",
194
+ provider: this.cfg.provider,
195
+ model: this.cfg.model,
196
+ cacheDir: resolveEntityModelCacheDir(this.stateDir),
197
+ error: message,
198
+ });
199
+ return {
200
+ ok: false,
201
+ cacheDir: resolveEntityModelCacheDir(this.stateDir),
202
+ model: this.cfg.model,
203
+ entities: 0,
204
+ durMs: Date.now() - startedAt,
205
+ error: message,
206
+ };
207
+ }
208
+ }
209
+
210
+ async extract(params: { text: string; runId?: string }): Promise<ExtractedEntity[]> {
211
+ if (!this.cfg.enabled) {
212
+ return [];
213
+ }
214
+
215
+ const text = normalizeWhitespace(params.text);
216
+ if (!text) {
217
+ return [];
218
+ }
219
+
220
+ const pipeline = await this.ensurePipeline();
221
+ if (!pipeline) {
222
+ return [];
223
+ }
224
+
225
+ try {
226
+ const entities = await this.extractWithPipeline({ pipeline, text });
227
+ this.log.debug("memory_braid.entity.extract", {
228
+ runId: params.runId,
229
+ provider: this.cfg.provider,
230
+ model: this.cfg.model,
231
+ entities: entities.length,
232
+ entityTypes: summarizeEntityTypes(entities),
233
+ sampleEntityUris: entities.slice(0, 5).map((entry) => entry.canonicalUri),
234
+ });
235
+ return entities;
236
+ } catch (err) {
237
+ this.log.warn("memory_braid.entity.extract", {
238
+ runId: params.runId,
239
+ provider: this.cfg.provider,
240
+ model: this.cfg.model,
241
+ error: err instanceof Error ? err.message : String(err),
242
+ });
243
+ return [];
244
+ }
245
+ }
246
+
247
+ private async ensurePipeline(forceReload = false): Promise<NerPipeline | null> {
248
+ if (!this.cfg.enabled) {
249
+ return null;
250
+ }
251
+
252
+ if (forceReload) {
253
+ this.pipelinePromise = null;
254
+ }
255
+
256
+ if (this.pipelinePromise) {
257
+ return this.pipelinePromise;
258
+ }
259
+
260
+ this.pipelinePromise = this.loadPipeline();
261
+ return this.pipelinePromise;
262
+ }
263
+
264
+ private async loadPipeline(): Promise<NerPipeline | null> {
265
+ const cacheDir = resolveEntityModelCacheDir(this.stateDir);
266
+ this.log.info("memory_braid.entity.model_load", {
267
+ provider: this.cfg.provider,
268
+ model: this.cfg.model,
269
+ cacheDir,
270
+ });
271
+
272
+ try {
273
+ const mod = (await import("@xenova/transformers")) as {
274
+ env?: Record<string, unknown>;
275
+ pipeline?: (
276
+ task: string,
277
+ model: string,
278
+ options?: Record<string, unknown>,
279
+ ) => Promise<unknown>;
280
+ };
281
+
282
+ if (!mod.pipeline) {
283
+ throw new Error("@xenova/transformers pipeline export not found");
284
+ }
285
+
286
+ if (mod.env) {
287
+ mod.env.cacheDir = cacheDir;
288
+ mod.env.allowRemoteModels = true;
289
+ mod.env.allowLocalModels = true;
290
+ mod.env.useFS = true;
291
+ }
292
+
293
+ const classifier = await mod.pipeline("token-classification", this.cfg.model, {
294
+ quantized: true,
295
+ });
296
+
297
+ if (typeof classifier !== "function") {
298
+ throw new Error("token-classification pipeline is not callable");
299
+ }
300
+
301
+ return classifier as NerPipeline;
302
+ } catch (err) {
303
+ this.log.error("memory_braid.entity.model_load", {
304
+ provider: this.cfg.provider,
305
+ model: this.cfg.model,
306
+ cacheDir,
307
+ error: err instanceof Error ? err.message : String(err),
308
+ });
309
+ return null;
310
+ }
311
+ }
312
+
313
+ private async extractWithPipeline(params: {
314
+ pipeline: NerPipeline;
315
+ text: string;
316
+ }): Promise<ExtractedEntity[]> {
317
+ const raw = await params.pipeline(params.text, {
318
+ aggregation_strategy: "simple",
319
+ });
320
+ const rows = Array.isArray(raw) ? raw : [];
321
+
322
+ const deduped = new Map<string, ExtractedEntity>();
323
+ for (const row of rows) {
324
+ if (!row || typeof row !== "object") {
325
+ continue;
326
+ }
327
+ const record = row as NerRecord;
328
+ const entityText = normalizeEntityText(record.word);
329
+ if (!entityText) {
330
+ continue;
331
+ }
332
+ const score = typeof record.score === "number" ? Math.max(0, Math.min(1, record.score)) : 0;
333
+ if (score < this.cfg.minScore) {
334
+ continue;
335
+ }
336
+
337
+ const type = normalizeEntityType(record.entity_group ?? record.entity);
338
+ const canonicalUri = buildCanonicalEntityUri(type, entityText);
339
+ const current = deduped.get(canonicalUri);
340
+ if (!current || score > current.score) {
341
+ deduped.set(canonicalUri, {
342
+ text: entityText,
343
+ type,
344
+ score,
345
+ canonicalUri,
346
+ });
347
+ }
348
+ }
349
+
350
+ return Array.from(deduped.values())
351
+ .sort((a, b) => b.score - a.score)
352
+ .slice(0, this.cfg.maxEntitiesPerMemory);
353
+ }
354
+ }
package/src/extract.ts CHANGED
@@ -3,6 +3,8 @@ import type { MemoryBraidConfig } from "./config.js";
3
3
  import { MemoryBraidLogger } from "./logger.js";
4
4
  import type { ExtractedCandidate } from "./types.js";
5
5
 
6
+ type MlProvider = "openai" | "anthropic" | "gemini";
7
+
6
8
  const HEURISTIC_PATTERNS = [
7
9
  /remember|remember that|keep in mind|note that/i,
8
10
  /i prefer|prefer to|don't like|do not like|hate|love/i,
@@ -145,14 +147,11 @@ function parseJsonObjectArray(raw: string): Array<Record<string, unknown>> {
145
147
  }
146
148
 
147
149
  async function callMlEnrichment(params: {
148
- provider: "openai" | "anthropic" | "gemini";
150
+ provider: MlProvider;
149
151
  model: string;
150
152
  timeoutMs: number;
151
153
  candidates: ExtractedCandidate[];
152
154
  }): Promise<Array<Record<string, unknown>>> {
153
- const controller = new AbortController();
154
- const timer = setTimeout(() => controller.abort(), params.timeoutMs);
155
-
156
155
  const prompt = [
157
156
  "Classify the memory candidates.",
158
157
  "Return ONLY JSON array.",
@@ -160,6 +159,52 @@ async function callMlEnrichment(params: {
160
159
  "Category one of: preference, decision, fact, task, other.",
161
160
  JSON.stringify(params.candidates.map((candidate, index) => ({ index, text: candidate.text }))),
162
161
  ].join("\n");
162
+ return callMlJson({
163
+ provider: params.provider,
164
+ model: params.model,
165
+ timeoutMs: params.timeoutMs,
166
+ prompt,
167
+ });
168
+ }
169
+
170
+ async function callMlExtraction(params: {
171
+ provider: MlProvider;
172
+ model: string;
173
+ timeoutMs: number;
174
+ maxItems: number;
175
+ messages: Array<{ role: string; text: string }>;
176
+ }): Promise<Array<Record<string, unknown>>> {
177
+ const recent = params.messages.slice(-30).map((item) => ({
178
+ role: item.role,
179
+ text: item.text,
180
+ }));
181
+
182
+ const prompt = [
183
+ "Extract durable user memories from this conversation.",
184
+ "Return ONLY JSON array.",
185
+ "Each item: {text:string, category:string, score:number}.",
186
+ "Category one of: preference, decision, fact, task, other.",
187
+ "Keep each text concise and atomic.",
188
+ `Maximum items: ${params.maxItems}.`,
189
+ JSON.stringify(recent),
190
+ ].join("\n");
191
+
192
+ return callMlJson({
193
+ provider: params.provider,
194
+ model: params.model,
195
+ timeoutMs: params.timeoutMs,
196
+ prompt,
197
+ });
198
+ }
199
+
200
+ async function callMlJson(params: {
201
+ provider: MlProvider;
202
+ model: string;
203
+ timeoutMs: number;
204
+ prompt: string;
205
+ }): Promise<Array<Record<string, unknown>>> {
206
+ const controller = new AbortController();
207
+ const timer = setTimeout(() => controller.abort(), params.timeoutMs);
163
208
 
164
209
  try {
165
210
  if (params.provider === "openai") {
@@ -183,7 +228,7 @@ async function callMlEnrichment(params: {
183
228
  },
184
229
  {
185
230
  role: "user",
186
- content: prompt,
231
+ content: params.prompt,
187
232
  },
188
233
  ],
189
234
  }),
@@ -212,7 +257,7 @@ async function callMlEnrichment(params: {
212
257
  model: params.model,
213
258
  max_tokens: 1000,
214
259
  temperature: 0,
215
- messages: [{ role: "user", content: prompt }],
260
+ messages: [{ role: "user", content: params.prompt }],
216
261
  }),
217
262
  signal: controller.signal,
218
263
  });
@@ -236,7 +281,7 @@ async function callMlEnrichment(params: {
236
281
  },
237
282
  body: JSON.stringify({
238
283
  generationConfig: { temperature: 0 },
239
- contents: [{ role: "user", parts: [{ text: prompt }] }],
284
+ contents: [{ role: "user", parts: [{ text: params.prompt }] }],
240
285
  }),
241
286
  signal: controller.signal,
242
287
  },
@@ -251,6 +296,19 @@ async function callMlEnrichment(params: {
251
296
  }
252
297
  }
253
298
 
299
+ function normalizeCategory(value: unknown, fallback: ExtractedCandidate["category"] = "other"): ExtractedCandidate["category"] {
300
+ if (
301
+ value === "preference" ||
302
+ value === "decision" ||
303
+ value === "fact" ||
304
+ value === "task" ||
305
+ value === "other"
306
+ ) {
307
+ return value;
308
+ }
309
+ return fallback;
310
+ }
311
+
254
312
  function applyMlResult(
255
313
  candidates: ExtractedCandidate[],
256
314
  result: Array<Record<string, unknown>>,
@@ -282,14 +340,7 @@ function applyMlResult(
282
340
  if (!keep) {
283
341
  continue;
284
342
  }
285
- const category =
286
- ml.category === "preference" ||
287
- ml.category === "decision" ||
288
- ml.category === "fact" ||
289
- ml.category === "task" ||
290
- ml.category === "other"
291
- ? (ml.category as ExtractedCandidate["category"])
292
- : candidate.category;
343
+ const category = normalizeCategory(ml.category, candidate.category);
293
344
  const score = typeof ml.score === "number" ? Math.max(0, Math.min(1, ml.score)) : candidate.score;
294
345
  out.push({
295
346
  ...candidate,
@@ -301,6 +352,39 @@ function applyMlResult(
301
352
  return out;
302
353
  }
303
354
 
355
+ function applyMlExtractionResult(
356
+ result: Array<Record<string, unknown>>,
357
+ maxItems: number,
358
+ ): ExtractedCandidate[] {
359
+ const out: ExtractedCandidate[] = [];
360
+ const seen = new Set<string>();
361
+
362
+ for (const item of result) {
363
+ const rawText = typeof item.text === "string" ? item.text : "";
364
+ const text = normalizeWhitespace(rawText);
365
+ if (!text || text.length < 20 || text.length > 3000) {
366
+ continue;
367
+ }
368
+ const key = sha256(normalizeForHash(text));
369
+ if (seen.has(key)) {
370
+ continue;
371
+ }
372
+ seen.add(key);
373
+
374
+ out.push({
375
+ text,
376
+ category: normalizeCategory(item.category),
377
+ score: typeof item.score === "number" ? Math.max(0, Math.min(1, item.score)) : 0.5,
378
+ source: "ml",
379
+ });
380
+ if (out.length >= maxItems) {
381
+ break;
382
+ }
383
+ }
384
+
385
+ return out;
386
+ }
387
+
304
388
  export async function extractCandidates(params: {
305
389
  messages: unknown[];
306
390
  cfg: MemoryBraidConfig;
@@ -308,43 +392,86 @@ export async function extractCandidates(params: {
308
392
  runId?: string;
309
393
  }): Promise<ExtractedCandidate[]> {
310
394
  const normalized = normalizeMessages(params.messages);
311
- const heuristic = pickHeuristicCandidates(normalized, params.cfg.capture.ml.maxItemsPerRun);
395
+ const heuristic = pickHeuristicCandidates(normalized, params.cfg.capture.maxItemsPerRun);
312
396
 
313
397
  params.log.debug("memory_braid.capture.extract", {
314
398
  runId: params.runId,
399
+ mode: params.cfg.capture.mode,
400
+ maxItemsPerRun: params.cfg.capture.maxItemsPerRun,
315
401
  totalMessages: normalized.length,
316
402
  heuristicCandidates: heuristic.length,
317
403
  });
318
404
 
319
- if (
320
- params.cfg.capture.extraction.mode !== "heuristic_plus_ml" ||
321
- !params.cfg.capture.ml.provider ||
322
- !params.cfg.capture.ml.model
323
- ) {
405
+ if (params.cfg.capture.mode === "local") {
406
+ params.log.debug("memory_braid.capture.mode", {
407
+ runId: params.runId,
408
+ mode: params.cfg.capture.mode,
409
+ decision: "heuristic_only",
410
+ candidates: heuristic.length,
411
+ });
412
+ return heuristic;
413
+ }
414
+
415
+ if (!params.cfg.capture.ml.provider || !params.cfg.capture.ml.model) {
416
+ params.log.warn("memory_braid.capture.ml", {
417
+ runId: params.runId,
418
+ reason: "missing_provider_or_model",
419
+ mode: params.cfg.capture.mode,
420
+ hasProvider: Boolean(params.cfg.capture.ml.provider),
421
+ hasModel: Boolean(params.cfg.capture.ml.model),
422
+ fallback: "heuristic",
423
+ candidates: heuristic.length,
424
+ });
324
425
  return heuristic;
325
426
  }
326
427
 
327
428
  try {
328
- const ml = await callMlEnrichment({
429
+ if (params.cfg.capture.mode === "hybrid") {
430
+ const ml = await callMlEnrichment({
431
+ provider: params.cfg.capture.ml.provider,
432
+ model: params.cfg.capture.ml.model,
433
+ timeoutMs: params.cfg.capture.ml.timeoutMs,
434
+ candidates: heuristic,
435
+ });
436
+ const enriched = applyMlResult(heuristic, ml);
437
+ params.log.debug("memory_braid.capture.ml", {
438
+ runId: params.runId,
439
+ mode: params.cfg.capture.mode,
440
+ provider: params.cfg.capture.ml.provider,
441
+ model: params.cfg.capture.ml.model,
442
+ requested: heuristic.length,
443
+ returned: ml.length,
444
+ enriched: enriched.length,
445
+ fallbackUsed: ml.length === 0,
446
+ });
447
+ return enriched;
448
+ }
449
+
450
+ const mlExtractedRaw = await callMlExtraction({
329
451
  provider: params.cfg.capture.ml.provider,
330
452
  model: params.cfg.capture.ml.model,
331
453
  timeoutMs: params.cfg.capture.ml.timeoutMs,
332
- candidates: heuristic,
454
+ maxItems: params.cfg.capture.maxItemsPerRun,
455
+ messages: normalized,
333
456
  });
334
- const enriched = applyMlResult(heuristic, ml);
457
+ const mlExtracted = applyMlExtractionResult(mlExtractedRaw, params.cfg.capture.maxItemsPerRun);
335
458
  params.log.debug("memory_braid.capture.ml", {
336
459
  runId: params.runId,
460
+ mode: params.cfg.capture.mode,
337
461
  provider: params.cfg.capture.ml.provider,
338
462
  model: params.cfg.capture.ml.model,
339
- requested: heuristic.length,
340
- returned: ml.length,
341
- enriched: enriched.length,
463
+ returned: mlExtractedRaw.length,
464
+ extracted: mlExtracted.length,
465
+ fallbackUsed: mlExtracted.length === 0,
342
466
  });
343
- return enriched;
467
+ return mlExtracted.length > 0 ? mlExtracted : heuristic;
344
468
  } catch (err) {
345
469
  params.log.warn("memory_braid.capture.ml", {
346
470
  runId: params.runId,
471
+ mode: params.cfg.capture.mode,
347
472
  error: err instanceof Error ? err.message : String(err),
473
+ fallback: "heuristic",
474
+ candidates: heuristic.length,
348
475
  });
349
476
  return heuristic;
350
477
  }