graphiti-core 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/__init__.py +1 -2
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +19 -4
- graphiti_core/cross_encoder/openai_reranker_client.py +5 -3
- graphiti_core/driver/__init__.py +1 -2
- graphiti_core/driver/falkordb_driver.py +15 -4
- graphiti_core/embedder/gemini.py +31 -7
- graphiti_core/embedder/voyage.py +12 -1
- graphiti_core/graphiti.py +107 -53
- graphiti_core/llm_client/anthropic_client.py +17 -4
- graphiti_core/llm_client/gemini_client.py +24 -8
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/nodes.py +4 -4
- graphiti_core/prompts/dedupe_edges.py +5 -4
- graphiti_core/prompts/dedupe_nodes.py +3 -3
- graphiti_core/search/search_utils.py +4 -2
- graphiti_core/utils/bulk_utils.py +211 -255
- graphiti_core/utils/maintenance/edge_operations.py +34 -120
- graphiti_core/utils/maintenance/graph_data_operations.py +2 -1
- graphiti_core/utils/maintenance/node_operations.py +11 -58
- {graphiti_core-0.15.0.dist-info → graphiti_core-0.16.0.dist-info}/METADATA +45 -4
- {graphiti_core-0.15.0.dist-info → graphiti_core-0.16.0.dist-info}/RECORD +24 -24
- {graphiti_core-0.15.0.dist-info → graphiti_core-0.16.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.15.0.dist-info → graphiti_core-0.16.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,7 +15,6 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
from .client import CrossEncoderClient
|
|
18
|
-
from .gemini_reranker_client import GeminiRerankerClient
|
|
19
18
|
from .openai_reranker_client import OpenAIRerankerClient
|
|
20
19
|
|
|
21
|
-
__all__ = ['CrossEncoderClient', '
|
|
20
|
+
__all__ = ['CrossEncoderClient', 'OpenAIRerankerClient']
|
|
@@ -15,8 +15,18 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from sentence_transformers import CrossEncoder
|
|
22
|
+
else:
|
|
23
|
+
try:
|
|
24
|
+
from sentence_transformers import CrossEncoder
|
|
25
|
+
except ImportError:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
'sentence-transformers is required for BGERerankerClient. '
|
|
28
|
+
'Install it with: pip install graphiti-core[sentence-transformers]'
|
|
29
|
+
) from None
|
|
20
30
|
|
|
21
31
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
32
|
|
|
@@ -16,24 +16,39 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
18
|
import re
|
|
19
|
-
|
|
20
|
-
from google import genai # type: ignore
|
|
21
|
-
from google.genai import types # type: ignore
|
|
19
|
+
from typing import TYPE_CHECKING
|
|
22
20
|
|
|
23
21
|
from ..helpers import semaphore_gather
|
|
24
22
|
from ..llm_client import LLMConfig, RateLimitError
|
|
25
23
|
from .client import CrossEncoderClient
|
|
26
24
|
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from google import genai
|
|
27
|
+
from google.genai import types
|
|
28
|
+
else:
|
|
29
|
+
try:
|
|
30
|
+
from google import genai
|
|
31
|
+
from google.genai import types
|
|
32
|
+
except ImportError:
|
|
33
|
+
raise ImportError(
|
|
34
|
+
'google-genai is required for GeminiRerankerClient. '
|
|
35
|
+
'Install it with: pip install graphiti-core[google-genai]'
|
|
36
|
+
) from None
|
|
37
|
+
|
|
27
38
|
logger = logging.getLogger(__name__)
|
|
28
39
|
|
|
29
40
|
DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
|
30
41
|
|
|
31
42
|
|
|
32
43
|
class GeminiRerankerClient(CrossEncoderClient):
|
|
44
|
+
"""
|
|
45
|
+
Google Gemini Reranker Client
|
|
46
|
+
"""
|
|
47
|
+
|
|
33
48
|
def __init__(
|
|
34
49
|
self,
|
|
35
50
|
config: LLMConfig | None = None,
|
|
36
|
-
client: genai.Client | None = None,
|
|
51
|
+
client: 'genai.Client | None' = None,
|
|
37
52
|
):
|
|
38
53
|
"""
|
|
39
54
|
Initialize the GeminiRerankerClient with the provided configuration and client.
|
|
@@ -22,7 +22,7 @@ import openai
|
|
|
22
22
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
23
23
|
|
|
24
24
|
from ..helpers import semaphore_gather
|
|
25
|
-
from ..llm_client import LLMConfig, RateLimitError
|
|
25
|
+
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
|
|
26
26
|
from ..prompts import Message
|
|
27
27
|
from .client import CrossEncoderClient
|
|
28
28
|
|
|
@@ -35,7 +35,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
35
35
|
def __init__(
|
|
36
36
|
self,
|
|
37
37
|
config: LLMConfig | None = None,
|
|
38
|
-
client: AsyncOpenAI | AsyncAzureOpenAI | None = None,
|
|
38
|
+
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
|
|
39
39
|
):
|
|
40
40
|
"""
|
|
41
41
|
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
|
@@ -45,7 +45,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
45
45
|
|
|
46
46
|
Args:
|
|
47
47
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
48
|
-
client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
48
|
+
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
49
49
|
"""
|
|
50
50
|
if config is None:
|
|
51
51
|
config = LLMConfig()
|
|
@@ -53,6 +53,8 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
53
53
|
self.config = config
|
|
54
54
|
if client is None:
|
|
55
55
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
56
|
+
elif isinstance(client, OpenAIClient):
|
|
57
|
+
self.client = client.client
|
|
56
58
|
else:
|
|
57
59
|
self.client = client
|
|
58
60
|
|
graphiti_core/driver/__init__.py
CHANGED
|
@@ -16,10 +16,21 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
18
|
from datetime import datetime
|
|
19
|
-
from typing import Any
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
from falkordb
|
|
19
|
+
from typing import TYPE_CHECKING, Any
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from falkordb import Graph as FalkorGraph
|
|
23
|
+
from falkordb.asyncio import FalkorDB
|
|
24
|
+
else:
|
|
25
|
+
try:
|
|
26
|
+
from falkordb import Graph as FalkorGraph
|
|
27
|
+
from falkordb.asyncio import FalkorDB
|
|
28
|
+
except ImportError:
|
|
29
|
+
# If falkordb is not installed, raise an ImportError
|
|
30
|
+
raise ImportError(
|
|
31
|
+
'falkordb is required for FalkorDriver. '
|
|
32
|
+
'Install it with: pip install graphiti-core[falkordb]'
|
|
33
|
+
) from None
|
|
23
34
|
|
|
24
35
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
25
36
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
graphiti_core/embedder/gemini.py
CHANGED
|
@@ -15,9 +15,21 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
from collections.abc import Iterable
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from google import genai
|
|
22
|
+
from google.genai import types
|
|
23
|
+
else:
|
|
24
|
+
try:
|
|
25
|
+
from google import genai
|
|
26
|
+
from google.genai import types
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
'google-genai is required for GeminiEmbedder. '
|
|
30
|
+
'Install it with: pip install graphiti-core[google-genai]'
|
|
31
|
+
) from None
|
|
18
32
|
|
|
19
|
-
from google import genai # type: ignore
|
|
20
|
-
from google.genai import types # type: ignore
|
|
21
33
|
from pydantic import Field
|
|
22
34
|
|
|
23
35
|
from .client import EmbedderClient, EmbedderConfig
|
|
@@ -35,15 +47,27 @@ class GeminiEmbedder(EmbedderClient):
|
|
|
35
47
|
Google Gemini Embedder Client
|
|
36
48
|
"""
|
|
37
49
|
|
|
38
|
-
def __init__(
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
config: GeminiEmbedderConfig | None = None,
|
|
53
|
+
client: 'genai.Client | None' = None,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Initialize the GeminiEmbedder with the provided configuration and client.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
|
|
60
|
+
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
61
|
+
"""
|
|
39
62
|
if config is None:
|
|
40
63
|
config = GeminiEmbedderConfig()
|
|
64
|
+
|
|
41
65
|
self.config = config
|
|
42
66
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
67
|
+
if client is None:
|
|
68
|
+
self.client = genai.Client(api_key=config.api_key)
|
|
69
|
+
else:
|
|
70
|
+
self.client = client
|
|
47
71
|
|
|
48
72
|
async def create(
|
|
49
73
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
graphiti_core/embedder/voyage.py
CHANGED
|
@@ -15,8 +15,19 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
from collections.abc import Iterable
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
import voyageai
|
|
22
|
+
else:
|
|
23
|
+
try:
|
|
24
|
+
import voyageai
|
|
25
|
+
except ImportError:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
'voyageai is required for VoyageAIEmbedderClient. '
|
|
28
|
+
'Install it with: pip install graphiti-core[voyageai]'
|
|
29
|
+
) from None
|
|
18
30
|
|
|
19
|
-
import voyageai # type: ignore
|
|
20
31
|
from pydantic import Field
|
|
21
32
|
|
|
22
33
|
from .client import EmbedderClient, EmbedderConfig
|
graphiti_core/graphiti.py
CHANGED
|
@@ -57,7 +57,6 @@ from graphiti_core.utils.bulk_utils import (
|
|
|
57
57
|
add_nodes_and_edges_bulk,
|
|
58
58
|
dedupe_edges_bulk,
|
|
59
59
|
dedupe_nodes_bulk,
|
|
60
|
-
extract_edge_dates_bulk,
|
|
61
60
|
extract_nodes_and_edges_bulk,
|
|
62
61
|
resolve_edge_pointers,
|
|
63
62
|
retrieve_previous_episodes_bulk,
|
|
@@ -166,7 +165,7 @@ class Graphiti:
|
|
|
166
165
|
self.driver = graph_driver
|
|
167
166
|
else:
|
|
168
167
|
if uri is None:
|
|
169
|
-
raise ValueError(
|
|
168
|
+
raise ValueError('uri must be provided when graph_driver is None')
|
|
170
169
|
self.driver = Neo4jDriver(uri, user, password)
|
|
171
170
|
|
|
172
171
|
self.database = DEFAULT_DATABASE
|
|
@@ -508,7 +507,7 @@ class Graphiti:
|
|
|
508
507
|
|
|
509
508
|
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
|
510
509
|
|
|
511
|
-
episodic_edges = build_episodic_edges(nodes, episode, now)
|
|
510
|
+
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
|
512
511
|
|
|
513
512
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
514
513
|
|
|
@@ -536,8 +535,16 @@ class Graphiti:
|
|
|
536
535
|
except Exception as e:
|
|
537
536
|
raise e
|
|
538
537
|
|
|
539
|
-
|
|
540
|
-
async def add_episode_bulk(
|
|
538
|
+
##### EXPERIMENTAL #####
|
|
539
|
+
async def add_episode_bulk(
|
|
540
|
+
self,
|
|
541
|
+
bulk_episodes: list[RawEpisode],
|
|
542
|
+
group_id: str = '',
|
|
543
|
+
entity_types: dict[str, BaseModel] | None = None,
|
|
544
|
+
excluded_entity_types: list[str] | None = None,
|
|
545
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
546
|
+
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
547
|
+
):
|
|
541
548
|
"""
|
|
542
549
|
Process multiple episodes in bulk and update the graph.
|
|
543
550
|
|
|
@@ -580,8 +587,17 @@ class Graphiti:
|
|
|
580
587
|
|
|
581
588
|
validate_group_id(group_id)
|
|
582
589
|
|
|
590
|
+
# Create default edge type map
|
|
591
|
+
edge_type_map_default = (
|
|
592
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
593
|
+
if edge_types is not None
|
|
594
|
+
else {('Entity', 'Entity'): []}
|
|
595
|
+
)
|
|
596
|
+
|
|
583
597
|
episodes = [
|
|
584
|
-
EpisodicNode(
|
|
598
|
+
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
|
|
599
|
+
if episode.uuid is not None
|
|
600
|
+
else EpisodicNode(
|
|
585
601
|
name=episode.name,
|
|
586
602
|
labels=[],
|
|
587
603
|
source=episode.source,
|
|
@@ -594,68 +610,106 @@ class Graphiti:
|
|
|
594
610
|
for episode in bulk_episodes
|
|
595
611
|
]
|
|
596
612
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
613
|
+
episodes_by_uuid: dict[str, EpisodicNode] = {
|
|
614
|
+
episode.uuid: episode for episode in episodes
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
# Save all episodes
|
|
618
|
+
await add_nodes_and_edges_bulk(
|
|
619
|
+
driver=self.driver,
|
|
620
|
+
episodic_nodes=episodes,
|
|
621
|
+
episodic_edges=[],
|
|
622
|
+
entity_nodes=[],
|
|
623
|
+
entity_edges=[],
|
|
624
|
+
embedder=self.embedder,
|
|
601
625
|
)
|
|
602
626
|
|
|
603
627
|
# Get previous episode context for each episode
|
|
604
|
-
|
|
628
|
+
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
605
629
|
|
|
606
|
-
# Extract all nodes and edges
|
|
607
|
-
(
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
await semaphore_gather(
|
|
615
|
-
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
|
616
|
-
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
|
617
|
-
max_coroutines=self.max_coroutines,
|
|
630
|
+
# Extract all nodes and edges for each episode
|
|
631
|
+
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
|
|
632
|
+
self.clients,
|
|
633
|
+
episode_context,
|
|
634
|
+
edge_type_map=edge_type_map or edge_type_map_default,
|
|
635
|
+
edge_types=edge_types,
|
|
636
|
+
entity_types=entity_types,
|
|
637
|
+
excluded_entity_types=excluded_entity_types,
|
|
618
638
|
)
|
|
619
639
|
|
|
620
|
-
# Dedupe extracted nodes
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
|
624
|
-
max_coroutines=self.max_coroutines,
|
|
640
|
+
# Dedupe extracted nodes in memory
|
|
641
|
+
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
|
|
642
|
+
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
|
625
643
|
)
|
|
626
644
|
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
max_coroutines=self.max_coroutines,
|
|
631
|
-
)
|
|
645
|
+
episodic_edges: list[EpisodicEdge] = []
|
|
646
|
+
for episode_uuid, nodes in nodes_by_episode.items():
|
|
647
|
+
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
|
|
632
648
|
|
|
633
649
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
|
638
|
-
episodic_edges, uuid_map
|
|
639
|
-
)
|
|
650
|
+
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
|
|
651
|
+
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
|
652
|
+
]
|
|
640
653
|
|
|
641
|
-
#
|
|
642
|
-
await
|
|
643
|
-
|
|
644
|
-
|
|
654
|
+
# Dedupe extracted edges in memory
|
|
655
|
+
edges_by_episode = await dedupe_edges_bulk(
|
|
656
|
+
self.clients,
|
|
657
|
+
extracted_edges_bulk_updated,
|
|
658
|
+
episode_context,
|
|
659
|
+
[],
|
|
660
|
+
edge_types or {},
|
|
661
|
+
edge_type_map or edge_type_map_default,
|
|
645
662
|
)
|
|
646
663
|
|
|
647
|
-
#
|
|
648
|
-
|
|
649
|
-
|
|
664
|
+
# Extract node attributes
|
|
665
|
+
nodes_by_uuid: dict[str, EntityNode] = {
|
|
666
|
+
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
|
|
670
|
+
for node in nodes_by_uuid.values():
|
|
671
|
+
episode_uuids: list[str] = []
|
|
672
|
+
for episode_uuid, mentioned_nodes in nodes_by_episode.items():
|
|
673
|
+
for mentioned_node in mentioned_nodes:
|
|
674
|
+
if node.uuid == mentioned_node.uuid:
|
|
675
|
+
episode_uuids.append(episode_uuid)
|
|
676
|
+
break
|
|
677
|
+
|
|
678
|
+
episode_mentions: list[EpisodicNode] = [
|
|
679
|
+
episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
|
|
680
|
+
]
|
|
681
|
+
episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
|
|
682
|
+
|
|
683
|
+
extract_attributes_params.append((node, episode_mentions))
|
|
684
|
+
|
|
685
|
+
new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
|
|
686
|
+
*[
|
|
687
|
+
extract_attributes_from_nodes(
|
|
688
|
+
self.clients,
|
|
689
|
+
[params[0]],
|
|
690
|
+
params[1][0],
|
|
691
|
+
params[1][0:],
|
|
692
|
+
entity_types,
|
|
693
|
+
)
|
|
694
|
+
for params in extract_attributes_params
|
|
695
|
+
]
|
|
650
696
|
)
|
|
651
|
-
logger.debug(f'extracted edge length: {len(edges)}')
|
|
652
697
|
|
|
653
|
-
|
|
698
|
+
hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
|
|
654
699
|
|
|
655
|
-
#
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
700
|
+
# TODO: Resolve nodes and edges against the existing graph
|
|
701
|
+
edges_by_uuid: dict[str, EntityEdge] = {
|
|
702
|
+
edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
# save data to KG
|
|
706
|
+
await add_nodes_and_edges_bulk(
|
|
707
|
+
self.driver,
|
|
708
|
+
episodes,
|
|
709
|
+
episodic_edges,
|
|
710
|
+
hydrated_nodes,
|
|
711
|
+
list(edges_by_uuid.values()),
|
|
712
|
+
self.embedder,
|
|
659
713
|
)
|
|
660
714
|
|
|
661
715
|
end = time()
|
|
@@ -828,7 +882,7 @@ class Graphiti:
|
|
|
828
882
|
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
|
829
883
|
)[0]
|
|
830
884
|
|
|
831
|
-
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
|
885
|
+
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
|
832
886
|
self.llm_client,
|
|
833
887
|
updated_edge,
|
|
834
888
|
related_edges,
|
|
@@ -19,11 +19,8 @@ import logging
|
|
|
19
19
|
import os
|
|
20
20
|
import typing
|
|
21
21
|
from json import JSONDecodeError
|
|
22
|
-
from typing import Literal
|
|
22
|
+
from typing import TYPE_CHECKING, Literal
|
|
23
23
|
|
|
24
|
-
import anthropic
|
|
25
|
-
from anthropic import AsyncAnthropic
|
|
26
|
-
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
|
|
27
24
|
from pydantic import BaseModel, ValidationError
|
|
28
25
|
|
|
29
26
|
from ..prompts.models import Message
|
|
@@ -31,6 +28,22 @@ from .client import LLMClient
|
|
|
31
28
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
32
29
|
from .errors import RateLimitError, RefusalError
|
|
33
30
|
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
import anthropic
|
|
33
|
+
from anthropic import AsyncAnthropic
|
|
34
|
+
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
|
|
35
|
+
else:
|
|
36
|
+
try:
|
|
37
|
+
import anthropic
|
|
38
|
+
from anthropic import AsyncAnthropic
|
|
39
|
+
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
|
|
40
|
+
except ImportError:
|
|
41
|
+
raise ImportError(
|
|
42
|
+
'anthropic is required for AnthropicClient. '
|
|
43
|
+
'Install it with: pip install graphiti-core[anthropic]'
|
|
44
|
+
) from None
|
|
45
|
+
|
|
46
|
+
|
|
34
47
|
logger = logging.getLogger(__name__)
|
|
35
48
|
|
|
36
49
|
AnthropicModel = Literal[
|
|
@@ -17,10 +17,8 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
|
-
from typing import ClassVar
|
|
20
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
21
21
|
|
|
22
|
-
from google import genai # type: ignore
|
|
23
|
-
from google.genai import types # type: ignore
|
|
24
22
|
from pydantic import BaseModel
|
|
25
23
|
|
|
26
24
|
from ..prompts.models import Message
|
|
@@ -28,6 +26,21 @@ from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
|
|
28
26
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
29
27
|
from .errors import RateLimitError
|
|
30
28
|
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from google import genai
|
|
31
|
+
from google.genai import types
|
|
32
|
+
else:
|
|
33
|
+
try:
|
|
34
|
+
from google import genai
|
|
35
|
+
from google.genai import types
|
|
36
|
+
except ImportError:
|
|
37
|
+
# If gemini client is not installed, raise an ImportError
|
|
38
|
+
raise ImportError(
|
|
39
|
+
'google-genai is required for GeminiClient. '
|
|
40
|
+
'Install it with: pip install graphiti-core[google-genai]'
|
|
41
|
+
) from None
|
|
42
|
+
|
|
43
|
+
|
|
31
44
|
logger = logging.getLogger(__name__)
|
|
32
45
|
|
|
33
46
|
DEFAULT_MODEL = 'gemini-2.5-flash'
|
|
@@ -63,6 +76,7 @@ class GeminiClient(LLMClient):
|
|
|
63
76
|
cache: bool = False,
|
|
64
77
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
65
78
|
thinking_config: types.ThinkingConfig | None = None,
|
|
79
|
+
client: 'genai.Client | None' = None,
|
|
66
80
|
):
|
|
67
81
|
"""
|
|
68
82
|
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
|
@@ -72,7 +86,7 @@ class GeminiClient(LLMClient):
|
|
|
72
86
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
73
87
|
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
|
74
88
|
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
|
75
|
-
|
|
89
|
+
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
76
90
|
"""
|
|
77
91
|
if config is None:
|
|
78
92
|
config = LLMConfig()
|
|
@@ -80,10 +94,12 @@ class GeminiClient(LLMClient):
|
|
|
80
94
|
super().__init__(config, cache)
|
|
81
95
|
|
|
82
96
|
self.model = config.model
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
api_key=config.api_key
|
|
86
|
-
|
|
97
|
+
|
|
98
|
+
if client is None:
|
|
99
|
+
self.client = genai.Client(api_key=config.api_key)
|
|
100
|
+
else:
|
|
101
|
+
self.client = client
|
|
102
|
+
|
|
87
103
|
self.max_tokens = max_tokens
|
|
88
104
|
self.thinking_config = thinking_config
|
|
89
105
|
|
|
@@ -17,10 +17,21 @@ limitations under the License.
|
|
|
17
17
|
import json
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
20
21
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
from groq
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
import groq
|
|
24
|
+
from groq import AsyncGroq
|
|
25
|
+
from groq.types.chat import ChatCompletionMessageParam
|
|
26
|
+
else:
|
|
27
|
+
try:
|
|
28
|
+
import groq
|
|
29
|
+
from groq import AsyncGroq
|
|
30
|
+
from groq.types.chat import ChatCompletionMessageParam
|
|
31
|
+
except ImportError:
|
|
32
|
+
raise ImportError(
|
|
33
|
+
'groq is required for GroqClient. Install it with: pip install graphiti-core[groq]'
|
|
34
|
+
) from None
|
|
24
35
|
from pydantic import BaseModel
|
|
25
36
|
|
|
26
37
|
from ..prompts.models import Message
|
graphiti_core/nodes.py
CHANGED
|
@@ -542,12 +542,12 @@ class CommunityNode(Node):
|
|
|
542
542
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
543
543
|
created_at = parse_db_date(record['created_at'])
|
|
544
544
|
valid_at = parse_db_date(record['valid_at'])
|
|
545
|
-
|
|
545
|
+
|
|
546
546
|
if created_at is None:
|
|
547
|
-
raise ValueError(f
|
|
547
|
+
raise ValueError(f'created_at cannot be None for episode {record.get("uuid", "unknown")}')
|
|
548
548
|
if valid_at is None:
|
|
549
|
-
raise ValueError(f
|
|
550
|
-
|
|
549
|
+
raise ValueError(f'valid_at cannot be None for episode {record.get("uuid", "unknown")}')
|
|
550
|
+
|
|
551
551
|
return EpisodicNode(
|
|
552
552
|
content=record['content'],
|
|
553
553
|
created_at=created_at,
|
|
@@ -23,9 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class EdgeDuplicate(BaseModel):
|
|
26
|
-
|
|
26
|
+
duplicate_facts: list[int] = Field(
|
|
27
27
|
...,
|
|
28
|
-
description='
|
|
28
|
+
description='List of ids of any duplicate facts. If no duplicate facts are found, default to empty list.',
|
|
29
29
|
)
|
|
30
30
|
contradicted_facts: list[int] = Field(
|
|
31
31
|
...,
|
|
@@ -75,8 +75,9 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
75
75
|
</NEW EDGE>
|
|
76
76
|
|
|
77
77
|
Task:
|
|
78
|
-
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
|
|
79
|
-
|
|
78
|
+
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
|
|
79
|
+
as part of the list of duplicate_facts.
|
|
80
|
+
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return an empty list.
|
|
80
81
|
|
|
81
82
|
Guidelines:
|
|
82
83
|
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
|
@@ -32,9 +32,9 @@ class NodeDuplicate(BaseModel):
|
|
|
32
32
|
...,
|
|
33
33
|
description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
|
|
34
34
|
)
|
|
35
|
-
|
|
35
|
+
duplicates: list[int] = Field(
|
|
36
36
|
...,
|
|
37
|
-
description='idx of
|
|
37
|
+
description='idx of all duplicate entities.',
|
|
38
38
|
)
|
|
39
39
|
|
|
40
40
|
|
|
@@ -94,7 +94,7 @@ def node(context: dict[str, Any]) -> list[Message]:
|
|
|
94
94
|
1. Compare `new_entity` against each item in `existing_entities`.
|
|
95
95
|
2. If it refers to the same real‐world object or concept, collect its index.
|
|
96
96
|
3. Let `duplicate_idx` = the *first* collected index, or –1 if none.
|
|
97
|
-
4. Let `
|
|
97
|
+
4. Let `duplicates` = the list of *all* collected indices (empty list if none).
|
|
98
98
|
|
|
99
99
|
Also return the full name of the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
|
|
100
100
|
is a duplicate of, or a combination of the two).
|
|
@@ -539,7 +539,8 @@ async def community_fulltext_search(
|
|
|
539
539
|
comm.group_id AS group_id,
|
|
540
540
|
comm.name AS name,
|
|
541
541
|
comm.created_at AS created_at,
|
|
542
|
-
comm.summary AS summary
|
|
542
|
+
comm.summary AS summary,
|
|
543
|
+
comm.name_embedding AS name_embedding
|
|
543
544
|
ORDER BY score DESC
|
|
544
545
|
LIMIT $limit
|
|
545
546
|
"""
|
|
@@ -589,7 +590,8 @@ async def community_similarity_search(
|
|
|
589
590
|
comm.group_id AS group_id,
|
|
590
591
|
comm.name AS name,
|
|
591
592
|
comm.created_at AS created_at,
|
|
592
|
-
comm.summary AS summary
|
|
593
|
+
comm.summary AS summary,
|
|
594
|
+
comm.name_embedding AS name_embedding
|
|
593
595
|
ORDER BY score DESC
|
|
594
596
|
LIMIT $limit
|
|
595
597
|
"""
|