graphiti-core 0.4.2__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (61) hide show
  1. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/PKG-INFO +1 -1
  2. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/bge_reranker_client.py +1 -2
  3. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/client.py +3 -4
  4. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
  5. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/edges.py +56 -7
  6. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/embedder/client.py +3 -3
  7. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/embedder/openai.py +2 -2
  8. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/embedder/voyage.py +3 -3
  9. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/graphiti.py +39 -37
  10. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/helpers.py +26 -0
  11. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/anthropic_client.py +4 -1
  12. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/client.py +45 -5
  13. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/errors.py +8 -0
  14. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/groq_client.py +4 -1
  15. graphiti_core-0.5.0/graphiti_core/llm_client/openai_client.py +163 -0
  16. graphiti_core-0.4.2/graphiti_core/llm_client/openai_client.py → graphiti_core-0.5.0/graphiti_core/llm_client/openai_generic_client.py +67 -3
  17. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/nodes.py +58 -8
  18. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/dedupe_edges.py +20 -17
  19. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/dedupe_nodes.py +15 -1
  20. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/eval.py +17 -14
  21. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_edge_dates.py +15 -7
  22. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_edges.py +18 -19
  23. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_nodes.py +11 -21
  24. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/invalidate_edges.py +13 -25
  25. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/lib.py +5 -1
  26. graphiti_core-0.5.0/graphiti_core/prompts/prompt_helpers.py +1 -0
  27. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/summarize_nodes.py +17 -16
  28. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/search.py +5 -5
  29. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/search_utils.py +55 -14
  30. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/bulk_utils.py +22 -15
  31. graphiti_core-0.5.0/graphiti_core/utils/datetime_utils.py +42 -0
  32. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/community_operations.py +13 -9
  33. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/edge_operations.py +32 -26
  34. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
  35. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/node_operations.py +19 -13
  36. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/temporal_operations.py +17 -9
  37. graphiti_core-0.5.0/graphiti_core/utils/maintenance/utils.py +0 -0
  38. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/pyproject.toml +1 -1
  39. graphiti_core-0.4.2/graphiti_core/utils/__init__.py +0 -15
  40. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/LICENSE +0 -0
  41. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/README.md +0 -0
  42. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/__init__.py +0 -0
  43. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/__init__.py +0 -0
  44. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/embedder/__init__.py +0 -0
  45. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/errors.py +0 -0
  46. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/__init__.py +0 -0
  47. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/config.py +0 -0
  48. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/utils.py +0 -0
  49. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/__init__.py +0 -0
  50. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/edges/__init__.py +0 -0
  51. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/edges/edge_db_queries.py +0 -0
  52. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/nodes/__init__.py +0 -0
  53. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  54. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/__init__.py +0 -0
  55. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/models.py +0 -0
  56. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/py.typed +0 -0
  57. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/__init__.py +0 -0
  58. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/search_config.py +0 -0
  59. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/search_config_recipes.py +0 -0
  60. /graphiti_core-0.4.2/graphiti_core/utils/maintenance/utils.py → /graphiti_core-0.5.0/graphiti_core/utils/__init__.py +0 -0
  61. {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.4.2
3
+ Version: 0.5.0
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -15,7 +15,6 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import asyncio
18
- from typing import List, Tuple
19
18
 
20
19
  from sentence_transformers import CrossEncoder
21
20
 
@@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
26
25
  def __init__(self):
27
26
  self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
28
27
 
29
- async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
28
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
30
29
  if not passages:
31
30
  return []
32
31
 
@@ -15,7 +15,6 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from abc import ABC, abstractmethod
18
- from typing import List, Tuple
19
18
 
20
19
 
21
20
  class CrossEncoderClient(ABC):
@@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
26
25
  """
27
26
 
28
27
  @abstractmethod
29
- async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
28
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
30
29
  """
31
30
  Rank the given passages based on their relevance to the query.
32
31
 
33
32
  Args:
34
33
  query (str): The query string.
35
- passages (List[str]): A list of passages to rank.
34
+ passages (list[str]): A list of passages to rank.
36
35
 
37
36
  Returns:
38
- List[Tuple[str, float]]: A list of tuples containing the passage and its score,
37
+ list[tuple[str, float]]: A list of tuples containing the passage and its score,
39
38
  sorted in descending order of relevance.
40
39
  """
41
40
  pass
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from typing import Any
20
19
 
@@ -22,6 +21,7 @@ import openai
22
21
  from openai import AsyncOpenAI
23
22
  from pydantic import BaseModel
24
23
 
24
+ from ..helpers import semaphore_gather
25
25
  from ..llm_client import LLMConfig, RateLimitError
26
26
  from ..prompts import Message
27
27
  from .client import CrossEncoderClient
@@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
75
75
  for passage in passages
76
76
  ]
77
77
  try:
78
- responses = await asyncio.gather(
78
+ responses = await semaphore_gather(
79
79
  *[
80
80
  self.client.chat.completions.create(
81
81
  model=DEFAULT_MODEL,
@@ -23,6 +23,7 @@ from uuid import uuid4
23
23
 
24
24
  from neo4j import AsyncDriver
25
25
  from pydantic import BaseModel, Field
26
+ from typing_extensions import LiteralString
26
27
 
27
28
  from graphiti_core.embedder import EmbedderClient
28
29
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
@@ -50,7 +51,7 @@ class Edge(BaseModel, ABC):
50
51
  async def delete(self, driver: AsyncDriver):
51
52
  result = await driver.execute_query(
52
53
  """
53
- MATCH (n)-[e {uuid: $uuid}]->(m)
54
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
54
55
  DELETE e
55
56
  """,
56
57
  uuid=self.uuid,
@@ -137,19 +138,35 @@ class EpisodicEdge(Edge):
137
138
  return edges
138
139
 
139
140
  @classmethod
140
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
141
+ async def get_by_group_ids(
142
+ cls,
143
+ driver: AsyncDriver,
144
+ group_ids: list[str],
145
+ limit: int | None = None,
146
+ created_at: datetime | None = None,
147
+ ):
148
+ cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
149
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
150
+
141
151
  records, _, _ = await driver.execute_query(
142
152
  """
143
153
  MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
144
154
  WHERE e.group_id IN $group_ids
155
+ """
156
+ + cursor_query
157
+ + """
145
158
  RETURN
146
159
  e.uuid As uuid,
147
160
  e.group_id AS group_id,
148
161
  n.uuid AS source_node_uuid,
149
162
  m.uuid AS target_node_uuid,
150
163
  e.created_at AS created_at
151
- """,
164
+ ORDER BY e.uuid DESC
165
+ """
166
+ + limit_query,
152
167
  group_ids=group_ids,
168
+ created_at=created_at,
169
+ limit=limit,
153
170
  database_=DEFAULT_DATABASE,
154
171
  routing_='r',
155
172
  )
