graphiti-core 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (41) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +1 -2
  2. graphiti_core/cross_encoder/client.py +3 -4
  3. graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
  4. graphiti_core/edges.py +56 -7
  5. graphiti_core/embedder/client.py +3 -3
  6. graphiti_core/embedder/openai.py +2 -2
  7. graphiti_core/embedder/voyage.py +3 -3
  8. graphiti_core/graphiti.py +39 -37
  9. graphiti_core/helpers.py +26 -0
  10. graphiti_core/llm_client/anthropic_client.py +4 -1
  11. graphiti_core/llm_client/client.py +45 -5
  12. graphiti_core/llm_client/errors.py +8 -0
  13. graphiti_core/llm_client/groq_client.py +4 -1
  14. graphiti_core/llm_client/openai_client.py +71 -7
  15. graphiti_core/llm_client/openai_generic_client.py +163 -0
  16. graphiti_core/nodes.py +58 -8
  17. graphiti_core/prompts/dedupe_edges.py +20 -17
  18. graphiti_core/prompts/dedupe_nodes.py +15 -1
  19. graphiti_core/prompts/eval.py +17 -14
  20. graphiti_core/prompts/extract_edge_dates.py +15 -7
  21. graphiti_core/prompts/extract_edges.py +18 -19
  22. graphiti_core/prompts/extract_nodes.py +11 -21
  23. graphiti_core/prompts/invalidate_edges.py +13 -25
  24. graphiti_core/prompts/lib.py +5 -1
  25. graphiti_core/prompts/prompt_helpers.py +1 -0
  26. graphiti_core/prompts/summarize_nodes.py +17 -16
  27. graphiti_core/search/search.py +5 -5
  28. graphiti_core/search/search_utils.py +55 -14
  29. graphiti_core/utils/__init__.py +0 -15
  30. graphiti_core/utils/bulk_utils.py +22 -15
  31. graphiti_core/utils/datetime_utils.py +42 -0
  32. graphiti_core/utils/maintenance/community_operations.py +13 -9
  33. graphiti_core/utils/maintenance/edge_operations.py +32 -26
  34. graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
  35. graphiti_core/utils/maintenance/node_operations.py +19 -13
  36. graphiti_core/utils/maintenance/temporal_operations.py +17 -9
  37. {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/METADATA +1 -1
  38. graphiti_core-0.5.0.dist-info/RECORD +60 -0
  39. graphiti_core-0.4.2.dist-info/RECORD +0 -57
  40. {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/LICENSE +0 -0
  41. {graphiti_core-0.4.2.dist-info → graphiti_core-0.5.0.dist-info}/WHEEL +0 -0
@@ -17,9 +17,26 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class Edge(BaseModel):
26
+ relation_type: str = Field(..., description='RELATION_TYPE_IN_CAPS')
27
+ source_entity_name: str = Field(..., description='name of the source entity')
28
+ target_entity_name: str = Field(..., description='name of the target entity')
29
+ fact: str = Field(..., description='extracted factual information')
30
+
31
+
32
+ class ExtractedEdges(BaseModel):
33
+ edges: list[Edge]
34
+
35
+
36
+ class MissingFacts(BaseModel):
37
+ missing_facts: list[str] = Field(..., description="facts that weren't extracted")
38
+
39
+
23
40
  class Prompt(Protocol):
24
41
  edge: PromptVersion
25
42
  reflexion: PromptVersion
@@ -54,25 +71,12 @@ def edge(context: dict[str, Any]) -> list[Message]:
54
71
 
55
72
  Given the above MESSAGES and ENTITIES, extract all facts pertaining to the listed ENTITIES from the CURRENT MESSAGE.
56
73
 
57
-
58
74
  Guidelines:
59
75
  1. Extract facts only between the provided entities.
60
76
  2. Each fact should represent a clear relationship between two DISTINCT nodes.
61
77
  3. The relation_type should be a concise, all-caps description of the fact (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
62
78
  4. Provide a more detailed fact containing all relevant information.
63
79
  5. Consider temporal aspects of relationships when relevant.
64
-
65
- Respond with a JSON object in the following format:
66
- {{
67
- "edges": [
68
- {{
69
- "relation_type": "RELATION_TYPE_IN_CAPS",
70
- "source_entity_name": "name of the source entity",
71
- "target_entity_name": "name of the target entity",
72
- "fact": "extracted factual information",
73
- }}
74
- ]
75
- }}
76
80
  """,
77
81
  ),
78
82
  ]
@@ -98,12 +102,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
98
102
  </EXTRACTED FACTS>
99
103
 
100
104
  Given the above MESSAGES, list of EXTRACTED ENTITIES entities, and list of EXTRACTED FACTS;
101
- determine if any facts haven't been extracted:
102
-
103
- Respond with a JSON object in the following format:
104
- {{
105
- "missing_facts": [ "facts that weren't extracted", ...]
106
- }}
105
+ determine if any facts haven't been extracted.
107
106
  """
108
107
  return [
109
108
  Message(role='system', content=sys_prompt),
@@ -17,9 +17,19 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class ExtractedNodes(BaseModel):
26
+ extracted_node_names: list[str] = Field(..., description='Name of the extracted entity')
27
+
28
+
29
+ class MissedEntities(BaseModel):
30
+ missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
31
+
32
+
23
33
  class Prompt(Protocol):
24
34
  extract_message: PromptVersion
25
35
  extract_json: PromptVersion
@@ -56,11 +66,6 @@ Guidelines:
56
66
  4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
57
67
  5. Be as explicit as possible in your node names, using full names.
58
68
  6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
59
-
60
- Respond with a JSON object in the following format:
61
- {{
62
- "extracted_node_names": ["Name of the extracted entity", ...],
63
- }}
64
69
  """
65
70
  return [
66
71
  Message(role='system', content=sys_prompt),
@@ -87,11 +92,6 @@ Given the above source description and JSON, extract relevant entity nodes from
87
92
  Guidelines:
88
93
  1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
89
94
  2. Do NOT extract any properties that contain dates
90
-
91
- Respond with a JSON object in the following format:
92
- {{
93
- "extracted_node_names": ["Name of the extracted entity", ...],
94
- }}
95
95
  """
96
96
  return [
97
97
  Message(role='system', content=sys_prompt),
@@ -116,11 +116,6 @@ Guidelines:
116
116
  2. Avoid creating nodes for relationships or actions.
117
117
  3. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
118
118
  4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
119
-
120
- Respond with a JSON object in the following format:
121
- {{
122
- "extracted_node_names": ["Name of the extracted entity", ...],
123
- }}
124
119
  """
125
120
  return [
126
121
  Message(role='system', content=sys_prompt),
@@ -144,12 +139,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
144
139
  </EXTRACTED ENTITIES>
145
140
 
146
141
  Given the above previous messages, current message, and list of extracted entities; determine if any entities haven't been
147
- extracted:
148
-
149
- Respond with a JSON object in the following format:
150
- {{
151
- "missed_entities": [ "name of entity that wasn't extracted", ...]
152
- }}
142
+ extracted.
153
143
  """
154
144
  return [
155
145
  Message(role='system', content=sys_prompt),
@@ -16,9 +16,22 @@ limitations under the License.
16
16
 
17
17
  from typing import Any, Protocol, TypedDict
18
18
 
19
+ from pydantic import BaseModel, Field
20
+
19
21
  from .models import Message, PromptFunction, PromptVersion
20
22
 
21
23
 
24
+ class InvalidatedEdge(BaseModel):
25
+ uuid: str = Field(..., description='The UUID of the edge to be invalidated')
26
+ fact: str = Field(..., description='Updated fact of the edge')
27
+
28
+
29
+ class InvalidatedEdges(BaseModel):
30
+ invalidated_edges: list[InvalidatedEdge] = Field(
31
+ ..., description='List of edges that should be invalidated'
32
+ )
33
+
34
+
22
35
  class Prompt(Protocol):
23
36
  v1: PromptVersion
24
37
  v2: PromptVersion
@@ -56,18 +69,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
56
69
  {context['new_edges']}
57
70
 
58
71
  Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), START_DATE (END_DATE, optional))"
59
-
60
- For each existing edge that should be invalidated, respond with a JSON object in the following format:
61
- {{
62
- "invalidated_edges": [
63
- {{
64
- "edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
65
- "fact": "Updated fact of the edge"
66
- }}
67
- ]
68
- }}
69
-
70
- If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
71
72
  """,
72
73
  ),
73
74
  ]
@@ -89,19 +90,6 @@ def v2(context: dict[str, Any]) -> list[Message]:
89
90
 
90
91
  New Edge:
91
92
  {context['new_edge']}
92
-
93
-
94
- For each existing edge that should be invalidated, respond with a JSON object in the following format:
95
- {{
96
- "invalidated_edges": [
97
- {{
98
- "uuid": "The UUID of the edge to be invalidated",
99
- "fact": "Updated fact of the edge"
100
- }}
101
- ]
102
- }}
103
-
104
- If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
105
93
  """,
106
94
  ),
107
95
  ]
@@ -74,6 +74,7 @@ from .invalidate_edges import (
74
74
  versions as invalidate_edges_versions,
75
75
  )
76
76
  from .models import Message, PromptFunction
77
+ from .prompt_helpers import DO_NOT_ESCAPE_UNICODE
77
78
  from .summarize_nodes import Prompt as SummarizeNodesPrompt
78
79
  from .summarize_nodes import Versions as SummarizeNodesVersions
79
80
  from .summarize_nodes import versions as summarize_nodes_versions
@@ -106,7 +107,10 @@ class VersionWrapper:
106
107
  self.func = func
107
108
 
108
109
  def __call__(self, context: dict[str, Any]) -> list[Message]:
109
- return self.func(context)
110
+ messages = self.func(context)
111
+ for message in messages:
112
+ message.content += DO_NOT_ESCAPE_UNICODE if message.role == 'system' else ''
113
+ return messages
110
114
 
111
115
 
112
116
  class PromptTypeWrapper:
@@ -0,0 +1 @@
1
+ DO_NOT_ESCAPE_UNICODE = '\nDo not escape unicode characters.\n'
@@ -17,9 +17,21 @@ limitations under the License.
17
17
  import json
18
18
  from typing import Any, Protocol, TypedDict
19
19
 
20
+ from pydantic import BaseModel, Field
21
+
20
22
  from .models import Message, PromptFunction, PromptVersion
21
23
 
22
24
 
25
+ class Summary(BaseModel):
26
+ summary: str = Field(
27
+ ..., description='Summary containing the important information from both summaries'
28
+ )
29
+
30
+
31
+ class SummaryDescription(BaseModel):
32
+ description: str = Field(..., description='One sentence description of the provided summary')
33
+
34
+
23
35
  class Prompt(Protocol):
24
36
  summarize_pair: PromptVersion
25
37
  summarize_context: PromptVersion
@@ -42,14 +54,11 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
42
54
  role='user',
43
55
  content=f"""
44
56
  Synthesize the information from the following two summaries into a single succinct summary.
57
+
58
+ Summaries must be under 500 words.
45
59
 
46
60
  Summaries:
47
61
  {json.dumps(context['node_summaries'], indent=2)}
48
-
49
- Respond with a JSON object in the following format:
50
- {{
51
- "summary": "Summary containing the important information from both summaries"
52
- }}
53
62
  """,
54
63
  ),
55
64
  ]
