graphiti-core 0.21.0rc5__py3-none-any.whl → 0.21.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/graphiti.py CHANGED
@@ -79,7 +79,6 @@ from graphiti_core.utils.maintenance.community_operations import (
79
79
  update_community,
80
80
  )
81
81
  from graphiti_core.utils.maintenance.edge_operations import (
82
- build_duplicate_of_edges,
83
82
  build_episodic_edges,
84
83
  extract_edges,
85
84
  resolve_extracted_edge,
@@ -503,7 +502,7 @@ class Graphiti:
503
502
  )
504
503
 
505
504
  # Extract edges and resolve nodes
506
- (nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
505
+ (nodes, uuid_map, _), extracted_edges = await semaphore_gather(
507
506
  resolve_extracted_nodes(
508
507
  self.clients,
509
508
  extracted_nodes,
@@ -540,9 +539,7 @@ class Graphiti:
540
539
  max_coroutines=self.max_coroutines,
541
540
  )
542
541
 
543
- duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
544
-
545
- entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
542
+ entity_edges = resolved_edges + invalidated_edges
546
543
 
547
544
  episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
548
545
 
@@ -1073,6 +1070,7 @@ class Graphiti:
1073
1070
  group_id=edge.group_id,
1074
1071
  ),
1075
1072
  None,
1073
+ None,
1076
1074
  self.ensure_ascii,
1077
1075
  )
1078
1076
 
@@ -32,9 +32,19 @@ from .errors import RateLimitError
32
32
  DEFAULT_TEMPERATURE = 0
33
33
  DEFAULT_CACHE_DIR = './llm_cache'
34
34
 
35
- MULTILINGUAL_EXTRACTION_RESPONSES = (
36
- '\n\nAny extracted information should be returned in the same language as it was written in.'
37
- )
35
+
36
+ def get_extraction_language_instruction() -> str:
37
+ """Returns instruction for language extraction behavior.
38
+
39
+ Override this function to customize language extraction:
40
+ - Return empty string to disable multilingual instructions
41
+ - Return custom instructions for specific language requirements
42
+
43
+ Returns:
44
+ str: Language instruction to append to system messages
45
+ """
46
+ return '\n\nAny extracted information should be returned in the same language as it was written in.'
47
+
38
48
 
39
49
  logger = logging.getLogger(__name__)
40
50
 
@@ -145,7 +155,7 @@ class LLMClient(ABC):
145
155
  )
146
156
 
147
157
  # Add multilingual extraction instructions
148
- messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
158
+ messages[0].content += get_extraction_language_instruction()
149
159
 
150
160
  if self.cache_enabled and self.cache_dir is not None:
151
161
  cache_key = self._get_cache_key(messages)
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, ClassVar
23
23
  from pydantic import BaseModel
24
24
 
25
25
  from ..prompts.models import Message
26
- from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
26
+ from .client import LLMClient, get_extraction_language_instruction
27
27
  from .config import LLMConfig, ModelSize
28
28
  from .errors import RateLimitError
29
29
 
@@ -376,7 +376,7 @@ class GeminiClient(LLMClient):
376
376
  last_output = None
377
377
 
378
378
  # Add multilingual extraction instructions
379
- messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
379
+ messages[0].content += get_extraction_language_instruction()
380
380
 
381
381
  while retry_count < self.MAX_RETRIES:
382
382
  try:
@@ -25,7 +25,7 @@ from openai.types.chat import ChatCompletionMessageParam
25
25
  from pydantic import BaseModel
26
26
 
27
27
  from ..prompts.models import Message
28
- from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
28
+ from .client import LLMClient, get_extraction_language_instruction
29
29
  from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
30
30
  from .errors import RateLimitError, RefusalError
31
31
 
@@ -184,7 +184,7 @@ class BaseOpenAIClient(LLMClient):
184
184
  last_error = None
185
185
 
186
186
  # Add multilingual extraction instructions
187
- messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
187
+ messages[0].content += get_extraction_language_instruction()
188
188
 
189
189
  while retry_count <= self.MAX_RETRIES:
190
190
  try:
@@ -25,7 +25,7 @@ from openai.types.chat import ChatCompletionMessageParam
25
25
  from pydantic import BaseModel
26
26
 
27
27
  from ..prompts.models import Message
28
- from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
28
+ from .client import LLMClient, get_extraction_language_instruction
29
29
  from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
30
30
  from .errors import RateLimitError, RefusalError
31
31
 
@@ -136,7 +136,7 @@ class OpenAIGenericClient(LLMClient):
136
136
  )
137
137
 
138
138
  # Add multilingual extraction instructions
