graphiti-core 0.5.0rc5__tar.gz → 0.5.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (60) hide show
  1. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/PKG-INFO +1 -1
  2. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
  3. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/graphiti.py +19 -20
  4. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/helpers.py +16 -2
  5. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/anthropic_client.py +5 -2
  6. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/client.py +15 -7
  7. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/config.py +1 -1
  8. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/groq_client.py +5 -2
  9. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/openai_client.py +16 -6
  10. graphiti_core-0.5.2/graphiti_core/llm_client/openai_generic_client.py +171 -0
  11. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search.py +5 -5
  12. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search_utils.py +48 -11
  13. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/bulk_utils.py +15 -11
  14. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/community_operations.py +6 -4
  15. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/edge_operations.py +8 -5
  16. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
  17. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/node_operations.py +3 -4
  18. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/temporal_operations.py +2 -2
  19. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/pyproject.toml +1 -1
  20. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/LICENSE +0 -0
  21. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/README.md +0 -0
  22. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/__init__.py +0 -0
  23. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/__init__.py +0 -0
  24. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
  25. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/client.py +0 -0
  26. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/edges.py +0 -0
  27. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/embedder/__init__.py +0 -0
  28. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/embedder/client.py +0 -0
  29. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/embedder/openai.py +0 -0
  30. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/embedder/voyage.py +0 -0
  31. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/errors.py +0 -0
  32. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/__init__.py +0 -0
  33. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/errors.py +0 -0
  34. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/utils.py +0 -0
  35. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/__init__.py +0 -0
  36. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/edges/__init__.py +0 -0
  37. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/edges/edge_db_queries.py +0 -0
  38. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/nodes/__init__.py +0 -0
  39. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  40. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/nodes.py +0 -0
  41. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/__init__.py +0 -0
  42. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/dedupe_edges.py +0 -0
  43. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  44. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/eval.py +0 -0
  45. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  46. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/extract_edges.py +0 -0
  47. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/extract_nodes.py +0 -0
  48. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/invalidate_edges.py +0 -0
  49. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/lib.py +0 -0
  50. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/models.py +0 -0
  51. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/prompt_helpers.py +0 -0
  52. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/summarize_nodes.py +0 -0
  53. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/py.typed +0 -0
  54. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/__init__.py +0 -0
  55. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search_config.py +0 -0
  56. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search_config_recipes.py +0 -0
  57. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/__init__.py +0 -0
  58. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/datetime_utils.py +0 -0
  59. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/__init__.py +0 -0
  60. {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.5.0rc5
3
+ Version: 0.5.2
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from typing import Any
20
19
 
@@ -22,6 +21,7 @@ import openai
22
21
  from openai import AsyncOpenAI
23
22
  from pydantic import BaseModel
24
23
 
24
+ from ..helpers import semaphore_gather
25
25
  from ..llm_client import LLMConfig, RateLimitError
26
26
  from ..prompts import Message
27
27
  from .client import CrossEncoderClient
@@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
75
75
  for passage in passages
76
76
  ]
77
77
  try:
78
- responses = await asyncio.gather(
78
+ responses = await semaphore_gather(
79
79
  *[
80
80
  self.client.chat.completions.create(
81
81
  model=DEFAULT_MODEL,
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from datetime import datetime
20
19
  from time import time
@@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
27
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
28
27
  from graphiti_core.edges import EntityEdge, EpisodicEdge
29
28
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
- from graphiti_core.helpers import DEFAULT_DATABASE
29
+ from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
31
30
  from graphiti_core.llm_client import LLMClient, OpenAIClient
32
31
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
33
32
  from graphiti_core.search.search import SearchConfig, search
@@ -340,13 +339,13 @@ class Graphiti:
340
339
 
341
340
  # Calculate Embeddings
342
341
 
343
- await asyncio.gather(
342
+ await semaphore_gather(
344
343
  *[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
345
344
  )
346
345
 
347
346
  # Find relevant nodes already in the graph
348
347
  existing_nodes_lists: list[list[EntityNode]] = list(
349
- await asyncio.gather(
348
+ await semaphore_gather(
350
349
  *[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
351
350
  )
352
351
  )
@@ -354,7 +353,7 @@ class Graphiti:
354
353
  # Resolve extracted nodes with nodes already in the graph and extract facts
355
354
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
356
355
 
357
- (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
356
+ (mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
358
357
  resolve_extracted_nodes(
359
358
  self.llm_client,
360
359
  extracted_nodes,
@@ -374,7 +373,7 @@ class Graphiti:
374
373
  )
375
374
 
376
375
  # calculate embeddings
377
- await asyncio.gather(
376
+ await semaphore_gather(
378
377
  *[
379
378
  edge.generate_embedding(self.embedder)
380
379
  for edge in extracted_edges_with_resolved_pointers
@@ -383,7 +382,7 @@ class Graphiti:
383
382
 
384
383
  # Resolve extracted edges with related edges already in the graph
385
384
  related_edges_list: list[list[EntityEdge]] = list(
386
- await asyncio.gather(
385
+ await semaphore_gather(
387
386
  *[
388
387
  get_relevant_edges(
389
388
  self.driver,
@@ -404,7 +403,7 @@ class Graphiti:
404
403
  )
405
404
 
406
405
  existing_source_edges_list: list[list[EntityEdge]] = list(
407
- await asyncio.gather(
406
+ await semaphore_gather(
408
407
  *[
409
408
  get_relevant_edges(
410
409
  self.driver,
@@ -419,7 +418,7 @@ class Graphiti:
419
418
  )
420
419
 
421
420
  existing_target_edges_list: list[list[EntityEdge]] = list(
422
- await asyncio.gather(
421
+ await semaphore_gather(
423
422
  *[
424
423
  get_relevant_edges(
425
424
  self.driver,
@@ -468,7 +467,7 @@ class Graphiti:
468
467
 
469
468
  # Update any communities
470
469
  if update_communities:
471
- await asyncio.gather(
470
+ await semaphore_gather(
472
471
  *[
473
472
  update_community(self.driver, self.llm_client, self.embedder, node)
474
473
  for node in nodes
@@ -538,7 +537,7 @@ class Graphiti:
538
537
  ]
539
538
 
540
539
  # Save all the episodes
541
- await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
540
+ await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
542
541
 
543
542
  # Get previous episode context for each episode
544
543
  episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
@@ -551,19 +550,19 @@ class Graphiti:
551
550
  ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
552
551
 
553
552
  # Generate embeddings
554
- await asyncio.gather(
553
+ await semaphore_gather(
555
554
  *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
556
555
  *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
557
556
  )
558
557
 
559
558
  # Dedupe extracted nodes, compress extracted edges
560
- (nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
559
+ (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
561
560
  dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
562
561
  extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
563
562
  )
564
563
 
565
564
  # save nodes to KG
566
- await asyncio.gather(*[node.save(self.driver) for node in nodes])
565
+ await semaphore_gather(*[node.save(self.driver) for node in nodes])
567
566
 
568
567
  # re-map edge pointers so that they don't point to discard dupe nodes
569
568
  extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
@@ -574,7 +573,7 @@ class Graphiti:
574
573
  )
575
574
 
576
575
  # save episodic edges to KG
577
- await asyncio.gather(
576
+ await semaphore_gather(
578
577
  *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
579
578
  )
580
579
 
@@ -587,7 +586,7 @@ class Graphiti:
587
586
  # invalidate edges
588
587
 
589
588
  # save edges to KG
590
- await asyncio.gather(*[edge.save(self.driver) for edge in edges])
589
+ await semaphore_gather(*[edge.save(self.driver) for edge in edges])
591
590
 
592
591
  end = time()
593
592
  logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
@@ -610,12 +609,12 @@ class Graphiti:
610
609
  self.driver, self.llm_client, group_ids
611
610
  )
612
611
 
613
- await asyncio.gather(
612
+ await semaphore_gather(
614
613
  *[node.generate_name_embedding(self.embedder) for node in community_nodes]
615
614
  )
616
615
 
617
- await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
618
- await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
616
+ await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
617
+ await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
619
618
 
620
619
  return community_nodes
621
620
 
@@ -698,7 +697,7 @@ class Graphiti:
698
697
  async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
699
698
  episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
700
699
 
701
- edges_list = await asyncio.gather(
700
+ edges_list = await semaphore_gather(
702
701
  *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
703
702
  )
704
703
 
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import asyncio
17
18
  import os
19
+ from collections.abc import Coroutine
18
20
  from datetime import datetime
19
21
 
20
22
  import numpy as np
@@ -25,6 +27,7 @@ load_dotenv()
25
27
 
26
28
  DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
27
29
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
30
+ SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
28
31
  MAX_REFLEXION_ITERATIONS = 2
29
32
  DEFAULT_PAGE_LIMIT = 20
30
33
 
@@ -70,13 +73,24 @@ def lucene_sanitize(query: str) -> str:
70
73
  return sanitized
71
74
 
72
75
 
73
- def normalize_l2(embedding: list[float]) -> list[float]:
76
+ def normalize_l2(embedding: list[float]):
74
77
  embedding_array = np.array(embedding)
75
78
  if embedding_array.ndim == 1:
76
79
  norm = np.linalg.norm(embedding_array)
77
80
  if norm == 0:
78
- return embedding_array.tolist()
81
+ return [0.0] * len(embedding)
79
82
  return (embedding_array / norm).tolist()
80
83
  else:
81
84
  norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
82
85
  return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
86
+
87
+
88
+ # Use this instead of asyncio.gather() to bound coroutines
89
+ async def semaphore_gather(*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT):
90
+ semaphore = asyncio.Semaphore(max_coroutines)
91
+
92
+ async def _wrap_coroutine(coroutine):
93
+ async with semaphore:
94
+ return await coroutine
95
+
96
+ return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
@@ -48,7 +48,10 @@ class AnthropicClient(LLMClient):
48
48
  )
49
49
 
50
50
  async def _generate_response(
51
- self, messages: list[Message], response_model: type[BaseModel] | None = None
51
+ self,
52
+ messages: list[Message],
53
+ response_model: type[BaseModel] | None = None,
54
+ max_tokens: int = DEFAULT_MAX_TOKENS,
52
55
  ) -> dict[str, typing.Any]:
53
56
  system_message = messages[0]
54
57
  user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
@@ -59,7 +62,7 @@ class AnthropicClient(LLMClient):
59
62
  result = await self.client.messages.create(
60
63
  system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
61
64
  + system_message.content,
62
- max_tokens=self.max_tokens,
65
+ max_tokens=max_tokens or self.max_tokens,
63
66
  temperature=self.temperature,
64
67
  messages=user_messages, # type: ignore
65
68
  model=self.model or DEFAULT_MODEL,
@@ -26,7 +26,7 @@ from pydantic import BaseModel
26
26
  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
27
27
 
28
28
  from ..prompts.models import Message
29
- from .config import LLMConfig
29
+ from .config import DEFAULT_MAX_TOKENS, LLMConfig
30
30
  from .errors import RateLimitError
31
31
 
32
32
  DEFAULT_TEMPERATURE = 0
@@ -56,7 +56,6 @@ class LLMClient(ABC):
56
56
  self.cache_enabled = cache
57
57
  self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
58
58
 
59
-
60
59
  def _clean_input(self, input: str) -> str:
61
60
  """Clean input string of invalid unicode and control characters.
62
61
 
@@ -91,16 +90,22 @@ class LLMClient(ABC):
91
90
  reraise=True,
92
91
  )
93
92
  async def _generate_response_with_retry(
94
- self, messages: list[Message], response_model: type[BaseModel] | None = None
93
+ self,
94
+ messages: list[Message],
95
+ response_model: type[BaseModel] | None = None,
96
+ max_tokens: int = DEFAULT_MAX_TOKENS,
95
97
  ) -> dict[str, typing.Any]:
96
98
  try:
97
- return await self._generate_response(messages, response_model)
99
+ return await self._generate_response(messages, response_model, max_tokens)
98
100
  except (httpx.HTTPStatusError, RateLimitError) as e:
99
101
  raise e
100
102
 
101
103
  @abstractmethod
102
104
  async def _generate_response(
103
- self, messages: list[Message], response_model: type[BaseModel] | None = None
105
+ self,
106
+ messages: list[Message],
107
+ response_model: type[BaseModel] | None = None,
108
+ max_tokens: int = DEFAULT_MAX_TOKENS,
104
109
  ) -> dict[str, typing.Any]:
105
110
  pass
106
111
 
@@ -111,7 +116,10 @@ class LLMClient(ABC):
111
116
  return hashlib.md5(key_str.encode()).hexdigest()
112
117
 
113
118
  async def generate_response(
114
- self, messages: list[Message], response_model: type[BaseModel] | None = None
119
+ self,
120
+ messages: list[Message],
121
+ response_model: type[BaseModel] | None = None,
122
+ max_tokens: int = DEFAULT_MAX_TOKENS,
115
123
  ) -> dict[str, typing.Any]:
116
124
  if response_model is not None:
117
125
  serialized_model = json.dumps(response_model.model_json_schema())
@@ -132,7 +140,7 @@ class LLMClient(ABC):
132
140
  for message in messages:
133
141
  message.content = self._clean_input(message.content)
134
142
 
135
- response = await self._generate_response_with_retry(messages, response_model)
143
+ response = await self._generate_response_with_retry(messages, response_model, max_tokens)
136
144
 
137
145
  if self.cache_enabled:
138
146
  self.cache_dir.set(cache_key, response)
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- DEFAULT_MAX_TOKENS = 16384
17
+ DEFAULT_MAX_TOKENS = 1024
18
18
  DEFAULT_TEMPERATURE = 0
19
19
 
20
20
 
@@ -45,7 +45,10 @@ class GroqClient(LLMClient):
45
45
  self.client = AsyncGroq(api_key=config.api_key)
46
46
 
47
47
  async def _generate_response(
48
- self, messages: list[Message], response_model: type[BaseModel] | None = None
48
+ self,
49
+ messages: list[Message],
50
+ response_model: type[BaseModel] | None = None,
51
+ max_tokens: int = DEFAULT_MAX_TOKENS,
49
52
  ) -> dict[str, typing.Any]:
50
53
  msgs: list[ChatCompletionMessageParam] = []
51
54
  for m in messages:
@@ -58,7 +61,7 @@ class GroqClient(LLMClient):
58
61
  model=self.model or DEFAULT_MODEL,
59
62
  messages=msgs,
60
63
  temperature=self.temperature,
61
- max_tokens=self.max_tokens,
64
+ max_tokens=max_tokens or self.max_tokens,
62
65
  response_format={'type': 'json_object'},
63
66
  )
64
67
  result = response.choices[0].message.content or ''
@@ -25,7 +25,7 @@ from pydantic import BaseModel
25
25
 
26
26
  from ..prompts.models import Message
27
27
  from .client import LLMClient
28
- from .config import LLMConfig
28
+ from .config import DEFAULT_MAX_TOKENS, LLMConfig
29
29
  from .errors import RateLimitError, RefusalError
30
30
 
31
31
  logger = logging.getLogger(__name__)
@@ -58,7 +58,11 @@ class OpenAIClient(LLMClient):
58
58
  MAX_RETRIES: ClassVar[int] = 2
59
59
 
60
60
  def __init__(
61
- self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
61
+ self,
62
+ config: LLMConfig | None = None,
63
+ cache: bool = False,
64
+ client: typing.Any = None,
65
+ max_tokens: int = DEFAULT_MAX_TOKENS,
62
66
  ):
63
67
  """
64
68
  Initialize the OpenAIClient with the provided configuration, cache setting, and client.
@@ -84,7 +88,10 @@ class OpenAIClient(LLMClient):
84
88
  self.client = client
85
89
 
86
90
  async def _generate_response(
87
- self, messages: list[Message], response_model: type[BaseModel] | None = None
91
+ self,
92
+ messages: list[Message],
93
+ response_model: type[BaseModel] | None = None,
94
+ max_tokens: int = DEFAULT_MAX_TOKENS,
88
95
  ) -> dict[str, typing.Any]:
89
96
  openai_messages: list[ChatCompletionMessageParam] = []
90
97
  for m in messages:
@@ -98,7 +105,7 @@ class OpenAIClient(LLMClient):
98
105
  model=self.model or DEFAULT_MODEL,
99
106
  messages=openai_messages,
100
107
  temperature=self.temperature,
101
- max_tokens=self.max_tokens,
108
+ max_tokens=max_tokens or self.max_tokens,
102
109
  response_format=response_model, # type: ignore
103
110
  )
104
111
 
@@ -119,14 +126,17 @@ class OpenAIClient(LLMClient):
119
126
  raise
120
127
 
121
128
  async def generate_response(
122
- self, messages: list[Message], response_model: type[BaseModel] | None = None
129
+ self,
130
+ messages: list[Message],
131
+ response_model: type[BaseModel] | None = None,
132
+ max_tokens: int = DEFAULT_MAX_TOKENS,
123
133
  ) -> dict[str, typing.Any]:
124
134
  retry_count = 0
125
135
  last_error = None
126
136
 
127
137
  while retry_count <= self.MAX_RETRIES:
128
138
  try:
129
- response = await self._generate_response(messages, response_model)
139
+ response = await self._generate_response(messages, response_model, max_tokens)
130
140
  return response
131
141
  except (RateLimitError, RefusalError):
132
142
  # These errors should not trigger retries
@@ -0,0 +1,171 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ import logging
19
+ import typing
20
+ from typing import ClassVar
21
+
22
+ import openai
23
+ from openai import AsyncOpenAI
24
+ from openai.types.chat import ChatCompletionMessageParam
25
+ from pydantic import BaseModel
26
+
27
+ from ..prompts.models import Message
28
+ from .client import LLMClient
29
+ from .config import DEFAULT_MAX_TOKENS, LLMConfig
30
+ from .errors import RateLimitError, RefusalError
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ DEFAULT_MODEL = 'gpt-4o-mini'
35
+
36
+
37
+ class OpenAIGenericClient(LLMClient):
38
+ """
39
+ OpenAIClient is a client class for interacting with OpenAI's language models.
40
+
41
+ This class extends the LLMClient and provides methods to initialize the client,
42
+ get an embedder, and generate responses from the language model.
43
+
44
+ Attributes:
45
+ client (AsyncOpenAI): The OpenAI client used to interact with the API.
46
+ model (str): The model name to use for generating responses.
47
+ temperature (float): The temperature to use for generating responses.
48
+ max_tokens (int): The maximum number of tokens to generate in a response.
49
+
50
+ Methods:
51
+ __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
52
+ Initializes the OpenAIClient with the provided configuration, cache setting, and client.
53
+
54
+ _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
55
+ Generates a response from the language model based on the provided messages.
56
+ """
57
+
58
+ # Class-level constants
59
+ MAX_RETRIES: ClassVar[int] = 2
60
+
61
+ def __init__(
62
+ self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
63
+ ):
64
+ """
65
+ Initialize the OpenAIClient with the provided configuration, cache setting, and client.
66
+
67
+ Args:
68
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
69
+ cache (bool): Whether to use caching for responses. Defaults to False.
70
+ client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
71
+
72
+ """
73
+ # removed caching to simplify the `generate_response` override
74
+ if cache:
75
+ raise NotImplementedError('Caching is not implemented for OpenAI')
76
+
77
+ if config is None:
78
+ config = LLMConfig()
79
+
80
+ super().__init__(config, cache)
81
+
82
+ if client is None:
83
+ self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
84
+ else:
85
+ self.client = client
86
+
87
+ async def _generate_response(
88
+ self,
89
+ messages: list[Message],
90
+ response_model: type[BaseModel] | None = None,
91
+ max_tokens: int = DEFAULT_MAX_TOKENS,
92
+ ) -> dict[str, typing.Any]:
93
+ openai_messages: list[ChatCompletionMessageParam] = []
94
+ for m in messages:
95
+ m.content = self._clean_input(m.content)
96
+ if m.role == 'user':
97
+ openai_messages.append({'role': 'user', 'content': m.content})
98
+ elif m.role == 'system':
99
+ openai_messages.append({'role': 'system', 'content': m.content})
100
+ try:
101
+ response = await self.client.chat.completions.create(
102
+ model=self.model or DEFAULT_MODEL,
103
+ messages=openai_messages,
104
+ temperature=self.temperature,
105
+ max_tokens=self.max_tokens,
106
+ response_format={'type': 'json_object'},
107
+ )
108
+ result = response.choices[0].message.content or ''
109
+ return json.loads(result)
110
+ except openai.RateLimitError as e:
111
+ raise RateLimitError from e
112
+ except Exception as e:
113
+ logger.error(f'Error in generating LLM response: {e}')
114
+ raise
115
+
116
+ async def generate_response(
117
+ self,
118
+ messages: list[Message],
119
+ response_model: type[BaseModel] | None = None,
120
+ max_tokens: int = DEFAULT_MAX_TOKENS,
121
+ ) -> dict[str, typing.Any]:
122
+ retry_count = 0
123
+ last_error = None
124
+
125
+ if response_model is not None:
126
+ serialized_model = json.dumps(response_model.model_json_schema())
127
+ messages[
128
+ -1
129
+ ].content += (
130
+ f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
131
+ )
132
+
133
+ while retry_count <= self.MAX_RETRIES:
134
+ try:
135
+ response = await self._generate_response(
136
+ messages, response_model, max_tokens=max_tokens
137
+ )
138
+ return response
139
+ except (RateLimitError, RefusalError):
140
+ # These errors should not trigger retries
141
+ raise
142
+ except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
143
+ # Let OpenAI's client handle these retries
144
+ raise
145
+ except Exception as e:
146
+ last_error = e
147
+
148
+ # Don't retry if we've hit the max retries
149
+ if retry_count >= self.MAX_RETRIES:
150
+ logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
151
+ raise
152
+
153
+ retry_count += 1
154
+
155
+ # Construct a detailed error message for the LLM
156
+ error_context = (
157
+ f'The previous response attempt was invalid. '
158
+ f'Error type: {e.__class__.__name__}. '
159
+ f'Error details: {str(e)}. '
160
+ f'Please try again with a valid response, ensuring the output matches '
161
+ f'the expected format and constraints.'
162
+ )
163
+
164
+ error_message = Message(role='user', content=error_context)
165
+ messages.append(error_message)
166
+ logger.warning(
167
+ f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
168
+ )
169
+
170
+ # If we somehow get here, raise the last error
171
+ raise last_error or Exception('Max retries exceeded with no specific error')
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from collections import defaultdict
20
19
  from time import time
@@ -25,6 +24,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
25
24
  from graphiti_core.edges import EntityEdge
26
25
  from graphiti_core.embedder import EmbedderClient
27
26
  from graphiti_core.errors import SearchRerankerError
27
+ from graphiti_core.helpers import semaphore_gather
28
28
  from graphiti_core.nodes import CommunityNode, EntityNode
29
29
  from graphiti_core.search.search_config import (
30
30
  DEFAULT_SEARCH_LIMIT,
@@ -78,7 +78,7 @@ async def search(
78
78
 
79
79
  # if group_ids is empty, set it to None
80
80
  group_ids = group_ids if group_ids else None
81
- edges, nodes, communities = await asyncio.gather(
81
+ edges, nodes, communities = await semaphore_gather(
82
82
  edge_search(
83
83
  driver,
84
84
  cross_encoder,
@@ -141,7 +141,7 @@ async def edge_search(
141
141
  return []
142
142
 
143
143
  search_results: list[list[EntityEdge]] = list(
144
- await asyncio.gather(
144
+ await semaphore_gather(
145
145
  *[
146
146
  edge_fulltext_search(driver, query, group_ids, 2 * limit),
147
147
  edge_similarity_search(
@@ -226,7 +226,7 @@ async def node_search(
226
226
  return []
227
227
 
228
228
  search_results: list[list[EntityNode]] = list(
229
- await asyncio.gather(
229
+ await semaphore_gather(
230
230
  *[
231
231
  node_fulltext_search(driver, query, group_ids, 2 * limit),
232
232
  node_similarity_search(
@@ -295,7 +295,7 @@ async def community_search(
295
295
  return []
296
296
 
297
297
  search_results: list[list[CommunityNode]] = list(
298
- await asyncio.gather(
298
+ await semaphore_gather(
299
299
  *[
300
300
  community_fulltext_search(driver, query, group_ids, 2 * limit),
301
301
  community_similarity_search(
@@ -14,10 +14,10 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from collections import defaultdict
20
19
  from time import time
20
+ from typing import Any
21
21
 
22
22
  import numpy as np
23
23
  from neo4j import AsyncDriver, Query
@@ -29,6 +29,7 @@ from graphiti_core.helpers import (
29
29
  USE_PARALLEL_RUNTIME,
30
30
  lucene_sanitize,
31
31
  normalize_l2,
32
+ semaphore_gather,
32
33
  )
33
34
  from graphiti_core.nodes import (
34
35
  CommunityNode,
@@ -191,12 +192,27 @@ async def edge_similarity_search(
191
192
  'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
192
193
  )
193
194
 
194
- query: LiteralString = """
195
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
196
- WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
197
- AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
198
- AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
199
- WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
195
+ query_params: dict[str, Any] = {}
196
+
197
+ group_filter_query: LiteralString = ''
198
+ if group_ids is not None:
199
+ group_filter_query += 'WHERE r.group_id IN $group_ids'
200
+ query_params['group_ids'] = group_ids
201
+ query_params['source_node_uuid'] = source_node_uuid
202
+ query_params['target_node_uuid'] = target_node_uuid
203
+
204
+ if source_node_uuid is not None:
205
+ group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
206
+
207
+ if target_node_uuid is not None:
208
+ group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
209
+
210
+ query: LiteralString = (
211
+ """
212
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
213
+ """
214
+ + group_filter_query
215
+ + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
200
216
  WHERE score > $min_score
201
217
  RETURN
202
218
  r.uuid AS uuid,
@@ -214,9 +230,11 @@ async def edge_similarity_search(
214
230
  ORDER BY score DESC
215
231
  LIMIT $limit
216
232
  """
233
+ )
217
234
 
218
235
  records, _, _ = await driver.execute_query(
219
236
  runtime_query + query,
237
+ query_params,
220
238
  search_vector=search_vector,
221
239
  source_uuid=source_node_uuid,
222
240
  target_uuid=target_node_uuid,
@@ -325,11 +343,20 @@ async def node_similarity_search(
325
343
  'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
326
344
  )
327
345
 
346
+ query_params: dict[str, Any] = {}
347
+
348
+ group_filter_query: LiteralString = ''
349
+ if group_ids is not None:
350
+ group_filter_query += 'WHERE n.group_id IN $group_ids'
351
+ query_params['group_ids'] = group_ids
352
+
328
353
  records, _, _ = await driver.execute_query(
329
354
  runtime_query
330
355
  + """
331
356
  MATCH (n:Entity)
332
- WHERE $group_ids IS NULL OR n.group_id IN $group_ids
357
+ """
358
+ + group_filter_query
359
+ + """
333
360
  WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
334
361
  WHERE score > $min_score
335
362
  RETURN
@@ -342,6 +369,7 @@ async def node_similarity_search(
342
369
  ORDER BY score DESC
343
370
  LIMIT $limit
344
371
  """,
372
+ query_params,
345
373
  search_vector=search_vector,
346
374
  group_ids=group_ids,
347
375
  limit=limit,
@@ -436,11 +464,20 @@ async def community_similarity_search(
436
464
  'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
437
465
  )
438
466
 
467
+ query_params: dict[str, Any] = {}
468
+
469
+ group_filter_query: LiteralString = ''
470
+ if group_ids is not None:
471
+ group_filter_query += 'WHERE comm.group_id IN $group_ids'
472
+ query_params['group_ids'] = group_ids
473
+
439
474
  records, _, _ = await driver.execute_query(
440
475
  runtime_query
441
476
  + """
442
477
  MATCH (comm:Community)
443
- WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
478
+ """
479
+ + group_filter_query
480
+ + """
444
481
  WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
445
482
  WHERE score > $min_score
446
483
  RETURN
@@ -512,7 +549,7 @@ async def hybrid_node_search(
512
549
 
513
550
  start = time()
514
551
  results: list[list[EntityNode]] = list(
515
- await asyncio.gather(
552
+ await semaphore_gather(
516
553
  *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
517
554
  *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
518
555
  )
@@ -582,7 +619,7 @@ async def get_relevant_edges(
582
619
  relevant_edges: list[EntityEdge] = []
583
620
  relevant_edge_uuids = set()
584
621
 
585
- results = await asyncio.gather(
622
+ results = await semaphore_gather(
586
623
  *[
587
624
  edge_similarity_search(
588
625
  driver,
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  import typing
20
19
  from collections import defaultdict
@@ -26,6 +25,7 @@ from numpy import dot, sqrt
26
25
  from pydantic import BaseModel
27
26
 
28
27
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
28
+ from graphiti_core.helpers import semaphore_gather
29
29
  from graphiti_core.llm_client import LLMClient
30
30
  from graphiti_core.models.edges.edge_db_queries import (
31
31
  ENTITY_EDGE_SAVE_BULK,
@@ -71,7 +71,7 @@ class RawEpisode(BaseModel):
71
71
  async def retrieve_previous_episodes_bulk(
72
72
  driver: AsyncDriver, episodes: list[EpisodicNode]
73
73
  ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
74
- previous_episodes_list = await asyncio.gather(
74
+ previous_episodes_list = await semaphore_gather(
75
75
  *[
76
76
  retrieve_episodes(
77
77
  driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
@@ -118,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx(
118
118
  async def extract_nodes_and_edges_bulk(
119
119
  llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
120
120
  ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
121
- extracted_nodes_bulk = await asyncio.gather(
121
+ extracted_nodes_bulk = await semaphore_gather(
122
122
  *[
123
123
  extract_nodes(llm_client, episode, previous_episodes)
124
124
  for episode, previous_episodes in episode_tuples
@@ -130,7 +130,7 @@ async def extract_nodes_and_edges_bulk(
130
130
  [episode[1] for episode in episode_tuples],
131
131
  )
132
132
 
133
- extracted_edges_bulk = await asyncio.gather(
133
+ extracted_edges_bulk = await semaphore_gather(
134
134
  *[
135
135
  extract_edges(
136
136
  llm_client,
@@ -171,13 +171,13 @@ async def dedupe_nodes_bulk(
171
171
  node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
172
172
 
173
173
  existing_nodes_chunks: list[list[EntityNode]] = list(
174
- await asyncio.gather(
174
+ await semaphore_gather(
175
175
  *[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
176
176
  )
177
177
  )
178
178
 
179
179
  results: list[tuple[list[EntityNode], dict[str, str]]] = list(
180
- await asyncio.gather(
180
+ await semaphore_gather(
181
181
  *[
182
182
  dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
183
183
  for i, node_chunk in enumerate(node_chunks)
@@ -205,13 +205,13 @@ async def dedupe_edges_bulk(
205
205
  ]
206
206
 
207
207
  relevant_edges_chunks: list[list[EntityEdge]] = list(
208
- await asyncio.gather(
208
+ await semaphore_gather(
209
209
  *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
210
210
  )
211
211
  )
212
212
 
213
213
  resolved_edge_chunks: list[list[EntityEdge]] = list(
214
- await asyncio.gather(
214
+ await semaphore_gather(
215
215
  *[
216
216
  dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
217
217
  for i, edge_chunk in enumerate(edge_chunks)
@@ -292,7 +292,9 @@ async def compress_nodes(
292
292
  # add both nodes to the shortest chunk
293
293
  node_chunks[-1].extend([n, m])
294
294
 
295
- results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
295
+ results = await semaphore_gather(
296
+ *[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
297
+ )
296
298
 
297
299
  extended_map = dict(uuid_map)
298
300
  compressed_nodes: list[EntityNode] = []
@@ -315,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
315
317
  # We build a map of the edges based on their source and target nodes.
316
318
  edge_chunks = chunk_edges_by_nodes(edges)
317
319
 
318
- results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
320
+ results = await semaphore_gather(
321
+ *[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
322
+ )
319
323
 
320
324
  compressed_edges: list[EntityEdge] = []
321
325
  for edge_chunk in results:
@@ -368,7 +372,7 @@ async def extract_edge_dates_bulk(
368
372
  episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
369
373
  }
370
374
 
371
- results = await asyncio.gather(
375
+ results = await semaphore_gather(
372
376
  *[
373
377
  extract_edge_dates(
374
378
  llm_client,
@@ -7,7 +7,7 @@ from pydantic import BaseModel
7
7
 
8
8
  from graphiti_core.edges import CommunityEdge
9
9
  from graphiti_core.embedder import EmbedderClient
10
- from graphiti_core.helpers import DEFAULT_DATABASE
10
+ from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
11
11
  from graphiti_core.llm_client import LLMClient
12
12
  from graphiti_core.nodes import (
13
13
  CommunityNode,
@@ -71,7 +71,7 @@ async def get_community_clusters(
71
71
 
72
72
  community_clusters.extend(
73
73
  list(
74
- await asyncio.gather(
74
+ await semaphore_gather(
75
75
  *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
76
76
  )
77
77
  )
@@ -164,7 +164,7 @@ async def build_community(
164
164
  odd_one_out = summaries.pop()
165
165
  length -= 1
166
166
  new_summaries: list[str] = list(
167
- await asyncio.gather(
167
+ await semaphore_gather(
168
168
  *[
169
169
  summarize_pair(llm_client, (str(left_summary), str(right_summary)))
170
170
  for left_summary, right_summary in zip(
@@ -207,7 +207,9 @@ async def build_communities(
207
207
  return await build_community(llm_client, cluster)
208
208
 
209
209
  communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
210
- await asyncio.gather(*[limited_build_community(cluster) for cluster in community_clusters])
210
+ await semaphore_gather(
211
+ *[limited_build_community(cluster) for cluster in community_clusters]
212
+ )
211
213
  )
212
214
 
213
215
  community_nodes: list[CommunityNode] = []
@@ -14,13 +14,12 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from datetime import datetime
20
19
  from time import time
21
20
 
22
21
  from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
23
- from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
22
+ from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
24
23
  from graphiti_core.llm_client import LLMClient
25
24
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
26
25
  from graphiti_core.prompts import prompt_library
@@ -80,6 +79,8 @@ async def extract_edges(
80
79
  ) -> list[EntityEdge]:
81
80
  start = time()
82
81
 
82
+ EXTRACT_EDGES_MAX_TOKENS = 16384
83
+
83
84
  node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
84
85
 
85
86
  # Prepare context for LLM
@@ -94,7 +95,9 @@ async def extract_edges(
94
95
  reflexion_iterations = 0
95
96
  while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
96
97
  llm_response = await llm_client.generate_response(
97
- prompt_library.extract_edges.edge(context), response_model=ExtractedEdges
98
+ prompt_library.extract_edges.edge(context),
99
+ response_model=ExtractedEdges,
100
+ max_tokens=EXTRACT_EDGES_MAX_TOKENS,
98
101
  )
99
102
  edges_data = llm_response.get('edges', [])
100
103
 
@@ -199,7 +202,7 @@ async def resolve_extracted_edges(
199
202
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
200
203
  # resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
201
204
  results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
202
- await asyncio.gather(
205
+ await semaphore_gather(
203
206
  *[
204
207
  resolve_extracted_edge(
205
208
  llm_client,
@@ -266,7 +269,7 @@ async def resolve_extracted_edge(
266
269
  current_episode: EpisodicNode,
267
270
  previous_episodes: list[EpisodicNode],
268
271
  ) -> tuple[EntityEdge, list[EntityEdge]]:
269
- resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
272
+ resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather(
270
273
  dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
271
274
  extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
272
275
  get_edge_contradictions(llm_client, extracted_edge, existing_edges),
@@ -14,14 +14,13 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from datetime import datetime, timezone
20
19
 
21
20
  from neo4j import AsyncDriver
22
21
  from typing_extensions import LiteralString
23
22
 
24
- from graphiti_core.helpers import DEFAULT_DATABASE
23
+ from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
25
24
  from graphiti_core.nodes import EpisodeType, EpisodicNode
26
25
 
27
26
  EPISODE_WINDOW_LEN = 3
@@ -38,7 +37,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
38
37
  database_=DEFAULT_DATABASE,
39
38
  )
40
39
  index_names = [record['name'] for record in records]
41
- await asyncio.gather(
40
+ await semaphore_gather(
42
41
  *[
43
42
  driver.execute_query(
44
43
  """DROP INDEX $name""",
@@ -82,7 +81,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
82
81
 
83
82
  index_queries: list[LiteralString] = range_indices + fulltext_indices
84
83
 
85
- await asyncio.gather(
84
+ await semaphore_gather(
86
85
  *[
87
86
  driver.execute_query(
88
87
  query,
@@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from time import time
20
19
 
21
- from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
20
+ from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
22
21
  from graphiti_core.llm_client import LLMClient
23
22
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
24
23
  from graphiti_core.prompts import prompt_library
@@ -223,7 +222,7 @@ async def resolve_extracted_nodes(
223
222
  uuid_map: dict[str, str] = {}
224
223
  resolved_nodes: list[EntityNode] = []
225
224
  results: list[tuple[EntityNode, dict[str, str]]] = list(
226
- await asyncio.gather(
225
+ await semaphore_gather(
227
226
  *[
228
227
  resolve_extracted_node(
229
228
  llm_client, extracted_node, existing_nodes, episode, previous_episodes
@@ -275,7 +274,7 @@ async def resolve_extracted_node(
275
274
  else [],
276
275
  }
277
276
 
278
- llm_response, node_summary_response = await asyncio.gather(
277
+ llm_response, node_summary_response = await semaphore_gather(
279
278
  llm_client.generate_response(
280
279
  prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
281
280
  ),
@@ -55,7 +55,7 @@ async def extract_edge_dates(
55
55
  try:
56
56
  valid_at_datetime = ensure_utc(datetime.fromisoformat(valid_at.replace('Z', '+00:00')))
57
57
  except ValueError as e:
58
- logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
58
+ logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')
59
59
 
60
60
  if invalid_at:
61
61
  try:
@@ -63,7 +63,7 @@ async def extract_edge_dates(
63
63
  datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
64
64
  )
65
65
  except ValueError as e:
66
- logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
66
+ logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
67
67
 
68
68
  return valid_at_datetime, invalid_at_datetime
69
69
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "graphiti-core"
3
- version = "0.5.0pre5"
3
+ version = "0.5.2"
4
4
  description = "A temporal graph building library"
5
5
  authors = [
6
6
  "Paul Paliychuk <paul@getzep.com>",
File without changes