graphiti-core 0.5.0rc5__tar.gz → 0.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/PKG-INFO +1 -1
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/graphiti.py +19 -20
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/helpers.py +16 -2
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/anthropic_client.py +5 -2
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/client.py +15 -7
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/config.py +1 -1
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/groq_client.py +5 -2
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/openai_client.py +16 -6
- graphiti_core-0.5.2/graphiti_core/llm_client/openai_generic_client.py +171 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search.py +5 -5
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search_utils.py +48 -11
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/bulk_utils.py +15 -11
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/community_operations.py +6 -4
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/edge_operations.py +8 -5
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/node_operations.py +3 -4
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/temporal_operations.py +2 -2
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/pyproject.toml +1 -1
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/LICENSE +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/README.md +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/client.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/edges.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/embedder/openai.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/embedder/voyage.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/edges/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/edges/edge_db_queries.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/nodes/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/nodes/node_db_queries.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/nodes.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/eval.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/prompt_helpers.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search_config.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search_config_recipes.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/datetime_utils.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/utils.py +0 -0
{graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/openai_reranker_client.py
RENAMED
|
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from typing import Any
|
|
20
19
|
|
|
@@ -22,6 +21,7 @@ import openai
|
|
|
22
21
|
from openai import AsyncOpenAI
|
|
23
22
|
from pydantic import BaseModel
|
|
24
23
|
|
|
24
|
+
from ..helpers import semaphore_gather
|
|
25
25
|
from ..llm_client import LLMConfig, RateLimitError
|
|
26
26
|
from ..prompts import Message
|
|
27
27
|
from .client import CrossEncoderClient
|
|
@@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
75
75
|
for passage in passages
|
|
76
76
|
]
|
|
77
77
|
try:
|
|
78
|
-
responses = await
|
|
78
|
+
responses = await semaphore_gather(
|
|
79
79
|
*[
|
|
80
80
|
self.client.chat.completions.create(
|
|
81
81
|
model=DEFAULT_MODEL,
|
|
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from datetime import datetime
|
|
20
19
|
from time import time
|
|
@@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
|
27
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
28
27
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
29
28
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
29
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
31
30
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
32
31
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
33
32
|
from graphiti_core.search.search import SearchConfig, search
|
|
@@ -340,13 +339,13 @@ class Graphiti:
|
|
|
340
339
|
|
|
341
340
|
# Calculate Embeddings
|
|
342
341
|
|
|
343
|
-
await
|
|
342
|
+
await semaphore_gather(
|
|
344
343
|
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
|
|
345
344
|
)
|
|
346
345
|
|
|
347
346
|
# Find relevant nodes already in the graph
|
|
348
347
|
existing_nodes_lists: list[list[EntityNode]] = list(
|
|
349
|
-
await
|
|
348
|
+
await semaphore_gather(
|
|
350
349
|
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
|
|
351
350
|
)
|
|
352
351
|
)
|
|
@@ -354,7 +353,7 @@ class Graphiti:
|
|
|
354
353
|
# Resolve extracted nodes with nodes already in the graph and extract facts
|
|
355
354
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
356
355
|
|
|
357
|
-
(mentioned_nodes, uuid_map), extracted_edges = await
|
|
356
|
+
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
|
|
358
357
|
resolve_extracted_nodes(
|
|
359
358
|
self.llm_client,
|
|
360
359
|
extracted_nodes,
|
|
@@ -374,7 +373,7 @@ class Graphiti:
|
|
|
374
373
|
)
|
|
375
374
|
|
|
376
375
|
# calculate embeddings
|
|
377
|
-
await
|
|
376
|
+
await semaphore_gather(
|
|
378
377
|
*[
|
|
379
378
|
edge.generate_embedding(self.embedder)
|
|
380
379
|
for edge in extracted_edges_with_resolved_pointers
|
|
@@ -383,7 +382,7 @@ class Graphiti:
|
|
|
383
382
|
|
|
384
383
|
# Resolve extracted edges with related edges already in the graph
|
|
385
384
|
related_edges_list: list[list[EntityEdge]] = list(
|
|
386
|
-
await
|
|
385
|
+
await semaphore_gather(
|
|
387
386
|
*[
|
|
388
387
|
get_relevant_edges(
|
|
389
388
|
self.driver,
|
|
@@ -404,7 +403,7 @@ class Graphiti:
|
|
|
404
403
|
)
|
|
405
404
|
|
|
406
405
|
existing_source_edges_list: list[list[EntityEdge]] = list(
|
|
407
|
-
await
|
|
406
|
+
await semaphore_gather(
|
|
408
407
|
*[
|
|
409
408
|
get_relevant_edges(
|
|
410
409
|
self.driver,
|
|
@@ -419,7 +418,7 @@ class Graphiti:
|
|
|
419
418
|
)
|
|
420
419
|
|
|
421
420
|
existing_target_edges_list: list[list[EntityEdge]] = list(
|
|
422
|
-
await
|
|
421
|
+
await semaphore_gather(
|
|
423
422
|
*[
|
|
424
423
|
get_relevant_edges(
|
|
425
424
|
self.driver,
|
|
@@ -468,7 +467,7 @@ class Graphiti:
|
|
|
468
467
|
|
|
469
468
|
# Update any communities
|
|
470
469
|
if update_communities:
|
|
471
|
-
await
|
|
470
|
+
await semaphore_gather(
|
|
472
471
|
*[
|
|
473
472
|
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
474
473
|
for node in nodes
|
|
@@ -538,7 +537,7 @@ class Graphiti:
|
|
|
538
537
|
]
|
|
539
538
|
|
|
540
539
|
# Save all the episodes
|
|
541
|
-
await
|
|
540
|
+
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
|
|
542
541
|
|
|
543
542
|
# Get previous episode context for each episode
|
|
544
543
|
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
@@ -551,19 +550,19 @@ class Graphiti:
|
|
|
551
550
|
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
|
552
551
|
|
|
553
552
|
# Generate embeddings
|
|
554
|
-
await
|
|
553
|
+
await semaphore_gather(
|
|
555
554
|
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
|
556
555
|
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
|
557
556
|
)
|
|
558
557
|
|
|
559
558
|
# Dedupe extracted nodes, compress extracted edges
|
|
560
|
-
(nodes, uuid_map), extracted_edges_timestamped = await
|
|
559
|
+
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
|
561
560
|
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
|
562
561
|
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
|
563
562
|
)
|
|
564
563
|
|
|
565
564
|
# save nodes to KG
|
|
566
|
-
await
|
|
565
|
+
await semaphore_gather(*[node.save(self.driver) for node in nodes])
|
|
567
566
|
|
|
568
567
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
|
569
568
|
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
|
@@ -574,7 +573,7 @@ class Graphiti:
|
|
|
574
573
|
)
|
|
575
574
|
|
|
576
575
|
# save episodic edges to KG
|
|
577
|
-
await
|
|
576
|
+
await semaphore_gather(
|
|
578
577
|
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
|
579
578
|
)
|
|
580
579
|
|
|
@@ -587,7 +586,7 @@ class Graphiti:
|
|
|
587
586
|
# invalidate edges
|
|
588
587
|
|
|
589
588
|
# save edges to KG
|
|
590
|
-
await
|
|
589
|
+
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
|
|
591
590
|
|
|
592
591
|
end = time()
|
|
593
592
|
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
|
@@ -610,12 +609,12 @@ class Graphiti:
|
|
|
610
609
|
self.driver, self.llm_client, group_ids
|
|
611
610
|
)
|
|
612
611
|
|
|
613
|
-
await
|
|
612
|
+
await semaphore_gather(
|
|
614
613
|
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
|
615
614
|
)
|
|
616
615
|
|
|
617
|
-
await
|
|
618
|
-
await
|
|
616
|
+
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
|
|
617
|
+
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
|
|
619
618
|
|
|
620
619
|
return community_nodes
|
|
621
620
|
|
|
@@ -698,7 +697,7 @@ class Graphiti:
|
|
|
698
697
|
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
|
699
698
|
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
|
700
699
|
|
|
701
|
-
edges_list = await
|
|
700
|
+
edges_list = await semaphore_gather(
|
|
702
701
|
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
703
702
|
)
|
|
704
703
|
|
|
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import asyncio
|
|
17
18
|
import os
|
|
19
|
+
from collections.abc import Coroutine
|
|
18
20
|
from datetime import datetime
|
|
19
21
|
|
|
20
22
|
import numpy as np
|
|
@@ -25,6 +27,7 @@ load_dotenv()
|
|
|
25
27
|
|
|
26
28
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
|
27
29
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
30
|
+
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
28
31
|
MAX_REFLEXION_ITERATIONS = 2
|
|
29
32
|
DEFAULT_PAGE_LIMIT = 20
|
|
30
33
|
|
|
@@ -70,13 +73,24 @@ def lucene_sanitize(query: str) -> str:
|
|
|
70
73
|
return sanitized
|
|
71
74
|
|
|
72
75
|
|
|
73
|
-
def normalize_l2(embedding: list[float])
|
|
76
|
+
def normalize_l2(embedding: list[float]):
|
|
74
77
|
embedding_array = np.array(embedding)
|
|
75
78
|
if embedding_array.ndim == 1:
|
|
76
79
|
norm = np.linalg.norm(embedding_array)
|
|
77
80
|
if norm == 0:
|
|
78
|
-
return
|
|
81
|
+
return [0.0] * len(embedding)
|
|
79
82
|
return (embedding_array / norm).tolist()
|
|
80
83
|
else:
|
|
81
84
|
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
|
82
85
|
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# Use this instead of asyncio.gather() to bound coroutines
|
|
89
|
+
async def semaphore_gather(*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT):
|
|
90
|
+
semaphore = asyncio.Semaphore(max_coroutines)
|
|
91
|
+
|
|
92
|
+
async def _wrap_coroutine(coroutine):
|
|
93
|
+
async with semaphore:
|
|
94
|
+
return await coroutine
|
|
95
|
+
|
|
96
|
+
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
|
@@ -48,7 +48,10 @@ class AnthropicClient(LLMClient):
|
|
|
48
48
|
)
|
|
49
49
|
|
|
50
50
|
async def _generate_response(
|
|
51
|
-
self,
|
|
51
|
+
self,
|
|
52
|
+
messages: list[Message],
|
|
53
|
+
response_model: type[BaseModel] | None = None,
|
|
54
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
52
55
|
) -> dict[str, typing.Any]:
|
|
53
56
|
system_message = messages[0]
|
|
54
57
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
|
@@ -59,7 +62,7 @@ class AnthropicClient(LLMClient):
|
|
|
59
62
|
result = await self.client.messages.create(
|
|
60
63
|
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
|
|
61
64
|
+ system_message.content,
|
|
62
|
-
max_tokens=self.max_tokens,
|
|
65
|
+
max_tokens=max_tokens or self.max_tokens,
|
|
63
66
|
temperature=self.temperature,
|
|
64
67
|
messages=user_messages, # type: ignore
|
|
65
68
|
model=self.model or DEFAULT_MODEL,
|
|
@@ -26,7 +26,7 @@ from pydantic import BaseModel
|
|
|
26
26
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
27
27
|
|
|
28
28
|
from ..prompts.models import Message
|
|
29
|
-
from .config import LLMConfig
|
|
29
|
+
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
|
30
30
|
from .errors import RateLimitError
|
|
31
31
|
|
|
32
32
|
DEFAULT_TEMPERATURE = 0
|
|
@@ -56,7 +56,6 @@ class LLMClient(ABC):
|
|
|
56
56
|
self.cache_enabled = cache
|
|
57
57
|
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
|
58
58
|
|
|
59
|
-
|
|
60
59
|
def _clean_input(self, input: str) -> str:
|
|
61
60
|
"""Clean input string of invalid unicode and control characters.
|
|
62
61
|
|
|
@@ -91,16 +90,22 @@ class LLMClient(ABC):
|
|
|
91
90
|
reraise=True,
|
|
92
91
|
)
|
|
93
92
|
async def _generate_response_with_retry(
|
|
94
|
-
self,
|
|
93
|
+
self,
|
|
94
|
+
messages: list[Message],
|
|
95
|
+
response_model: type[BaseModel] | None = None,
|
|
96
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
95
97
|
) -> dict[str, typing.Any]:
|
|
96
98
|
try:
|
|
97
|
-
return await self._generate_response(messages, response_model)
|
|
99
|
+
return await self._generate_response(messages, response_model, max_tokens)
|
|
98
100
|
except (httpx.HTTPStatusError, RateLimitError) as e:
|
|
99
101
|
raise e
|
|
100
102
|
|
|
101
103
|
@abstractmethod
|
|
102
104
|
async def _generate_response(
|
|
103
|
-
self,
|
|
105
|
+
self,
|
|
106
|
+
messages: list[Message],
|
|
107
|
+
response_model: type[BaseModel] | None = None,
|
|
108
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
104
109
|
) -> dict[str, typing.Any]:
|
|
105
110
|
pass
|
|
106
111
|
|
|
@@ -111,7 +116,10 @@ class LLMClient(ABC):
|
|
|
111
116
|
return hashlib.md5(key_str.encode()).hexdigest()
|
|
112
117
|
|
|
113
118
|
async def generate_response(
|
|
114
|
-
self,
|
|
119
|
+
self,
|
|
120
|
+
messages: list[Message],
|
|
121
|
+
response_model: type[BaseModel] | None = None,
|
|
122
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
115
123
|
) -> dict[str, typing.Any]:
|
|
116
124
|
if response_model is not None:
|
|
117
125
|
serialized_model = json.dumps(response_model.model_json_schema())
|
|
@@ -132,7 +140,7 @@ class LLMClient(ABC):
|
|
|
132
140
|
for message in messages:
|
|
133
141
|
message.content = self._clean_input(message.content)
|
|
134
142
|
|
|
135
|
-
response = await self._generate_response_with_retry(messages, response_model)
|
|
143
|
+
response = await self._generate_response_with_retry(messages, response_model, max_tokens)
|
|
136
144
|
|
|
137
145
|
if self.cache_enabled:
|
|
138
146
|
self.cache_dir.set(cache_key, response)
|
|
@@ -45,7 +45,10 @@ class GroqClient(LLMClient):
|
|
|
45
45
|
self.client = AsyncGroq(api_key=config.api_key)
|
|
46
46
|
|
|
47
47
|
async def _generate_response(
|
|
48
|
-
self,
|
|
48
|
+
self,
|
|
49
|
+
messages: list[Message],
|
|
50
|
+
response_model: type[BaseModel] | None = None,
|
|
51
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
49
52
|
) -> dict[str, typing.Any]:
|
|
50
53
|
msgs: list[ChatCompletionMessageParam] = []
|
|
51
54
|
for m in messages:
|
|
@@ -58,7 +61,7 @@ class GroqClient(LLMClient):
|
|
|
58
61
|
model=self.model or DEFAULT_MODEL,
|
|
59
62
|
messages=msgs,
|
|
60
63
|
temperature=self.temperature,
|
|
61
|
-
max_tokens=self.max_tokens,
|
|
64
|
+
max_tokens=max_tokens or self.max_tokens,
|
|
62
65
|
response_format={'type': 'json_object'},
|
|
63
66
|
)
|
|
64
67
|
result = response.choices[0].message.content or ''
|
|
@@ -25,7 +25,7 @@ from pydantic import BaseModel
|
|
|
25
25
|
|
|
26
26
|
from ..prompts.models import Message
|
|
27
27
|
from .client import LLMClient
|
|
28
|
-
from .config import LLMConfig
|
|
28
|
+
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
|
29
29
|
from .errors import RateLimitError, RefusalError
|
|
30
30
|
|
|
31
31
|
logger = logging.getLogger(__name__)
|
|
@@ -58,7 +58,11 @@ class OpenAIClient(LLMClient):
|
|
|
58
58
|
MAX_RETRIES: ClassVar[int] = 2
|
|
59
59
|
|
|
60
60
|
def __init__(
|
|
61
|
-
self,
|
|
61
|
+
self,
|
|
62
|
+
config: LLMConfig | None = None,
|
|
63
|
+
cache: bool = False,
|
|
64
|
+
client: typing.Any = None,
|
|
65
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
62
66
|
):
|
|
63
67
|
"""
|
|
64
68
|
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
|
@@ -84,7 +88,10 @@ class OpenAIClient(LLMClient):
|
|
|
84
88
|
self.client = client
|
|
85
89
|
|
|
86
90
|
async def _generate_response(
|
|
87
|
-
self,
|
|
91
|
+
self,
|
|
92
|
+
messages: list[Message],
|
|
93
|
+
response_model: type[BaseModel] | None = None,
|
|
94
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
88
95
|
) -> dict[str, typing.Any]:
|
|
89
96
|
openai_messages: list[ChatCompletionMessageParam] = []
|
|
90
97
|
for m in messages:
|
|
@@ -98,7 +105,7 @@ class OpenAIClient(LLMClient):
|
|
|
98
105
|
model=self.model or DEFAULT_MODEL,
|
|
99
106
|
messages=openai_messages,
|
|
100
107
|
temperature=self.temperature,
|
|
101
|
-
max_tokens=self.max_tokens,
|
|
108
|
+
max_tokens=max_tokens or self.max_tokens,
|
|
102
109
|
response_format=response_model, # type: ignore
|
|
103
110
|
)
|
|
104
111
|
|
|
@@ -119,14 +126,17 @@ class OpenAIClient(LLMClient):
|
|
|
119
126
|
raise
|
|
120
127
|
|
|
121
128
|
async def generate_response(
|
|
122
|
-
self,
|
|
129
|
+
self,
|
|
130
|
+
messages: list[Message],
|
|
131
|
+
response_model: type[BaseModel] | None = None,
|
|
132
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
123
133
|
) -> dict[str, typing.Any]:
|
|
124
134
|
retry_count = 0
|
|
125
135
|
last_error = None
|
|
126
136
|
|
|
127
137
|
while retry_count <= self.MAX_RETRIES:
|
|
128
138
|
try:
|
|
129
|
-
response = await self._generate_response(messages, response_model)
|
|
139
|
+
response = await self._generate_response(messages, response_model, max_tokens)
|
|
130
140
|
return response
|
|
131
141
|
except (RateLimitError, RefusalError):
|
|
132
142
|
# These errors should not trigger retries
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import typing
|
|
20
|
+
from typing import ClassVar
|
|
21
|
+
|
|
22
|
+
import openai
|
|
23
|
+
from openai import AsyncOpenAI
|
|
24
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
|
|
27
|
+
from ..prompts.models import Message
|
|
28
|
+
from .client import LLMClient
|
|
29
|
+
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
|
30
|
+
from .errors import RateLimitError, RefusalError
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
DEFAULT_MODEL = 'gpt-4o-mini'
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OpenAIGenericClient(LLMClient):
|
|
38
|
+
"""
|
|
39
|
+
OpenAIClient is a client class for interacting with OpenAI's language models.
|
|
40
|
+
|
|
41
|
+
This class extends the LLMClient and provides methods to initialize the client,
|
|
42
|
+
get an embedder, and generate responses from the language model.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
client (AsyncOpenAI): The OpenAI client used to interact with the API.
|
|
46
|
+
model (str): The model name to use for generating responses.
|
|
47
|
+
temperature (float): The temperature to use for generating responses.
|
|
48
|
+
max_tokens (int): The maximum number of tokens to generate in a response.
|
|
49
|
+
|
|
50
|
+
Methods:
|
|
51
|
+
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
|
|
52
|
+
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
|
|
53
|
+
|
|
54
|
+
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
|
55
|
+
Generates a response from the language model based on the provided messages.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
# Class-level constants
|
|
59
|
+
MAX_RETRIES: ClassVar[int] = 2
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
69
|
+
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
70
|
+
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
71
|
+
|
|
72
|
+
"""
|
|
73
|
+
# removed caching to simplify the `generate_response` override
|
|
74
|
+
if cache:
|
|
75
|
+
raise NotImplementedError('Caching is not implemented for OpenAI')
|
|
76
|
+
|
|
77
|
+
if config is None:
|
|
78
|
+
config = LLMConfig()
|
|
79
|
+
|
|
80
|
+
super().__init__(config, cache)
|
|
81
|
+
|
|
82
|
+
if client is None:
|
|
83
|
+
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
84
|
+
else:
|
|
85
|
+
self.client = client
|
|
86
|
+
|
|
87
|
+
async def _generate_response(
|
|
88
|
+
self,
|
|
89
|
+
messages: list[Message],
|
|
90
|
+
response_model: type[BaseModel] | None = None,
|
|
91
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
92
|
+
) -> dict[str, typing.Any]:
|
|
93
|
+
openai_messages: list[ChatCompletionMessageParam] = []
|
|
94
|
+
for m in messages:
|
|
95
|
+
m.content = self._clean_input(m.content)
|
|
96
|
+
if m.role == 'user':
|
|
97
|
+
openai_messages.append({'role': 'user', 'content': m.content})
|
|
98
|
+
elif m.role == 'system':
|
|
99
|
+
openai_messages.append({'role': 'system', 'content': m.content})
|
|
100
|
+
try:
|
|
101
|
+
response = await self.client.chat.completions.create(
|
|
102
|
+
model=self.model or DEFAULT_MODEL,
|
|
103
|
+
messages=openai_messages,
|
|
104
|
+
temperature=self.temperature,
|
|
105
|
+
max_tokens=self.max_tokens,
|
|
106
|
+
response_format={'type': 'json_object'},
|
|
107
|
+
)
|
|
108
|
+
result = response.choices[0].message.content or ''
|
|
109
|
+
return json.loads(result)
|
|
110
|
+
except openai.RateLimitError as e:
|
|
111
|
+
raise RateLimitError from e
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f'Error in generating LLM response: {e}')
|
|
114
|
+
raise
|
|
115
|
+
|
|
116
|
+
async def generate_response(
|
|
117
|
+
self,
|
|
118
|
+
messages: list[Message],
|
|
119
|
+
response_model: type[BaseModel] | None = None,
|
|
120
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
121
|
+
) -> dict[str, typing.Any]:
|
|
122
|
+
retry_count = 0
|
|
123
|
+
last_error = None
|
|
124
|
+
|
|
125
|
+
if response_model is not None:
|
|
126
|
+
serialized_model = json.dumps(response_model.model_json_schema())
|
|
127
|
+
messages[
|
|
128
|
+
-1
|
|
129
|
+
].content += (
|
|
130
|
+
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
while retry_count <= self.MAX_RETRIES:
|
|
134
|
+
try:
|
|
135
|
+
response = await self._generate_response(
|
|
136
|
+
messages, response_model, max_tokens=max_tokens
|
|
137
|
+
)
|
|
138
|
+
return response
|
|
139
|
+
except (RateLimitError, RefusalError):
|
|
140
|
+
# These errors should not trigger retries
|
|
141
|
+
raise
|
|
142
|
+
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
|
143
|
+
# Let OpenAI's client handle these retries
|
|
144
|
+
raise
|
|
145
|
+
except Exception as e:
|
|
146
|
+
last_error = e
|
|
147
|
+
|
|
148
|
+
# Don't retry if we've hit the max retries
|
|
149
|
+
if retry_count >= self.MAX_RETRIES:
|
|
150
|
+
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
|
151
|
+
raise
|
|
152
|
+
|
|
153
|
+
retry_count += 1
|
|
154
|
+
|
|
155
|
+
# Construct a detailed error message for the LLM
|
|
156
|
+
error_context = (
|
|
157
|
+
f'The previous response attempt was invalid. '
|
|
158
|
+
f'Error type: {e.__class__.__name__}. '
|
|
159
|
+
f'Error details: {str(e)}. '
|
|
160
|
+
f'Please try again with a valid response, ensuring the output matches '
|
|
161
|
+
f'the expected format and constraints.'
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
error_message = Message(role='user', content=error_context)
|
|
165
|
+
messages.append(error_message)
|
|
166
|
+
logger.warning(
|
|
167
|
+
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# If we somehow get here, raise the last error
|
|
171
|
+
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from collections import defaultdict
|
|
20
19
|
from time import time
|
|
@@ -25,6 +24,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
|
25
24
|
from graphiti_core.edges import EntityEdge
|
|
26
25
|
from graphiti_core.embedder import EmbedderClient
|
|
27
26
|
from graphiti_core.errors import SearchRerankerError
|
|
27
|
+
from graphiti_core.helpers import semaphore_gather
|
|
28
28
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
29
29
|
from graphiti_core.search.search_config import (
|
|
30
30
|
DEFAULT_SEARCH_LIMIT,
|
|
@@ -78,7 +78,7 @@ async def search(
|
|
|
78
78
|
|
|
79
79
|
# if group_ids is empty, set it to None
|
|
80
80
|
group_ids = group_ids if group_ids else None
|
|
81
|
-
edges, nodes, communities = await
|
|
81
|
+
edges, nodes, communities = await semaphore_gather(
|
|
82
82
|
edge_search(
|
|
83
83
|
driver,
|
|
84
84
|
cross_encoder,
|
|
@@ -141,7 +141,7 @@ async def edge_search(
|
|
|
141
141
|
return []
|
|
142
142
|
|
|
143
143
|
search_results: list[list[EntityEdge]] = list(
|
|
144
|
-
await
|
|
144
|
+
await semaphore_gather(
|
|
145
145
|
*[
|
|
146
146
|
edge_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
147
147
|
edge_similarity_search(
|
|
@@ -226,7 +226,7 @@ async def node_search(
|
|
|
226
226
|
return []
|
|
227
227
|
|
|
228
228
|
search_results: list[list[EntityNode]] = list(
|
|
229
|
-
await
|
|
229
|
+
await semaphore_gather(
|
|
230
230
|
*[
|
|
231
231
|
node_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
232
232
|
node_similarity_search(
|
|
@@ -295,7 +295,7 @@ async def community_search(
|
|
|
295
295
|
return []
|
|
296
296
|
|
|
297
297
|
search_results: list[list[CommunityNode]] = list(
|
|
298
|
-
await
|
|
298
|
+
await semaphore_gather(
|
|
299
299
|
*[
|
|
300
300
|
community_fulltext_search(driver, query, group_ids, 2 * limit),
|
|
301
301
|
community_similarity_search(
|
|
@@ -14,10 +14,10 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from collections import defaultdict
|
|
20
19
|
from time import time
|
|
20
|
+
from typing import Any
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
from neo4j import AsyncDriver, Query
|
|
@@ -29,6 +29,7 @@ from graphiti_core.helpers import (
|
|
|
29
29
|
USE_PARALLEL_RUNTIME,
|
|
30
30
|
lucene_sanitize,
|
|
31
31
|
normalize_l2,
|
|
32
|
+
semaphore_gather,
|
|
32
33
|
)
|
|
33
34
|
from graphiti_core.nodes import (
|
|
34
35
|
CommunityNode,
|
|
@@ -191,12 +192,27 @@ async def edge_similarity_search(
|
|
|
191
192
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
192
193
|
)
|
|
193
194
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
195
|
+
query_params: dict[str, Any] = {}
|
|
196
|
+
|
|
197
|
+
group_filter_query: LiteralString = ''
|
|
198
|
+
if group_ids is not None:
|
|
199
|
+
group_filter_query += 'WHERE r.group_id IN $group_ids'
|
|
200
|
+
query_params['group_ids'] = group_ids
|
|
201
|
+
query_params['source_node_uuid'] = source_node_uuid
|
|
202
|
+
query_params['target_node_uuid'] = target_node_uuid
|
|
203
|
+
|
|
204
|
+
if source_node_uuid is not None:
|
|
205
|
+
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
|
|
206
|
+
|
|
207
|
+
if target_node_uuid is not None:
|
|
208
|
+
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
|
209
|
+
|
|
210
|
+
query: LiteralString = (
|
|
211
|
+
"""
|
|
212
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
213
|
+
"""
|
|
214
|
+
+ group_filter_query
|
|
215
|
+
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
|
200
216
|
WHERE score > $min_score
|
|
201
217
|
RETURN
|
|
202
218
|
r.uuid AS uuid,
|
|
@@ -214,9 +230,11 @@ async def edge_similarity_search(
|
|
|
214
230
|
ORDER BY score DESC
|
|
215
231
|
LIMIT $limit
|
|
216
232
|
"""
|
|
233
|
+
)
|
|
217
234
|
|
|
218
235
|
records, _, _ = await driver.execute_query(
|
|
219
236
|
runtime_query + query,
|
|
237
|
+
query_params,
|
|
220
238
|
search_vector=search_vector,
|
|
221
239
|
source_uuid=source_node_uuid,
|
|
222
240
|
target_uuid=target_node_uuid,
|
|
@@ -325,11 +343,20 @@ async def node_similarity_search(
|
|
|
325
343
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
326
344
|
)
|
|
327
345
|
|
|
346
|
+
query_params: dict[str, Any] = {}
|
|
347
|
+
|
|
348
|
+
group_filter_query: LiteralString = ''
|
|
349
|
+
if group_ids is not None:
|
|
350
|
+
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
|
351
|
+
query_params['group_ids'] = group_ids
|
|
352
|
+
|
|
328
353
|
records, _, _ = await driver.execute_query(
|
|
329
354
|
runtime_query
|
|
330
355
|
+ """
|
|
331
356
|
MATCH (n:Entity)
|
|
332
|
-
|
|
357
|
+
"""
|
|
358
|
+
+ group_filter_query
|
|
359
|
+
+ """
|
|
333
360
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
|
334
361
|
WHERE score > $min_score
|
|
335
362
|
RETURN
|
|
@@ -342,6 +369,7 @@ async def node_similarity_search(
|
|
|
342
369
|
ORDER BY score DESC
|
|
343
370
|
LIMIT $limit
|
|
344
371
|
""",
|
|
372
|
+
query_params,
|
|
345
373
|
search_vector=search_vector,
|
|
346
374
|
group_ids=group_ids,
|
|
347
375
|
limit=limit,
|
|
@@ -436,11 +464,20 @@ async def community_similarity_search(
|
|
|
436
464
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
437
465
|
)
|
|
438
466
|
|
|
467
|
+
query_params: dict[str, Any] = {}
|
|
468
|
+
|
|
469
|
+
group_filter_query: LiteralString = ''
|
|
470
|
+
if group_ids is not None:
|
|
471
|
+
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
|
472
|
+
query_params['group_ids'] = group_ids
|
|
473
|
+
|
|
439
474
|
records, _, _ = await driver.execute_query(
|
|
440
475
|
runtime_query
|
|
441
476
|
+ """
|
|
442
477
|
MATCH (comm:Community)
|
|
443
|
-
|
|
478
|
+
"""
|
|
479
|
+
+ group_filter_query
|
|
480
|
+
+ """
|
|
444
481
|
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
|
445
482
|
WHERE score > $min_score
|
|
446
483
|
RETURN
|
|
@@ -512,7 +549,7 @@ async def hybrid_node_search(
|
|
|
512
549
|
|
|
513
550
|
start = time()
|
|
514
551
|
results: list[list[EntityNode]] = list(
|
|
515
|
-
await
|
|
552
|
+
await semaphore_gather(
|
|
516
553
|
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
|
517
554
|
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
|
518
555
|
)
|
|
@@ -582,7 +619,7 @@ async def get_relevant_edges(
|
|
|
582
619
|
relevant_edges: list[EntityEdge] = []
|
|
583
620
|
relevant_edge_uuids = set()
|
|
584
621
|
|
|
585
|
-
results = await
|
|
622
|
+
results = await semaphore_gather(
|
|
586
623
|
*[
|
|
587
624
|
edge_similarity_search(
|
|
588
625
|
driver,
|
|
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
import typing
|
|
20
19
|
from collections import defaultdict
|
|
@@ -26,6 +25,7 @@ from numpy import dot, sqrt
|
|
|
26
25
|
from pydantic import BaseModel
|
|
27
26
|
|
|
28
27
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
28
|
+
from graphiti_core.helpers import semaphore_gather
|
|
29
29
|
from graphiti_core.llm_client import LLMClient
|
|
30
30
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
31
31
|
ENTITY_EDGE_SAVE_BULK,
|
|
@@ -71,7 +71,7 @@ class RawEpisode(BaseModel):
|
|
|
71
71
|
async def retrieve_previous_episodes_bulk(
|
|
72
72
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
|
73
73
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
|
74
|
-
previous_episodes_list = await
|
|
74
|
+
previous_episodes_list = await semaphore_gather(
|
|
75
75
|
*[
|
|
76
76
|
retrieve_episodes(
|
|
77
77
|
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
|
|
@@ -118,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
118
118
|
async def extract_nodes_and_edges_bulk(
|
|
119
119
|
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
|
120
120
|
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
|
121
|
-
extracted_nodes_bulk = await
|
|
121
|
+
extracted_nodes_bulk = await semaphore_gather(
|
|
122
122
|
*[
|
|
123
123
|
extract_nodes(llm_client, episode, previous_episodes)
|
|
124
124
|
for episode, previous_episodes in episode_tuples
|
|
@@ -130,7 +130,7 @@ async def extract_nodes_and_edges_bulk(
|
|
|
130
130
|
[episode[1] for episode in episode_tuples],
|
|
131
131
|
)
|
|
132
132
|
|
|
133
|
-
extracted_edges_bulk = await
|
|
133
|
+
extracted_edges_bulk = await semaphore_gather(
|
|
134
134
|
*[
|
|
135
135
|
extract_edges(
|
|
136
136
|
llm_client,
|
|
@@ -171,13 +171,13 @@ async def dedupe_nodes_bulk(
|
|
|
171
171
|
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
|
172
172
|
|
|
173
173
|
existing_nodes_chunks: list[list[EntityNode]] = list(
|
|
174
|
-
await
|
|
174
|
+
await semaphore_gather(
|
|
175
175
|
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
|
|
176
176
|
)
|
|
177
177
|
)
|
|
178
178
|
|
|
179
179
|
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
|
180
|
-
await
|
|
180
|
+
await semaphore_gather(
|
|
181
181
|
*[
|
|
182
182
|
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
|
183
183
|
for i, node_chunk in enumerate(node_chunks)
|
|
@@ -205,13 +205,13 @@ async def dedupe_edges_bulk(
|
|
|
205
205
|
]
|
|
206
206
|
|
|
207
207
|
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
|
208
|
-
await
|
|
208
|
+
await semaphore_gather(
|
|
209
209
|
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
|
|
210
210
|
)
|
|
211
211
|
)
|
|
212
212
|
|
|
213
213
|
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
|
214
|
-
await
|
|
214
|
+
await semaphore_gather(
|
|
215
215
|
*[
|
|
216
216
|
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
|
217
217
|
for i, edge_chunk in enumerate(edge_chunks)
|
|
@@ -292,7 +292,9 @@ async def compress_nodes(
|
|
|
292
292
|
# add both nodes to the shortest chunk
|
|
293
293
|
node_chunks[-1].extend([n, m])
|
|
294
294
|
|
|
295
|
-
results = await
|
|
295
|
+
results = await semaphore_gather(
|
|
296
|
+
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
|
|
297
|
+
)
|
|
296
298
|
|
|
297
299
|
extended_map = dict(uuid_map)
|
|
298
300
|
compressed_nodes: list[EntityNode] = []
|
|
@@ -315,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
|
|
|
315
317
|
# We build a map of the edges based on their source and target nodes.
|
|
316
318
|
edge_chunks = chunk_edges_by_nodes(edges)
|
|
317
319
|
|
|
318
|
-
results = await
|
|
320
|
+
results = await semaphore_gather(
|
|
321
|
+
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
|
|
322
|
+
)
|
|
319
323
|
|
|
320
324
|
compressed_edges: list[EntityEdge] = []
|
|
321
325
|
for edge_chunk in results:
|
|
@@ -368,7 +372,7 @@ async def extract_edge_dates_bulk(
|
|
|
368
372
|
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
|
369
373
|
}
|
|
370
374
|
|
|
371
|
-
results = await
|
|
375
|
+
results = await semaphore_gather(
|
|
372
376
|
*[
|
|
373
377
|
extract_edge_dates(
|
|
374
378
|
llm_client,
|
|
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|
|
7
7
|
|
|
8
8
|
from graphiti_core.edges import CommunityEdge
|
|
9
9
|
from graphiti_core.embedder import EmbedderClient
|
|
10
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
10
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
11
11
|
from graphiti_core.llm_client import LLMClient
|
|
12
12
|
from graphiti_core.nodes import (
|
|
13
13
|
CommunityNode,
|
|
@@ -71,7 +71,7 @@ async def get_community_clusters(
|
|
|
71
71
|
|
|
72
72
|
community_clusters.extend(
|
|
73
73
|
list(
|
|
74
|
-
await
|
|
74
|
+
await semaphore_gather(
|
|
75
75
|
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
|
|
76
76
|
)
|
|
77
77
|
)
|
|
@@ -164,7 +164,7 @@ async def build_community(
|
|
|
164
164
|
odd_one_out = summaries.pop()
|
|
165
165
|
length -= 1
|
|
166
166
|
new_summaries: list[str] = list(
|
|
167
|
-
await
|
|
167
|
+
await semaphore_gather(
|
|
168
168
|
*[
|
|
169
169
|
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
|
170
170
|
for left_summary, right_summary in zip(
|
|
@@ -207,7 +207,9 @@ async def build_communities(
|
|
|
207
207
|
return await build_community(llm_client, cluster)
|
|
208
208
|
|
|
209
209
|
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
|
210
|
-
await
|
|
210
|
+
await semaphore_gather(
|
|
211
|
+
*[limited_build_community(cluster) for cluster in community_clusters]
|
|
212
|
+
)
|
|
211
213
|
)
|
|
212
214
|
|
|
213
215
|
community_nodes: list[CommunityNode] = []
|
{graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/edge_operations.py
RENAMED
|
@@ -14,13 +14,12 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from datetime import datetime
|
|
20
19
|
from time import time
|
|
21
20
|
|
|
22
21
|
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
|
23
|
-
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
|
22
|
+
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
|
24
23
|
from graphiti_core.llm_client import LLMClient
|
|
25
24
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
26
25
|
from graphiti_core.prompts import prompt_library
|
|
@@ -80,6 +79,8 @@ async def extract_edges(
|
|
|
80
79
|
) -> list[EntityEdge]:
|
|
81
80
|
start = time()
|
|
82
81
|
|
|
82
|
+
EXTRACT_EDGES_MAX_TOKENS = 16384
|
|
83
|
+
|
|
83
84
|
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
|
84
85
|
|
|
85
86
|
# Prepare context for LLM
|
|
@@ -94,7 +95,9 @@ async def extract_edges(
|
|
|
94
95
|
reflexion_iterations = 0
|
|
95
96
|
while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
|
96
97
|
llm_response = await llm_client.generate_response(
|
|
97
|
-
prompt_library.extract_edges.edge(context),
|
|
98
|
+
prompt_library.extract_edges.edge(context),
|
|
99
|
+
response_model=ExtractedEdges,
|
|
100
|
+
max_tokens=EXTRACT_EDGES_MAX_TOKENS,
|
|
98
101
|
)
|
|
99
102
|
edges_data = llm_response.get('edges', [])
|
|
100
103
|
|
|
@@ -199,7 +202,7 @@ async def resolve_extracted_edges(
|
|
|
199
202
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
200
203
|
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
|
|
201
204
|
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
|
202
|
-
await
|
|
205
|
+
await semaphore_gather(
|
|
203
206
|
*[
|
|
204
207
|
resolve_extracted_edge(
|
|
205
208
|
llm_client,
|
|
@@ -266,7 +269,7 @@ async def resolve_extracted_edge(
|
|
|
266
269
|
current_episode: EpisodicNode,
|
|
267
270
|
previous_episodes: list[EpisodicNode],
|
|
268
271
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
|
269
|
-
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await
|
|
272
|
+
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather(
|
|
270
273
|
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
|
271
274
|
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
|
272
275
|
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
|
@@ -14,14 +14,13 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from datetime import datetime, timezone
|
|
20
19
|
|
|
21
20
|
from neo4j import AsyncDriver
|
|
22
21
|
from typing_extensions import LiteralString
|
|
23
22
|
|
|
24
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
23
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
25
24
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|
26
25
|
|
|
27
26
|
EPISODE_WINDOW_LEN = 3
|
|
@@ -38,7 +37,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|
|
38
37
|
database_=DEFAULT_DATABASE,
|
|
39
38
|
)
|
|
40
39
|
index_names = [record['name'] for record in records]
|
|
41
|
-
await
|
|
40
|
+
await semaphore_gather(
|
|
42
41
|
*[
|
|
43
42
|
driver.execute_query(
|
|
44
43
|
"""DROP INDEX $name""",
|
|
@@ -82,7 +81,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|
|
82
81
|
|
|
83
82
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
|
84
83
|
|
|
85
|
-
await
|
|
84
|
+
await semaphore_gather(
|
|
86
85
|
*[
|
|
87
86
|
driver.execute_query(
|
|
88
87
|
query,
|
{graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/utils/maintenance/node_operations.py
RENAMED
|
@@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from time import time
|
|
20
19
|
|
|
21
|
-
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
|
20
|
+
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
|
22
21
|
from graphiti_core.llm_client import LLMClient
|
|
23
22
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
24
23
|
from graphiti_core.prompts import prompt_library
|
|
@@ -223,7 +222,7 @@ async def resolve_extracted_nodes(
|
|
|
223
222
|
uuid_map: dict[str, str] = {}
|
|
224
223
|
resolved_nodes: list[EntityNode] = []
|
|
225
224
|
results: list[tuple[EntityNode, dict[str, str]]] = list(
|
|
226
|
-
await
|
|
225
|
+
await semaphore_gather(
|
|
227
226
|
*[
|
|
228
227
|
resolve_extracted_node(
|
|
229
228
|
llm_client, extracted_node, existing_nodes, episode, previous_episodes
|
|
@@ -275,7 +274,7 @@ async def resolve_extracted_node(
|
|
|
275
274
|
else [],
|
|
276
275
|
}
|
|
277
276
|
|
|
278
|
-
llm_response, node_summary_response = await
|
|
277
|
+
llm_response, node_summary_response = await semaphore_gather(
|
|
279
278
|
llm_client.generate_response(
|
|
280
279
|
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
|
|
281
280
|
),
|
|
@@ -55,7 +55,7 @@ async def extract_edge_dates(
|
|
|
55
55
|
try:
|
|
56
56
|
valid_at_datetime = ensure_utc(datetime.fromisoformat(valid_at.replace('Z', '+00:00')))
|
|
57
57
|
except ValueError as e:
|
|
58
|
-
logger.
|
|
58
|
+
logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')
|
|
59
59
|
|
|
60
60
|
if invalid_at:
|
|
61
61
|
try:
|
|
@@ -63,7 +63,7 @@ async def extract_edge_dates(
|
|
|
63
63
|
datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
|
64
64
|
)
|
|
65
65
|
except ValueError as e:
|
|
66
|
-
logger.
|
|
66
|
+
logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
|
67
67
|
|
|
68
68
|
return valid_at_datetime, invalid_at_datetime
|
|
69
69
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/cross_encoder/bge_reranker_client.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/edges/edge_db_queries.py
RENAMED
|
File without changes
|
|
File without changes
|
{graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/models/nodes/node_db_queries.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.5.0rc5 → graphiti_core-0.5.2}/graphiti_core/search/search_config_recipes.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|