139
- messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
139
+ messages[0].content += get_extraction_language_instruction()
140
140
 
141
141
  while retry_count <= self.MAX_RETRIES:
142
142
  try:
@@ -92,12 +92,23 @@ def node(context: dict[str, Any]) -> list[Message]:
92
92
 
93
93
  TASK:
94
94
  1. Compare `new_entity` against each item in `existing_entities`.
95
- 2. If it refers to the same realworld object or concept, collect its index.
96
- 3. Let `duplicate_idx` = the *first* collected index, or 1 if none.
97
- 4. Let `duplicates` = the list of *all* collected indices (empty list if none).
98
-
99
- Also return the full name of the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
100
- is a duplicate of, or a combination of the two).
95
+ 2. If it refers to the same real-world object or concept, collect its index.
96
+ 3. Let `duplicate_idx` = the smallest collected index, or -1 if none.
97
+ 4. Let `duplicates` = the sorted list of all collected indices (empty list if none).
98
+
99
+ Respond with a JSON object containing an "entity_resolutions" array with a single entry:
100
+ {{
101
+ "entity_resolutions": [
102
+ {{
103
+ "id": integer id from NEW ENTITY,
104
+ "name": the best full name for the entity,
105
+ "duplicate_idx": integer index of the best duplicate in EXISTING ENTITIES, or -1 if none,
106
+ "duplicates": sorted list of all duplicate indices you collected (deduplicate the list, use [] when none)
107
+ }}
108
+ ]
109
+ }}
110
+
111
+ Only reference indices that appear in EXISTING ENTITIES, and return [] / -1 when unsure.
101
112
  """,
102
113
  ),
103
114
  ]
@@ -126,26 +137,26 @@ def nodes(context: dict[str, Any]) -> list[Message]:
126
137
  {{
127
138
  id: integer id of the entity,
128
139
  name: "name of the entity",
129
- entity_type: "ontological classification of the entity",
130
- entity_type_description: "Description of what the entity type represents",
131
- duplication_candidates: [
132
- {{
133
- idx: integer index of the candidate entity,
134
- name: "name of the candidate entity",
135
- entity_type: "ontological classification of the candidate entity",
136
- ...<additional attributes>
137
- }}
138
- ]
140
+ entity_type: ["Entity", "<optional additional label>", ...],
141
+ entity_type_description: "Description of what the entity type represents"
139
142
  }}
140
-
143
+
141
144
  <ENTITIES>
142
145
  {to_prompt_json(context['extracted_nodes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
143
146
  </ENTITIES>
144
-
147
+
145
148
  <EXISTING ENTITIES>
146
149
  {to_prompt_json(context['existing_nodes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
147
150
  </EXISTING ENTITIES>
148
151
 
152
+ Each entry in EXISTING ENTITIES is an object with the following structure:
153
+ {{
154
+ idx: integer index of the candidate entity (use this when referencing a duplicate),
155
+ name: "name of the candidate entity",
156
+ entity_types: ["Entity", "<optional additional label>", ...],
157
+ ...<additional attributes such as summaries or metadata>
158
+ }}
159
+
149
160
  For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
150
161
 
151
162
  Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
@@ -155,14 +166,19 @@ def nodes(context: dict[str, Any]) -> list[Message]:
155
166
  - They have similar names or purposes but refer to separate instances or concepts.
156
167
 
157
168
  Task:
158
- Your response will be a list called entity_resolutions which contains one entry for each entity.
159
-
160
- For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
161
- as an integer.
162
-
163
- - If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a
164
- duplicate of.
165
- - If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx
169
+ Respond with a JSON object that contains an "entity_resolutions" array with one entry for each entity in ENTITIES, ordered by the entity id.
170
+
171
+ For every entity, return an object with the following keys:
172
+ {{
173
+ "id": integer id from ENTITIES,
174
+ "name": the best full name for the entity (preserve the original name unless a duplicate has a more complete name),
175
+ "duplicate_idx": the idx of the EXISTING ENTITY that is the best duplicate match, or -1 if there is no duplicate,
176
+ "duplicates": a sorted list of all idx values from EXISTING ENTITIES that refer to duplicates (deduplicate the list, use [] when none or unsure)
177
+ }}
178
+
179
+ - Only use idx values that appear in EXISTING ENTITIES.
180
+ - Set duplicate_idx to the smallest idx you collected for that entity, or -1 if duplicates is empty.
181
+ - Never fabricate entities or indices.
166
182
  """,
167
183
  ),
168
184
  ]
@@ -152,7 +152,8 @@ Indicate the classified entity type by providing its entity_type_id.
152
152
 
153
153
  Guidelines:
154
154
  1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
155
- 2. Do NOT extract any properties that contain dates
155
+ 2. Extract all entities mentioned in all other properties throughout the JSON structure
156
+ 3. Do NOT extract any properties that contain dates
156
157
  """
157
158
  return [
158
159
  Message(role='system', content=sys_prompt),
@@ -43,8 +43,14 @@ from graphiti_core.models.nodes.node_db_queries import (
43
43
  get_entity_node_save_bulk_query,
44
44
  get_episode_node_save_bulk_query,
45
45
  )
46
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
46
+ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
47
47
  from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
48
+ from graphiti_core.utils.maintenance.dedup_helpers import (
49
+ DedupResolutionState,
50
+ _build_candidate_indexes,
51
+ _normalize_string_exact,
52
+ _resolve_with_similarity,
53
+ )
48
54
  from graphiti_core.utils.maintenance.edge_operations import (
49
55
  extract_edges,
50
56
  resolve_extracted_edge,
@@ -63,6 +69,38 @@ logger = logging.getLogger(__name__)
63
69
  CHUNK_SIZE = 10
64
70
 
65
71
 
72
+ def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
73
+ """Collapse alias -> canonical chains while preserving direction.
74
+
75
+ The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
76
+ union-find with iterative path compression to ensure every source UUID resolves to its ultimate
77
+ canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
78
+ """
79
+
80
+ parent: dict[str, str] = {}
81
+
82
+ def find(uuid: str) -> str:
83
+ """Directed union-find lookup using iterative path compression."""
84
+ parent.setdefault(uuid, uuid)
85
+ root = uuid
86
+ while parent[root] != root:
87
+ root = parent[root]
88
+
89
+ while parent[uuid] != root:
90
+ next_uuid = parent[uuid]
91
+ parent[uuid] = root
92
+ uuid = next_uuid
93
+
94
+ return root
95
+
96
+ for source_uuid, target_uuid in pairs:
97
+ parent.setdefault(source_uuid, source_uuid)
98
+ parent.setdefault(target_uuid, target_uuid)
99
+ parent[find(source_uuid)] = find(target_uuid)
100
+
101
+ return {uuid: find(uuid) for uuid in parent}
102
+
103
+
66
104
  class RawEpisode(BaseModel):
67
105
  name: str
68
106
  uuid: str | None = Field(default=None)
@@ -266,83 +304,111 @@ async def dedupe_nodes_bulk(
266
304
  episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
267
305
  entity_types: dict[str, type[BaseModel]] | None = None,
268
306
  ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
269
- embedder = clients.embedder
270
- min_score = 0.8
271
-
272
- # generate embeddings
273
- await semaphore_gather(
274
- *[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
275
- )
276
-
277
- # Find similar results
278
- dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
279
- for i, nodes_i in enumerate(extracted_nodes):
280
- existing_nodes: list[EntityNode] = []
281
- for j, nodes_j in enumerate(extracted_nodes):
282
- if i == j:
283
- continue
284
- existing_nodes += nodes_j
285
-
286
- candidates_i: list[EntityNode] = []
287
- for node in nodes_i:
288
- for existing_node in existing_nodes:
289
- # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
290
- # This approach will cast a wider net than BM25, which is ideal for this use case
291
- node_words = set(node.name.lower().split())
292
- existing_node_words = set(existing_node.name.lower().split())
293
- has_overlap = not node_words.isdisjoint(existing_node_words)
294
- if has_overlap:
295
- candidates_i.append(existing_node)
296
- continue
307
+ """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
297
308
 
