@susu-eng/gralkor 27.0.0 → 27.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/server/main.py CHANGED
@@ -7,6 +7,7 @@ import logging
7
7
  import os
8
8
  import time
9
9
  from contextlib import asynccontextmanager
10
+ from copy import deepcopy
10
11
  from dataclasses import dataclass, field
11
12
  from datetime import datetime, timezone
12
13
  from typing import Any, Literal
@@ -23,6 +24,7 @@ from graphiti_core.driver.falkordb_driver import FalkorDriver
23
24
  from graphiti_core.edges import EntityEdge
24
25
  from graphiti_core.nodes import EpisodicNode, EpisodeType
25
26
  from graphiti_core.llm_client import LLMConfig
27
+ from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_CROSS_ENCODER
26
28
 
27
29
 
28
30
  # ── Config ────────────────────────────────────────────────────
@@ -134,16 +136,14 @@ def _build_ontology(
134
136
  dict[tuple[str, str], list[str]] | None,
135
137
  list[str] | None,
136
138
  ]:
137
- """Build ontology from config. Returns (entity_types, edge_types, edge_type_map, excluded)."""
139
+ """Build ontology from config. Returns (entity_types, edge_types, edge_type_map)."""
138
140
  raw = cfg.get("ontology")
139
141
  if not raw:
140
- return None, None, None, None
142
+ return None, None, None
141
143
 
142
144
  entity_defs = raw.get("entities") or {}
143
145
  edge_defs = raw.get("edges") or {}
144
146
  edge_map_raw = raw.get("edgeMap") or {}
145
- excluded_raw = raw.get("excludedEntityTypes")
146
-
147
147
  entity_types = _build_type_defs(entity_defs) if entity_defs else None
148
148
  edge_types = _build_type_defs(edge_defs) if edge_defs else None
149
149
 
@@ -154,12 +154,10 @@ def _build_ontology(
154
154
  parts = key.split(",")
155
155
  edge_type_map[(parts[0], parts[1])] = values
156
156
 
157
- excluded = list(excluded_raw) if excluded_raw else None
158
-
159
- if not entity_types and not edge_types and not edge_type_map and not excluded:
160
- return None, None, None, None
157
+ if not entity_types and not edge_types and not edge_type_map:
158
+ return None, None, None
161
159
 
162
- return entity_types, edge_types, edge_type_map, excluded
160
+ return entity_types, edge_types, edge_type_map
163
161
 
164
162
 
165
163
  def _log_falkordblite_diagnostics(error: Exception) -> None:
@@ -200,12 +198,11 @@ graphiti: Graphiti | None = None
200
198
  ontology_entity_types: dict[str, type[BaseModel]] | None = None
201
199
  ontology_edge_types: dict[str, type[BaseModel]] | None = None
202
200
  ontology_edge_type_map: dict[tuple[str, str], list[str]] | None = None
203
- ontology_excluded: list[str] | None = None
204
201
 
205
202
 
206
203
  @asynccontextmanager
207
204
  async def lifespan(_app: FastAPI):
208
- global graphiti, ontology_entity_types, ontology_edge_types, ontology_edge_type_map, ontology_excluded
205
+ global graphiti, ontology_entity_types, ontology_edge_types, ontology_edge_type_map
209
206
  cfg = _load_config()
210
207
 
211
208
  falkordb_uri = os.getenv("FALKORDB_URI")
@@ -260,7 +257,7 @@ async def lifespan(_app: FastAPI):
260
257
  handler.setFormatter(logging.Formatter("%(message)s"))
261
258
  logger.addHandler(handler)
262
259
 
263
- ontology_entity_types, ontology_edge_types, ontology_edge_type_map, ontology_excluded = _build_ontology(cfg)
260
+ ontology_entity_types, ontology_edge_types, ontology_edge_type_map = _build_ontology(cfg)
264
261
  if ontology_entity_types or ontology_edge_types:
265
262
  entity_names = list(ontology_entity_types or {})
266
263
  edge_names = list(ontology_edge_types or {})
@@ -282,10 +279,10 @@ def _find_rate_limit_error(exc: Exception) -> Exception | None:
282
279
  seen: set[int] = set()
283
280
  while current is not None and id(current) not in seen:
284
281
  seen.add(id(current))
285
- # Match openai.RateLimitError, anthropic.RateLimitError, etc.
286
- if type(current).__name__ == "RateLimitError" or (
287
- hasattr(current, "status_code") and getattr(current, "status_code", None) == 429
288
- ):
282
+ # Match openai.RateLimitError, anthropic.RateLimitError, google.genai.errors.ClientError, etc.
283
+ # Note: Google's APIError uses .code, most others use .status_code.
284
+ http_code = getattr(current, "status_code", None) or getattr(current, "code", None)
285
+ if type(current).__name__ == "RateLimitError" or http_code == 429:
289
286
  return current
290
287
  current = current.__cause__ or current.__context__
291
288
  return None
@@ -386,6 +383,7 @@ class SearchRequest(BaseModel):
386
383
  query: str
387
384
  group_ids: list[str]
388
385
  num_results: int = 10
386
+ mode: Literal["fast", "slow"] = "fast"
389
387
 
390
388
 
391
389
  class GroupIdRequest(BaseModel):
@@ -399,6 +397,15 @@ def _ts(dt: datetime | None) -> str | None:
399
397
  return dt.isoformat() if dt else None
400
398
 
401
399
 
400
+ def _serialize_node(node) -> dict[str, Any]:
401
+ return {
402
+ "uuid": node.uuid,
403
+ "name": node.name,
404
+ "summary": node.summary,
405
+ "group_id": node.group_id,
406
+ }
407
+
408
+
402
409
  def _serialize_fact(edge: EntityEdge) -> dict[str, Any]:
403
410
  return {
404
411
  "uuid": edge.uuid,
@@ -592,7 +599,7 @@ async def add_episode(req: AddEpisodeRequest):
592
599
  entity_types=ontology_entity_types,
593
600
  edge_types=ontology_edge_types,
594
601
  edge_type_map=ontology_edge_type_map,
595
- excluded_entity_types=ontology_excluded,
602
+ excluded_entity_types=None,
596
603
  )
597
604
  duration_ms = (time.monotonic() - t0) * 1000
598
605
  episode = result.episode
@@ -634,7 +641,7 @@ async def ingest_messages(req: IngestMessagesRequest):
634
641
  entity_types=ontology_entity_types,
635
642
  edge_types=ontology_edge_types,
636
643
  edge_type_map=ontology_edge_type_map,
637
- excluded_entity_types=ontology_excluded,
644
+ excluded_entity_types=None,
638
645
  )
639
646
  duration_ms = (time.monotonic() - t0) * 1000
640
647
  episode = result.episode
@@ -673,33 +680,10 @@ def _ensure_driver_graph(group_ids: list[str] | None) -> None:
673
680
  print(f"[gralkor] driver graph routed: {target}", flush=True)
674
681
 
675
682
 
676
- def _prioritize_facts(
677
- edges: list[EntityEdge], limit: int, reserved_ratio: float = 0.7,
678
- ) -> list[EntityEdge]:
679
- """Reserve slots for valid facts, fill the rest by relevance.
680
-
681
- First ~70% of slots are reserved for valid facts (no invalid_at).
682
- Remaining slots are filled from whatever Graphiti ranked highest
683
- among the leftovers — valid or not — preserving relevance scoring.
684
- """
685
- reserved_count = max(1, round(limit * reserved_ratio))
686
-
687
- reserved: list[EntityEdge] = []
688
- rest: list[EntityEdge] = []
689
- for e in edges:
690
- if len(reserved) < reserved_count and e.invalid_at is None:
691
- reserved.append(e)
692
- else:
693
- rest.append(e)
694
-
695
- remainder_count = limit - len(reserved)
696
- return reserved + rest[:remainder_count]
697
-
698
-
699
683
  @app.post("/search")
700
684
  async def search(req: SearchRequest):
701
- logger.info("[gralkor] search — query:%d chars group_ids:%s num_results:%d",
702
- len(req.query), req.group_ids, req.num_results)
685
+ logger.info("[gralkor] search — mode:%s query:%d chars group_ids:%s num_results:%d",
686
+ req.mode, len(req.query), req.group_ids, req.num_results)
703
687
  # graphiti.add_episode() clones the driver to target the correct FalkorDB
704
688
  # named graph (database=group_id), but graphiti.search() does not — it just
705
689
  # uses whatever graph the driver currently points at. Before the first
@@ -707,26 +691,38 @@ async def search(req: SearchRequest):
707
691
  # searches return 0 results. Fix: route to the correct graph here.
708
692
  _ensure_driver_graph(req.group_ids)
709
693
  t0 = time.monotonic()
710
- # Over-fetch to compensate for expired facts that will be deprioritized.
711
- fetch_limit = req.num_results * 2
712
694
  try:
713
- edges = await graphiti.search(
714
- query=_sanitize_query(req.query),
715
- group_ids=req.group_ids,
716
- num_results=fetch_limit,
717
- )
695
+ if req.mode == "slow":
696
+ # Cross-encoder + BFS: higher quality, also returns entity node summaries.
697
+ # deepcopy required — COMBINED_HYBRID_SEARCH_CROSS_ENCODER is a module-level
698
+ # constant; mutating .limit directly would corrupt it across requests.
699
+ config = deepcopy(COMBINED_HYBRID_SEARCH_CROSS_ENCODER)
700
+ config.limit = req.num_results
701
+ search_result = await graphiti.search_(
702
+ query=_sanitize_query(req.query),
703
+ group_ids=req.group_ids,
704
+ config=config,
705
+ )
706
+ edges = search_result.edges
707
+ nodes = search_result.nodes
708
+ else:
709
+ edges = await graphiti.search(
710
+ query=_sanitize_query(req.query),
711
+ group_ids=req.group_ids,
712
+ num_results=req.num_results,
713
+ )
714
+ nodes = []
718
715
  except Exception as e:
719
716
  duration_ms = (time.monotonic() - t0) * 1000
720
- logger.error("[gralkor] search failed — %.0fms: %s", duration_ms, e)
717
+ logger.error("[gralkor] search failed — mode:%s %.0fms: %s", req.mode, duration_ms, e)
721
718
  raise
722
719
  duration_ms = (time.monotonic() - t0) * 1000
723
- prioritized = _prioritize_facts(edges, req.num_results)
724
- valid_count = sum(1 for e in prioritized if e.invalid_at is None)
725
- result = [_serialize_fact(e) for e in prioritized]
726
- logger.info("[gralkor] search result — %d facts (%d valid, %d non-valid) from %d fetched %.0fms",
727
- len(prioritized), valid_count, len(prioritized) - valid_count, len(edges), duration_ms)
720
+ result = [_serialize_fact(e) for e in edges]
721
+ serialized_nodes = [_serialize_node(n) for n in nodes]
722
+ logger.info("[gralkor] search result mode:%s %d facts %d nodes %.0fms",
723
+ req.mode, len(result), len(serialized_nodes), duration_ms)
728
724
  logger.debug("[gralkor] search facts: %s", result)
729
- return {"facts": result}
725
+ return {"facts": result, "nodes": serialized_nodes}
730
726
 
731
727
 
732
728