@@ -74,15 +83,11 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
74
83
  information from the provided MESSAGES. Your summary should also only contain information relevant to the
75
84
  provided ENTITY.
76
85
 
86
+ Summaries must be under 500 words.
87
+
77
88
  <ENTITY>
78
89
  {context['node_name']}
79
90
  </ENTITY>
80
-
81
-
82
- Respond with a JSON object in the following format:
83
- {{
84
- "summary": "Entity summary"
85
- }}
86
91
  """,
87
92
  ),
88
93
  ]
@@ -98,14 +103,10 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
98
103
  role='user',
99
104
  content=f"""
100
105
  Create a short one sentence description of the summary that explains what kind of information is summarized.
106
+ Summaries must be under 500 words.
101
107
 
102
108
  Summary:
103
109
  {json.dumps(context['summary'], indent=2)}
104
-
105
- Respond with a JSON object in the following format:
106
- {{
107
- "description": "One sentence description of the provided summary"
108
- }}
109
110
  """,
110
111
  ),
111
112
  ]
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from collections import defaultdict
20
19
  from time import time
@@ -25,6 +24,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
25
24
  from graphiti_core.edges import EntityEdge
26
25
  from graphiti_core.embedder import EmbedderClient
27
26
  from graphiti_core.errors import SearchRerankerError
27
+ from graphiti_core.helpers import semaphore_gather
28
28
  from graphiti_core.nodes import CommunityNode, EntityNode
29
29
  from graphiti_core.search.search_config import (
30
30
  DEFAULT_SEARCH_LIMIT,
@@ -78,7 +78,7 @@ async def search(
78
78
 
79
79
  # if group_ids is empty, set it to None
80
80
  group_ids = group_ids if group_ids else None
81
- edges, nodes, communities = await asyncio.gather(
81
+ edges, nodes, communities = await semaphore_gather(
82
82
  edge_search(
83
83
  driver,
84
84
  cross_encoder,
@@ -141,7 +141,7 @@ async def edge_search(
141
141
  return []
142
142
 
143
143
  search_results: list[list[EntityEdge]] = list(
144
- await asyncio.gather(
144
+ await semaphore_gather(
145
145
  *[
146
146
  edge_fulltext_search(driver, query, group_ids, 2 * limit),
147
147
  edge_similarity_search(
@@ -226,7 +226,7 @@ async def node_search(
226
226
  return []
227
227
 
228
228
  search_results: list[list[EntityNode]] = list(
229
- await asyncio.gather(
229
+ await semaphore_gather(
230
230
  *[
231
231
  node_fulltext_search(driver, query, group_ids, 2 * limit),
232
232
  node_similarity_search(
@@ -295,7 +295,7 @@ async def community_search(
295
295
  return []
296
296
 
297
297
  search_results: list[list[CommunityNode]] = list(
298
- await asyncio.gather(
298
+ await semaphore_gather(
299
299
  *[
300
300
  community_fulltext_search(driver, query, group_ids, 2 * limit),
301
301
  community_similarity_search(
@@ -14,10 +14,10 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from collections import defaultdict
20
19
  from time import time
20
+ from typing import Any
21
21
 
22
22
  import numpy as np
23
23
  from neo4j import AsyncDriver, Query
@@ -29,6 +29,7 @@ from graphiti_core.helpers import (
29
29
  USE_PARALLEL_RUNTIME,
30
30
  lucene_sanitize,
31
31
  normalize_l2,
32
+ semaphore_gather,
32
33
  )
33
34
  from graphiti_core.nodes import (
34
35
  CommunityNode,
@@ -40,7 +41,7 @@ from graphiti_core.nodes import (
40
41
 
41
42
  logger = logging.getLogger(__name__)
42
43
 
43
- RELEVANT_SCHEMA_LIMIT = 3
44
+ RELEVANT_SCHEMA_LIMIT = 10
44
45
  DEFAULT_MIN_SCORE = 0.6
45
46
  DEFAULT_MMR_LAMBDA = 0.5
46
47
  MAX_SEARCH_DEPTH = 3
@@ -191,12 +192,27 @@ async def edge_similarity_search(
191
192
  'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
192
193
  )
193
194
 
194
- query: LiteralString = """
195
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
196
- WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
197
- AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
198
- AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
199
- WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
195
+ query_params: dict[str, Any] = {}
196
+
197
+ group_filter_query: LiteralString = ''
198
+ if group_ids is not None:
199
+ group_filter_query += 'WHERE r.group_id IN $group_ids'
200
+ query_params['group_ids'] = group_ids
201
+ query_params['source_node_uuid'] = source_node_uuid
202
+ query_params['target_node_uuid'] = target_node_uuid
203
+
204
+ if source_node_uuid is not None:
205
+ group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
206
+
207
+ if target_node_uuid is not None:
208
+ group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
209
+
210
+ query: LiteralString = (
211
+ """
212
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
213
+ """
214
+ + group_filter_query
215
+ + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
200
216
  WHERE score > $min_score
201
217
  RETURN
202
218
  r.uuid AS uuid,
@@ -214,9 +230,11 @@ async def edge_similarity_search(
214
230
  ORDER BY score DESC
215
231
  LIMIT $limit
216
232
  """
233
+ )
217
234
 
218
235
  records, _, _ = await driver.execute_query(
219
236
  runtime_query + query,
237
+ query_params,
220
238
  search_vector=search_vector,
221
239
  source_uuid=source_node_uuid,
222
240
  target_uuid=target_node_uuid,
@@ -325,11 +343,20 @@ async def node_similarity_search(
325
343
  'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
326
344
  )
327
345
 
346
+ query_params: dict[str, Any] = {}
347
+
348
+ group_filter_query: LiteralString = ''
349
+ if group_ids is not None:
350
+ group_filter_query += 'WHERE n.group_id IN $group_ids'
351
+ query_params['group_ids'] = group_ids
352
+
328
353
  records, _, _ = await driver.execute_query(
329
354
  runtime_query
330
355
  + """
331
356
  MATCH (n:Entity)
332
- WHERE $group_ids IS NULL OR n.group_id IN $group_ids
357
+ """
358
+ + group_filter_query
359
+ + """
333
360
  WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
334
361
  WHERE score > $min_score
335
362
  RETURN
@@ -342,6 +369,7 @@ async def node_similarity_search(
342
369
  ORDER BY score DESC
343
370
  LIMIT $limit
344
371
  """,
372
+ query_params,
345
373
  search_vector=search_vector,
346
374
  group_ids=group_ids,
347
375
  limit=limit,
@@ -436,11 +464,20 @@ async def community_similarity_search(
436
464
  'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
437
465
  )
438
466
 
467
+ query_params: dict[str, Any] = {}
468
+
469
+ group_filter_query: LiteralString = ''
470
+ if group_ids is not None:
471
+ group_filter_query += 'WHERE comm.group_id IN $group_ids'
472
+ query_params['group_ids'] = group_ids
473
+
439
474
  records, _, _ = await driver.execute_query(
440
475
  runtime_query
441
476
  + """
442
477
  MATCH (comm:Community)
443
- WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
478
+ """
479
+ + group_filter_query
480
+ + """
444
481
  WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
445
482
  WHERE score > $min_score
446
483
  RETURN
@@ -512,7 +549,7 @@ async def hybrid_node_search(
512
549
 
513
550
  start = time()
514
551
  results: list[list[EntityNode]] = list(
515
- await asyncio.gather(
552
+ await semaphore_gather(
516
553
  *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
517
554
  *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
518
555
  )
@@ -582,7 +619,7 @@ async def get_relevant_edges(
582
619
  relevant_edges: list[EntityEdge] = []
583
620
  relevant_edge_uuids = set()
584
621
 
585
- results = await asyncio.gather(
622
+ results = await semaphore_gather(
586
623
  *[
587
624
  edge_similarity_search(
588
625
  driver,
@@ -631,7 +668,7 @@ async def node_distance_reranker(
631
668
  ) -> list[str]:
632
669
  # filter out node_uuid center node node uuid
633
670
  filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
634
- scores: dict[str, float] = {}
671
+ scores: dict[str, float] = {center_node_uuid: 0.0}
635
672
 
636
673
  # Find the shortest path to center node
637
674
  query = Query("""
@@ -649,9 +686,13 @@ async def node_distance_reranker(
649
686
 
650
687
  for result in path_results:
651
688
  uuid = result['uuid']
652
- score = result['score'] if 'score' in result else float('inf')
689
+ score = result['score']
653
690
  scores[uuid] = score
654
691
 
692
+ for uuid in filtered_uuids:
693
+ if uuid not in scores:
694
+ scores[uuid] = float('inf')
695
+
655
696
  # rerank on shortest distance
656
697
  filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
657
698
 
@@ -1,15 +0,0 @@
1
- from .maintenance import (
2
- build_episodic_edges,
3
- clear_data,
4
- extract_edges,
5
- extract_nodes,
6
- retrieve_episodes,
7
- )
8
-
9
- __all__ = [
10
- 'extract_edges',
11
- 'build_episodic_edges',
12
- 'extract_nodes',
13
- 'clear_data',
14
- 'retrieve_episodes',
15
- ]
@@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  import typing
20
19
  from collections import defaultdict
21
- from datetime import datetime, timezone
20
+ from datetime import datetime
22
21
  from math import ceil
23
22
 
24
23
  from neo4j import AsyncDriver, AsyncManagedTransaction
@@ -26,6 +25,7 @@ from numpy import dot, sqrt
26
25
  from pydantic import BaseModel
27
26
 
28
27
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
28
+ from graphiti_core.helpers import semaphore_gather
29
29
  from graphiti_core.llm_client import LLMClient
30
30
  from graphiti_core.models.edges.edge_db_queries import (
31
31
  ENTITY_EDGE_SAVE_BULK,
@@ -37,14 +37,17 @@ from graphiti_core.models.nodes.node_db_queries import (
37
37
  )
38
38
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
39
39
  from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
40
- from graphiti_core.utils import retrieve_episodes
40
+ from graphiti_core.utils.datetime_utils import utc_now
41
41
  from graphiti_core.utils.maintenance.edge_operations import (
42
42
  build_episodic_edges,
43
43
  dedupe_edge_list,
44
44
  dedupe_extracted_edges,
45
45
  extract_edges,
46
46
  )
47
- from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
47
+ from graphiti_core.utils.maintenance.graph_data_operations import (
48
+ EPISODE_WINDOW_LEN,
49
+ retrieve_episodes,
50
+ )
48
51
  from graphiti_core.utils.maintenance.node_operations import (
49
52
  dedupe_extracted_nodes,
50
53
  dedupe_node_list,
@@ -68,7 +71,7 @@ class RawEpisode(BaseModel):
68
71
  async def retrieve_previous_episodes_bulk(
69
72
  driver: AsyncDriver, episodes: list[EpisodicNode]
70
73
  ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
71
- previous_episodes_list = await asyncio.gather(
74
+ previous_episodes_list = await semaphore_gather(
72
75
  *[
73
76
  retrieve_episodes(
74
77
  driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
@@ -115,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx(
115
118
  async def extract_nodes_and_edges_bulk(
116
119
  llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
117
120
  ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
118
- extracted_nodes_bulk = await asyncio.gather(
121
+ extracted_nodes_bulk = await semaphore_gather(
119
122
  *[
120
123
  extract_nodes(llm_client, episode, previous_episodes)
121
124
  for episode, previous_episodes in episode_tuples
@@ -127,7 +130,7 @@ async def extract_nodes_and_edges_bulk(
127
130
  [episode[1] for episode in episode_tuples],
128
131
  )
129
132
 
130
- extracted_edges_bulk = await asyncio.gather(
133
+ extracted_edges_bulk = await semaphore_gather(
131
134
  *[
132
135
  extract_edges(
133
136
  llm_client,
@@ -168,13 +171,13 @@ async def dedupe_nodes_bulk(
168
171
  node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
169
172
 
170
173
  existing_nodes_chunks: list[list[EntityNode]] = list(
171
- await asyncio.gather(
174
+ await semaphore_gather(
172
175
  *[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
173
176
  )
174
177
  )
175
178
 
176
179
  results: list[tuple[list[EntityNode], dict[str, str]]] = list(
177
- await asyncio.gather(
180
+ await semaphore_gather(
178
181
  *[
179
182
  dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
180
183
  for i, node_chunk in enumerate(node_chunks)
@@ -202,13 +205,13 @@ async def dedupe_edges_bulk(
202
205
  ]
203
206
 
204
207
  relevant_edges_chunks: list[list[EntityEdge]] = list(
205
- await asyncio.gather(
208
+ await semaphore_gather(
206
209
  *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
207
210
  )
208
211
  )
209
212
 
210
213
  resolved_edge_chunks: list[list[EntityEdge]] = list(
211
- await asyncio.gather(
214
+ await semaphore_gather(
212
215
  *[
213
216
  dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
214
217
  for i, edge_chunk in enumerate(edge_chunks)
@@ -289,7 +292,9 @@ async def compress_nodes(
289
292
  # add both nodes to the shortest chunk
290
293
  node_chunks[-1].extend([n, m])
291
294
 
292
- results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
295
+ results = await semaphore_gather(
296
+ *[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
297
+ )
293
298
 
294
299
  extended_map = dict(uuid_map)
295
300
  compressed_nodes: list[EntityNode] = []
@@ -312,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
312
317
  # We build a map of the edges based on their source and target nodes.
313
318
  edge_chunks = chunk_edges_by_nodes(edges)
314
319
 
315
- results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
320
+ results = await semaphore_gather(
321
+ *[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
322
+ )
316
323
 
317
324
  compressed_edges: list[EntityEdge] = []
318
325
  for edge_chunk in results:
@@ -365,7 +372,7 @@ async def extract_edge_dates_bulk(
365
372
  episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
366
373
  }
367
374
 
368
- results = await asyncio.gather(
375
+ results = await semaphore_gather(
369
376
  *[
370
377
  extract_edge_dates(
371
378
  llm_client,
@@ -385,7 +392,7 @@ async def extract_edge_dates_bulk(
385
392
  edge.valid_at = valid_at
386
393
  edge.invalid_at = invalid_at
387
394
  if edge.invalid_at:
388
- edge.expired_at = datetime.now(timezone.utc)
395
+ edge.expired_at = utc_now()
389
396
 
390
397
  return edges
391
398