graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
@@ -14,58 +14,93 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  import typing
19
- from collections import defaultdict
20
20
  from datetime import datetime
21
- from math import ceil
22
21
 
23
- from neo4j import AsyncDriver, AsyncManagedTransaction
24
- from numpy import dot, sqrt
25
- from pydantic import BaseModel
22
+ import numpy as np
23
+ from pydantic import BaseModel, Field
26
24
  from typing_extensions import Any
27
25
 
28
- from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
26
+ from graphiti_core.driver.driver import (
27
+ GraphDriver,
28
+ GraphDriverSession,
29
+ GraphProvider,
30
+ )
31
+ from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
29
32
  from graphiti_core.embedder import EmbedderClient
30
33
  from graphiti_core.graphiti_types import GraphitiClients
31
- from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
32
- from graphiti_core.llm_client import LLMClient
34
+ from graphiti_core.helpers import normalize_l2, semaphore_gather
33
35
  from graphiti_core.models.edges.edge_db_queries import (
34
- ENTITY_EDGE_SAVE_BULK,
35
- EPISODIC_EDGE_SAVE_BULK,
36
+ get_entity_edge_save_bulk_query,
37
+ get_episodic_edge_save_bulk_query,
36
38
  )
37
39
  from graphiti_core.models.nodes.node_db_queries import (
38
- ENTITY_NODE_SAVE_BULK,
39
- EPISODIC_NODE_SAVE_BULK,
40
+ get_entity_node_save_bulk_query,
41
+ get_episode_node_save_bulk_query,
40
42
  )
41
43
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
42
- from graphiti_core.search.search_filters import SearchFilters
43
- from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
44
- from graphiti_core.utils.datetime_utils import utc_now
44
+ from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
45
+ from graphiti_core.utils.maintenance.dedup_helpers import (
46
+ DedupResolutionState,
47
+ _build_candidate_indexes,
48
+ _normalize_string_exact,
49
+ _resolve_with_similarity,
50
+ )
45
51
  from graphiti_core.utils.maintenance.edge_operations import (
46
- build_episodic_edges,
47
- dedupe_edge_list,
48
- dedupe_extracted_edges,
49
52
  extract_edges,
53
+ resolve_extracted_edge,
50
54
  )
51
55
  from graphiti_core.utils.maintenance.graph_data_operations import (
52
56
  EPISODE_WINDOW_LEN,
53
57
  retrieve_episodes,
54
58
  )
55
59
  from graphiti_core.utils.maintenance.node_operations import (
56
- dedupe_extracted_nodes,
57
- dedupe_node_list,
58
60
  extract_nodes,
61
+ resolve_extracted_nodes,
59
62
  )
60
- from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
61
63
 
62
64
  logger = logging.getLogger(__name__)
63
65
 
64
66
  CHUNK_SIZE = 10
65
67
 
66
68
 
69
+ def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
70
+ """Collapse alias -> canonical chains while preserving direction.
71
+
72
+ The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
73
+ union-find with iterative path compression to ensure every source UUID resolves to its ultimate
74
+ canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
75
+ """
76
+
77
+ parent: dict[str, str] = {}
78
+
79
+ def find(uuid: str) -> str:
80
+ """Directed union-find lookup using iterative path compression."""
81
+ parent.setdefault(uuid, uuid)
82
+ root = uuid
83
+ while parent[root] != root:
84
+ root = parent[root]
85
+
86
+ while parent[uuid] != root:
87
+ next_uuid = parent[uuid]
88
+ parent[uuid] = root
89
+ uuid = next_uuid
90
+
91
+ return root
92
+
93
+ for source_uuid, target_uuid in pairs:
94
+ parent.setdefault(source_uuid, source_uuid)
95
+ parent.setdefault(target_uuid, target_uuid)
96
+ parent[find(source_uuid)] = find(target_uuid)
97
+
98
+ return {uuid: find(uuid) for uuid in parent}
99
+
100
+
67
101
  class RawEpisode(BaseModel):
68
102
  name: str
103
+ uuid: str | None = Field(default=None)
69
104
  content: str
70
105
  source_description: str
71
106
  source: EpisodeType
@@ -73,7 +108,7 @@ class RawEpisode(BaseModel):
73
108
 
74
109
 
75
110
  async def retrieve_previous_episodes_bulk(
76
- driver: AsyncDriver, episodes: list[EpisodicNode]
111
+ driver: GraphDriver, episodes: list[EpisodicNode]
77
112
  ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
78
113
  previous_episodes_list = await semaphore_gather(
79
114
  *[
@@ -91,14 +126,15 @@ async def retrieve_previous_episodes_bulk(
91
126
 
92
127
 
93
128
  async def add_nodes_and_edges_bulk(
94
- driver: AsyncDriver,
129
+ driver: GraphDriver,
95
130
  episodic_nodes: list[EpisodicNode],
96
131
  episodic_edges: list[EpisodicEdge],
97
132
  entity_nodes: list[EntityNode],
98
133
  entity_edges: list[EntityEdge],
99
134
  embedder: EmbedderClient,
100
135
  ):
101
- async with driver.session(database=DEFAULT_DATABASE) as session:
136
+ session = driver.session()
137
+ try:
102
138
  await session.execute_write(
103
139
  add_nodes_and_edges_bulk_tx,
104
140
  episodic_nodes,
@@ -106,38 +142,51 @@ async def add_nodes_and_edges_bulk(
106
142
  entity_nodes,
107
143
  entity_edges,
108
144
  embedder,
145
+ driver=driver,
109
146
  )
147
+ finally:
148
+ await session.close()
110
149
 
111
150
 
112
151
  async def add_nodes_and_edges_bulk_tx(
113
- tx: AsyncManagedTransaction,
152
+ tx: GraphDriverSession,
114
153
  episodic_nodes: list[EpisodicNode],
115
154
  episodic_edges: list[EpisodicEdge],
116
155
  entity_nodes: list[EntityNode],
117
156
  entity_edges: list[EntityEdge],
118
157
  embedder: EmbedderClient,
158
+ driver: GraphDriver,
119
159
  ):
120
160
  episodes = [dict(episode) for episode in episodic_nodes]
121
161
  for episode in episodes:
122
162
  episode['source'] = str(episode['source'].value)
123
- nodes: list[dict[str, Any]] = []
163
+ episode.pop('labels', None)
164
+
165
+ nodes = []
166
+
124
167
  for node in entity_nodes:
125
168
  if node.name_embedding is None:
126
169
  await node.generate_name_embedding(embedder)
170
+
127
171
  entity_data: dict[str, Any] = {
128
172
  'uuid': node.uuid,
129
173
  'name': node.name,
130
- 'name_embedding': node.name_embedding,
131
174
  'group_id': node.group_id,
132
175
  'summary': node.summary,
133
176
  'created_at': node.created_at,
177
+ 'name_embedding': node.name_embedding,
178
+ 'labels': list(set(node.labels + ['Entity'])),
134
179
  }
135
180
 
136
- entity_data.update(node.attributes or {})
137
- entity_data['labels'] = list(set(node.labels + ['Entity']))
181
+ if driver.provider == GraphProvider.KUZU:
182
+ attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
183
+ entity_data['attributes'] = json.dumps(attributes)
184
+ else:
185
+ entity_data.update(node.attributes or {})
186
+
138
187
  nodes.append(entity_data)
139
188
 
140
- edges: list[dict[str, Any]] = []
189
+ edges = []
141
190
  for edge in entity_edges:
142
191
  if edge.fact_embedding is None:
143
192
  await edge.generate_embedding(embedder)
@@ -147,253 +196,343 @@ async def add_nodes_and_edges_bulk_tx(
147
196
  'target_node_uuid': edge.target_node_uuid,
148
197
  'name': edge.name,
149
198
  'fact': edge.fact,
150
- 'fact_embedding': edge.fact_embedding,
151
199
  'group_id': edge.group_id,
152
200
  'episodes': edge.episodes,
153
201
  'created_at': edge.created_at,
154
202
  'expired_at': edge.expired_at,
155
203
  'valid_at': edge.valid_at,
156
204
  'invalid_at': edge.invalid_at,
205
+ 'fact_embedding': edge.fact_embedding,
157
206
  }
158
207
 
159
- edge_data.update(edge.attributes or {})
208
+ if driver.provider == GraphProvider.KUZU:
209
+ attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
210
+ edge_data['attributes'] = json.dumps(attributes)
211
+ else:
212
+ edge_data.update(edge.attributes or {})
213
+
160
214
  edges.append(edge_data)
161
215
 
162
- await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
163
- await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
164
- await tx.run(
165
- EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
166
- )
167
- await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
216
+ if driver.graph_operations_interface:
217
+ await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
218
+ await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
219
+ await driver.graph_operations_interface.episodic_edge_save_bulk(
220
+ None, driver, tx, [edge.model_dump() for edge in episodic_edges]
221
+ )
222
+ await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
223
+
224
+ elif driver.provider == GraphProvider.KUZU:
225
+ # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
226
+ episode_query = get_episode_node_save_bulk_query(driver.provider)
227
+ for episode in episodes:
228
+ await tx.run(episode_query, **episode)
229
+ entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
230
+ for node in nodes:
231
+ await tx.run(entity_node_query, **node)
232
+ entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
233
+ for edge in edges:
234
+ await tx.run(entity_edge_query, **edge)
235
+ episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
236
+ for edge in episodic_edges:
237
+ await tx.run(episodic_edge_query, **edge.model_dump())
238
+ else:
239
+ await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
240
+ await tx.run(
241
+ get_entity_node_save_bulk_query(driver.provider, nodes),
242
+ nodes=nodes,
243
+ )
244
+ await tx.run(
245
+ get_episodic_edge_save_bulk_query(driver.provider),
246
+ episodic_edges=[edge.model_dump() for edge in episodic_edges],
247
+ )
248
+ await tx.run(
249
+ get_entity_edge_save_bulk_query(driver.provider),
250
+ entity_edges=edges,
251
+ )
168
252
 
169
253
 
170
254
  async def extract_nodes_and_edges_bulk(
171
- clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
172
- ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
173
- extracted_nodes_bulk = await semaphore_gather(
255
+ clients: GraphitiClients,
256
+ episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
257
+ edge_type_map: dict[tuple[str, str], list[str]],
258
+ entity_types: dict[str, type[BaseModel]] | None = None,
259
+ excluded_entity_types: list[str] | None = None,
260
+ edge_types: dict[str, type[BaseModel]] | None = None,
261
+ ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
262
+ extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
174
263
  *[
175
- extract_nodes(clients, episode, previous_episodes)
264
+ extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
176
265
  for episode, previous_episodes in episode_tuples
177
266
  ]
178
267
  )
179
268
 
180
- episodes, previous_episodes_list = (
181
- [episode[0] for episode in episode_tuples],
182
- [episode[1] for episode in episode_tuples],
183
- )
184
-
185
- extracted_edges_bulk = await semaphore_gather(
269
+ extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
186
270
  *[
187
271
  extract_edges(
188
272
  clients,
189
273
  episode,
190
274
  extracted_nodes_bulk[i],
191
- previous_episodes_list[i],
192
- episode.group_id,
275
+ previous_episodes,
276
+ edge_type_map=edge_type_map,
277
+ group_id=episode.group_id,
278
+ edge_types=edge_types,
193
279
  )
194
- for i, episode in enumerate(episodes)
280
+ for i, (episode, previous_episodes) in enumerate(episode_tuples)
195
281
  ]
196
282
  )
197
283
 
198
- episodic_edges: list[EpisodicEdge] = []
199
- for i, episode in enumerate(episodes):
200
- episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
201
-
202
- nodes: list[EntityNode] = []
203
- for extracted_nodes in extracted_nodes_bulk:
204
- nodes += extracted_nodes
205
-
206
- edges: list[EntityEdge] = []
207
- for extracted_edges in extracted_edges_bulk:
208
- edges += extracted_edges
209
-
210
- return nodes, edges, episodic_edges
284
+ return extracted_nodes_bulk, extracted_edges_bulk
211
285
 
212
286
 
213
287
  async def dedupe_nodes_bulk(
214
- driver: AsyncDriver,
215
- llm_client: LLMClient,
216
- extracted_nodes: list[EntityNode],
217
- ) -> tuple[list[EntityNode], dict[str, str]]:
218
- # Compress nodes
219
- nodes, uuid_map = node_name_match(extracted_nodes)
220
-
221
- compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
222
-
223
- node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
224
-
225
- existing_nodes_chunks: list[list[EntityNode]] = list(
226
- await semaphore_gather(
227
- *[get_relevant_nodes(driver, node_chunk, SearchFilters()) for node_chunk in node_chunks]
228
- )
229
- )
230
-
231
- results: list[tuple[list[EntityNode], dict[str, str]]] = list(
232
- await semaphore_gather(
233
- *[
234
- dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
235
- for i, node_chunk in enumerate(node_chunks)
236
- ]
237
- )
288
+ clients: GraphitiClients,
289
+ extracted_nodes: list[list[EntityNode]],
290
+ episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
291
+ entity_types: dict[str, type[BaseModel]] | None = None,
292
+ ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
293
+ """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
294
+
295
+ 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
296
+ reconciled against the live graph just like the non-batch flow.
297
+ 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
298
+ duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
299
+ can apply to edges and persistence.
300
+ """
301
+
302
+ first_pass_results = await semaphore_gather(
303
+ *[
304
+ resolve_extracted_nodes(
305
+ clients,
306
+ nodes,
307
+ episode_tuples[i][0],
308
+ episode_tuples[i][1],
309
+ entity_types,
310
+ )
311
+ for i, nodes in enumerate(extracted_nodes)
312
+ ]
238
313
  )
239
314
 
240
- final_nodes: list[EntityNode] = []
241
- for result in results:
242
- final_nodes.extend(result[0])
243
- partial_uuid_map = result[1]
244
- compressed_map.update(partial_uuid_map)
245
-
246
- return final_nodes, compressed_map
315
+ episode_resolutions: list[tuple[str, list[EntityNode]]] = []
316
+ per_episode_uuid_maps: list[dict[str, str]] = []
317
+ duplicate_pairs: list[tuple[str, str]] = []
318
+
319
+ for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
320
+ first_pass_results, episode_tuples, strict=True
321
+ ):
322
+ episode_resolutions.append((episode.uuid, resolved_nodes))
323
+ per_episode_uuid_maps.append(uuid_map)
324
+ duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
325
+
326
+ canonical_nodes: dict[str, EntityNode] = {}
327
+ for _, resolved_nodes in episode_resolutions:
328
+ for node in resolved_nodes:
329
+ # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
330
+ # the MinHash index for the accumulated canonical pool each time. The LRU-backed
331
+ # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
332
+ # but if batches grow significantly we should switch to an incremental index or chunked
333
+ # processing.
334
+ if not canonical_nodes:
335
+ canonical_nodes[node.uuid] = node
336
+ continue
337
+
338
+ existing_candidates = list(canonical_nodes.values())
339
+ normalized = _normalize_string_exact(node.name)
340
+ exact_match = next(
341
+ (
342
+ candidate
343
+ for candidate in existing_candidates
344
+ if _normalize_string_exact(candidate.name) == normalized
345
+ ),
346
+ None,
347
+ )
348
+ if exact_match is not None:
349
+ if exact_match.uuid != node.uuid:
350
+ duplicate_pairs.append((node.uuid, exact_match.uuid))
351
+ continue
352
+
353
+ indexes = _build_candidate_indexes(existing_candidates)
354
+ state = DedupResolutionState(
355
+ resolved_nodes=[None],
356
+ uuid_map={},
357
+ unresolved_indices=[],
358
+ )
359
+ _resolve_with_similarity([node], indexes, state)
360
+
361
+ resolved = state.resolved_nodes[0]
362
+ if resolved is None:
363
+ canonical_nodes[node.uuid] = node
364
+ continue
365
+
366
+ canonical_uuid = resolved.uuid
367
+ canonical_nodes.setdefault(canonical_uuid, resolved)
368
+ if canonical_uuid != node.uuid:
369
+ duplicate_pairs.append((node.uuid, canonical_uuid))
370
+
371
+ union_pairs: list[tuple[str, str]] = []
372
+ for uuid_map in per_episode_uuid_maps:
373
+ union_pairs.extend(uuid_map.items())
374
+ union_pairs.extend(duplicate_pairs)
375
+
376
+ compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
377
+
378
+ nodes_by_episode: dict[str, list[EntityNode]] = {}
379
+ for episode_uuid, resolved_nodes in episode_resolutions:
380
+ deduped_nodes: list[EntityNode] = []
381
+ seen: set[str] = set()
382
+ for node in resolved_nodes:
383
+ canonical_uuid = compressed_map.get(node.uuid, node.uuid)
384
+ if canonical_uuid in seen:
385
+ continue
386
+ seen.add(canonical_uuid)
387
+ canonical_node = canonical_nodes.get(canonical_uuid)
388
+ if canonical_node is None:
389
+ logger.error(
390
+ 'Canonical node %s missing during batch dedupe; falling back to %s',
391
+ canonical_uuid,
392
+ node.uuid,
393
+ )
394
+ canonical_node = node
395
+ deduped_nodes.append(canonical_node)
396
+
397
+ nodes_by_episode[episode_uuid] = deduped_nodes
398
+
399
+ return nodes_by_episode, compressed_map
247
400
 
248
401
 
249
402
  async def dedupe_edges_bulk(
250
- driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
251
- ) -> list[EntityEdge]:
252
- # First compress edges
253
- compressed_edges = await compress_edges(llm_client, extracted_edges)
254
-
255
- edge_chunks = [
256
- compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
257
- ]
258
-
259
- relevant_edges_chunks: list[list[EntityEdge]] = list(
260
- await semaphore_gather(
261
- *[get_relevant_edges(driver, edge_chunk, SearchFilters()) for edge_chunk in edge_chunks]
262
- )
403
+ clients: GraphitiClients,
404
+ extracted_edges: list[list[EntityEdge]],
405
+ episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
406
+ _entities: list[EntityNode],
407
+ edge_types: dict[str, type[BaseModel]],
408
+ _edge_type_map: dict[tuple[str, str], list[str]],
409
+ ) -> dict[str, list[EntityEdge]]:
410
+ embedder = clients.embedder
411
+ min_score = 0.6
412
+
413
+ # generate embeddings
414
+ await semaphore_gather(
415
+ *[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
263
416
  )
264
417
 
265
- resolved_edge_chunks: list[list[EntityEdge]] = list(
266
- await semaphore_gather(
267
- *[
268
- dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
269
- for i, edge_chunk in enumerate(edge_chunks)
270
- ]
271
- )
418
+ # Find similar results
419
+ dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
420
+ for i, edges_i in enumerate(extracted_edges):
421
+ existing_edges: list[EntityEdge] = []
422
+ for edges_j in extracted_edges:
423
+ existing_edges += edges_j
424
+
425
+ for edge in edges_i:
426
+ candidates: list[EntityEdge] = []
427
+ for existing_edge in existing_edges:
428
+ # Skip self-comparison
429
+ if edge.uuid == existing_edge.uuid:
430
+ continue
431
+ # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
432
+ # This approach will cast a wider net than BM25, which is ideal for this use case
433
+ if (
434
+ edge.source_node_uuid != existing_edge.source_node_uuid
435
+ or edge.target_node_uuid != existing_edge.target_node_uuid
436
+ ):
437
+ continue
438
+
439
+ edge_words = set(edge.fact.lower().split())
440
+ existing_edge_words = set(existing_edge.fact.lower().split())
441
+ has_overlap = not edge_words.isdisjoint(existing_edge_words)
442
+ if has_overlap:
443
+ candidates.append(existing_edge)
444
+ continue
445
+
446
+ # Check for semantic similarity even if there is no overlap
447
+ similarity = np.dot(
448
+ normalize_l2(edge.fact_embedding or []),
449
+ normalize_l2(existing_edge.fact_embedding or []),
450
+ )
451
+ if similarity >= min_score:
452
+ candidates.append(existing_edge)
453
+
454
+ dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
455
+
456
+ bulk_edge_resolutions: list[
457
+ tuple[EntityEdge, EntityEdge, list[EntityEdge]]
458
+ ] = await semaphore_gather(
459
+ *[
460
+ resolve_extracted_edge(
461
+ clients.llm_client,
462
+ edge,
463
+ candidates,
464
+ candidates,
465
+ episode,
466
+ edge_types,
467
+ set(edge_types),
468
+ )
469
+ for episode, edge, candidates in dedupe_tuples
470
+ ]
272
471
  )
273
472
 
274
- edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
275
- return edges
276
-
277
-
278
- def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
279
- uuid_map: dict[str, str] = {}
280
- name_map: dict[str, EntityNode] = {}
281
- for node in nodes:
282
- if node.name in name_map:
283
- uuid_map[node.uuid] = name_map[node.name].uuid
284
- continue
285
-
286
- name_map[node.name] = node
287
-
288
- return [node for node in name_map.values()], uuid_map
289
-
290
-
291
- async def compress_nodes(
292
- llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
293
- ) -> tuple[list[EntityNode], dict[str, str]]:
294
- # We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
295
- if len(nodes) == 0:
296
- return nodes, uuid_map
473
+ # For now we won't track edge invalidation
474
+ duplicate_pairs: list[tuple[str, str]] = []
475
+ for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
476
+ episode, edge, candidates = dedupe_tuples[i]
477
+ for duplicate in duplicates:
478
+ duplicate_pairs.append((edge.uuid, duplicate.uuid))
297
479
 
298
- # Our approach involves us deduplicating chunks of nodes in parallel.
299
- # We want n chunks of size n so that n ** 2 == len(nodes).
300
- # We want chunk sizes to be at least 10 for optimizing LLM processing time
301
- chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
480
+ # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
481
+ compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
302
482
 
303
- # First calculate similarity scores between nodes
304
- similarity_scores: list[tuple[int, int, float]] = [
305
- (i, j, dot(n.name_embedding or [], m.name_embedding or []))
306
- for i, n in enumerate(nodes)
307
- for j, m in enumerate(nodes[:i])
308
- ]
309
-
310
- # We now sort by semantic similarity
311
- similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
483
+ edge_uuid_map: dict[str, EntityEdge] = {
484
+ edge.uuid: edge for edges in extracted_edges for edge in edges
485
+ }
312
486
 
313
- # initialize our chunks based on chunk size
314
- node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
487
+ edges_by_episode: dict[str, list[EntityEdge]] = {}
488
+ for i, edges in enumerate(extracted_edges):
489
+ episode = episode_tuples[i][0]
315
490
 
316
- # Draft the most similar nodes into the same chunk
317
- while len(similarity_scores) > 0:
318
- i, j, _ = similarity_scores.pop()
319
- # determine if any of the nodes have already been drafted into a chunk
320
- n = nodes[i]
321
- m = nodes[j]
322
- # make sure the shortest chunks get preference
323
- node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
491
+ edges_by_episode[episode.uuid] = [
492
+ edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
493
+ ]
324
494
 
325
- n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
326
- m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
495
+ return edges_by_episode
327
496
 
328
- # both nodes already in a chunk
329
- if n_chunk > -1 and m_chunk > -1:
330
- continue
331
497
 
332
- # n has a chunk and that chunk is not full
333
- elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
334
- # put m in the same chunk as n
335
- node_chunks[n_chunk].append(m)
498
+ class UnionFind:
499
+ def __init__(self, elements):
500
+ # start each element in its own set
501
+ self.parent = {e: e for e in elements}
336
502
 
337
- # m has a chunk and that chunk is not full
338
- elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
339
- # put n in the same chunk as m
340
- node_chunks[m_chunk].append(n)
503
+ def find(self, x):
504
+ # path‐compression
505
+ if self.parent[x] != x:
506
+ self.parent[x] = self.find(self.parent[x])
507
+ return self.parent[x]
341
508
 
342
- # neither node has a chunk or the chunk is full
509
+ def union(self, a, b):
510
+ ra, rb = self.find(a), self.find(b)
511
+ if ra == rb:
512
+ return
513
+ # attach the lexicographically larger root under the smaller
514
+ if ra < rb:
515
+ self.parent[rb] = ra
343
516
  else:
344
- # add both nodes to the shortest chunk
345
- node_chunks[-1].extend([n, m])
346
-
347
- results = await semaphore_gather(
348
- *[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
349
- )
350
-
351
- extended_map = dict(uuid_map)
352
- compressed_nodes: list[EntityNode] = []
353
- for node_chunk, uuid_map_chunk in results:
354
- compressed_nodes += node_chunk
355
- extended_map.update(uuid_map_chunk)
356
-
357
- # Check if we have removed all duplicates
358
- if len(compressed_nodes) == len(nodes):
359
- compressed_uuid_map = compress_uuid_map(extended_map)
360
- return compressed_nodes, compressed_uuid_map
361
-
362
- return await compress_nodes(llm_client, compressed_nodes, extended_map)
363
-
364
-
365
- async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
366
- if len(edges) == 0:
367
- return edges
368
- # We only want to dedupe edges that are between the same pair of nodes
369
- # We build a map of the edges based on their source and target nodes.
370
- edge_chunks = chunk_edges_by_nodes(edges)
371
-
372
- results = await semaphore_gather(
373
- *[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
374
- )
375
-
376
- compressed_edges: list[EntityEdge] = []
377
- for edge_chunk in results:
378
- compressed_edges += edge_chunk
379
-
380
- # Check if we have removed all duplicates
381
- if len(compressed_edges) == len(edges):
382
- return compressed_edges
383
-
384
- return await compress_edges(llm_client, compressed_edges)
517
+ self.parent[ra] = rb
385
518
 
386
519
 
387
- def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
388
- # make sure all uuid values aren't mapped to other uuids
389
- compressed_map = {}
390
- for key, uuid in uuid_map.items():
391
- curr_value = uuid
392
- while curr_value in uuid_map:
393
- curr_value = uuid_map[curr_value]
520
+ def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
521
+ """
522
+ all_ids: iterable of all entity IDs (strings)
523
+ duplicate_pairs: iterable of (id1, id2) pairs
524
+ returns: dict mapping each id -> lexicographically smallest id in its duplicate set
525
+ """
526
+ all_uuids = set()
527
+ for pair in duplicate_pairs:
528
+ all_uuids.add(pair[0])
529
+ all_uuids.add(pair[1])
394
530
 
395
- compressed_map[key] = curr_value
396
- return compressed_map
531
+ uf = UnionFind(all_uuids)
532
+ for a, b in duplicate_pairs:
533
+ uf.union(a, b)
534
+ # ensure full path‐compression before mapping
535
+ return {uuid: uf.find(uuid) for uuid in all_uuids}
397
536
 
398
537
 
399
538
  E = typing.TypeVar('E', bound=Edge)
@@ -407,63 +546,3 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
407
546
  edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
408
547
 
409
548
  return edges
410
-
411
-
412
- async def extract_edge_dates_bulk(
413
- llm_client: LLMClient,
414
- extracted_edges: list[EntityEdge],
415
- episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
416
- ) -> list[EntityEdge]:
417
- edges: list[EntityEdge] = []
418
- # confirm that all of our edges have at least one episode
419
- for edge in extracted_edges:
420
- if edge.episodes is not None and len(edge.episodes) > 0:
421
- edges.append(edge)
422
-
423
- episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
424
- episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
425
- }
426
-
427
- results = await semaphore_gather(
428
- *[
429
- extract_edge_dates(
430
- llm_client,
431
- edge,
432
- episode_uuid_map[edge.episodes[0]][0], # type: ignore
433
- episode_uuid_map[edge.episodes[0]][1], # type: ignore
434
- )
435
- for edge in edges
436
- ]
437
- )
438
-
439
- for i, result in enumerate(results):
440
- valid_at = result[0]
441
- invalid_at = result[1]
442
- edge = edges[i]
443
-
444
- edge.valid_at = valid_at
445
- edge.invalid_at = invalid_at
446
- if edge.invalid_at:
447
- edge.expired_at = utc_now()
448
-
449
- return edges
450
-
451
-
452
- def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
453
- # We only want to dedupe edges that are between the same pair of nodes
454
- # We build a map of the edges based on their source and target nodes.
455
- edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
456
- for edge in edges:
457
- # We drop loop edges
458
- if edge.source_node_uuid == edge.target_node_uuid:
459
- continue
460
-
461
- # Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
462
- pointers = [edge.source_node_uuid, edge.target_node_uuid]
463
- pointers.sort()
464
-
465
- edge_chunk_map[pointers[0] + pointers[1]].append(edge)
466
-
467
- edge_chunks = [chunk for chunk in edge_chunk_map.values()]
468
-
469
- return edge_chunks