graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +70 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +635 -260
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +37 -15
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +92 -48
  24. graphiti_core/llm_client/openai_client.py +39 -9
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +24 -15
  33. graphiti_core/prompts/extract_nodes.py +76 -35
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +110 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/content_chunking.py +702 -0
  45. graphiti_core/utils/datetime_utils.py +13 -0
  46. graphiti_core/utils/maintenance/community_operations.py +62 -38
  47. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  48. graphiti_core/utils/maintenance/edge_operations.py +306 -156
  49. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  50. graphiti_core/utils/maintenance/node_operations.py +466 -206
  51. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  52. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  53. graphiti_core/utils/text_utils.py +53 -0
  54. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
  55. graphiti_core-0.25.3.dist-info/RECORD +87 -0
  56. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
  57. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  58. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  59. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,193 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from abc import ABC, abstractmethod
18
+ from collections.abc import Generator
19
+ from contextlib import AbstractContextManager, contextmanager, suppress
20
+ from typing import TYPE_CHECKING, Any
21
+
22
+ if TYPE_CHECKING:
23
+ from opentelemetry.trace import Span, StatusCode
24
+
25
+ try:
26
+ from opentelemetry.trace import Span, StatusCode
27
+
28
+ OTEL_AVAILABLE = True
29
+ except ImportError:
30
+ OTEL_AVAILABLE = False
31
+
32
+
33
+ class TracerSpan(ABC):
34
+ """Abstract base class for tracer spans."""
35
+
36
+ @abstractmethod
37
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
38
+ """Add attributes to the span."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def set_status(self, status: str, description: str | None = None) -> None:
43
+ """Set the status of the span."""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def record_exception(self, exception: Exception) -> None:
48
+ """Record an exception in the span."""
49
+ pass
50
+
51
+
52
+ class Tracer(ABC):
53
+ """Abstract base class for tracers."""
54
+
55
+ @abstractmethod
56
+ def start_span(self, name: str) -> AbstractContextManager[TracerSpan]:
57
+ """Start a new span with the given name."""
58
+ pass
59
+
60
+
61
+ class NoOpSpan(TracerSpan):
62
+ """No-op span implementation that does nothing."""
63
+
64
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
65
+ pass
66
+
67
+ def set_status(self, status: str, description: str | None = None) -> None:
68
+ pass
69
+
70
+ def record_exception(self, exception: Exception) -> None:
71
+ pass
72
+
73
+
74
+ class NoOpTracer(Tracer):
75
+ """No-op tracer implementation that does nothing."""
76
+
77
+ @contextmanager
78
+ def start_span(self, name: str) -> Generator[NoOpSpan, None, None]:
79
+ """Return a no-op span."""
80
+ yield NoOpSpan()
81
+
82
+
83
+ class OpenTelemetrySpan(TracerSpan):
84
+ """Wrapper for OpenTelemetry span."""
85
+
86
+ def __init__(self, span: 'Span'):
87
+ self._span = span
88
+
89
+ def add_attributes(self, attributes: dict[str, Any]) -> None:
90
+ """Add attributes to the OpenTelemetry span."""
91
+ try:
92
+ # Filter out None values and convert all values to appropriate types
93
+ filtered_attrs = {}
94
+ for key, value in attributes.items():
95
+ if value is not None:
96
+ # Convert to string if not a primitive type
97
+ if isinstance(value, str | int | float | bool):
98
+ filtered_attrs[key] = value
99
+ else:
100
+ filtered_attrs[key] = str(value)
101
+
102
+ if filtered_attrs:
103
+ self._span.set_attributes(filtered_attrs)
104
+ except Exception:
105
+ # Silently ignore tracing errors
106
+ pass
107
+
108
+ def set_status(self, status: str, description: str | None = None) -> None:
109
+ """Set the status of the OpenTelemetry span."""
110
+ try:
111
+ if OTEL_AVAILABLE:
112
+ if status == 'error':
113
+ self._span.set_status(StatusCode.ERROR, description)
114
+ elif status == 'ok':
115
+ self._span.set_status(StatusCode.OK, description)
116
+ except Exception:
117
+ # Silently ignore tracing errors
118
+ pass
119
+
120
+ def record_exception(self, exception: Exception) -> None:
121
+ """Record an exception in the OpenTelemetry span."""
122
+ with suppress(Exception):
123
+ self._span.record_exception(exception)
124
+
125
+
126
+ class OpenTelemetryTracer(Tracer):
127
+ """Wrapper for OpenTelemetry tracer with configurable span name prefix."""
128
+
129
+ def __init__(self, tracer: Any, span_prefix: str = 'graphiti'):
130
+ """
131
+ Initialize the OpenTelemetry tracer wrapper.
132
+
133
+ Parameters
134
+ ----------
135
+ tracer : opentelemetry.trace.Tracer
136
+ The OpenTelemetry tracer instance.
137
+ span_prefix : str, optional
138
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
139
+ """
140
+ if not OTEL_AVAILABLE:
141
+ raise ImportError(
142
+ 'OpenTelemetry is not installed. Install it with: pip install opentelemetry-api'
143
+ )
144
+ self._tracer = tracer
145
+ self._span_prefix = span_prefix.rstrip('.')
146
+
147
+ @contextmanager
148
+ def start_span(self, name: str) -> Generator[OpenTelemetrySpan | NoOpSpan, None, None]:
149
+ """Start a new OpenTelemetry span with the configured prefix."""
150
+ try:
151
+ full_name = f'{self._span_prefix}.{name}'
152
+ with self._tracer.start_as_current_span(full_name) as span:
153
+ yield OpenTelemetrySpan(span)
154
+ except Exception:
155
+ # If tracing fails, yield a no-op span to prevent breaking the operation
156
+ yield NoOpSpan()
157
+
158
+
159
+ def create_tracer(otel_tracer: Any | None = None, span_prefix: str = 'graphiti') -> Tracer:
160
+ """
161
+ Create a tracer instance.
162
+
163
+ Parameters
164
+ ----------
165
+ otel_tracer : opentelemetry.trace.Tracer | None, optional
166
+ An OpenTelemetry tracer instance. If None, a no-op tracer is returned.
167
+ span_prefix : str, optional
168
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
169
+
170
+ Returns
171
+ -------
172
+ Tracer
173
+ A tracer instance (either OpenTelemetryTracer or NoOpTracer).
174
+
175
+ Examples
176
+ --------
177
+ Using with OpenTelemetry:
178
+
179
+ >>> from opentelemetry import trace
180
+ >>> otel_tracer = trace.get_tracer(__name__)
181
+ >>> tracer = create_tracer(otel_tracer, span_prefix='myapp.graphiti')
182
+
183
+ Using no-op tracer:
184
+
185
+ >>> tracer = create_tracer() # Returns NoOpTracer
186
+ """
187
+ if otel_tracer is None:
188
+ return NoOpTracer()
189
+
190
+ if not OTEL_AVAILABLE:
191
+ return NoOpTracer()
192
+
193
+ return OpenTelemetryTracer(otel_tracer, span_prefix)
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  import typing
19
20
  from datetime import datetime
@@ -22,22 +23,31 @@ import numpy as np
22
23
  from pydantic import BaseModel, Field
23
24
  from typing_extensions import Any
24
25
 
25
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
26
+ from graphiti_core.driver.driver import (
27
+ GraphDriver,
28
+ GraphDriverSession,
29
+ GraphProvider,
30
+ )
26
31
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
27
32
  from graphiti_core.embedder import EmbedderClient
28
- from graphiti_core.graph_queries import (
29
- get_entity_edge_save_bulk_query,
30
- get_entity_node_save_bulk_query,
31
- )
32
33
  from graphiti_core.graphiti_types import GraphitiClients
33
34
  from graphiti_core.helpers import normalize_l2, semaphore_gather
34
35
  from graphiti_core.models.edges.edge_db_queries import (
35
- EPISODIC_EDGE_SAVE_BULK,
36
+ get_entity_edge_save_bulk_query,
37
+ get_episodic_edge_save_bulk_query,
36
38
  )
37
39
  from graphiti_core.models.nodes.node_db_queries import (
38
- EPISODIC_NODE_SAVE_BULK,
40
+ get_entity_node_save_bulk_query,
41
+ get_episode_node_save_bulk_query,
42
+ )
43
+ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
44
+ from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
45
+ from graphiti_core.utils.maintenance.dedup_helpers import (
46
+ DedupResolutionState,
47
+ _build_candidate_indexes,
48
+ _normalize_string_exact,
49
+ _resolve_with_similarity,
39
50
  )
40
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
41
51
  from graphiti_core.utils.maintenance.edge_operations import (
42
52
  extract_edges,
43
53
  resolve_extracted_edge,
@@ -56,6 +66,38 @@ logger = logging.getLogger(__name__)
56
66
  CHUNK_SIZE = 10
57
67
 
58
68
 
69
+ def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
70
+ """Collapse alias -> canonical chains while preserving direction.
71
+
72
+ The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
73
+ union-find with iterative path compression to ensure every source UUID resolves to its ultimate
74
+ canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
75
+ """
76
+
77
+ parent: dict[str, str] = {}
78
+
79
+ def find(uuid: str) -> str:
80
+ """Directed union-find lookup using iterative path compression."""
81
+ parent.setdefault(uuid, uuid)
82
+ root = uuid
83
+ while parent[root] != root:
84
+ root = parent[root]
85
+
86
+ while parent[uuid] != root:
87
+ next_uuid = parent[uuid]
88
+ parent[uuid] = root
89
+ uuid = next_uuid
90
+
91
+ return root
92
+
93
+ for source_uuid, target_uuid in pairs:
94
+ parent.setdefault(source_uuid, source_uuid)
95
+ parent.setdefault(target_uuid, target_uuid)
96
+ parent[find(source_uuid)] = find(target_uuid)
97
+
98
+ return {uuid: find(uuid) for uuid in parent}
99
+
100
+
59
101
  class RawEpisode(BaseModel):
60
102
  name: str
61
103
  uuid: str | None = Field(default=None)
@@ -118,24 +160,33 @@ async def add_nodes_and_edges_bulk_tx(
118
160
  episodes = [dict(episode) for episode in episodic_nodes]
119
161
  for episode in episodes:
120
162
  episode['source'] = str(episode['source'].value)
121
- nodes: list[dict[str, Any]] = []
163
+ episode.pop('labels', None)
164
+
165
+ nodes = []
166
+
122
167
  for node in entity_nodes:
123
168
  if node.name_embedding is None:
124
169
  await node.generate_name_embedding(embedder)
170
+
125
171
  entity_data: dict[str, Any] = {
126
172
  'uuid': node.uuid,
127
173
  'name': node.name,
128
- 'name_embedding': node.name_embedding,
129
174
  'group_id': node.group_id,
130
175
  'summary': node.summary,
131
176
  'created_at': node.created_at,
177
+ 'name_embedding': node.name_embedding,
178
+ 'labels': list(set(node.labels + ['Entity'])),
132
179
  }
133
180
 
134
- entity_data.update(node.attributes or {})
135
- entity_data['labels'] = list(set(node.labels + ['Entity']))
181
+ if driver.provider == GraphProvider.KUZU:
182
+ attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
183
+ entity_data['attributes'] = json.dumps(attributes)
184
+ else:
185
+ entity_data.update(node.attributes or {})
186
+
136
187
  nodes.append(entity_data)
137
188
 
138
- edges: list[dict[str, Any]] = []
189
+ edges = []
139
190
  for edge in entity_edges:
140
191
  if edge.fact_embedding is None:
141
192
  await edge.generate_embedding(embedder)
@@ -145,35 +196,68 @@ async def add_nodes_and_edges_bulk_tx(
145
196
  'target_node_uuid': edge.target_node_uuid,
146
197
  'name': edge.name,
147
198
  'fact': edge.fact,
148
- 'fact_embedding': edge.fact_embedding,
149
199
  'group_id': edge.group_id,
150
200
  'episodes': edge.episodes,
151
201
  'created_at': edge.created_at,
152
202
  'expired_at': edge.expired_at,
153
203
  'valid_at': edge.valid_at,
154
204
  'invalid_at': edge.invalid_at,
205
+ 'fact_embedding': edge.fact_embedding,
155
206
  }
156
207
 
157
- edge_data.update(edge.attributes or {})
208
+ if driver.provider == GraphProvider.KUZU:
209
+ attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
210
+ edge_data['attributes'] = json.dumps(attributes)
211
+ else:
212
+ edge_data.update(edge.attributes or {})
213
+
158
214
  edges.append(edge_data)
159
215
 
160
- await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
161
- entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
162
- await tx.run(entity_node_save_bulk, nodes=nodes)
163
- await tx.run(
164
- EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
165
- )
166
- entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
167
- await tx.run(entity_edge_save_bulk, entity_edges=edges)
216
+ if driver.graph_operations_interface:
217
+ await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
218
+ await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
219
+ await driver.graph_operations_interface.episodic_edge_save_bulk(
220
+ None, driver, tx, [edge.model_dump() for edge in episodic_edges]
221
+ )
222
+ await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
223
+
224
+ elif driver.provider == GraphProvider.KUZU:
225
+ # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
226
+ episode_query = get_episode_node_save_bulk_query(driver.provider)
227
+ for episode in episodes:
228
+ await tx.run(episode_query, **episode)
229
+ entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
230
+ for node in nodes:
231
+ await tx.run(entity_node_query, **node)
232
+ entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
233
+ for edge in edges:
234
+ await tx.run(entity_edge_query, **edge)
235
+ episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
236
+ for edge in episodic_edges:
237
+ await tx.run(episodic_edge_query, **edge.model_dump())
238
+ else:
239
+ await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
240
+ await tx.run(
241
+ get_entity_node_save_bulk_query(driver.provider, nodes),
242
+ nodes=nodes,
243
+ )
244
+ await tx.run(
245
+ get_episodic_edge_save_bulk_query(driver.provider),
246
+ episodic_edges=[edge.model_dump() for edge in episodic_edges],
247
+ )
248
+ await tx.run(
249
+ get_entity_edge_save_bulk_query(driver.provider),
250
+ entity_edges=edges,
251
+ )
168
252
 
169
253
 
170
254
  async def extract_nodes_and_edges_bulk(
171
255
  clients: GraphitiClients,
172
256
  episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
173
257
  edge_type_map: dict[tuple[str, str], list[str]],
174
- entity_types: dict[str, BaseModel] | None = None,
258
+ entity_types: dict[str, type[BaseModel]] | None = None,
175
259
  excluded_entity_types: list[str] | None = None,
176
- edge_types: dict[str, BaseModel] | None = None,
260
+ edge_types: dict[str, type[BaseModel]] | None = None,
177
261
  ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
178
262
  extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
179
263
  *[
@@ -204,85 +288,113 @@ async def dedupe_nodes_bulk(
204
288
  clients: GraphitiClients,
205
289
  extracted_nodes: list[list[EntityNode]],
206
290
  episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
207
- entity_types: dict[str, BaseModel] | None = None,
291
+ entity_types: dict[str, type[BaseModel]] | None = None,
208
292
  ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
209
- embedder = clients.embedder
210
- min_score = 0.8
211
-
212
- # generate embeddings
213
- await semaphore_gather(
214
- *[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
215
- )
293
+ """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
216
294
 
