graphiti-core 0.20.4__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (39) hide show
  1. graphiti_core/driver/driver.py +28 -0
  2. graphiti_core/driver/falkordb_driver.py +112 -0
  3. graphiti_core/driver/kuzu_driver.py +1 -0
  4. graphiti_core/driver/neo4j_driver.py +10 -2
  5. graphiti_core/driver/neptune_driver.py +4 -6
  6. graphiti_core/edges.py +67 -7
  7. graphiti_core/embedder/client.py +2 -1
  8. graphiti_core/graph_queries.py +35 -6
  9. graphiti_core/graphiti.py +27 -23
  10. graphiti_core/graphiti_types.py +0 -1
  11. graphiti_core/helpers.py +2 -2
  12. graphiti_core/llm_client/client.py +19 -4
  13. graphiti_core/llm_client/gemini_client.py +4 -2
  14. graphiti_core/llm_client/openai_base_client.py +3 -2
  15. graphiti_core/llm_client/openai_generic_client.py +3 -2
  16. graphiti_core/models/edges/edge_db_queries.py +36 -16
  17. graphiti_core/models/nodes/node_db_queries.py +30 -10
  18. graphiti_core/nodes.py +126 -25
  19. graphiti_core/prompts/dedupe_edges.py +40 -29
  20. graphiti_core/prompts/dedupe_nodes.py +51 -34
  21. graphiti_core/prompts/eval.py +3 -3
  22. graphiti_core/prompts/extract_edges.py +17 -9
  23. graphiti_core/prompts/extract_nodes.py +10 -9
  24. graphiti_core/prompts/prompt_helpers.py +3 -3
  25. graphiti_core/prompts/summarize_nodes.py +5 -5
  26. graphiti_core/search/search_filters.py +53 -0
  27. graphiti_core/search/search_helpers.py +5 -7
  28. graphiti_core/search/search_utils.py +227 -57
  29. graphiti_core/utils/bulk_utils.py +168 -69
  30. graphiti_core/utils/maintenance/community_operations.py +8 -20
  31. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  32. graphiti_core/utils/maintenance/edge_operations.py +187 -50
  33. graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
  34. graphiti_core/utils/maintenance/node_operations.py +244 -88
  35. graphiti_core/utils/maintenance/temporal_operations.py +0 -4
  36. {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/METADATA +7 -1
  37. {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/RECORD +39 -38
  38. {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/WHEEL +0 -0
  39. {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/licenses/LICENSE +0 -0
@@ -16,13 +16,25 @@ limitations under the License.
16
16
 
17
17
  import copy
18
18
  import logging
19
+ import os
19
20
  from abc import ABC, abstractmethod
20
21
  from collections.abc import Coroutine
21
22
  from enum import Enum
22
23
  from typing import Any
23
24
 
25
+ from dotenv import load_dotenv
26
+
24
27
  logger = logging.getLogger(__name__)
25
28
 
29
+ DEFAULT_SIZE = 10
30
+
31
+ load_dotenv()
32
+
33
+ ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
34
+ EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
35
+ COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
36
+ ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
37
+
26
38
 
27
39
  class GraphProvider(Enum):
28
40
  NEO4J = 'neo4j'
@@ -61,6 +73,7 @@ class GraphDriver(ABC):
61
73
  '' # Neo4j (default) syntax does not require a prefix for fulltext queries
62
74
  )
63
75
  _database: str
76
+ aoss_client: Any # type: ignore
64
77
 
65
78
  @abstractmethod
66
79
  def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@@ -87,3 +100,18 @@ class GraphDriver(ABC):
87
100
  cloned._database = database
88
101
 
89
102
  return cloned
103
+
104
+ def build_fulltext_query(
105
+ self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
106
+ ) -> str:
107
+ """
108
+ Specific fulltext query builder for database providers.
109
+ Only implemented by providers that need custom fulltext query building.
110
+ """
111
+ raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
112
+
113
+ async def save_to_aoss(self, name: str, data: list[dict]) -> int:
114
+ return 0
115
+
116
+ async def clear_aoss_indices(self):
117
+ return 1
@@ -36,6 +36,42 @@ from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
36
36
 
37
37
  logger = logging.getLogger(__name__)
38
38
 
39
+ STOPWORDS = [
40
+ 'a',
41
+ 'is',
42
+ 'the',
43
+ 'an',
44
+ 'and',
45
+ 'are',
46
+ 'as',
47
+ 'at',
48
+ 'be',
49
+ 'but',
50
+ 'by',
51
+ 'for',
52
+ 'if',
53
+ 'in',
54
+ 'into',
55
+ 'it',
56
+ 'no',
57
+ 'not',
58
+ 'of',
59
+ 'on',
60
+ 'or',
61
+ 'such',
62
+ 'that',
63
+ 'their',
64
+ 'then',
65
+ 'there',
66
+ 'these',
67
+ 'they',
68
+ 'this',
69
+ 'to',
70
+ 'was',
71
+ 'will',
72
+ 'with',
73
+ ]
74
+
39
75
 
40
76
  class FalkorDriverSession(GraphDriverSession):
41
77
  provider = GraphProvider.FALKORDB
@@ -74,6 +110,7 @@ class FalkorDriverSession(GraphDriverSession):
74
110
 
75
111
  class FalkorDriver(GraphDriver):
76
112
  provider = GraphProvider.FALKORDB
113
+ aoss_client: None = None
77
114
 
78
115
  def __init__(
79
116
  self,
@@ -166,3 +203,78 @@ class FalkorDriver(GraphDriver):
166
203
  cloned = FalkorDriver(falkor_db=self.client, database=database)
167
204
 
168
205
  return cloned
206
+
207
+ def sanitize(self, query: str) -> str:
208
+ """
209
+ Replace FalkorDB special characters with whitespace.
210
+ Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
211
+ """
212
+ # FalkorDB separator characters that break text into tokens
213
+ separator_map = str.maketrans(
214
+ {
215
+ ',': ' ',
216
+ '.': ' ',
217
+ '<': ' ',
218
+ '>': ' ',
219
+ '{': ' ',
220
+ '}': ' ',
221
+ '[': ' ',
222
+ ']': ' ',
223
+ '"': ' ',
224
+ "'": ' ',
225
+ ':': ' ',
226
+ ';': ' ',
227
+ '!': ' ',
228
+ '@': ' ',
229
+ '#': ' ',
230
+ '$': ' ',
231
+ '%': ' ',
232
+ '^': ' ',
233
+ '&': ' ',
234
+ '*': ' ',
235
+ '(': ' ',
236
+ ')': ' ',
237
+ '-': ' ',
238
+ '+': ' ',
239
+ '=': ' ',
240
+ '~': ' ',
241
+ '?': ' ',
242
+ }
243
+ )
244
+ sanitized = query.translate(separator_map)
245
+ # Clean up multiple spaces
246
+ sanitized = ' '.join(sanitized.split())
247
+ return sanitized
248
+
249
+ def build_fulltext_query(
250
+ self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
251
+ ) -> str:
252
+ """
253
+ Build a fulltext query string for FalkorDB using RedisSearch syntax.
254
+ FalkorDB uses RedisSearch-like syntax where:
255
+ - Field queries use @ prefix: @field:value
256
+ - Multiple values for same field: (@field:value1|value2)
257
+ - Text search doesn't need @ prefix for content fields
258
+ - AND is implicit with space: (@group_id:value) (text)
259
+ - OR uses pipe within parentheses: (@group_id:value1|value2)
260
+ """
261
+ if group_ids is None or len(group_ids) == 0:
262
+ group_filter = ''
263
+ else:
264
+ group_values = '|'.join(group_ids)
265
+ group_filter = f'(@group_id:{group_values})'
266
+
267
+ sanitized_query = self.sanitize(query)
268
+
269
+ # Remove stopwords from the sanitized query
270
+ query_words = sanitized_query.split()
271
+ filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
272
+ sanitized_query = ' | '.join(filtered_words)
273
+
274
+ # If the query is too long return no query
275
+ if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
276
+ return ''
277
+
278
+ full_query = group_filter + ' (' + sanitized_query + ')'
279
+
280
+ return full_query
@@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
92
92
 
93
93
  class KuzuDriver(GraphDriver):
94
94
  provider: GraphProvider = GraphProvider.KUZU
95
+ aoss_client: None = None
95
96
 
96
97
  def __init__(
97
98
  self,
@@ -29,7 +29,13 @@ logger = logging.getLogger(__name__)
29
29
  class Neo4jDriver(GraphDriver):
30
30
  provider = GraphProvider.NEO4J
31
31
 
32
- def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
32
+ def __init__(
33
+ self,
34
+ uri: str,
35
+ user: str | None,
36
+ password: str | None,
37
+ database: str = 'neo4j',
38
+ ):
33
39
  super().__init__()
34
40
  self.client = AsyncGraphDatabase.driver(
35
41
  uri=uri,
@@ -37,6 +43,8 @@ class Neo4jDriver(GraphDriver):
37
43
  )
38
44
  self._database = database
39
45
 
46
+ self.aoss_client = None
47
+
40
48
  async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
41
49
  # Check if database_ is provided in kwargs.
42
50
  # If not populated, set the value to retain backwards compatibility
@@ -60,7 +68,7 @@ class Neo4jDriver(GraphDriver):
60
68
  async def close(self) -> None:
61
69
  return await self.client.close()
62
70
 
63
- def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
71
+ def delete_all_indexes(self) -> Coroutine:
64
72
  return self.client.execute_query(
65
73
  'CALL db.indexes() YIELD name DROP INDEX name',
66
74
  )
@@ -257,15 +257,13 @@ class NeptuneDriver(GraphDriver):
257
257
  if name.lower() == index['index_name']:
258
258
  to_index = []
259
259
  for d in data:
260
- item = {'_index': name}
260
+ item = {'_index': name, '_id': d['uuid']}
261
261
  for p in index['body']['mappings']['properties']:
262
- item[p] = d[p]
262
+ if p in d:
263
+ item[p] = d[p]
263
264
  to_index.append(item)
264
265
  success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
265
- if failed > 0:
266
- return success
267
- else:
268
- return 0
266
+ return success
269
267
 
270
268
  return 0
271
269
 
graphiti_core/edges.py CHANGED
@@ -25,7 +25,7 @@ from uuid import uuid4
25
25
  from pydantic import BaseModel, Field
26
26
  from typing_extensions import LiteralString
27
27
 
28
- from graphiti_core.driver.driver import GraphDriver, GraphProvider
28
+ from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
31
31
  from graphiti_core.helpers import parse_db_date
@@ -77,6 +77,13 @@ class Edge(BaseModel, ABC):
77
77
  uuid=self.uuid,
78
78
  )
79
79
 
80
+ if driver.aoss_client:
81
+ await driver.aoss_client.delete(
82
+ index=ENTITY_EDGE_INDEX_NAME,
83
+ id=self.uuid,
84
+ params={'routing': self.group_id},
85
+ )
86
+
80
87
  logger.debug(f'Deleted Edge: {self.uuid}')
81
88
 
82
89
  @classmethod
@@ -108,6 +115,12 @@ class Edge(BaseModel, ABC):
108
115
  uuids=uuids,
109
116
  )
110
117
 
118
+ if driver.aoss_client:
119
+ await driver.aoss_client.delete_by_query(
120
+ index=ENTITY_EDGE_INDEX_NAME,
121
+ body={'query': {'terms': {'uuid': uuids}}},
122
+ )
123
+
111
124
  logger.debug(f'Deleted Edges: {uuids}')
112
125
 
113
126
  def __hash__(self):
@@ -255,6 +268,21 @@ class EntityEdge(Edge):
255
268
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
256
269
  RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
257
270
  """
271
+ elif driver.aoss_client:
272
+ resp = await driver.aoss_client.search(
273
+ body={
274
+ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
275
+ 'size': 1,
276
+ },
277
+ index=ENTITY_EDGE_INDEX_NAME,
278
+ params={'routing': self.group_id},
279
+ )
280
+
281
+ if resp['hits']['hits']:
282
+ self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
283
+ return
284
+ else:
285
+ raise EdgeNotFoundError(self.uuid)
258
286
 
259
287
  if driver.provider == GraphProvider.KUZU:
260
288
  query = """
@@ -292,14 +320,14 @@ class EntityEdge(Edge):
292
320
  if driver.provider == GraphProvider.KUZU:
293
321
  edge_data['attributes'] = json.dumps(self.attributes)
294
322
  result = await driver.execute_query(
295
- get_entity_edge_save_query(driver.provider),
323
+ get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
296
324
  **edge_data,
297
325
  )
298
326
  else:
299
327
  edge_data.update(self.attributes or {})
300
328
 
301
- if driver.provider == GraphProvider.NEPTUNE:
302
- driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
329
+ if driver.aoss_client:
330
+ await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
303
331
 
304
332
  result = await driver.execute_query(
305
333
  get_entity_edge_save_query(driver.provider),
@@ -336,6 +364,35 @@ class EntityEdge(Edge):
336
364
  raise EdgeNotFoundError(uuid)
337
365
  return edges[0]
338
366
 
367
+ @classmethod
368
+ async def get_between_nodes(
369
+ cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
370
+ ):
371
+ match_query = """
372
+ MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
373
+ """
374
+ if driver.provider == GraphProvider.KUZU:
375
+ match_query = """
376
+ MATCH (n:Entity {uuid: $source_node_uuid})
377
+ -[:RELATES_TO]->(e:RelatesToNode_)
378
+ -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
379
+ """
380
+
381
+ records, _, _ = await driver.execute_query(
382
+ match_query
383
+ + """
384
+ RETURN
385
+ """
386
+ + get_entity_edge_return_query(driver.provider),
387
+ source_node_uuid=source_node_uuid,
388
+ target_node_uuid=target_node_uuid,
389
+ routing_='r',
390
+ )
391
+
392
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
393
+
394
+ return edges
395
+
339
396
  @classmethod
340
397
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
341
398
  if len(uuids) == 0:
@@ -587,8 +644,11 @@ def get_community_edge_from_record(record: Any):
587
644
 
588
645
 
589
646
  async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
590
- if len(edges) == 0:
647
+ # filter out falsey values from edges
648
+ filtered_edges = [edge for edge in edges if edge.fact]
649
+
650
+ if len(filtered_edges) == 0:
591
651
  return
592
- fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
593
- for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
652
+ fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
653
+ for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
594
654
  edge.fact_embedding = fact_embedding
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import os
17
18
  from abc import ABC, abstractmethod
18
19
  from collections.abc import Iterable
19
20
 
20
21
  from pydantic import BaseModel, Field
21
22
 
22
- EMBEDDING_DIM = 1024
23
+ EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
23
24
 
24
25
 
25
26
  class EmbedderConfig(BaseModel):
@@ -71,12 +71,41 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
71
71
 
72
72
  def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
73
73
  if provider == GraphProvider.FALKORDB:
74
- return [
75
- """CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
76
- """CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
77
- """CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
78
- """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
79
- ]
74
+ from typing import cast
75
+
76
+ from graphiti_core.driver.falkordb_driver import STOPWORDS
77
+
78
+ # Convert to string representation for embedding in queries
79
+ stopwords_str = str(STOPWORDS)
80
+
81
+ # Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
82
+ return cast(
83
+ list[LiteralString],
84
+ [
85
+ f"""CALL db.idx.fulltext.createNodeIndex(
86
+ {{
87
+ label: 'Episodic',
88
+ stopwords: {stopwords_str}
89
+ }},
90
+ 'content', 'source', 'source_description', 'group_id'
91
+ )""",
92
+ f"""CALL db.idx.fulltext.createNodeIndex(
93
+ {{
94
+ label: 'Entity',
95
+ stopwords: {stopwords_str}
96
+ }},
97
+ 'name', 'summary', 'group_id'
98
+ )""",
99
+ f"""CALL db.idx.fulltext.createNodeIndex(
100
+ {{
101
+ label: 'Community',
102
+ stopwords: {stopwords_str}
103
+ }},
104
+ 'name', 'group_id'
105
+ )""",
106
+ """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
107
+ ],
108
+ )
80
109
 
81
110
  if provider == GraphProvider.KUZU:
82
111
  return [
graphiti_core/graphiti.py CHANGED
@@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
60
60
  from graphiti_core.search.search_filters import SearchFilters
61
61
  from graphiti_core.search.search_utils import (
62
62
  RELEVANT_SCHEMA_LIMIT,
63
- get_edge_invalidation_candidates,
64
63
  get_mentioned_nodes,
65
- get_relevant_edges,
66
64
  )
67
65
  from graphiti_core.telemetry import capture_event
68
66
  from graphiti_core.utils.bulk_utils import (
@@ -81,7 +79,6 @@ from graphiti_core.utils.maintenance.community_operations import (
81
79
  update_community,
82
80
  )
83
81
  from graphiti_core.utils.maintenance.edge_operations import (
84
- build_duplicate_of_edges,
85
82
  build_episodic_edges,
86
83
  extract_edges,
87
84
  resolve_extracted_edge,
@@ -139,7 +136,6 @@ class Graphiti:
139
136
  store_raw_episode_content: bool = True,
140
137
  graph_driver: GraphDriver | None = None,
141
138
  max_coroutines: int | None = None,
142
- ensure_ascii: bool = False,
143
139
  ):
144
140
  """
145
141
  Initialize a Graphiti instance.
@@ -172,10 +168,6 @@ class Graphiti:
172
168
  max_coroutines : int | None, optional
173
169
  The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
174
170
  If not set, the Graphiti default is used.
175
- ensure_ascii : bool, optional
176
- Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False.
177
- Set as False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their
178
- original form, making them readable in LLM logs and improving model understanding.
179
171
 
180
172
  Returns
181
173
  -------
@@ -205,7 +197,6 @@ class Graphiti:
205
197
 
206
198
  self.store_raw_episode_content = store_raw_episode_content
207
199
  self.max_coroutines = max_coroutines
208
- self.ensure_ascii = ensure_ascii
209
200
  if llm_client:
210
201
  self.llm_client = llm_client
211
202
  else:
@@ -224,7 +215,6 @@ class Graphiti:
224
215
  llm_client=self.llm_client,
225
216
  embedder=self.embedder,
226
217
  cross_encoder=self.cross_encoder,
227
- ensure_ascii=self.ensure_ascii,
228
218
  )
229
219
 
230
220
  # Capture telemetry event
@@ -458,12 +448,12 @@ class Graphiti:
458
448
  start = time()
459
449
  now = utc_now()
460
450
 
461
- # if group_id is None, use the default group id by the provider
462
- group_id = group_id or get_default_group_id(self.driver.provider)
463
451
  validate_entity_types(entity_types)
464
452
 
465
453
  validate_excluded_entity_types(excluded_entity_types, entity_types)
466
454
  validate_group_id(group_id)
455
+ # if group_id is None, use the default group id by the provider
456
+ group_id = group_id or get_default_group_id(self.driver.provider)
467
457
 
468
458
  previous_episodes = (
469
459
  await self.retrieve_episodes(
@@ -505,7 +495,7 @@ class Graphiti:
505
495
  )
506
496
 
507
497
  # Extract edges and resolve nodes
508
- (nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
498
+ (nodes, uuid_map, _), extracted_edges = await semaphore_gather(
509
499
  resolve_extracted_nodes(
510
500
  self.clients,
511
501
  extracted_nodes,
@@ -542,9 +532,7 @@ class Graphiti:
542
532
  max_coroutines=self.max_coroutines,
543
533
  )
544
534
 
545
- duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
546
-
547
- entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
535
+ entity_edges = resolved_edges + invalidated_edges
548
536
 
549
537
  episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
550
538
 
@@ -564,9 +552,7 @@ class Graphiti:
564
552
  if update_communities:
565
553
  communities, community_edges = await semaphore_gather(
566
554
  *[
567
- update_community(
568
- self.driver, self.llm_client, self.embedder, node, self.ensure_ascii
569
- )
555
+ update_community(self.driver, self.llm_client, self.embedder, node)
570
556
  for node in nodes
571
557
  ],
572
558
  max_coroutines=self.max_coroutines,
@@ -1037,10 +1023,28 @@ class Graphiti:
1037
1023
 
1038
1024
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
1039
1025
 
1040
- related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
1026
+ valid_edges = await EntityEdge.get_between_nodes(
1027
+ self.driver, edge.source_node_uuid, edge.target_node_uuid
1028
+ )
1029
+
1030
+ related_edges = (
1031
+ await search(
1032
+ self.clients,
1033
+ updated_edge.fact,
1034
+ group_ids=[updated_edge.group_id],
1035
+ config=EDGE_HYBRID_SEARCH_RRF,
1036
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1037
+ )
1038
+ ).edges
1041
1039
  existing_edges = (
1042
- await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
1043
- )[0]
1040
+ await search(
1041
+ self.clients,
1042
+ updated_edge.fact,
1043
+ group_ids=[updated_edge.group_id],
1044
+ config=EDGE_HYBRID_SEARCH_RRF,
1045
+ search_filter=SearchFilters(),
1046
+ )
1047
+ ).edges
1044
1048
 
1045
1049
  resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
1046
1050
  self.llm_client,
@@ -1057,7 +1061,7 @@ class Graphiti:
1057
1061
  group_id=edge.group_id,
1058
1062
  ),
1059
1063
  None,
1060
- self.ensure_ascii,
1064
+ None,
1061
1065
  )
1062
1066
 
1063
1067
  edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
@@ -27,6 +27,5 @@ class GraphitiClients(BaseModel):
27
27
  llm_client: LLMClient
28
28
  embedder: EmbedderClient
29
29
  cross_encoder: CrossEncoderClient
30
- ensure_ascii: bool = False
31
30
 
32
31
  model_config = ConfigDict(arbitrary_types_allowed=True)
graphiti_core/helpers.py CHANGED
@@ -54,7 +54,7 @@ def get_default_group_id(provider: GraphProvider) -> str:
54
54
  For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
55
55
  """
56
56
  if provider == GraphProvider.FALKORDB:
57
- return '_'
57
+ return '\\_'
58
58
  else:
59
59
  return ''
60
60
 
@@ -116,7 +116,7 @@ async def semaphore_gather(
116
116
  return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
117
117
 
118
118
 
119
- def validate_group_id(group_id: str) -> bool:
119
+ def validate_group_id(group_id: str | None) -> bool:
120
120
  """
121
121
  Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
122
122
 
@@ -32,9 +32,23 @@ from .errors import RateLimitError
32
32
  DEFAULT_TEMPERATURE = 0
33
33
  DEFAULT_CACHE_DIR = './llm_cache'
34
34
 
35
- MULTILINGUAL_EXTRACTION_RESPONSES = (
36
- '\n\nAny extracted information should be returned in the same language as it was written in.'
37
- )
35
+
36
+ def get_extraction_language_instruction(group_id: str | None = None) -> str:
37
+ """Returns instruction for language extraction behavior.
38
+
39
+ Override this function to customize language extraction:
40
+ - Return empty string to disable multilingual instructions
41
+ - Return custom instructions for specific language requirements
42
+ - Use group_id to provide different instructions per group/partition
43
+
44
+ Args:
45
+ group_id: Optional partition identifier for the graph
46
+
47
+ Returns:
48
+ str: Language instruction to append to system messages
49
+ """
50
+ return '\n\nAny extracted information should be returned in the same language as it was written in.'
51
+
38
52
 
39
53
  logger = logging.getLogger(__name__)
40
54
 
@@ -132,6 +146,7 @@ class LLMClient(ABC):
132
146
  response_model: type[BaseModel] | None = None,
133
147
  max_tokens: int | None = None,
134
148
  model_size: ModelSize = ModelSize.medium,
149
+ group_id: str | None = None,
135
150
  ) -> dict[str, typing.Any]:
136
151
  if max_tokens is None:
137
152
  max_tokens = self.max_tokens
@@ -145,7 +160,7 @@ class LLMClient(ABC):
145
160
  )
146
161
 
147
162
  # Add multilingual extraction instructions
148
- messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
163
+ messages[0].content += get_extraction_language_instruction(group_id)
149
164
 
150
165
  if self.cache_enabled and self.cache_dir is not None:
151
166
  cache_key = self._get_cache_key(messages)
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, ClassVar
23
23
  from pydantic import BaseModel
24
24
 
25
25
  from ..prompts.models import Message
26
- from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
26
+ from .client import LLMClient, get_extraction_language_instruction
27
27
  from .config import LLMConfig, ModelSize
28
28
  from .errors import RateLimitError
29
29
 
@@ -357,6 +357,7 @@ class GeminiClient(LLMClient):
357
357
  response_model: type[BaseModel] | None = None,
358
358
  max_tokens: int | None = None,
359
359
  model_size: ModelSize = ModelSize.medium,
360
+ group_id: str | None = None,
360
361
  ) -> dict[str, typing.Any]:
361
362
  """
362
363
  Generate a response from the Gemini language model with retry logic and error handling.
@@ -367,6 +368,7 @@ class GeminiClient(LLMClient):
367
368
  response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
368
369
  max_tokens (int | None): The maximum number of tokens to generate in the response.
369
370
  model_size (ModelSize): The size of the model to use (small or medium).
371
+ group_id (str | None): Optional partition identifier for the graph.
370
372
 
371
373
  Returns:
372
374
  dict[str, typing.Any]: The response from the language model.
@@ -376,7 +378,7 @@ class GeminiClient(LLMClient):
376
378
  last_output = None
377
379
 
378
380
  # Add multilingual extraction instructions
379
- messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
381
+ messages[0].content += get_extraction_language_instruction(group_id)
380
382
 
381
383
  while retry_count < self.MAX_RETRIES:
382
384
  try: