graphiti-core 0.4.2__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/PKG-INFO +1 -1
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/bge_reranker_client.py +1 -2
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/client.py +3 -4
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/edges.py +56 -7
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/embedder/client.py +3 -3
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/embedder/openai.py +2 -2
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/embedder/voyage.py +3 -3
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/graphiti.py +39 -37
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/helpers.py +26 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/anthropic_client.py +4 -1
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/client.py +45 -5
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/errors.py +8 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/groq_client.py +4 -1
- graphiti_core-0.5.0/graphiti_core/llm_client/openai_client.py +163 -0
- graphiti_core-0.4.2/graphiti_core/llm_client/openai_client.py → graphiti_core-0.5.0/graphiti_core/llm_client/openai_generic_client.py +67 -3
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/nodes.py +58 -8
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/dedupe_edges.py +20 -17
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/dedupe_nodes.py +15 -1
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/eval.py +17 -14
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_edge_dates.py +15 -7
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_edges.py +18 -19
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_nodes.py +11 -21
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/invalidate_edges.py +13 -25
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/lib.py +5 -1
- graphiti_core-0.5.0/graphiti_core/prompts/prompt_helpers.py +1 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/summarize_nodes.py +17 -16
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/search.py +5 -5
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/search_utils.py +55 -14
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/bulk_utils.py +22 -15
- graphiti_core-0.5.0/graphiti_core/utils/datetime_utils.py +42 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/community_operations.py +13 -9
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/edge_operations.py +32 -26
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/node_operations.py +19 -13
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/temporal_operations.py +17 -9
- graphiti_core-0.5.0/graphiti_core/utils/maintenance/utils.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/pyproject.toml +1 -1
- graphiti_core-0.4.2/graphiti_core/utils/__init__.py +0 -15
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/LICENSE +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/README.md +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/edges/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/edges/edge_db_queries.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/nodes/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/models/nodes/node_db_queries.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/search_config.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/search/search_config_recipes.py +0 -0
- /graphiti_core-0.4.2/graphiti_core/utils/maintenance/utils.py → /graphiti_core-0.5.0/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/__init__.py +0 -0
{graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/bge_reranker_client.py
RENAMED
|
@@ -15,7 +15,6 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
|
-
from typing import List, Tuple
|
|
19
18
|
|
|
20
19
|
from sentence_transformers import CrossEncoder
|
|
21
20
|
|
|
@@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
|
|
|
26
25
|
def __init__(self):
|
|
27
26
|
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')
|
|
28
27
|
|
|
29
|
-
async def rank(self, query: str, passages:
|
|
28
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
30
29
|
if not passages:
|
|
31
30
|
return []
|
|
32
31
|
|
|
@@ -15,7 +15,6 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
|
-
from typing import List, Tuple
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
class CrossEncoderClient(ABC):
|
|
@@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
|
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
27
|
@abstractmethod
|
|
29
|
-
async def rank(self, query: str, passages:
|
|
28
|
+
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
|
30
29
|
"""
|
|
31
30
|
Rank the given passages based on their relevance to the query.
|
|
32
31
|
|
|
33
32
|
Args:
|
|
34
33
|
query (str): The query string.
|
|
35
|
-
passages (
|
|
34
|
+
passages (list[str]): A list of passages to rank.
|
|
36
35
|
|
|
37
36
|
Returns:
|
|
38
|
-
|
|
37
|
+
list[tuple[str, float]]: A list of tuples containing the passage and its score,
|
|
39
38
|
sorted in descending order of relevance.
|
|
40
39
|
"""
|
|
41
40
|
pass
|
{graphiti_core-0.4.2 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/openai_reranker_client.py
RENAMED
|
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from typing import Any
|
|
20
19
|
|
|
@@ -22,6 +21,7 @@ import openai
|
|
|
22
21
|
from openai import AsyncOpenAI
|
|
23
22
|
from pydantic import BaseModel
|
|
24
23
|
|
|
24
|
+
from ..helpers import semaphore_gather
|
|
25
25
|
from ..llm_client import LLMConfig, RateLimitError
|
|
26
26
|
from ..prompts import Message
|
|
27
27
|
from .client import CrossEncoderClient
|
|
@@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
75
75
|
for passage in passages
|
|
76
76
|
]
|
|
77
77
|
try:
|
|
78
|
-
responses = await
|
|
78
|
+
responses = await semaphore_gather(
|
|
79
79
|
*[
|
|
80
80
|
self.client.chat.completions.create(
|
|
81
81
|
model=DEFAULT_MODEL,
|
|
@@ -23,6 +23,7 @@ from uuid import uuid4
|
|
|
23
23
|
|
|
24
24
|
from neo4j import AsyncDriver
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
|
+
from typing_extensions import LiteralString
|
|
26
27
|
|
|
27
28
|
from graphiti_core.embedder import EmbedderClient
|
|
28
29
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
@@ -50,7 +51,7 @@ class Edge(BaseModel, ABC):
|
|
|
50
51
|
async def delete(self, driver: AsyncDriver):
|
|
51
52
|
result = await driver.execute_query(
|
|
52
53
|
"""
|
|
53
|
-
MATCH (n)-[e {uuid: $uuid}]->(m)
|
|
54
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
54
55
|
DELETE e
|
|
55
56
|
""",
|
|
56
57
|
uuid=self.uuid,
|
|
@@ -137,19 +138,35 @@ class EpisodicEdge(Edge):
|
|
|
137
138
|
return edges
|
|
138
139
|
|
|
139
140
|
@classmethod
|
|
140
|
-
async def get_by_group_ids(
|
|
141
|
+
async def get_by_group_ids(
|
|
142
|
+
cls,
|
|
143
|
+
driver: AsyncDriver,
|
|
144
|
+
group_ids: list[str],
|
|
145
|
+
limit: int | None = None,
|
|
146
|
+
created_at: datetime | None = None,
|
|
147
|
+
):
|
|
148
|
+
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
149
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
150
|
+
|
|
141
151
|
records, _, _ = await driver.execute_query(
|
|
142
152
|
"""
|
|
143
153
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
144
154
|
WHERE e.group_id IN $group_ids
|
|
155
|
+
"""
|
|
156
|
+
+ cursor_query
|
|
157
|
+
+ """
|
|
145
158
|
RETURN
|
|
146
159
|
e.uuid As uuid,
|
|
147
160
|
e.group_id AS group_id,
|
|
148
161
|
n.uuid AS source_node_uuid,
|
|
149
162
|
m.uuid AS target_node_uuid,
|
|
150
163
|
e.created_at AS created_at
|
|
151
|
-
|
|
164
|
+
ORDER BY e.uuid DESC
|
|
165
|
+
"""
|
|
166
|
+
+ limit_query,
|
|
152
167
|
group_ids=group_ids,
|
|
168
|
+
created_at=created_at,
|
|
169
|
+
limit=limit,
|
|
153
170
|
database_=DEFAULT_DATABASE,
|
|
154
171
|
routing_='r',
|
|
155
172
|
)
|
|
@@ -274,11 +291,23 @@ class EntityEdge(Edge):
|
|
|
274
291
|
return edges
|
|
275
292
|
|
|
276
293
|
@classmethod
|
|
277
|
-
async def get_by_group_ids(
|
|
294
|
+
async def get_by_group_ids(
|
|
295
|
+
cls,
|
|
296
|
+
driver: AsyncDriver,
|
|
297
|
+
group_ids: list[str],
|
|
298
|
+
limit: int | None = None,
|
|
299
|
+
created_at: datetime | None = None,
|
|
300
|
+
):
|
|
301
|
+
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
302
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
303
|
+
|
|
278
304
|
records, _, _ = await driver.execute_query(
|
|
279
305
|
"""
|
|
280
306
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
281
307
|
WHERE e.group_id IN $group_ids
|
|
308
|
+
"""
|
|
309
|
+
+ cursor_query
|
|
310
|
+
+ """
|
|
282
311
|
RETURN
|
|
283
312
|
e.uuid AS uuid,
|
|
284
313
|
n.uuid AS source_node_uuid,
|
|
@@ -292,8 +321,12 @@ class EntityEdge(Edge):
|
|
|
292
321
|
e.expired_at AS expired_at,
|
|
293
322
|
e.valid_at AS valid_at,
|
|
294
323
|
e.invalid_at AS invalid_at
|
|
295
|
-
|
|
324
|
+
ORDER BY e.uuid DESC
|
|
325
|
+
"""
|
|
326
|
+
+ limit_query,
|
|
296
327
|
group_ids=group_ids,
|
|
328
|
+
created_at=created_at,
|
|
329
|
+
limit=limit,
|
|
297
330
|
database_=DEFAULT_DATABASE,
|
|
298
331
|
routing_='r',
|
|
299
332
|
)
|
|
@@ -365,19 +398,35 @@ class CommunityEdge(Edge):
|
|
|
365
398
|
return edges
|
|
366
399
|
|
|
367
400
|
@classmethod
|
|
368
|
-
async def get_by_group_ids(
|
|
401
|
+
async def get_by_group_ids(
|
|
402
|
+
cls,
|
|
403
|
+
driver: AsyncDriver,
|
|
404
|
+
group_ids: list[str],
|
|
405
|
+
limit: int | None = None,
|
|
406
|
+
created_at: datetime | None = None,
|
|
407
|
+
):
|
|
408
|
+
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
409
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
410
|
+
|
|
369
411
|
records, _, _ = await driver.execute_query(
|
|
370
412
|
"""
|
|
371
413
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
|
372
414
|
WHERE e.group_id IN $group_ids
|
|
415
|
+
"""
|
|
416
|
+
+ cursor_query
|
|
417
|
+
+ """
|
|
373
418
|
RETURN
|
|
374
419
|
e.uuid As uuid,
|
|
375
420
|
e.group_id AS group_id,
|
|
376
421
|
n.uuid AS source_node_uuid,
|
|
377
422
|
m.uuid AS target_node_uuid,
|
|
378
423
|
e.created_at AS created_at
|
|
379
|
-
|
|
424
|
+
ORDER BY e.uuid DESC
|
|
425
|
+
"""
|
|
426
|
+
+ limit_query,
|
|
380
427
|
group_ids=group_ids,
|
|
428
|
+
created_at=created_at,
|
|
429
|
+
limit=limit,
|
|
381
430
|
database_=DEFAULT_DATABASE,
|
|
382
431
|
routing_='r',
|
|
383
432
|
)
|
|
@@ -15,7 +15,7 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
|
-
from
|
|
18
|
+
from collections.abc import Iterable
|
|
19
19
|
|
|
20
20
|
from pydantic import BaseModel, Field
|
|
21
21
|
|
|
@@ -23,12 +23,12 @@ EMBEDDING_DIM = 1024
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class EmbedderConfig(BaseModel):
|
|
26
|
-
embedding_dim:
|
|
26
|
+
embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class EmbedderClient(ABC):
|
|
30
30
|
@abstractmethod
|
|
31
31
|
async def create(
|
|
32
|
-
self, input_data: str |
|
|
32
|
+
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
33
33
|
) -> list[float]:
|
|
34
34
|
pass
|
|
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from collections.abc import Iterable
|
|
18
18
|
|
|
19
19
|
from openai import AsyncOpenAI
|
|
20
20
|
from openai.types import EmbeddingModel
|
|
@@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
|
|
|
42
42
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
43
43
|
|
|
44
44
|
async def create(
|
|
45
|
-
self, input_data: str |
|
|
45
|
+
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
46
46
|
) -> list[float]:
|
|
47
47
|
result = await self.client.embeddings.create(
|
|
48
48
|
input=input_data, model=self.config.embedding_model
|
|
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from collections.abc import Iterable
|
|
18
18
|
|
|
19
19
|
import voyageai # type: ignore
|
|
20
20
|
from pydantic import Field
|
|
@@ -41,11 +41,11 @@ class VoyageAIEmbedder(EmbedderClient):
|
|
|
41
41
|
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
|
42
42
|
|
|
43
43
|
async def create(
|
|
44
|
-
self, input_data: str |
|
|
44
|
+
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
45
45
|
) -> list[float]:
|
|
46
46
|
if isinstance(input_data, str):
|
|
47
47
|
input_list = [input_data]
|
|
48
|
-
elif isinstance(input_data,
|
|
48
|
+
elif isinstance(input_data, list):
|
|
49
49
|
input_list = [str(i) for i in input_data if i]
|
|
50
50
|
else:
|
|
51
51
|
input_list = [str(i) for i in input_data if i is not None]
|
|
@@ -14,9 +14,8 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
|
-
from datetime import datetime
|
|
18
|
+
from datetime import datetime
|
|
20
19
|
from time import time
|
|
21
20
|
|
|
22
21
|
from dotenv import load_dotenv
|
|
@@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
|
27
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
28
27
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
29
28
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
29
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
31
30
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
32
31
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
33
32
|
from graphiti_core.search.search import SearchConfig, search
|
|
@@ -43,10 +42,6 @@ from graphiti_core.search.search_utils import (
|
|
|
43
42
|
get_relevant_edges,
|
|
44
43
|
get_relevant_nodes,
|
|
45
44
|
)
|
|
46
|
-
from graphiti_core.utils import (
|
|
47
|
-
build_episodic_edges,
|
|
48
|
-
retrieve_episodes,
|
|
49
|
-
)
|
|
50
45
|
from graphiti_core.utils.bulk_utils import (
|
|
51
46
|
RawEpisode,
|
|
52
47
|
add_nodes_and_edges_bulk,
|
|
@@ -57,12 +52,14 @@ from graphiti_core.utils.bulk_utils import (
|
|
|
57
52
|
resolve_edge_pointers,
|
|
58
53
|
retrieve_previous_episodes_bulk,
|
|
59
54
|
)
|
|
55
|
+
from graphiti_core.utils.datetime_utils import utc_now
|
|
60
56
|
from graphiti_core.utils.maintenance.community_operations import (
|
|
61
57
|
build_communities,
|
|
62
58
|
remove_communities,
|
|
63
59
|
update_community,
|
|
64
60
|
)
|
|
65
61
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
62
|
+
build_episodic_edges,
|
|
66
63
|
dedupe_extracted_edge,
|
|
67
64
|
extract_edges,
|
|
68
65
|
resolve_edge_contradictions,
|
|
@@ -71,6 +68,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|
|
71
68
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
72
69
|
EPISODE_WINDOW_LEN,
|
|
73
70
|
build_indices_and_constraints,
|
|
71
|
+
retrieve_episodes,
|
|
74
72
|
)
|
|
75
73
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
76
74
|
extract_nodes,
|
|
@@ -313,22 +311,26 @@ class Graphiti:
|
|
|
313
311
|
start = time()
|
|
314
312
|
|
|
315
313
|
entity_edges: list[EntityEdge] = []
|
|
316
|
-
now =
|
|
314
|
+
now = utc_now()
|
|
317
315
|
|
|
318
316
|
previous_episodes = await self.retrieve_episodes(
|
|
319
317
|
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
|
|
320
318
|
)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
319
|
+
|
|
320
|
+
episode = (
|
|
321
|
+
await EpisodicNode.get_by_uuid(self.driver, uuid)
|
|
322
|
+
if uuid is not None
|
|
323
|
+
else EpisodicNode(
|
|
324
|
+
name=name,
|
|
325
|
+
group_id=group_id,
|
|
326
|
+
labels=[],
|
|
327
|
+
source=source,
|
|
328
|
+
content=episode_body,
|
|
329
|
+
source_description=source_description,
|
|
330
|
+
created_at=now,
|
|
331
|
+
valid_at=reference_time,
|
|
332
|
+
)
|
|
330
333
|
)
|
|
331
|
-
episode.uuid = uuid if uuid is not None else episode.uuid
|
|
332
334
|
|
|
333
335
|
# Extract entities as nodes
|
|
334
336
|
|
|
@@ -337,13 +339,13 @@ class Graphiti:
|
|
|
337
339
|
|
|
338
340
|
# Calculate Embeddings
|
|
339
341
|
|
|
340
|
-
await
|
|
342
|
+
await semaphore_gather(
|
|
341
343
|
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
|
|
342
344
|
)
|
|
343
345
|
|
|
344
346
|
# Find relevant nodes already in the graph
|
|
345
347
|
existing_nodes_lists: list[list[EntityNode]] = list(
|
|
346
|
-
await
|
|
348
|
+
await semaphore_gather(
|
|
347
349
|
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
|
|
348
350
|
)
|
|
349
351
|
)
|
|
@@ -351,7 +353,7 @@ class Graphiti:
|
|
|
351
353
|
# Resolve extracted nodes with nodes already in the graph and extract facts
|
|
352
354
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
353
355
|
|
|
354
|
-
(mentioned_nodes, uuid_map), extracted_edges = await
|
|
356
|
+
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
|
|
355
357
|
resolve_extracted_nodes(
|
|
356
358
|
self.llm_client,
|
|
357
359
|
extracted_nodes,
|
|
@@ -371,7 +373,7 @@ class Graphiti:
|
|
|
371
373
|
)
|
|
372
374
|
|
|
373
375
|
# calculate embeddings
|
|
374
|
-
await
|
|
376
|
+
await semaphore_gather(
|
|
375
377
|
*[
|
|
376
378
|
edge.generate_embedding(self.embedder)
|
|
377
379
|
for edge in extracted_edges_with_resolved_pointers
|
|
@@ -380,7 +382,7 @@ class Graphiti:
|
|
|
380
382
|
|
|
381
383
|
# Resolve extracted edges with related edges already in the graph
|
|
382
384
|
related_edges_list: list[list[EntityEdge]] = list(
|
|
383
|
-
await
|
|
385
|
+
await semaphore_gather(
|
|
384
386
|
*[
|
|
385
387
|
get_relevant_edges(
|
|
386
388
|
self.driver,
|
|
@@ -401,7 +403,7 @@ class Graphiti:
|
|
|
401
403
|
)
|
|
402
404
|
|
|
403
405
|
existing_source_edges_list: list[list[EntityEdge]] = list(
|
|
404
|
-
await
|
|
406
|
+
await semaphore_gather(
|
|
405
407
|
*[
|
|
406
408
|
get_relevant_edges(
|
|
407
409
|
self.driver,
|
|
@@ -416,7 +418,7 @@ class Graphiti:
|
|
|
416
418
|
)
|
|
417
419
|
|
|
418
420
|
existing_target_edges_list: list[list[EntityEdge]] = list(
|
|
419
|
-
await
|
|
421
|
+
await semaphore_gather(
|
|
420
422
|
*[
|
|
421
423
|
get_relevant_edges(
|
|
422
424
|
self.driver,
|
|
@@ -465,7 +467,7 @@ class Graphiti:
|
|
|
465
467
|
|
|
466
468
|
# Update any communities
|
|
467
469
|
if update_communities:
|
|
468
|
-
await
|
|
470
|
+
await semaphore_gather(
|
|
469
471
|
*[
|
|
470
472
|
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
471
473
|
for node in nodes
|
|
@@ -518,7 +520,7 @@ class Graphiti:
|
|
|
518
520
|
"""
|
|
519
521
|
try:
|
|
520
522
|
start = time()
|
|
521
|
-
now =
|
|
523
|
+
now = utc_now()
|
|
522
524
|
|
|
523
525
|
episodes = [
|
|
524
526
|
EpisodicNode(
|
|
@@ -535,7 +537,7 @@ class Graphiti:
|
|
|
535
537
|
]
|
|
536
538
|
|
|
537
539
|
# Save all the episodes
|
|
538
|
-
await
|
|
540
|
+
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
|
|
539
541
|
|
|
540
542
|
# Get previous episode context for each episode
|
|
541
543
|
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
@@ -548,19 +550,19 @@ class Graphiti:
|
|
|
548
550
|
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
|
549
551
|
|
|
550
552
|
# Generate embeddings
|
|
551
|
-
await
|
|
553
|
+
await semaphore_gather(
|
|
552
554
|
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
|
553
555
|
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
|
554
556
|
)
|
|
555
557
|
|
|
556
558
|
# Dedupe extracted nodes, compress extracted edges
|
|
557
|
-
(nodes, uuid_map), extracted_edges_timestamped = await
|
|
559
|
+
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
|
558
560
|
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
|
559
561
|
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
|
560
562
|
)
|
|
561
563
|
|
|
562
564
|
# save nodes to KG
|
|
563
|
-
await
|
|
565
|
+
await semaphore_gather(*[node.save(self.driver) for node in nodes])
|
|
564
566
|
|
|
565
567
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
|
566
568
|
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
|
@@ -571,7 +573,7 @@ class Graphiti:
|
|
|
571
573
|
)
|
|
572
574
|
|
|
573
575
|
# save episodic edges to KG
|
|
574
|
-
await
|
|
576
|
+
await semaphore_gather(
|
|
575
577
|
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
|
576
578
|
)
|
|
577
579
|
|
|
@@ -584,7 +586,7 @@ class Graphiti:
|
|
|
584
586
|
# invalidate edges
|
|
585
587
|
|
|
586
588
|
# save edges to KG
|
|
587
|
-
await
|
|
589
|
+
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
|
|
588
590
|
|
|
589
591
|
end = time()
|
|
590
592
|
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
|
@@ -607,12 +609,12 @@ class Graphiti:
|
|
|
607
609
|
self.driver, self.llm_client, group_ids
|
|
608
610
|
)
|
|
609
611
|
|
|
610
|
-
await
|
|
612
|
+
await semaphore_gather(
|
|
611
613
|
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
|
612
614
|
)
|
|
613
615
|
|
|
614
|
-
await
|
|
615
|
-
await
|
|
616
|
+
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
|
|
617
|
+
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
|
|
616
618
|
|
|
617
619
|
return community_nodes
|
|
618
620
|
|
|
@@ -695,7 +697,7 @@ class Graphiti:
|
|
|
695
697
|
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
|
696
698
|
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
|
697
699
|
|
|
698
|
-
edges_list = await
|
|
700
|
+
edges_list = await semaphore_gather(
|
|
699
701
|
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
700
702
|
)
|
|
701
703
|
|
|
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import asyncio
|
|
17
18
|
import os
|
|
19
|
+
from collections.abc import Coroutine
|
|
18
20
|
from datetime import datetime
|
|
19
21
|
|
|
20
22
|
import numpy as np
|
|
@@ -25,7 +27,9 @@ load_dotenv()
|
|
|
25
27
|
|
|
26
28
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
|
27
29
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
30
|
+
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
28
31
|
MAX_REFLEXION_ITERATIONS = 2
|
|
32
|
+
DEFAULT_PAGE_LIMIT = 20
|
|
29
33
|
|
|
30
34
|
|
|
31
35
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
@@ -56,6 +60,12 @@ def lucene_sanitize(query: str) -> str:
|
|
|
56
60
|
':': r'\:',
|
|
57
61
|
'\\': r'\\',
|
|
58
62
|
'/': r'\/',
|
|
63
|
+
'O': r'\O',
|
|
64
|
+
'R': r'\R',
|
|
65
|
+
'N': r'\N',
|
|
66
|
+
'T': r'\T',
|
|
67
|
+
'A': r'\A',
|
|
68
|
+
'D': r'\D',
|
|
59
69
|
}
|
|
60
70
|
)
|
|
61
71
|
|
|
@@ -73,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]:
|
|
|
73
83
|
else:
|
|
74
84
|
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
|
75
85
|
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# Use this instead of asyncio.gather() to bound coroutines
|
|
89
|
+
async def semaphore_gather(
|
|
90
|
+
*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True
|
|
91
|
+
):
|
|
92
|
+
semaphore = asyncio.Semaphore(max_coroutines)
|
|
93
|
+
|
|
94
|
+
async def _wrap_coroutine(coroutine):
|
|
95
|
+
async with semaphore:
|
|
96
|
+
return await coroutine
|
|
97
|
+
|
|
98
|
+
return await asyncio.gather(
|
|
99
|
+
*(_wrap_coroutine(coroutine) for coroutine in coroutines),
|
|
100
|
+
return_exceptions=return_exceptions,
|
|
101
|
+
)
|
|
@@ -20,6 +20,7 @@ import typing
|
|
|
20
20
|
|
|
21
21
|
import anthropic
|
|
22
22
|
from anthropic import AsyncAnthropic
|
|
23
|
+
from pydantic import BaseModel
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
@@ -46,7 +47,9 @@ class AnthropicClient(LLMClient):
|
|
|
46
47
|
max_retries=1,
|
|
47
48
|
)
|
|
48
49
|
|
|
49
|
-
async def _generate_response(
|
|
50
|
+
async def _generate_response(
|
|
51
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
52
|
+
) -> dict[str, typing.Any]:
|
|
50
53
|
system_message = messages[0]
|
|
51
54
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
|
52
55
|
{'role': 'assistant', 'content': '{'}
|
|
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
|
|
|
22
22
|
|
|
23
23
|
import httpx
|
|
24
24
|
from diskcache import Cache
|
|
25
|
+
from pydantic import BaseModel
|
|
25
26
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
26
27
|
|
|
27
28
|
from ..prompts.models import Message
|
|
@@ -55,6 +56,28 @@ class LLMClient(ABC):
|
|
|
55
56
|
self.cache_enabled = cache
|
|
56
57
|
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
|
57
58
|
|
|
59
|
+
def _clean_input(self, input: str) -> str:
|
|
60
|
+
"""Clean input string of invalid unicode and control characters.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
input: Raw input string to be cleaned
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Cleaned string safe for LLM processing
|
|
67
|
+
"""
|
|
68
|
+
# Clean any invalid Unicode
|
|
69
|
+
cleaned = input.encode('utf-8', errors='ignore').decode('utf-8')
|
|
70
|
+
|
|
71
|
+
# Remove zero-width characters and other invisible unicode
|
|
72
|
+
zero_width = '\u200b\u200c\u200d\ufeff\u2060'
|
|
73
|
+
for char in zero_width:
|
|
74
|
+
cleaned = cleaned.replace(char, '')
|
|
75
|
+
|
|
76
|
+
# Remove control characters except newlines, returns, and tabs
|
|
77
|
+
cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t')
|
|
78
|
+
|
|
79
|
+
return cleaned
|
|
80
|
+
|
|
58
81
|
@retry(
|
|
59
82
|
stop=stop_after_attempt(4),
|
|
60
83
|
wait=wait_random_exponential(multiplier=10, min=5, max=120),
|
|
@@ -66,14 +89,18 @@ class LLMClient(ABC):
|
|
|
66
89
|
else None,
|
|
67
90
|
reraise=True,
|
|
68
91
|
)
|
|
69
|
-
async def _generate_response_with_retry(
|
|
92
|
+
async def _generate_response_with_retry(
|
|
93
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
94
|
+
) -> dict[str, typing.Any]:
|
|
70
95
|
try:
|
|
71
|
-
return await self._generate_response(messages)
|
|
96
|
+
return await self._generate_response(messages, response_model)
|
|
72
97
|
except (httpx.HTTPStatusError, RateLimitError) as e:
|
|
73
98
|
raise e
|
|
74
99
|
|
|
75
100
|
@abstractmethod
|
|
76
|
-
async def _generate_response(
|
|
101
|
+
async def _generate_response(
|
|
102
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
103
|
+
) -> dict[str, typing.Any]:
|
|
77
104
|
pass
|
|
78
105
|
|
|
79
106
|
def _get_cache_key(self, messages: list[Message]) -> str:
|
|
@@ -82,7 +109,17 @@ class LLMClient(ABC):
|
|
|
82
109
|
key_str = f'{self.model}:{message_str}'
|
|
83
110
|
return hashlib.md5(key_str.encode()).hexdigest()
|
|
84
111
|
|
|
85
|
-
async def generate_response(
|
|
112
|
+
async def generate_response(
|
|
113
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
114
|
+
) -> dict[str, typing.Any]:
|
|
115
|
+
if response_model is not None:
|
|
116
|
+
serialized_model = json.dumps(response_model.model_json_schema())
|
|
117
|
+
messages[
|
|
118
|
+
-1
|
|
119
|
+
].content += (
|
|
120
|
+
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
121
|
+
)
|
|
122
|
+
|
|
86
123
|
if self.cache_enabled:
|
|
87
124
|
cache_key = self._get_cache_key(messages)
|
|
88
125
|
|
|
@@ -91,7 +128,10 @@ class LLMClient(ABC):
|
|
|
91
128
|
logger.debug(f'Cache hit for {cache_key}')
|
|
92
129
|
return cached_response
|
|
93
130
|
|
|
94
|
-
|
|
131
|
+
for message in messages:
|
|
132
|
+
message.content = self._clean_input(message.content)
|
|
133
|
+
|
|
134
|
+
response = await self._generate_response_with_retry(messages, response_model)
|
|
95
135
|
|
|
96
136
|
if self.cache_enabled:
|
|
97
137
|
self.cache_dir.set(cache_key, response)
|