217
- # Find similar results
218
- dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
219
- for i, nodes_i in enumerate(extracted_nodes):
220
- existing_nodes: list[EntityNode] = []
221
- for j, nodes_j in enumerate(extracted_nodes):
222
- if i == j:
223
- continue
224
- existing_nodes += nodes_j
225
-
226
- candidates_i: list[EntityNode] = []
227
- for node in nodes_i:
228
- for existing_node in existing_nodes:
229
- # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
230
- # This approach will cast a wider net than BM25, which is ideal for this use case
231
- node_words = set(node.name.lower().split())
232
- existing_node_words = set(existing_node.name.lower().split())
233
- has_overlap = not node_words.isdisjoint(existing_node_words)
234
- if has_overlap:
235
- candidates_i.append(existing_node)
236
- continue
237
-
238
- # Check for semantic similarity even if there is no overlap
239
- similarity = np.dot(
240
- normalize_l2(node.name_embedding or []),
241
- normalize_l2(existing_node.name_embedding or []),
242
- )
243
- if similarity >= min_score:
244
- candidates_i.append(existing_node)
245
-
246
- dedupe_tuples.append((nodes_i, candidates_i))
295
+ 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
296
+ reconciled against the live graph just like the non-batch flow.
297
+ 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
298
+ duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
299
+ can apply to edges and persistence.
300
+ """
247
301
 
248
- # Determine Node Resolutions
249
- bulk_node_resolutions: list[
250
- tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
251
- ] = await semaphore_gather(
302
+ first_pass_results = await semaphore_gather(
252
303
  *[
253
304
  resolve_extracted_nodes(
254
305
  clients,
255
- dedupe_tuple[0],
306
+ nodes,
256
307
  episode_tuples[i][0],
257
308
  episode_tuples[i][1],
258
309
  entity_types,
259
- existing_nodes_override=dedupe_tuples[i][1],
260
310
  )
261
- for i, dedupe_tuple in enumerate(dedupe_tuples)
311
+ for i, nodes in enumerate(extracted_nodes)
262
312
  ]
263
313
  )
264
314
 
265
- # Collect all duplicate pairs sorted by uuid
315
+ episode_resolutions: list[tuple[str, list[EntityNode]]] = []
316
+ per_episode_uuid_maps: list[dict[str, str]] = []
266
317
  duplicate_pairs: list[tuple[str, str]] = []
267
- for _, _, duplicates in bulk_node_resolutions:
268
- for duplicate in duplicates:
269
- n, m = duplicate
270
- duplicate_pairs.append((n.uuid, m.uuid))
271
318
 
272
- # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
273
- compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
319
+ for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
320
+ first_pass_results, episode_tuples, strict=True
321
+ ):
322
+ episode_resolutions.append((episode.uuid, resolved_nodes))
323
+ per_episode_uuid_maps.append(uuid_map)
324
+ duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
325
+
326
+ canonical_nodes: dict[str, EntityNode] = {}
327
+ for _, resolved_nodes in episode_resolutions:
328
+ for node in resolved_nodes:
329
+ # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
330
+ # the MinHash index for the accumulated canonical pool each time. The LRU-backed
331
+ # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
332
+ # but if batches grow significantly we should switch to an incremental index or chunked
333
+ # processing.
334
+ if not canonical_nodes:
335
+ canonical_nodes[node.uuid] = node
336
+ continue
274
337
 
275
- node_uuid_map: dict[str, EntityNode] = {
276
- node.uuid: node for nodes in extracted_nodes for node in nodes
277
- }
338
+ existing_candidates = list(canonical_nodes.values())
339
+ normalized = _normalize_string_exact(node.name)
340
+ exact_match = next(
341
+ (
342
+ candidate
343
+ for candidate in existing_candidates
344
+ if _normalize_string_exact(candidate.name) == normalized
345
+ ),
346
+ None,
347
+ )
348
+ if exact_match is not None:
349
+ if exact_match.uuid != node.uuid:
350
+ duplicate_pairs.append((node.uuid, exact_match.uuid))
351
+ continue
352
+
353
+ indexes = _build_candidate_indexes(existing_candidates)
354
+ state = DedupResolutionState(
355
+ resolved_nodes=[None],
356
+ uuid_map={},
357
+ unresolved_indices=[],
358
+ )
359
+ _resolve_with_similarity([node], indexes, state)
360
+
361
+ resolved = state.resolved_nodes[0]
362
+ if resolved is None:
363
+ canonical_nodes[node.uuid] = node
364
+ continue
365
+
366
+ canonical_uuid = resolved.uuid
367
+ canonical_nodes.setdefault(canonical_uuid, resolved)
368
+ if canonical_uuid != node.uuid:
369
+ duplicate_pairs.append((node.uuid, canonical_uuid))
370
+
371
+ union_pairs: list[tuple[str, str]] = []
372
+ for uuid_map in per_episode_uuid_maps:
373
+ union_pairs.extend(uuid_map.items())
374
+ union_pairs.extend(duplicate_pairs)
375
+
376
+ compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
278
377
 
279
378
  nodes_by_episode: dict[str, list[EntityNode]] = {}
280
- for i, nodes in enumerate(extracted_nodes):
281
- episode = episode_tuples[i][0]
379
+ for episode_uuid, resolved_nodes in episode_resolutions:
380
+ deduped_nodes: list[EntityNode] = []
381
+ seen: set[str] = set()
382
+ for node in resolved_nodes:
383
+ canonical_uuid = compressed_map.get(node.uuid, node.uuid)
384
+ if canonical_uuid in seen:
385
+ continue
386
+ seen.add(canonical_uuid)
387
+ canonical_node = canonical_nodes.get(canonical_uuid)
388
+ if canonical_node is None:
389
+ logger.error(
390
+ 'Canonical node %s missing during batch dedupe; falling back to %s',
391
+ canonical_uuid,
392
+ node.uuid,
393
+ )
394
+ canonical_node = node
395
+ deduped_nodes.append(canonical_node)
282
396
 
283
- nodes_by_episode[episode.uuid] = [
284
- node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
285
- ]
397
+ nodes_by_episode[episode_uuid] = deduped_nodes
286
398
 
287
399
  return nodes_by_episode, compressed_map
288
400
 
@@ -292,7 +404,7 @@ async def dedupe_edges_bulk(
292
404
  extracted_edges: list[list[EntityEdge]],
293
405
  episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
294
406
  _entities: list[EntityNode],
295
- edge_types: dict[str, BaseModel],
407
+ edge_types: dict[str, type[BaseModel]],
296
408
  _edge_type_map: dict[tuple[str, str], list[str]],
297
409
  ) -> dict[str, list[EntityEdge]]:
298
410
  embedder = clients.embedder
@@ -307,16 +419,23 @@ async def dedupe_edges_bulk(
307
419
  dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
308
420
  for i, edges_i in enumerate(extracted_edges):
309
421
  existing_edges: list[EntityEdge] = []
310
- for j, edges_j in enumerate(extracted_edges):
311
- if i == j:
312
- continue
422
+ for edges_j in extracted_edges:
313
423
  existing_edges += edges_j
314
424
 
315
425
  for edge in edges_i:
316
426
  candidates: list[EntityEdge] = []
317
427
  for existing_edge in existing_edges:
428
+ # Skip self-comparison
429
+ if edge.uuid == existing_edge.uuid:
430
+ continue
318
431
  # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
319
432
  # This approach will cast a wider net than BM25, which is ideal for this use case
433
+ if (
434
+ edge.source_node_uuid != existing_edge.source_node_uuid
435
+ or edge.target_node_uuid != existing_edge.target_node_uuid
436
+ ):
437
+ continue
438
+
320
439
  edge_words = set(edge.fact.lower().split())
321
440
  existing_edge_words = set(existing_edge.fact.lower().split())
322
441
  has_overlap = not edge_words.isdisjoint(existing_edge_words)
@@ -339,12 +458,19 @@ async def dedupe_edges_bulk(
339
458
  ] = await semaphore_gather(
340
459
  *[
341
460
  resolve_extracted_edge(
342
- clients.llm_client, edge, candidates, candidates, episode, edge_types
461
+ clients.llm_client,
462
+ edge,
463
+ candidates,
464
+ candidates,
465
+ episode,
466
+ edge_types,
467
+ set(edge_types),
343
468
  )
344
469
  for episode, edge, candidates in dedupe_tuples
345
470
  ]
346
471
  )
347
472
 
473
+ # For now we won't track edge invalidation
348
474
  duplicate_pairs: list[tuple[str, str]] = []
349
475
  for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
350
476
  episode, edge, candidates = dedupe_tuples[i]