@shadowforge0/aquifer-memory 0.6.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,8 +29,19 @@ const DEFAULTS = {
29
29
  maxRetries: 3,
30
30
  temperature: 0,
31
31
  },
32
- entities: { enabled: false, mergeCall: true },
32
+ entities: { enabled: false, mergeCall: true, scope: 'default' },
33
33
  rank: { rrf: 0.65, timeDecay: 0.25, access: 0.10, entityBoost: 0.18 },
34
+ rerank: {
35
+ enabled: false,
36
+ provider: null, // 'tei' | 'jina' | 'custom'
37
+ baseUrl: null, // TEI base URL
38
+ apiKey: null, // Jina API key
39
+ model: null, // Jina model override
40
+ topK: 20,
41
+ maxChars: 1600,
42
+ timeoutMs: 2000,
43
+ maxRetries: 1,
44
+ },
34
45
  };
35
46
 
36
47
  // ---------------------------------------------------------------------------
@@ -57,6 +68,15 @@ const ENV_MAP = [
57
68
  ['AQUIFER_LLM_TIMEOUT_MS', 'llm.timeoutMs', Number],
58
69
  ['AQUIFER_LLM_TEMPERATURE', 'llm.temperature', Number],
59
70
  ['AQUIFER_ENTITIES_ENABLED', 'entities.enabled', Boolean],
71
+ ['AQUIFER_ENTITY_SCOPE', 'entities.scope'],
72
+ ['AQUIFER_RERANK_ENABLED', 'rerank.enabled', Boolean],
73
+ ['AQUIFER_RERANK_PROVIDER', 'rerank.provider'],
74
+ ['AQUIFER_RERANK_BASE_URL', 'rerank.baseUrl'],
75
+ ['AQUIFER_RERANK_API_KEY', 'rerank.apiKey'],
76
+ ['AQUIFER_RERANK_MODEL', 'rerank.model'],
77
+ ['AQUIFER_RERANK_TOP_K', 'rerank.topK', Number],
78
+ ['AQUIFER_RERANK_MAX_CHARS', 'rerank.maxChars', Number],
79
+ ['AQUIFER_RERANK_TIMEOUT_MS','rerank.timeoutMs', Number],
60
80
  ];
61
81
 
62
82
  // ---------------------------------------------------------------------------
@@ -1,7 +1,7 @@
1
1
  'use strict';
2
2
 
3
3
  const { Pool } = require('pg');
4
- const { createAquifer, createEmbedder } = require('../../index');
4
+ const { createAquifer, createEmbedder, createReranker } = require('../../index');
5
5
  const { loadConfig } = require('./config');
6
6
  const { createLlmFn } = require('./llm');
7
7
 
@@ -57,6 +57,24 @@ function createAquiferFromConfig(overrides) {
57
57
  llmFn = createLlmFn(config.llm);
58
58
  }
59
59
 