298
- # Check for semantic similarity even if there is no overlap
299
- similarity = np.dot(
300
- normalize_l2(node.name_embedding or []),
301
- normalize_l2(existing_node.name_embedding or []),
302
- )
303
- if similarity >= min_score:
304
- candidates_i.append(existing_node)
305
-
306
- dedupe_tuples.append((nodes_i, candidates_i))
309
+ 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
310
+ reconciled against the live graph just like the non-batch flow.
311
+ 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
312
+ duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
313
+ can apply to edges and persistence.
314
+ """
307
315
 
308
- # Determine Node Resolutions
309
- bulk_node_resolutions: list[
310
- tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
311
- ] = await semaphore_gather(
316
+ first_pass_results = await semaphore_gather(
312
317
  *[
313
318
  resolve_extracted_nodes(
314
319
  clients,
315
- dedupe_tuple[0],
320
+ nodes,
316
321
  episode_tuples[i][0],
317
322
  episode_tuples[i][1],
318
323
  entity_types,
319
- existing_nodes_override=dedupe_tuples[i][1],
320
324
  )
321
- for i, dedupe_tuple in enumerate(dedupe_tuples)
325
+ for i, nodes in enumerate(extracted_nodes)
322
326
  ]
323
327
  )
324
328
 
325
- # Collect all duplicate pairs sorted by uuid
329
+ episode_resolutions: list[tuple[str, list[EntityNode]]] = []
330
+ per_episode_uuid_maps: list[dict[str, str]] = []
326
331
  duplicate_pairs: list[tuple[str, str]] = []
327
- for _, _, duplicates in bulk_node_resolutions:
328
- for duplicate in duplicates:
329
- n, m = duplicate
330
- duplicate_pairs.append((n.uuid, m.uuid))
331
332
 
332
- # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
333
- compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
333
+ for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
334
+ first_pass_results, episode_tuples, strict=True
335
+ ):
336
+ episode_resolutions.append((episode.uuid, resolved_nodes))
337
+ per_episode_uuid_maps.append(uuid_map)
338
+ duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
339
+
340
+ canonical_nodes: dict[str, EntityNode] = {}
341
+ for _, resolved_nodes in episode_resolutions:
342
+ for node in resolved_nodes:
343
+ # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
344
+ # the MinHash index for the accumulated canonical pool each time. The LRU-backed
345
+ # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
346
+ # but if batches grow significantly we should switch to an incremental index or chunked
347
+ # processing.
348
+ if not canonical_nodes:
349
+ canonical_nodes[node.uuid] = node
350
+ continue
334
351
 
335
- node_uuid_map: dict[str, EntityNode] = {
336
- node.uuid: node for nodes in extracted_nodes for node in nodes
337
- }
352
+ existing_candidates = list(canonical_nodes.values())
353
+ normalized = _normalize_string_exact(node.name)
354
+ exact_match = next(
355
+ (
356
+ candidate
357
+ for candidate in existing_candidates
358
+ if _normalize_string_exact(candidate.name) == normalized
359
+ ),
360
+ None,
361
+ )
362
+ if exact_match is not None:
363
+ if exact_match.uuid != node.uuid:
364
+ duplicate_pairs.append((node.uuid, exact_match.uuid))
365
+ continue
366
+
367
+ indexes = _build_candidate_indexes(existing_candidates)
368
+ state = DedupResolutionState(
369
+ resolved_nodes=[None],
370
+ uuid_map={},
371
+ unresolved_indices=[],
372
+ )
373
+ _resolve_with_similarity([node], indexes, state)
374
+
375
+ resolved = state.resolved_nodes[0]
376
+ if resolved is None:
377
+ canonical_nodes[node.uuid] = node
378
+ continue
379
+
380
+ canonical_uuid = resolved.uuid
381
+ canonical_nodes.setdefault(canonical_uuid, resolved)
382
+ if canonical_uuid != node.uuid:
383
+ duplicate_pairs.append((node.uuid, canonical_uuid))
384
+
385
+ union_pairs: list[tuple[str, str]] = []
386
+ for uuid_map in per_episode_uuid_maps:
387
+ union_pairs.extend(uuid_map.items())
388
+ union_pairs.extend(duplicate_pairs)
389
+
390
+ compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
338
391
 
339
392
  nodes_by_episode: dict[str, list[EntityNode]] = {}
340
- for i, nodes in enumerate(extracted_nodes):
341
- episode = episode_tuples[i][0]
393
+ for episode_uuid, resolved_nodes in episode_resolutions:
394
+ deduped_nodes: list[EntityNode] = []
395
+ seen: set[str] = set()
396
+ for node in resolved_nodes:
397
+ canonical_uuid = compressed_map.get(node.uuid, node.uuid)
398
+ if canonical_uuid in seen:
399
+ continue
400
+ seen.add(canonical_uuid)
401
+ canonical_node = canonical_nodes.get(canonical_uuid)
402
+ if canonical_node is None:
403
+ logger.error(
404
+ 'Canonical node %s missing during batch dedupe; falling back to %s',
405
+ canonical_uuid,
406
+ node.uuid,
407
+ )
408
+ canonical_node = node
409
+ deduped_nodes.append(canonical_node)
342
410
 
343
- nodes_by_episode[episode.uuid] = [
344
- node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
345
- ]
411
+ nodes_by_episode[episode_uuid] = deduped_nodes
346
412
 
347
413
  return nodes_by_episode, compressed_map
348
414
 
@@ -411,6 +477,7 @@ async def dedupe_edges_bulk(
411
477
  candidates,
412
478
  episode,
413
479
  edge_types,
480
+ set(edge_types),
414
481
  clients.ensure_ascii,
415
482
  )
416
483
  for episode, edge, candidates in dedupe_tuples