@@ -274,11 +291,23 @@ class EntityEdge(Edge):
274
291
  return edges
275
292
 
276
293
  @classmethod
277
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
294
+ async def get_by_group_ids(
295
+ cls,
296
+ driver: AsyncDriver,
297
+ group_ids: list[str],
298
+ limit: int | None = None,
299
+ created_at: datetime | None = None,
300
+ ):
301
+ cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
302
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
303
+
278
304
  records, _, _ = await driver.execute_query(
279
305
  """
280
306
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
281
307
  WHERE e.group_id IN $group_ids
308
+ """
309
+ + cursor_query
310
+ + """
282
311
  RETURN
283
312
  e.uuid AS uuid,
284
313
  n.uuid AS source_node_uuid,
@@ -292,8 +321,12 @@ class EntityEdge(Edge):
292
321
  e.expired_at AS expired_at,
293
322
  e.valid_at AS valid_at,
294
323
  e.invalid_at AS invalid_at
295
- """,
324
+ ORDER BY e.uuid DESC
325
+ """
326
+ + limit_query,
296
327
  group_ids=group_ids,
328
+ created_at=created_at,
329
+ limit=limit,
297
330
  database_=DEFAULT_DATABASE,
298
331
  routing_='r',
299
332
  )
@@ -365,19 +398,35 @@ class CommunityEdge(Edge):
365
398
  return edges
366
399
 
367
400
  @classmethod
368
- async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
401
+ async def get_by_group_ids(
402
+ cls,
403
+ driver: AsyncDriver,
404
+ group_ids: list[str],
405
+ limit: int | None = None,
406
+ created_at: datetime | None = None,
407
+ ):
408
+ cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
409
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
410
+
369
411
  records, _, _ = await driver.execute_query(
370
412
  """
371
413
  MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
372
414
  WHERE e.group_id IN $group_ids
415
+ """
416
+ + cursor_query
417
+ + """
373
418
  RETURN
374
419
  e.uuid As uuid,
375
420
  e.group_id AS group_id,
376
421
  n.uuid AS source_node_uuid,
377
422
  m.uuid AS target_node_uuid,
378
423
  e.created_at AS created_at
379
- """,
424
+ ORDER BY e.uuid DESC
425
+ """
426
+ + limit_query,
380
427
  group_ids=group_ids,
428
+ created_at=created_at,
429
+ limit=limit,
381
430
  database_=DEFAULT_DATABASE,
382
431
  routing_='r',
383
432
  )
@@ -15,7 +15,7 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from abc import ABC, abstractmethod
18
- from typing import Iterable, List, Literal
18
+ from collections.abc import Iterable
19
19
 
20
20
  from pydantic import BaseModel, Field
21
21
 
@@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024
23
23
 
24
24
 
25
25
  class EmbedderConfig(BaseModel):
26
- embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
26
+ embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)
27
27
 
28
28
 
29
29
  class EmbedderClient(ABC):
30
30
  @abstractmethod
31
31
  async def create(
32
- self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
32
+ self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
33
33
  ) -> list[float]:
34
34
  pass
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import Iterable, List
17
+ from collections.abc import Iterable
18
18
 
19
19
  from openai import AsyncOpenAI
20
20
  from openai.types import EmbeddingModel
@@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
42
42
  self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
43
43
 
44
44
  async def create(
45
- self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
+ self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
46
46
  ) -> list[float]:
47
47
  result = await self.client.embeddings.create(
48
48
  input=input_data, model=self.config.embedding_model
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import Iterable, List
17
+ from collections.abc import Iterable
18
18
 
19
19
  import voyageai # type: ignore
20
20
  from pydantic import Field
@@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient):
41
41
  self.client = voyageai.AsyncClient(api_key=config.api_key)
42
42
 
43
43
  async def create(
44
- self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
44
+ self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
45
45
  ) -> list[float]:
46
46
  if isinstance(input_data, str):
47
47
  input_list = [input_data]
48
- elif isinstance(input_data, List):
48
+ elif isinstance(input_data, list):
49
49
  input_list = [str(i) for i in input_data if i]
50
50
  else:
51
51
  input_list = [str(i) for i in input_data if i is not None]
@@ -14,9 +14,8 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
- from datetime import datetime, timezone
18
+ from datetime import datetime
20
19
  from time import time
21
20
 
22
21
  from dotenv import load_dotenv
@@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
27
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
28
27
  from graphiti_core.edges import EntityEdge, EpisodicEdge
29
28
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
- from graphiti_core.helpers import DEFAULT_DATABASE
29
+ from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
31
30
  from graphiti_core.llm_client import LLMClient, OpenAIClient
32
31
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
33
32
  from graphiti_core.search.search import SearchConfig, search
@@ -43,10 +42,6 @@ from graphiti_core.search.search_utils import (
43
42
  get_relevant_edges,
44
43
  get_relevant_nodes,
45
44
  )
46
- from graphiti_core.utils import (
47
- build_episodic_edges,
48
- retrieve_episodes,
49
- )
50
45
  from graphiti_core.utils.bulk_utils import (
51
46
  RawEpisode,
52
47
  add_nodes_and_edges_bulk,
@@ -57,12 +52,14 @@ from graphiti_core.utils.bulk_utils import (
57
52
  resolve_edge_pointers,
58
53
  retrieve_previous_episodes_bulk,
59
54
  )
55
+ from graphiti_core.utils.datetime_utils import utc_now
60
56
  from graphiti_core.utils.maintenance.community_operations import (
61
57
  build_communities,
62
58
  remove_communities,
63
59
  update_community,
64
60
  )
65
61
  from graphiti_core.utils.maintenance.edge_operations import (
62
+ build_episodic_edges,
66
63
  dedupe_extracted_edge,
67
64
  extract_edges,
68
65
  resolve_edge_contradictions,
@@ -71,6 +68,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
71
68
  from graphiti_core.utils.maintenance.graph_data_operations import (
72
69
  EPISODE_WINDOW_LEN,
73
70
  build_indices_and_constraints,
71
+ retrieve_episodes,
74
72
  )
75
73
  from graphiti_core.utils.maintenance.node_operations import (
76
74
  extract_nodes,
@@ -313,22 +311,26 @@ class Graphiti:
313
311
  start = time()
314
312
 
315
313
  entity_edges: list[EntityEdge] = []
316
- now = datetime.now(timezone.utc)
314
+ now = utc_now()
317
315
 
318
316
  previous_episodes = await self.retrieve_episodes(
319
317
  reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
320
318
  )
321
- episode = EpisodicNode(
322
- name=name,
323
- group_id=group_id,
324
- labels=[],
325
- source=source,
326
- content=episode_body,
327
- source_description=source_description,
328
- created_at=now,
329
- valid_at=reference_time,
319
+
320
+ episode = (
321
+ await EpisodicNode.get_by_uuid(self.driver, uuid)
322
+ if uuid is not None
323
+ else EpisodicNode(
324
+ name=name,
325
+ group_id=group_id,
326
+ labels=[],
327
+ source=source,
328
+ content=episode_body,
329
+ source_description=source_description,
330
+ created_at=now,
331
+ valid_at=reference_time,
332
+ )
330
333
  )
331
- episode.uuid = uuid if uuid is not None else episode.uuid
332
334
 
333
335
  # Extract entities as nodes
334
336
 
@@ -337,13 +339,13 @@ class Graphiti:
337
339
 
338
340
  # Calculate Embeddings
339
341
 
340
- await asyncio.gather(
342
+ await semaphore_gather(
341
343
  *[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
342
344
  )
343
345
 
344
346
  # Find relevant nodes already in the graph
345
347
  existing_nodes_lists: list[list[EntityNode]] = list(
346
- await asyncio.gather(
348
+ await semaphore_gather(
347
349
  *[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
348
350
  )
349
351
  )
@@ -351,7 +353,7 @@ class Graphiti:
351
353
  # Resolve extracted nodes with nodes already in the graph and extract facts
352
354
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
353
355
 
354
- (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
356
+ (mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
355
357
  resolve_extracted_nodes(
356
358
  self.llm_client,
357
359
  extracted_nodes,
@@ -371,7 +373,7 @@ class Graphiti:
371
373
  )
372
374
 
373
375
  # calculate embeddings
374
- await asyncio.gather(
376
+ await semaphore_gather(
375
377
  *[
376
378
  edge.generate_embedding(self.embedder)
377
379
  for edge in extracted_edges_with_resolved_pointers
@@ -380,7 +382,7 @@ class Graphiti:
380
382
 
381
383
  # Resolve extracted edges with related edges already in the graph
382
384
  related_edges_list: list[list[EntityEdge]] = list(
383
- await asyncio.gather(
385
+ await semaphore_gather(
384
386
  *[
385
387
  get_relevant_edges(
386
388
  self.driver,
@@ -401,7 +403,7 @@ class Graphiti:
401
403
  )
402
404
 
403
405
  existing_source_edges_list: list[list[EntityEdge]] = list(
404
- await asyncio.gather(
406
+ await semaphore_gather(
405
407
  *[
406
408
  get_relevant_edges(
407
409
  self.driver,
@@ -416,7 +418,7 @@ class Graphiti:
416
418
  )
417
419
 
418
420
  existing_target_edges_list: list[list[EntityEdge]] = list(
419
- await asyncio.gather(
421
+ await semaphore_gather(
420
422
  *[
421
423
  get_relevant_edges(
422
424
  self.driver,
@@ -465,7 +467,7 @@ class Graphiti:
465
467
 
466
468
  # Update any communities
467
469
  if update_communities:
468
- await asyncio.gather(
470
+ await semaphore_gather(
469
471
  *[
470
472
  update_community(self.driver, self.llm_client, self.embedder, node)
471
473
  for node in nodes
@@ -518,7 +520,7 @@ class Graphiti:
518
520
  """
519
521
  try:
520
522
  start = time()
521
- now = datetime.now(timezone.utc)
523
+ now = utc_now()
522
524
 
523
525
  episodes = [
524
526
  EpisodicNode(
@@ -535,7 +537,7 @@ class Graphiti:
535
537
  ]
536
538
 
537
539
  # Save all the episodes
538
- await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
540
+ await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
539
541
 
540
542
  # Get previous episode context for each episode
541
543
  episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
@@ -548,19 +550,19 @@ class Graphiti:
548
550
  ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
549
551
 
550
552
  # Generate embeddings
551
- await asyncio.gather(
553
+ await semaphore_gather(
552
554
  *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
553
555
  *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
554
556
  )
555
557
 
556
558
  # Dedupe extracted nodes, compress extracted edges
557
- (nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
559
+ (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
558
560
  dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
559
561
  extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
560
562
  )
561
563
 
562
564
  # save nodes to KG
563
- await asyncio.gather(*[node.save(self.driver) for node in nodes])
565
+ await semaphore_gather(*[node.save(self.driver) for node in nodes])
564
566
 
565
567
  # re-map edge pointers so that they don't point to discard dupe nodes
566
568
  extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
@@ -571,7 +573,7 @@ class Graphiti:
571
573
  )
572
574
 
573
575
  # save episodic edges to KG
574
- await asyncio.gather(
576
+ await semaphore_gather(
575
577
  *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
576
578
  )
577
579
 
@@ -584,7 +586,7 @@ class Graphiti:
584
586
  # invalidate edges
585
587
 
586
588
  # save edges to KG
587
- await asyncio.gather(*[edge.save(self.driver) for edge in edges])
589
+ await semaphore_gather(*[edge.save(self.driver) for edge in edges])
588
590
 
589
591
  end = time()
590
592
  logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
@@ -607,12 +609,12 @@ class Graphiti:
607
609
  self.driver, self.llm_client, group_ids
608
610
  )
609
611
 
610
- await asyncio.gather(
612
+ await semaphore_gather(
611
613
  *[node.generate_name_embedding(self.embedder) for node in community_nodes]
612
614
  )
613
615
 
614
- await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
615
- await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
616
+ await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
617
+ await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
616
618
 
617
619
  return community_nodes
618
620
 
@@ -695,7 +697,7 @@ class Graphiti:
695
697
  async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
696
698
  episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
697
699
 
698
- edges_list = await asyncio.gather(
700
+ edges_list = await semaphore_gather(
699
701
  *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
700
702
  )
701
703
 
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import asyncio
17
18
  import os
19
+ from collections.abc import Coroutine
18
20
  from datetime import datetime
19
21
 
20
22
  import numpy as np
@@ -25,7 +27,9 @@ load_dotenv()
25
27
 
26
28
  DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
27
29
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
30
+ SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
28
31
  MAX_REFLEXION_ITERATIONS = 2
32
+ DEFAULT_PAGE_LIMIT = 20
29
33
 
30
34
 
31
35
  def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
@@ -56,6 +60,12 @@ def lucene_sanitize(query: str) -> str:
56
60
  ':': r'\:',
57
61
  '\\': r'\\',
58
62
  '/': r'\/',
63
+ 'O': r'\O',
64
+ 'R': r'\R',
65
+ 'N': r'\N',
66
+ 'T': r'\T',
67
+ 'A': r'\A',
68
+ 'D': r'\D',
59
69
  }
60
70
  )
61
71
 
@@ -73,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]:
73
83
  else:
74
84
  norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
75
85
  return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
86
+
87
+
88
+ # Use this instead of asyncio.gather() to bound coroutines
89
+ async def semaphore_gather(
90
+ *coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True
91
+ ):
92
+ semaphore = asyncio.Semaphore(max_coroutines)
93
+
94
+ async def _wrap_coroutine(coroutine):
95
+ async with semaphore:
96
+ return await coroutine
97
+
98
+ return await asyncio.gather(
99
+ *(_wrap_coroutine(coroutine) for coroutine in coroutines),
100
+ return_exceptions=return_exceptions,
101
+ )
@@ -20,6 +20,7 @@ import typing
20
20
 
21
21
  import anthropic
22
22
  from anthropic import AsyncAnthropic
23
+ from pydantic import BaseModel
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
@@ -46,7 +47,9 @@ class AnthropicClient(LLMClient):
46
47
  max_retries=1,
47
48
  )
48
49
 
49
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
50
+ async def _generate_response(
51
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
52
+ ) -> dict[str, typing.Any]:
50
53
  system_message = messages[0]
51
54
  user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
52
55
  {'role': 'assistant', 'content': '{'}
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
22
22
 
23
23
  import httpx
24
24
  from diskcache import Cache
25
+ from pydantic import BaseModel
25
26
  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
26
27
 
27
28
  from ..prompts.models import Message
@@ -55,6 +56,28 @@ class LLMClient(ABC):
55
56
  self.cache_enabled = cache
56
57
  self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
57
58
 
59
+ def _clean_input(self, input: str) -> str:
60
+ """Clean input string of invalid unicode and control characters.
61
+
62
+ Args:
63
+ input: Raw input string to be cleaned
64
+
65
+ Returns:
66
+ Cleaned string safe for LLM processing
67
+ """
68
+ # Clean any invalid Unicode
69
+ cleaned = input.encode('utf-8', errors='ignore').decode('utf-8')
70
+
71
+ # Remove zero-width characters and other invisible unicode
72
+ zero_width = '\u200b\u200c\u200d\ufeff\u2060'
73
+ for char in zero_width:
74
+ cleaned = cleaned.replace(char, '')
75
+
76
+ # Remove control characters except newlines, returns, and tabs
77
+ cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t')
78
+
79
+ return cleaned
80
+
58
81
  @retry(
59
82
  stop=stop_after_attempt(4),
60
83
  wait=wait_random_exponential(multiplier=10, min=5, max=120),
@@ -66,14 +89,18 @@ class LLMClient(ABC):
66
89
  else None,
67
90
  reraise=True,
68
91
  )
69
- async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
92
+ async def _generate_response_with_retry(
93
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
94
+ ) -> dict[str, typing.Any]:
70
95
  try:
71
- return await self._generate_response(messages)
96
+ return await self._generate_response(messages, response_model)
72
97
  except (httpx.HTTPStatusError, RateLimitError) as e:
73
98
  raise e
74
99
 
75
100
  @abstractmethod
76
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
101
+ async def _generate_response(
102
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
103
+ ) -> dict[str, typing.Any]:
77
104
  pass
78
105
 
79
106
  def _get_cache_key(self, messages: list[Message]) -> str:
@@ -82,7 +109,17 @@ class LLMClient(ABC):
82
109
  key_str = f'{self.model}:{message_str}'
83
110
  return hashlib.md5(key_str.encode()).hexdigest()
84
111
 
85
- async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
112
+ async def generate_response(
113
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
114
+ ) -> dict[str, typing.Any]:
115
+ if response_model is not None:
116
+ serialized_model = json.dumps(response_model.model_json_schema())
117
+ messages[
118
+ -1
119
+ ].content += (
120
+ f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
121
+ )
122
+
86
123
  if self.cache_enabled:
87
124
  cache_key = self._get_cache_key(messages)
88
125
 
@@ -91,7 +128,10 @@ class LLMClient(ABC):
91
128
  logger.debug(f'Cache hit for {cache_key}')
92
129
  return cached_response
93
130
 
94
- response = await self._generate_response_with_retry(messages)
131
+ for message in messages:
132
+ message.content = self._clean_input(message.content)
133
+
134
+ response = await self._generate_response_with_retry(messages, response_model)
95
135
 
96
136
  if self.cache_enabled:
97
137
  self.cache_dir.set(cache_key, response)