60
+ // Rerank config (optional)
61
+ let rerankOpts = null;
62
+ if (config.rerank && config.rerank.enabled && config.rerank.provider) {
63
+ const rc = config.rerank;
64
+ const rerankConfig = { provider: rc.provider, topK: rc.topK, maxChars: rc.maxChars };
65
+ if (rc.provider === 'tei') {
66
+ rerankConfig.teiBaseUrl = rc.baseUrl || 'http://localhost:8080';
67
+ rerankConfig.timeout = rc.timeoutMs || 2000;
68
+ rerankConfig.maxRetries = rc.maxRetries ?? 1;
69
+ } else if (rc.provider === 'jina') {
70
+ rerankConfig.jinaApiKey = rc.apiKey;
71
+ if (rc.model) rerankConfig.jinaModel = rc.model;
72
+ rerankConfig.timeout = rc.timeoutMs || 2000;
73
+ rerankConfig.maxRetries = rc.maxRetries ?? 1;
74
+ }
75
+ rerankOpts = rerankConfig;
76
+ }
77
+
60
78
  const aquifer = createAquifer({
61
79
  db: pool,
62
80
  schema: config.schema,
@@ -65,6 +83,7 @@ function createAquiferFromConfig(overrides) {
65
83
  llm: llmFn ? { fn: llmFn } : null,
66
84
  entities: config.entities,
67
85
  rank: config.rank,
86
+ rerank: rerankOpts,
68
87
  });
69
88
 
70
89
  // Attach pool for lifecycle management
package/core/aquifer.js CHANGED
@@ -36,6 +36,24 @@ function loadSql(filename, schema) {
36
36
  return raw.replace(/\$\{schema\}/g, qi(schema));
37
37
  }
38
38
 
39
+ // ---------------------------------------------------------------------------
40
+ // buildRerankDocument — assemble text for cross-encoder reranking
41
+ // ---------------------------------------------------------------------------
42
+
43
+ function buildRerankDocument(row, maxChars) {
44
+ let text = (row.summary_text || row.summary_snippet || '').replace(/\s+/g, ' ').trim();
45
+ const turn = (row.matched_turn_text || '').replace(/\s+/g, ' ').trim();
46
+
47
+ if (!text) {
48
+ text = turn;
49
+ } else if (turn && !text.includes(turn)) {
50
+ text = `${text}\n\nMatched turn:\n${turn}`;
51
+ }
52
+
53
+ if (text.length > maxChars) text = text.slice(0, maxChars);
54
+ return text;
55
+ }
56
+
39
57
  // ---------------------------------------------------------------------------
40
58
  // createAquifer
41
59
  // ---------------------------------------------------------------------------
@@ -81,6 +99,9 @@ function createAquifer(config) {
81
99
  const entityPromptFn = config.entities && config.entities.prompt ? config.entities.prompt : null;
82
100
  const entityScope = (config.entities && config.entities.scope) || 'default';
83
101
 
102
+ // FTS config (default: 'simple'; set to 'zhcfg' for Chinese tokenization)
103
+ const ftsConfig = config.ftsConfig || 'simple';
104
+
84
105
  // Rank weights
85
106
  const rankWeights = {
86
107
  rrf: 0.65,
@@ -90,6 +111,16 @@ function createAquifer(config) {
90
111
  ...(config.rank || {}),
91
112
  };
92
113
 
114
+ // Reranker config (optional)
115
+ const rerankConfig = config.rerank || null;
116
+ let reranker = null;
117
+ if (rerankConfig) {
118
+ const { createReranker } = require('../pipeline/rerank');
119
+ reranker = createReranker(rerankConfig);
120
+ }
121
+ const defaultRerankTopK = rerankConfig ? Math.max(1, rerankConfig.topK || 20) : 0;
122
+ const rerankMaxChars = rerankConfig ? Math.max(200, rerankConfig.maxChars || 1600) : 0;
123
+
93
124
  // Source registry (in-memory)
94
125
  const sources = new Map();
95
126
 
@@ -106,7 +137,7 @@ function createAquifer(config) {
106
137
 
107
138
  // --- Helper: embed search on summaries ---
108
139
  async function embeddingSearchSummaries(queryVec, opts) {
109
- const { agentId, source, dateFrom, dateTo, limit = 20 } = opts;
140
+ const { agentIds, source, dateFrom, dateTo, limit = 20 } = opts;
110
141
  const where = [`s.tenant_id = $1`];
111
142
  const params = [tenantId];
112
143
 
@@ -121,9 +152,9 @@ function createAquifer(config) {
121
152
  params.push(dateTo);
122
153
  where.push(`($${params.length}::date IS NULL OR s.started_at::date <= $${params.length}::date)`);
123
154
  }
124
- if (agentId) {
125
- params.push(agentId);
126
- where.push(`s.agent_id = $${params.length}`);
155
+ if (agentIds && agentIds.length > 0) {
156
+ params.push(agentIds);
157
+ where.push(`s.agent_id = ANY($${params.length})`);
127
158
  }
128
159
  if (source) {
129
160
  params.push(source);
@@ -524,6 +555,7 @@ function createAquifer(config) {
524
555
 
525
556
  const {
526
557
  agentId,
558
+ agentIds: rawAgentIds,
527
559
  source,
528
560
  dateFrom,
529
561
  dateTo,
@@ -533,6 +565,12 @@ function createAquifer(config) {
533
565
  entityMode = 'any',
534
566
  } = opts;
535
567
 
568
+ // Normalize agentId/agentIds into a single resolved value
569
+ // agentIds takes precedence; agentId is sugar for agentIds: [agentId]
570
+ const resolvedAgentIds = rawAgentIds && rawAgentIds.length > 0
571
+ ? rawAgentIds
572
+ : (agentId ? [agentId] : null);
573
+
536
574
  // Validate before touching DB
537
575
  if (explicitEntities && explicitEntities.length > 0 && !entitiesEnabled) {
538
576
  throw new Error('Entities are not enabled');
@@ -540,7 +578,9 @@ function createAquifer(config) {
540
578
 
541
579
  await ensureMigrated();
542
580
 
543
- const fetchLimit = limit * 4;
581
+ const rerankEnabled = !!reranker && opts.rerank !== false;
582
+ const rerankTopK = rerankEnabled ? Math.max(limit, opts.rerankTopK || defaultRerankTopK) : limit;
583
+ const fetchLimit = rerankTopK * 4;
544
584
 
545
585
  // 1. Embed query
546
586
  const queryVecResult = await embedFn([query]);
@@ -624,13 +664,13 @@ function createAquifer(config) {
624
664
  // 3. Run 3 search paths in parallel
625
665
  const [ftsRows, embRows, turnResult] = await Promise.all([
626
666
  storage.searchSessions(pool, query, {
627
- schema, tenantId, agentId, source, dateFrom, dateTo, limit: fetchLimit,
667
+ schema, tenantId, agentIds: resolvedAgentIds, source, dateFrom, dateTo, limit: fetchLimit, ftsConfig,
628
668
  }).catch(() => []),
629
669
  embeddingSearchSummaries(queryVec, {
630
- agentId, source, dateFrom, dateTo, limit: fetchLimit,
670
+ agentIds: resolvedAgentIds, source, dateFrom, dateTo, limit: fetchLimit,
631
671
  }).catch(() => []),
632
672
  storage.searchTurnEmbeddings(pool, {
633
- schema, tenantId, queryVec, dateFrom, dateTo, agentId, source, limit: fetchLimit,
673
+ schema, tenantId, queryVec, dateFrom, dateTo, agentIds: resolvedAgentIds, source, limit: fetchLimit,
634
674
  }).catch(() => ({ rows: [] })),
635
675
  ]);
636
676
 
@@ -691,15 +731,45 @@ function createAquifer(config) {
691
731
  [...filteredEmb, ...filterFn(externalRows)],
692
732
  filteredTurn,
693
733
  {
694
- limit,
734
+ limit: rerankTopK,
695
735
  weights: mergedWeights,
696
736
  entityScoreBySession,
697
737
  openLoopSet,
698
738
  },
699
739
  );
700
740
 
741
+ // 6b. Rerank (optional)
742
+ let finalRanked = ranked;
743
+ if (rerankEnabled && ranked.length > 1) {
744
+ try {
745
+ const docs = ranked.map(r => buildRerankDocument(r, rerankMaxChars));
746
+ const rerankResult = await reranker.rerank(query, docs, { topN: ranked.length });
747
+ const scoreMap = new Map(rerankResult.map(r => [r.index, r.score]));
748
+
749
+ finalRanked = ranked.map((r, i) => ({
750
+ ...r,
751
+ _hybridScore: r._score,
752
+ _rerankScore: scoreMap.has(i) ? scoreMap.get(i) : null,
753
+ }));
754
+
755
+ finalRanked.sort((a, b) => {
756
+ const aR = a._rerankScore ?? -Infinity;
757
+ const bR = b._rerankScore ?? -Infinity;
758
+ if (aR !== bR) return bR - aR;
759
+ return (b._hybridScore || 0) - (a._hybridScore || 0);
760
+ });
761
+ finalRanked = finalRanked.slice(0, limit);
762
+ } catch (rerankErr) {
763
+ // Fallback: use original hybrid-rank order, flag in debug
764
+ if (process.env.AQUIFER_DEBUG) console.error('[aquifer] rerank error:', rerankErr.message);
765
+ finalRanked = ranked.slice(0, limit).map(r => ({ ...r, _rerankFallback: true }));
766
+ }
767
+ } else {
768
+ finalRanked = ranked.slice(0, limit);
769
+ }
770
+
701
771
  // 7. Record access
702
- const sessionRowIds = ranked
772
+ const sessionRowIds = finalRanked
703
773
  .map(r => r.id || r.session_row_id)
704
774
  .filter(Boolean);
705
775
 
@@ -710,7 +780,7 @@ function createAquifer(config) {
710
780
  }
711
781
 
712
782
  // 8. Format results
713
- return ranked.map(r => ({
783
+ return finalRanked.map(r => ({
714
784
  sessionId: r.session_id,
715
785
  agentId: r.agent_id,
716
786
  source: r.source,
@@ -720,7 +790,7 @@ function createAquifer(config) {
720
790
  summarySnippet: r.summary_snippet || null,
721
791
  matchedTurnText: r.matched_turn_text || null,
722
792
  matchedTurnIndex: r.matched_turn_index || null,
723
- score: r._score,
793
+ score: r._rerankScore ?? r._score,
724
794
  trustScore: r._trustScore ?? 0.5,
725
795
  _debug: {
726
796
  rrf: r._rrf,
@@ -730,6 +800,9 @@ function createAquifer(config) {
730
800
  trustScore: r._trustScore,
731
801
  trustMultiplier: r._trustMultiplier,
732
802
  openLoopBoost: r._openLoopBoost,
803
+ hybridScore: r._hybridScore ?? r._score,
804
+ rerankScore: r._rerankScore ?? null,
805
+ rerankFallback: r._rerankFallback || false,
733
806
  },
734
807
  }));
735
808
  },
package/core/storage.js CHANGED
@@ -331,12 +331,46 @@ async function searchSessions(pool, query, {
331
331
  schema,
332
332
  tenantId,
333
333
  agentId,
334
+ agentIds: rawAgentIds,
334
335
  source,
335
336
  dateFrom, // m1: add date filtering
336
337
  dateTo,
337
338
  limit = 20,
339
+ ftsConfig = 'simple',
338
340
  } = {}) {
339
341
  const clampedLimit = Math.max(1, Math.min(100, limit));
342
+ // Sanitize ftsConfig to prevent SQL injection (must be a valid regconfig name)
343
+ const safeFts = /^[a-zA-Z_][a-zA-Z0-9_]*$/.test(ftsConfig) ? ftsConfig : 'simple';
344
+
345
+ // Normalize agentId/agentIds
346
+ const agentIds = rawAgentIds && rawAgentIds.length > 0
347
+ ? rawAgentIds
348
+ : (agentId ? [agentId] : null);
349
+
350
+ const where = [
351
+ `ss.search_tsv @@ plainto_tsquery('${safeFts}', $1)`,
352
+ `s.tenant_id = $2`,
353
+ ];
354
+ const params = [query, tenantId];
355
+
356
+ if (agentIds) {
357
+ params.push(agentIds);
358
+ where.push(`s.agent_id = ANY($${params.length})`);
359
+ }
360
+ if (source) {
361
+ params.push(source);
362
+ where.push(`s.source = $${params.length}`);
363
+ }
364
+ if (dateFrom) {
365
+ params.push(dateFrom);
366
+ where.push(`s.started_at::date >= $${params.length}::date`);
367
+ }
368
+ if (dateTo) {
369
+ params.push(dateTo);
370
+ where.push(`s.started_at::date <= $${params.length}::date`);
371
+ }
372
+ params.push(clampedLimit);
373
+
340
374
  const result = await pool.query(
341
375
  `SELECT
342
376
  s.id,
@@ -351,19 +385,14 @@ async function searchSessions(pool, query, {
351
385
  ss.access_count,
352
386
  ss.last_accessed_at,
353
387
  ss.trust_score,
354
- ts_headline('simple', COALESCE(ss.summary_text, ''), plainto_tsquery('simple', $1)) AS summary_snippet,
355
- ts_rank(ss.search_tsv, plainto_tsquery('simple', $1)) AS fts_rank
388
+ ts_headline('${safeFts}', COALESCE(ss.summary_text, ''), plainto_tsquery('${safeFts}', $1)) AS summary_snippet,
389
+ ts_rank(ss.search_tsv, plainto_tsquery('${safeFts}', $1)) AS fts_rank
356
390
  FROM ${qi(schema)}.sessions s
357
391
  LEFT JOIN ${qi(schema)}.session_summaries ss ON ss.session_row_id = s.id
358
- WHERE ss.search_tsv @@ plainto_tsquery('simple', $1)
359
- AND s.tenant_id = $2
360
- AND ($3::text IS NULL OR s.agent_id = $3)
361
- AND ($4::text IS NULL OR s.source = $4)
362
- AND ($5::date IS NULL OR s.started_at::date >= $5::date)
363
- AND ($6::date IS NULL OR s.started_at::date <= $6::date)
392
+ WHERE ${where.join(' AND ')}
364
393
  ORDER BY fts_rank DESC, s.last_message_at DESC NULLS LAST
365
- LIMIT $7`,
366
- [query, tenantId, agentId || null, source || null, dateFrom || null, dateTo || null, clampedLimit]
394
+ LIMIT $${params.length}`,
395
+ params
367
396
  );
368
397
  return result.rows;
369
398
  }
@@ -479,23 +508,29 @@ async function searchTurnEmbeddings(pool, {
479
508
  dateFrom,
480
509
  dateTo,
481
510
  agentId,
511
+ agentIds: rawAgentIds,
482
512
  source,
483
513
  limit = 15,
484
514
  }) {
485
515
  const where = ['s.tenant_id = $1'];
486
516
  const params = [tenantId];
487
517
 
518
+ // Normalize agentId/agentIds
519
+ const agentIds = rawAgentIds && rawAgentIds.length > 0
520
+ ? rawAgentIds
521
+ : (agentId ? [agentId] : null);
522
+
488
523
  if (dateFrom) {
489
524
  params.push(dateFrom);
490
- where.push(`($${params.length}::date IS NULL OR s.started_at::date >= $${params.length}::date)`);
525
+ where.push(`s.started_at::date >= $${params.length}::date`);
491
526
  }
492
527
  if (dateTo) {
493
528
  params.push(dateTo);
494
- where.push(`($${params.length}::date IS NULL OR s.started_at::date <= $${params.length}::date)`);
529
+ where.push(`s.started_at::date <= $${params.length}::date`);
495
530
  }
496
- if (agentId) {
497
- params.push(agentId);
498
- where.push(`t.agent_id = $${params.length}`);
531
+ if (agentIds) {
532
+ params.push(agentIds);
533
+ where.push(`t.agent_id = ANY($${params.length})`);
499
534
  }
500
535
  if (source) {
501
536
  params.push(source);
package/index.js CHANGED
@@ -2,5 +2,6 @@
2
2
 
3
3
  const { createAquifer } = require('./core/aquifer');
4
4
  const { createEmbedder } = require('./pipeline/embed');
5
+ const { createReranker } = require('./pipeline/rerank');
5
6
 
6
- module.exports = { createAquifer, createEmbedder };
7
+ module.exports = { createAquifer, createEmbedder, createReranker };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@shadowforge0/aquifer-memory",
3
- "version": "0.6.0",
3
+ "version": "0.7.0",
4
4
  "description": "PG-native long-term memory for AI agents. Turn-level embedding, hybrid RRF ranking, optional knowledge graph. Includes CLI, MCP server, and OpenClaw plugin.",
5
5
  "main": "index.js",
6
6
  "files": [
@@ -0,0 +1,67 @@
1
+ 'use strict';
2
+
3
+ const http = require('http');
4
+ const https = require('https');
5
+
6
+ // ---------------------------------------------------------------------------
7
+ // HTTP helpers (shared by embed.js and rerank.js)
8
+ // ---------------------------------------------------------------------------
9
+
10
+ function httpRequest(url, options, body) {
11
+ return new Promise((resolve, reject) => {
12
+ const parsedUrl = new URL(url);
13
+ const transport = parsedUrl.protocol === 'https:' ? https : http;
14
+
15
+ // M8 fix: settled flag to prevent double-settle on timeout race
16
+ let settled = false;
17
+ const finish = (fn, val) => { if (!settled) { settled = true; fn(val); } };
18
+
19
+ const req = transport.request(parsedUrl, options, (res) => {
20
+ const chunks = [];
21
+ res.on('data', (chunk) => chunks.push(chunk));
22
+ res.on('end', () => {
23
+ if (timer) clearTimeout(timer);
24
+ const raw = Buffer.concat(chunks).toString();
25
+ if (res.statusCode < 200 || res.statusCode >= 300) {
26
+ finish(reject, new Error(`HTTP ${res.statusCode}: ${raw.slice(0, 500)}`));
27
+ return;
28
+ }
29
+ try {
30
+ finish(resolve, JSON.parse(raw));
31
+ } catch (e) {
32
+ finish(reject, new Error(`Invalid JSON response: ${raw.slice(0, 200)}`));
33
+ }
34
+ });
35
+ });
36
+
37
+ const timer = options.timeout
38
+ ? setTimeout(() => { req.destroy(); finish(reject, new Error('Request timeout')); }, options.timeout)
39
+ : null;
40
+
41
+ req.on('error', (err) => { if (timer) clearTimeout(timer); finish(reject, err); });
42
+ if (body) req.write(JSON.stringify(body));
43
+ req.end();
44
+ });
45
+ }
46
+
47
+ // ---------------------------------------------------------------------------
48
+ // Retry wrapper
49
+ // ---------------------------------------------------------------------------
50
+
51
+ async function withRetry(fn, { maxRetries = 3, initialBackoffMs = 2000 } = {}) {
52
+ let lastErr;
53
+ for (let attempt = 0; attempt < maxRetries; attempt++) {
54
+ try {
55
+ return await fn();
56
+ } catch (err) {
57
+ lastErr = err;
58
+ if (attempt < maxRetries - 1) {
59
+ const delay = initialBackoffMs * Math.pow(2, attempt);
60
+ await new Promise(r => setTimeout(r, delay));
61
+ }
62
+ }
63
+ }
64
+ throw lastErr;
65
+ }
66
+
67
+ module.exports = { httpRequest, withRetry };
package/pipeline/embed.js CHANGED
@@ -1,68 +1,6 @@
1
1
  'use strict';
2
2
 
3
- const http = require('http');
4
- const https = require('https');
5
-
6
- // ---------------------------------------------------------------------------
7
- // HTTP helpers
8
- // ---------------------------------------------------------------------------
9
-
10
- function httpRequest(url, options, body) {
11
- return new Promise((resolve, reject) => {
12
- const parsedUrl = new URL(url);
13
- const transport = parsedUrl.protocol === 'https:' ? https : http;
14
-
15
- // M8 fix: settled flag to prevent double-settle on timeout race
16
- let settled = false;
17
- const finish = (fn, val) => { if (!settled) { settled = true; fn(val); } };
18
-
19
- const req = transport.request(parsedUrl, options, (res) => {
20
- const chunks = [];
21
- res.on('data', (chunk) => chunks.push(chunk));
22
- res.on('end', () => {
23
- if (timer) clearTimeout(timer);
24
- const raw = Buffer.concat(chunks).toString();
25
- if (res.statusCode < 200 || res.statusCode >= 300) {
26
- finish(reject, new Error(`HTTP ${res.statusCode}: ${raw.slice(0, 500)}`));
27
- return;
28
- }
29
- try {
30
- finish(resolve, JSON.parse(raw));
31
- } catch (e) {
32
- finish(reject, new Error(`Invalid JSON response: ${raw.slice(0, 200)}`));
33
- }
34
- });
35
- });
36
-
37
- const timer = options.timeout
38
- ? setTimeout(() => { req.destroy(); finish(reject, new Error('Request timeout')); }, options.timeout)
39
- : null;
40
-
41
- req.on('error', (err) => { if (timer) clearTimeout(timer); finish(reject, err); });
42
- if (body) req.write(JSON.stringify(body));
43
- req.end();
44
- });
45
- }
46
-
47
- // ---------------------------------------------------------------------------
48
- // Retry wrapper
49
- // ---------------------------------------------------------------------------
50
-
51
- async function withRetry(fn, { maxRetries = 3, initialBackoffMs = 2000 } = {}) {
52
- let lastErr;
53
- for (let attempt = 0; attempt < maxRetries; attempt++) {
54
- try {
55
- return await fn();
56
- } catch (err) {
57
- lastErr = err;
58
- if (attempt < maxRetries - 1) {
59
- const delay = initialBackoffMs * Math.pow(2, attempt);
60
- await new Promise(r => setTimeout(r, delay));
61
- }
62
- }
63
- }
64
- throw lastErr;
65
- }
3
+ const { httpRequest, withRetry } = require('./_http');
66
4
 
67
5
  // ---------------------------------------------------------------------------
68
6
  // Ollama adapter
@@ -0,0 +1,161 @@
1
+ 'use strict';
2
+
3
+ const { httpRequest, withRetry } = require('./_http');
4
+
5
+ // ---------------------------------------------------------------------------
6
+ // Custom adapter
7
+ // ---------------------------------------------------------------------------
8
+
9
+ function validateResults(results) {
10
+ return results.filter(r =>
11
+ r && typeof r.index === 'number' && Number.isFinite(r.index)
12
+ && typeof r.score === 'number' && Number.isFinite(r.score)
13
+ );
14
+ }
15
+
16
+ function createCustomReranker(config) {
17
+ const fn = config.fn;
18
+ if (!fn) throw new Error('fn is required for custom reranker');
19
+
20
+ return {
21
+ async rerank(query, documents, opts = {}) {
22
+ if (!query || !documents || documents.length === 0) return [];
23
+ const topN = opts.topN || documents.length;
24
+ const results = await fn({ query, documents, topN });
25
+ if (!Array.isArray(results)) throw new Error('Custom reranker fn must return an array');
26
+ return validateResults(results).sort((a, b) => b.score - a.score);
27
+ },
28
+ };
29
+ }
30
+
31
+ // ---------------------------------------------------------------------------
32
+ // TEI adapter (HuggingFace Text Embeddings Inference)
33
+ // ---------------------------------------------------------------------------
34
+
35
+ function createTEIReranker(config) {
36
+ const baseUrl = (config.teiBaseUrl || config.baseUrl || 'http://localhost:8080').replace(/\/+$/, '');
37
+ const timeout = config.timeout || 2000;
38
+ const maxRetries = config.maxRetries ?? 1;
39
+ const initialBackoffMs = config.initialBackoffMs || 250;
40
+
41
+ return {
42
+ async rerank(query, documents, opts = {}) {
43
+ if (!query || !documents || documents.length === 0) return [];
44
+
45
+ const result = await withRetry(
46
+ () => httpRequest(`${baseUrl}/rerank`, {
47
+ method: 'POST',
48
+ headers: { 'Content-Type': 'application/json' },
49
+ timeout,
50
+ }, { query, texts: documents, raw_scores: false }),
51
+ { maxRetries, initialBackoffMs },
52
+ );
53
+
54
+ // TEI returns array of { index, score }
55
+ const arr = Array.isArray(result) ? result : [];
56
+ return validateResults(arr.map(r => ({ index: r.index, score: r.score })))
57
+ .sort((a, b) => b.score - a.score);
58
+ },
59
+ };
60
+ }
61
+
62
+ // ---------------------------------------------------------------------------
63
+ // Jina adapter
64
+ // ---------------------------------------------------------------------------
65
+
66
+ function createJinaReranker(config) {
67
+ const apiKey = config.jinaApiKey;
68
+ if (!apiKey) throw new Error('jinaApiKey is required for Jina reranker');
69
+
70
+ const model = config.jinaModel || 'jina-reranker-v2-base-multilingual';
71
+ const baseUrl = (config.jinaBaseUrl || 'https://api.jina.ai/v1/rerank').replace(/\/+$/, '');
72
+ const timeout = config.timeout || 2000;
73
+ const maxRetries = config.maxRetries ?? 1;
74
+ const initialBackoffMs = config.initialBackoffMs || 250;
75
+
76
+ return {
77
+ async rerank(query, documents, opts = {}) {
78
+ if (!query || !documents || documents.length === 0) return [];
79
+ const topN = opts.topN || documents.length;
80
+
81
+ const result = await withRetry(
82
+ () => httpRequest(baseUrl, {
83
+ method: 'POST',
84
+ headers: {
85
+ 'Content-Type': 'application/json',
86
+ 'Authorization': `Bearer ${apiKey}`,
87
+ },
88
+ timeout,
89
+ }, { model, query, documents, top_n: topN }),
90
+ { maxRetries, initialBackoffMs },
91
+ );
92
+
93
+ // Jina returns { results: [{ index, relevance_score }] }
94
+ const arr = result.results || [];
95
+ return validateResults(arr.map(r => ({ index: r.index, score: r.relevance_score })))
96
+ .sort((a, b) => b.score - a.score);
97
+ },
98
+ };
99
+ }
100
+
101
+ // ---------------------------------------------------------------------------
102
+ // OpenRouter adapter (Cohere rerank etc. via OpenRouter)
103
+ // ---------------------------------------------------------------------------
104
+
105
+ function createOpenRouterReranker(config) {
106
+ const apiKey = config.openrouterApiKey || config.apiKey;
107
+ if (!apiKey) throw new Error('openrouterApiKey is required for OpenRouter reranker');
108
+
109
+ const model = config.model || 'cohere/rerank-v3.5';
110
+ const baseUrl = (config.openrouterBaseUrl || 'https://openrouter.ai/api/v1/rerank').replace(/\/+$/, '');
111
+ const timeout = config.timeout || 5000;
112
+ const maxRetries = config.maxRetries ?? 1;
113
+ const initialBackoffMs = config.initialBackoffMs || 250;
114
+
115
+ return {
116
+ async rerank(query, documents, opts = {}) {
117
+ if (!query || !documents || documents.length === 0) return [];
118
+ const topN = opts.topN || documents.length;
119
+
120
+ const result = await withRetry(
121
+ () => httpRequest(baseUrl, {
122
+ method: 'POST',
123
+ headers: {
124
+ 'Content-Type': 'application/json',
125
+ 'Authorization': `Bearer ${apiKey}`,
126
+ },
127
+ timeout,
128
+ }, { model, query, documents, top_n: topN }),
129
+ { maxRetries, initialBackoffMs },
130
+ );
131
+
132
+ // OpenRouter returns { results: [{ index, relevance_score }] }
133
+ const arr = result.results || [];
134
+ return validateResults(arr.map(r => ({ index: r.index, score: r.relevance_score })))
135
+ .sort((a, b) => b.score - a.score);
136
+ },
137
+ };
138
+ }
139
+
140
+ // ---------------------------------------------------------------------------
141
+ // Factory
142
+ // ---------------------------------------------------------------------------
143
+
144
+ function createReranker(config = {}) {
145
+ const provider = config.provider || 'custom';
146
+
147
+ switch (provider) {
148
+ case 'custom':
149
+ return createCustomReranker(config);
150
+ case 'tei':
151
+ return createTEIReranker(config);
152
+ case 'jina':
153
+ return createJinaReranker(config);
154
+ case 'openrouter':
155
+ return createOpenRouterReranker(config);
156
+ default:
157
+ throw new Error(`Unknown rerank provider: ${provider}`);
158
+ }
159
+ }
160
+
161
+ module.exports = { createReranker };