@geravant/sinain 1.22.8 → 1.23.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -328,6 +328,70 @@ def _cooccurring_entities(
328
328
  return ranked[:max_entities]
329
329
 
330
330
 
331
+ _SEMANTIC_CACHE: dict = {} # {"db_path": {"names": [...], "embs": ndarray, "ts": float}}
332
+
333
+
334
+ def _expand_keywords_semantic(
335
+ keywords: list[str],
336
+ db_path: str,
337
+ threshold: float = 0.50,
338
+ max_expansions: int = 3,
339
+ ) -> list[str]:
340
+ """Expand keywords with semantically similar entity names from the graph.
341
+
342
+ "AI" → ["ai", "machine-learning", "ai-agents", ...]. Caches model + entity
343
+ embeddings for fast repeated calls (<50ms after first load).
344
+ """
345
+ import time as _t
346
+ try:
347
+ from sentence_transformers import SentenceTransformer
348
+ import numpy as np
349
+ from triplestore import TripleStore
350
+
351
+ if not hasattr(_expand_keywords_semantic, "_model"):
352
+ _expand_keywords_semantic._model = SentenceTransformer("all-MiniLM-L6-v2")
353
+ model = _expand_keywords_semantic._model
354
+
355
+ # Cache entity names + embeddings (refresh every 5 min)
356
+ cache = _SEMANTIC_CACHE.get(db_path)
357
+ if not cache or _t.time() - cache["ts"] > 300:
358
+ store = TripleStore(db_path)
359
+ entity_names = [n for eid, n in store.entities_with_attr("name")
360
+ if eid.startswith("entity:") and len(n) >= 4]
361
+ store.close()
362
+ if not entity_names:
363
+ return keywords
364
+ entity_embs = model.encode(entity_names, show_progress_bar=False)
365
+ _SEMANTIC_CACHE[db_path] = {"names": entity_names, "embs": entity_embs, "ts": _t.time()}
366
+ cache = _SEMANTIC_CACHE[db_path]
367
+
368
+ entity_names = cache["names"]
369
+ entity_embs = cache["embs"]
370
+
371
+ kw_embs = model.encode(keywords, show_progress_bar=False)
372
+
373
+ expanded = list(keywords)
374
+ for i, kw in enumerate(keywords):
375
+ # Skip expansion for very short keywords — embeddings are unreliable
376
+ # for abbreviations like "ml", "ai" (use community detection instead)
377
+ if len(kw) < 4:
378
+ continue
379
+ sims = []
380
+ for j, name in enumerate(entity_names):
381
+ if name == kw or name in expanded:
382
+ continue
383
+ sim = float(np.dot(kw_embs[i], entity_embs[j]) /
384
+ (np.linalg.norm(kw_embs[i]) * np.linalg.norm(entity_embs[j]) + 1e-9))
385
+ if sim >= threshold:
386
+ sims.append((name, sim))
387
+ sims.sort(key=lambda x: -x[1])
388
+ expanded.extend(name for name, _ in sims[:max_expansions])
389
+
390
+ return expanded
391
+ except (ImportError, Exception):
392
+ return keywords
393
+
394
+
331
395
  def query_facts_hybrid(
332
396
  db_path: str,
333
397
  query: str,
@@ -342,15 +406,32 @@ def query_facts_hybrid(
342
406
  import time
343
407
  keywords = [w.lower() for w in re.findall(r"[a-zA-Z][a-zA-Z0-9-]+", query) if len(w) > 2]
344
408
 
345
- # Entity graph pre-filter: find facts linked to mentioned entities via backrefs.
346
- # Used to BOOST relevant facts in RRF, not as a separate tier (avoids dilution).
409
+ # Change 0: Semantic entity expansion "ML" ["ml", "machine-learning", "ai", ...]
410
+ expanded_keywords = keywords
411
+ if len(keywords) >= 1:
412
+ expanded_keywords = _expand_keywords_semantic(keywords, db_path)
413
+
414
+ # Entity graph pre-filter with per-entity tracking for intersection (Change A)
347
415
  graph_fact_ids: set[str] = set()
416
+ graph_intersection: set[str] = set()
348
417
  community_fact_ids: set[str] = set()
349
- for kw in keywords:
418
+ per_entity_facts: dict[str, set[str]] = {}
419
+ for kw in expanded_keywords:
420
+ kw_facts: set[str] = set()
350
421
  for f in query_facts_by_entity_graph(db_path, kw, max_facts=50):
351
422
  eid = f.get("entity_id", "")
352
423
  if eid:
424
+ kw_facts.add(eid)
353
425
  graph_fact_ids.add(eid)
426
+ if kw_facts:
427
+ per_entity_facts[kw] = kw_facts
428
+
429
+ # Compute intersection: facts linked to ALL original query keywords
430
+ if len(per_entity_facts) >= 2:
431
+ try:
432
+ graph_intersection = set.intersection(*per_entity_facts.values())
433
+ except TypeError:
434
+ pass
354
435
 
355
436
  # Community expansion: follow mentions edges to find related entities
356
437
  t0 = time.monotonic()
@@ -359,14 +440,14 @@ def query_facts_hybrid(
359
440
  store = TripleStore(db_path)
360
441
 
361
442
  matched_entities = set()
362
- for kw in keywords:
443
+ for kw in expanded_keywords:
363
444
  node_id = f"entity:{kw}"
364
445
  if store.entity(node_id):
365
446
  matched_entities.add(kw)
366
447
 
367
448
  for ent in matched_entities:
368
449
  if time.monotonic() - t0 > 0.5:
369
- break # timing guard
450
+ break
370
451
  community = expand_entity_community(store, ent, max_related=3)
371
452
  for related_name, _count in community:
372
453
  for f in query_facts_by_entity_graph(db_path, related_name, max_facts=20):
@@ -378,12 +459,50 @@ def query_facts_hybrid(
378
459
  except Exception:
379
460
  pass
380
461
 
381
- # Run three retrieval methods independently
462
+ # Run retrieval methods independently
382
463
  candidate_limit = max_facts * 3
383
- fts_results = query_facts_fts(db_path, query, max_facts=candidate_limit)
384
- tag_results = query_facts_by_entities(db_path, keywords, max_facts=candidate_limit) if keywords else []
464
+
465
+ # Change C: FTS5 AND mode for multi-keyword queries
466
+ if len(keywords) > 1:
467
+ fts_and_query = " AND ".join(keywords)
468
+ fts_results = query_facts_fts(db_path, fts_and_query, max_facts=candidate_limit)
469
+ if len(fts_results) < candidate_limit:
470
+ fts_or = query_facts_fts(db_path, " OR ".join(keywords), max_facts=candidate_limit)
471
+ fts_results.extend(fts_or)
472
+ else:
473
+ fts_results = query_facts_fts(db_path, query, max_facts=candidate_limit)
474
+
475
+ tag_results = query_facts_by_entities(db_path, expanded_keywords, max_facts=candidate_limit) if expanded_keywords else []
385
476
  top_results = query_top_facts(db_path, limit=candidate_limit)
386
477
 
478
+ # Change B: Tag intersection tier (facts tagged with ALL keywords)
479
+ intersection_results: list[dict] = []
480
+ if len(keywords) >= 2:
481
+ try:
482
+ from triplestore import TripleStore
483
+ _istore = TripleStore(db_path)
484
+ placeholders = ",".join("?" for _ in keywords)
485
+ rows = _istore._conn.execute(
486
+ f"""SELECT entity_id, COUNT(DISTINCT value) as matches
487
+ FROM triples WHERE attribute = 'tag' AND NOT retracted
488
+ AND value IN ({placeholders})
489
+ GROUP BY entity_id HAVING COUNT(DISTINCT value) >= ?
490
+ ORDER BY matches DESC LIMIT ?""",
491
+ (*keywords, len(keywords), candidate_limit),
492
+ ).fetchall()
493
+ for r in rows:
494
+ fid = r["entity_id"]
495
+ attrs = _istore.entity(fid)
496
+ if attrs and "value" in attrs:
497
+ fact = {"entity_id": fid}
498
+ for attr_name, values in attrs.items():
499
+ if attr_name != "tag":
500
+ fact[attr_name] = values[0] if len(values) == 1 else values
501
+ intersection_results.append(fact)
502
+ _istore.close()
503
+ except Exception:
504
+ pass
505
+
387
506
  # Build ranked lists by entity_id
388
507
  def _ranked_ids(facts: list[dict]) -> list[str]:
389
508
  seen = set()
@@ -398,41 +517,58 @@ def query_facts_hybrid(
398
517
  fts_ranked = _ranked_ids(fts_results)
399
518
  tag_ranked = _ranked_ids(tag_results)
400
519
  top_ranked = _ranked_ids(top_results)
520
+ intersection_ranked = _ranked_ids(intersection_results)
401
521
 
402
522
  # Reciprocal Rank Fusion: RRF(d) = Σ 1/(k + rank_i(d))
403
- K = 60 # standard RRF constant
523
+ K = 60
404
524
  rrf_scores: dict[str, float] = {}
405
- for ranked_list in [fts_ranked, tag_ranked, top_ranked]:
525
+ tiers = [fts_ranked, tag_ranked, top_ranked]
526
+ if intersection_ranked:
527
+ tiers.append(intersection_ranked)
528
+ for ranked_list in tiers:
406
529
  for rank, eid in enumerate(ranked_list):
407
530
  rrf_scores[eid] = rrf_scores.get(eid, 0.0) + 1.0 / (K + rank)
408
531
 
409
- # Co-occurrence boost: use FTS/tag results to find temporally related entities
410
- import time as _time
411
- _t_cooccur = _time.monotonic()
412
- query_matched_ids = {f.get("entity_id", "") for f in fts_results + tag_results if f.get("entity_id")}
413
- if query_matched_ids and _time.monotonic() - _t_cooccur < 0.3:
532
+ # Change D: Session co-occurrence for multi-entity queries
533
+ if len(keywords) >= 2 and time.monotonic() - t0 < 1.0:
414
534
  try:
415
535
  from triplestore import TripleStore
416
- _store = TripleStore(db_path)
417
- cooccur = _cooccurring_entities(_store, query_matched_ids, max_entities=5)
418
- for ent_name in cooccur:
419
- for f in query_facts_by_entity_graph(db_path, ent_name, max_facts=10):
420
- eid = f.get("entity_id", "")
421
- if eid and eid not in graph_fact_ids:
536
+ _sstore = TripleStore(db_path)
537
+ # Find sessions where facts about BOTH keywords exist
538
+ kw_a, kw_b = keywords[0], keywords[1]
539
+ sess_rows = _sstore._conn.execute(
540
+ """SELECT DISTINCT t1.value as ts FROM triples t1
541
+ JOIN triples t2 ON t2.attribute='first_seen' AND t2.value=t1.value AND t2.retracted=0
542
+ WHERE t1.attribute='first_seen' AND t1.retracted=0
543
+ AND t1.entity_id IN (SELECT entity_id FROM triples WHERE attribute='tag' AND value=? AND NOT retracted)
544
+ AND t2.entity_id IN (SELECT entity_id FROM triples WHERE attribute='tag' AND value=? AND NOT retracted)
545
+ LIMIT 10""",
546
+ (kw_a, kw_b),
547
+ ).fetchall()
548
+ if sess_rows:
549
+ ts_values = [r[0] for r in sess_rows]
550
+ ph = ",".join("?" for _ in ts_values)
551
+ fact_rows = _sstore._conn.execute(
552
+ f"SELECT DISTINCT entity_id FROM triples WHERE attribute='first_seen' AND value IN ({ph}) AND NOT retracted AND entity_id LIKE 'fact:%' LIMIT 30",
553
+ ts_values,
554
+ ).fetchall()
555
+ for r in fact_rows:
556
+ eid = r[0]
557
+ if eid not in graph_fact_ids:
422
558
  community_fact_ids.add(eid)
423
- _store.close()
559
+ _sstore.close()
424
560
  except Exception:
425
561
  pass
426
562
 
427
- # Graph boost: facts linked to mentioned entities via backrefs get priority
428
- # +0.05 is significant vs RRF scores of ~0.015-0.033 — ensures entity-linked facts
429
- # rank above FTS noise in large graphs (100K+ triples)
430
- if graph_fact_ids or community_fact_ids:
563
+ # Graph boost with intersection bonus (Change A continued)
564
+ if graph_fact_ids or community_fact_ids or graph_intersection:
431
565
  for eid in rrf_scores:
432
- if eid in graph_fact_ids:
566
+ if eid in graph_intersection:
567
+ rrf_scores[eid] += 0.10 # intersection: linked to ALL queried entities
568
+ elif eid in graph_fact_ids:
433
569
  rrf_scores[eid] += 0.05 # direct graph-linked facts
434
570
  elif eid in community_fact_ids:
435
- rrf_scores[eid] += 0.025 # community-expanded facts (half weight)
571
+ rrf_scores[eid] += 0.025 # community-expanded facts
436
572
 
437
573
  # Apply confidence decay as secondary signal (fresh facts rank above stale ones)
438
574
  from triplestore import decayed_confidence
@@ -462,11 +598,30 @@ def query_facts_hybrid(
462
598
  if eid and eid not in fact_map:
463
599
  fact_map[eid] = f
464
600
 
465
- # Return top RRF candidates. Embedding re-ranking is done by the caller
466
- # (sinain-core Node.js) to avoid deadlock — the Python subprocess can't call
467
- # back to sinain-core's /embed endpoint while sinain-core is blocked waiting
468
- # for the subprocess.
469
- results = [fact_map[eid] for eid in sorted_ids[:max_facts] if eid in fact_map]
601
+ # Return top RRF candidates, optionally re-ranked by embedding similarity.
602
+ # When called from sinain-core subprocess, embedding re-ranking happens in
603
+ # Node.js (to avoid deadlock). When called standalone (benchmark, CLI),
604
+ # we re-rank in-process if sentence-transformers is available.
605
+ rrf_candidates = [fact_map[eid] for eid in sorted_ids[:max_facts * 2] if eid in fact_map]
606
+
607
+ results = rrf_candidates[:max_facts]
608
+ try:
609
+ from sentence_transformers import SentenceTransformer
610
+ import numpy as np
611
+ if not hasattr(query_facts_hybrid, "_embed_model"):
612
+ query_facts_hybrid._embed_model = SentenceTransformer("all-MiniLM-L6-v2")
613
+ model = query_facts_hybrid._embed_model
614
+ texts = [query] + [f.get("value", "") for f in rrf_candidates]
615
+ embs = model.encode(texts, show_progress_bar=False)
616
+ q_emb = embs[0]
617
+ scored = []
618
+ for i, f in enumerate(rrf_candidates):
619
+ sim = float(np.dot(q_emb, embs[i + 1]) / (np.linalg.norm(q_emb) * np.linalg.norm(embs[i + 1]) + 1e-9))
620
+ scored.append((sim, f))
621
+ scored.sort(key=lambda x: -x[0])
622
+ results = [f for _, f in scored[:max_facts]]
623
+ except ImportError:
624
+ pass # sentence-transformers not installed — use RRF order
470
625
 
471
626
  # Expand top results with 1-hop graph neighbors
472
627
  if results and len(results) < max_facts: