graphiti-core 0.15.1__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/gemini_reranker_client.py +4 -1
- graphiti_core/driver/driver.py +2 -4
- graphiti_core/driver/falkordb_driver.py +9 -7
- graphiti_core/driver/neo4j_driver.py +14 -13
- graphiti_core/edges.py +3 -20
- graphiti_core/embedder/gemini.py +17 -5
- graphiti_core/graphiti.py +107 -57
- graphiti_core/helpers.py +0 -1
- graphiti_core/llm_client/gemini_client.py +8 -5
- graphiti_core/nodes.py +3 -22
- graphiti_core/prompts/dedupe_edges.py +5 -4
- graphiti_core/prompts/dedupe_nodes.py +3 -3
- graphiti_core/search/search_utils.py +1 -20
- graphiti_core/utils/bulk_utils.py +212 -256
- graphiti_core/utils/maintenance/community_operations.py +1 -6
- graphiti_core/utils/maintenance/edge_operations.py +35 -122
- graphiti_core/utils/maintenance/graph_data_operations.py +2 -6
- graphiti_core/utils/maintenance/node_operations.py +11 -58
- {graphiti_core-0.15.1.dist-info → graphiti_core-0.17.0.dist-info}/METADATA +19 -1
- {graphiti_core-0.15.1.dist-info → graphiti_core-0.17.0.dist-info}/RECORD +22 -22
- {graphiti_core-0.15.1.dist-info → graphiti_core-0.17.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.15.1.dist-info → graphiti_core-0.17.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -41,6 +41,10 @@ DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class GeminiRerankerClient(CrossEncoderClient):
|
|
44
|
+
"""
|
|
45
|
+
Google Gemini Reranker Client
|
|
46
|
+
"""
|
|
47
|
+
|
|
44
48
|
def __init__(
|
|
45
49
|
self,
|
|
46
50
|
config: LLMConfig | None = None,
|
|
@@ -57,7 +61,6 @@ class GeminiRerankerClient(CrossEncoderClient):
|
|
|
57
61
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
|
58
62
|
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
59
63
|
"""
|
|
60
|
-
|
|
61
64
|
if config is None:
|
|
62
65
|
config = LLMConfig()
|
|
63
66
|
|
graphiti_core/driver/driver.py
CHANGED
|
@@ -19,8 +19,6 @@ from abc import ABC, abstractmethod
|
|
|
19
19
|
from collections.abc import Coroutine
|
|
20
20
|
from typing import Any
|
|
21
21
|
|
|
22
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
23
|
-
|
|
24
22
|
logger = logging.getLogger(__name__)
|
|
25
23
|
|
|
26
24
|
|
|
@@ -54,7 +52,7 @@ class GraphDriver(ABC):
|
|
|
54
52
|
raise NotImplementedError()
|
|
55
53
|
|
|
56
54
|
@abstractmethod
|
|
57
|
-
def session(self, database: str) -> GraphDriverSession:
|
|
55
|
+
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
58
56
|
raise NotImplementedError()
|
|
59
57
|
|
|
60
58
|
@abstractmethod
|
|
@@ -62,5 +60,5 @@ class GraphDriver(ABC):
|
|
|
62
60
|
raise NotImplementedError()
|
|
63
61
|
|
|
64
62
|
@abstractmethod
|
|
65
|
-
def delete_all_indexes(self, database_: str =
|
|
63
|
+
def delete_all_indexes(self, database_: str | None = None) -> Coroutine:
|
|
66
64
|
raise NotImplementedError()
|
|
@@ -33,7 +33,6 @@ else:
|
|
|
33
33
|
) from None
|
|
34
34
|
|
|
35
35
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
36
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
37
36
|
|
|
38
37
|
logger = logging.getLogger(__name__)
|
|
39
38
|
|
|
@@ -81,6 +80,7 @@ class FalkorDriver(GraphDriver):
|
|
|
81
80
|
username: str | None = None,
|
|
82
81
|
password: str | None = None,
|
|
83
82
|
falkor_db: FalkorDB | None = None,
|
|
83
|
+
database: str = 'default_db',
|
|
84
84
|
):
|
|
85
85
|
"""
|
|
86
86
|
Initialize the FalkorDB driver.
|
|
@@ -95,15 +95,16 @@ class FalkorDriver(GraphDriver):
|
|
|
95
95
|
self.client = falkor_db
|
|
96
96
|
else:
|
|
97
97
|
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
|
98
|
+
self._database = database
|
|
98
99
|
|
|
99
100
|
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
|
100
|
-
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is
|
|
101
|
+
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
|
101
102
|
if graph_name is None:
|
|
102
|
-
graph_name =
|
|
103
|
+
graph_name = self._database
|
|
103
104
|
return self.client.select_graph(graph_name)
|
|
104
105
|
|
|
105
106
|
async def execute_query(self, cypher_query_, **kwargs: Any):
|
|
106
|
-
graph_name = kwargs.pop('database_',
|
|
107
|
+
graph_name = kwargs.pop('database_', self._database)
|
|
107
108
|
graph = self._get_graph(graph_name)
|
|
108
109
|
|
|
109
110
|
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
|
|
@@ -136,7 +137,7 @@ class FalkorDriver(GraphDriver):
|
|
|
136
137
|
|
|
137
138
|
return records, header, None
|
|
138
139
|
|
|
139
|
-
def session(self, database: str | None) -> GraphDriverSession:
|
|
140
|
+
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
140
141
|
return FalkorDriverSession(self._get_graph(database))
|
|
141
142
|
|
|
142
143
|
async def close(self) -> None:
|
|
@@ -148,10 +149,11 @@ class FalkorDriver(GraphDriver):
|
|
|
148
149
|
elif hasattr(self.client.connection, 'close'):
|
|
149
150
|
await self.client.connection.close()
|
|
150
151
|
|
|
151
|
-
async def delete_all_indexes(self, database_: str =
|
|
152
|
+
async def delete_all_indexes(self, database_: str | None = None) -> None:
|
|
153
|
+
database = database_ or self._database
|
|
152
154
|
await self.execute_query(
|
|
153
155
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
154
|
-
database_=
|
|
156
|
+
database_=database,
|
|
155
157
|
)
|
|
156
158
|
|
|
157
159
|
|
|
@@ -22,7 +22,6 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
24
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
25
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
26
25
|
|
|
27
26
|
logger = logging.getLogger(__name__)
|
|
28
27
|
|
|
@@ -30,34 +29,36 @@ logger = logging.getLogger(__name__)
|
|
|
30
29
|
class Neo4jDriver(GraphDriver):
|
|
31
30
|
provider: str = 'neo4j'
|
|
32
31
|
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
uri: str,
|
|
36
|
-
user: str | None,
|
|
37
|
-
password: str | None,
|
|
38
|
-
):
|
|
32
|
+
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
|
39
33
|
super().__init__()
|
|
40
34
|
self.client = AsyncGraphDatabase.driver(
|
|
41
35
|
uri=uri,
|
|
42
36
|
auth=(user or '', password or ''),
|
|
43
37
|
)
|
|
38
|
+
self._database = database
|
|
44
39
|
|
|
45
40
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
|
41
|
+
# Check if database_ is provided in kwargs.
|
|
42
|
+
# If not populated, set the value to retain backwards compatibility
|
|
46
43
|
params = kwargs.pop('params', None)
|
|
44
|
+
if params is None:
|
|
45
|
+
params = {}
|
|
46
|
+
params.setdefault('database_', self._database)
|
|
47
|
+
|
|
47
48
|
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
|
48
49
|
|
|
49
50
|
return result
|
|
50
51
|
|
|
51
|
-
def session(self, database: str) -> GraphDriverSession:
|
|
52
|
-
|
|
52
|
+
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
53
|
+
_database = database or self._database
|
|
54
|
+
return self.client.session(database=_database) # type: ignore
|
|
53
55
|
|
|
54
56
|
async def close(self) -> None:
|
|
55
57
|
return await self.client.close()
|
|
56
58
|
|
|
57
|
-
def delete_all_indexes(
|
|
58
|
-
|
|
59
|
-
) -> Coroutine[Any, Any, EagerResult]:
|
|
59
|
+
def delete_all_indexes(self, database_: str | None = None) -> Coroutine[Any, Any, EagerResult]:
|
|
60
|
+
database = database_ or self._database
|
|
60
61
|
return self.client.execute_query(
|
|
61
62
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
62
|
-
database_=
|
|
63
|
+
database_=database,
|
|
63
64
|
)
|
graphiti_core/edges.py
CHANGED
|
@@ -27,7 +27,7 @@ from typing_extensions import LiteralString
|
|
|
27
27
|
from graphiti_core.driver.driver import GraphDriver
|
|
28
28
|
from graphiti_core.embedder import EmbedderClient
|
|
29
29
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
30
|
-
from graphiti_core.helpers import
|
|
30
|
+
from graphiti_core.helpers import parse_db_date
|
|
31
31
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
32
32
|
COMMUNITY_EDGE_SAVE,
|
|
33
33
|
ENTITY_EDGE_SAVE,
|
|
@@ -71,7 +71,6 @@ class Edge(BaseModel, ABC):
|
|
|
71
71
|
DELETE e
|
|
72
72
|
""",
|
|
73
73
|
uuid=self.uuid,
|
|
74
|
-
database_=DEFAULT_DATABASE,
|
|
75
74
|
)
|
|
76
75
|
|
|
77
76
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
@@ -99,7 +98,6 @@ class EpisodicEdge(Edge):
|
|
|
99
98
|
uuid=self.uuid,
|
|
100
99
|
group_id=self.group_id,
|
|
101
100
|
created_at=self.created_at,
|
|
102
|
-
database_=DEFAULT_DATABASE,
|
|
103
101
|
)
|
|
104
102
|
|
|
105
103
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
@@ -119,7 +117,6 @@ class EpisodicEdge(Edge):
|
|
|
119
117
|
e.created_at AS created_at
|
|
120
118
|
""",
|
|
121
119
|
uuid=uuid,
|
|
122
|
-
database_=DEFAULT_DATABASE,
|
|
123
120
|
routing_='r',
|
|
124
121
|
)
|
|
125
122
|
|
|
@@ -143,7 +140,6 @@ class EpisodicEdge(Edge):
|
|
|
143
140
|
e.created_at AS created_at
|
|
144
141
|
""",
|
|
145
142
|
uuids=uuids,
|
|
146
|
-
database_=DEFAULT_DATABASE,
|
|
147
143
|
routing_='r',
|
|
148
144
|
)
|
|
149
145
|
|
|
@@ -183,7 +179,6 @@ class EpisodicEdge(Edge):
|
|
|
183
179
|
group_ids=group_ids,
|
|
184
180
|
uuid=uuid_cursor,
|
|
185
181
|
limit=limit,
|
|
186
|
-
database_=DEFAULT_DATABASE,
|
|
187
182
|
routing_='r',
|
|
188
183
|
)
|
|
189
184
|
|
|
@@ -231,9 +226,7 @@ class EntityEdge(Edge):
|
|
|
231
226
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
232
227
|
RETURN e.fact_embedding AS fact_embedding
|
|
233
228
|
"""
|
|
234
|
-
records, _, _ = await driver.execute_query(
|
|
235
|
-
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
|
236
|
-
)
|
|
229
|
+
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
|
237
230
|
|
|
238
231
|
if len(records) == 0:
|
|
239
232
|
raise EdgeNotFoundError(self.uuid)
|
|
@@ -261,7 +254,6 @@ class EntityEdge(Edge):
|
|
|
261
254
|
result = await driver.execute_query(
|
|
262
255
|
ENTITY_EDGE_SAVE,
|
|
263
256
|
edge_data=edge_data,
|
|
264
|
-
database_=DEFAULT_DATABASE,
|
|
265
257
|
)
|
|
266
258
|
|
|
267
259
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
@@ -276,7 +268,6 @@ class EntityEdge(Edge):
|
|
|
276
268
|
"""
|
|
277
269
|
+ ENTITY_EDGE_RETURN,
|
|
278
270
|
uuid=uuid,
|
|
279
|
-
database_=DEFAULT_DATABASE,
|
|
280
271
|
routing_='r',
|
|
281
272
|
)
|
|
282
273
|
|
|
@@ -298,7 +289,6 @@ class EntityEdge(Edge):
|
|
|
298
289
|
"""
|
|
299
290
|
+ ENTITY_EDGE_RETURN,
|
|
300
291
|
uuids=uuids,
|
|
301
|
-
database_=DEFAULT_DATABASE,
|
|
302
292
|
routing_='r',
|
|
303
293
|
)
|
|
304
294
|
|
|
@@ -331,7 +321,6 @@ class EntityEdge(Edge):
|
|
|
331
321
|
group_ids=group_ids,
|
|
332
322
|
uuid=uuid_cursor,
|
|
333
323
|
limit=limit,
|
|
334
|
-
database_=DEFAULT_DATABASE,
|
|
335
324
|
routing_='r',
|
|
336
325
|
)
|
|
337
326
|
|
|
@@ -349,9 +338,7 @@ class EntityEdge(Edge):
|
|
|
349
338
|
"""
|
|
350
339
|
+ ENTITY_EDGE_RETURN
|
|
351
340
|
)
|
|
352
|
-
records, _, _ = await driver.execute_query(
|
|
353
|
-
query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
|
|
354
|
-
)
|
|
341
|
+
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
|
355
342
|
|
|
356
343
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
357
344
|
|
|
@@ -367,7 +354,6 @@ class CommunityEdge(Edge):
|
|
|
367
354
|
uuid=self.uuid,
|
|
368
355
|
group_id=self.group_id,
|
|
369
356
|
created_at=self.created_at,
|
|
370
|
-
database_=DEFAULT_DATABASE,
|
|
371
357
|
)
|
|
372
358
|
|
|
373
359
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
@@ -387,7 +373,6 @@ class CommunityEdge(Edge):
|
|
|
387
373
|
e.created_at AS created_at
|
|
388
374
|
""",
|
|
389
375
|
uuid=uuid,
|
|
390
|
-
database_=DEFAULT_DATABASE,
|
|
391
376
|
routing_='r',
|
|
392
377
|
)
|
|
393
378
|
|
|
@@ -409,7 +394,6 @@ class CommunityEdge(Edge):
|
|
|
409
394
|
e.created_at AS created_at
|
|
410
395
|
""",
|
|
411
396
|
uuids=uuids,
|
|
412
|
-
database_=DEFAULT_DATABASE,
|
|
413
397
|
routing_='r',
|
|
414
398
|
)
|
|
415
399
|
|
|
@@ -447,7 +431,6 @@ class CommunityEdge(Edge):
|
|
|
447
431
|
group_ids=group_ids,
|
|
448
432
|
uuid=uuid_cursor,
|
|
449
433
|
limit=limit,
|
|
450
|
-
database_=DEFAULT_DATABASE,
|
|
451
434
|
routing_='r',
|
|
452
435
|
)
|
|
453
436
|
|
graphiti_core/embedder/gemini.py
CHANGED
|
@@ -47,15 +47,27 @@ class GeminiEmbedder(EmbedderClient):
|
|
|
47
47
|
Google Gemini Embedder Client
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
|
-
def __init__(
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
config: GeminiEmbedderConfig | None = None,
|
|
53
|
+
client: 'genai.Client | None' = None,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Initialize the GeminiEmbedder with the provided configuration and client.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
|
|
60
|
+
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
61
|
+
"""
|
|
51
62
|
if config is None:
|
|
52
63
|
config = GeminiEmbedderConfig()
|
|
64
|
+
|
|
53
65
|
self.config = config
|
|
54
66
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
67
|
+
if client is None:
|
|
68
|
+
self.client = genai.Client(api_key=config.api_key)
|
|
69
|
+
else:
|
|
70
|
+
self.client = client
|
|
59
71
|
|
|
60
72
|
async def create(
|
|
61
73
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
graphiti_core/graphiti.py
CHANGED
|
@@ -30,7 +30,6 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
|
30
30
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
31
31
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
32
32
|
from graphiti_core.helpers import (
|
|
33
|
-
DEFAULT_DATABASE,
|
|
34
33
|
semaphore_gather,
|
|
35
34
|
validate_excluded_entity_types,
|
|
36
35
|
validate_group_id,
|
|
@@ -57,7 +56,6 @@ from graphiti_core.utils.bulk_utils import (
|
|
|
57
56
|
add_nodes_and_edges_bulk,
|
|
58
57
|
dedupe_edges_bulk,
|
|
59
58
|
dedupe_nodes_bulk,
|
|
60
|
-
extract_edge_dates_bulk,
|
|
61
59
|
extract_nodes_and_edges_bulk,
|
|
62
60
|
resolve_edge_pointers,
|
|
63
61
|
retrieve_previous_episodes_bulk,
|
|
@@ -169,7 +167,6 @@ class Graphiti:
|
|
|
169
167
|
raise ValueError('uri must be provided when graph_driver is None')
|
|
170
168
|
self.driver = Neo4jDriver(uri, user, password)
|
|
171
169
|
|
|
172
|
-
self.database = DEFAULT_DATABASE
|
|
173
170
|
self.store_raw_episode_content = store_raw_episode_content
|
|
174
171
|
self.max_coroutines = max_coroutines
|
|
175
172
|
if llm_client:
|
|
@@ -508,7 +505,7 @@ class Graphiti:
|
|
|
508
505
|
|
|
509
506
|
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
|
510
507
|
|
|
511
|
-
episodic_edges = build_episodic_edges(nodes, episode, now)
|
|
508
|
+
episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
|
|
512
509
|
|
|
513
510
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
514
511
|
|
|
@@ -536,8 +533,16 @@ class Graphiti:
|
|
|
536
533
|
except Exception as e:
|
|
537
534
|
raise e
|
|
538
535
|
|
|
539
|
-
|
|
540
|
-
async def add_episode_bulk(
|
|
536
|
+
##### EXPERIMENTAL #####
|
|
537
|
+
async def add_episode_bulk(
|
|
538
|
+
self,
|
|
539
|
+
bulk_episodes: list[RawEpisode],
|
|
540
|
+
group_id: str = '',
|
|
541
|
+
entity_types: dict[str, BaseModel] | None = None,
|
|
542
|
+
excluded_entity_types: list[str] | None = None,
|
|
543
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
544
|
+
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
|
545
|
+
):
|
|
541
546
|
"""
|
|
542
547
|
Process multiple episodes in bulk and update the graph.
|
|
543
548
|
|
|
@@ -580,8 +585,17 @@ class Graphiti:
|
|
|
580
585
|
|
|
581
586
|
validate_group_id(group_id)
|
|
582
587
|
|
|
588
|
+
# Create default edge type map
|
|
589
|
+
edge_type_map_default = (
|
|
590
|
+
{('Entity', 'Entity'): list(edge_types.keys())}
|
|
591
|
+
if edge_types is not None
|
|
592
|
+
else {('Entity', 'Entity'): []}
|
|
593
|
+
)
|
|
594
|
+
|
|
583
595
|
episodes = [
|
|
584
|
-
EpisodicNode(
|
|
596
|
+
await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
|
|
597
|
+
if episode.uuid is not None
|
|
598
|
+
else EpisodicNode(
|
|
585
599
|
name=episode.name,
|
|
586
600
|
labels=[],
|
|
587
601
|
source=episode.source,
|
|
@@ -594,68 +608,106 @@ class Graphiti:
|
|
|
594
608
|
for episode in bulk_episodes
|
|
595
609
|
]
|
|
596
610
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
611
|
+
episodes_by_uuid: dict[str, EpisodicNode] = {
|
|
612
|
+
episode.uuid: episode for episode in episodes
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
# Save all episodes
|
|
616
|
+
await add_nodes_and_edges_bulk(
|
|
617
|
+
driver=self.driver,
|
|
618
|
+
episodic_nodes=episodes,
|
|
619
|
+
episodic_edges=[],
|
|
620
|
+
entity_nodes=[],
|
|
621
|
+
entity_edges=[],
|
|
622
|
+
embedder=self.embedder,
|
|
601
623
|
)
|
|
602
624
|
|
|
603
625
|
# Get previous episode context for each episode
|
|
604
|
-
|
|
626
|
+
episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
605
627
|
|
|
606
|
-
# Extract all nodes and edges
|
|
607
|
-
(
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
await semaphore_gather(
|
|
615
|
-
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
|
616
|
-
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
|
617
|
-
max_coroutines=self.max_coroutines,
|
|
628
|
+
# Extract all nodes and edges for each episode
|
|
629
|
+
extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
|
|
630
|
+
self.clients,
|
|
631
|
+
episode_context,
|
|
632
|
+
edge_type_map=edge_type_map or edge_type_map_default,
|
|
633
|
+
edge_types=edge_types,
|
|
634
|
+
entity_types=entity_types,
|
|
635
|
+
excluded_entity_types=excluded_entity_types,
|
|
618
636
|
)
|
|
619
637
|
|
|
620
|
-
# Dedupe extracted nodes
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
|
624
|
-
max_coroutines=self.max_coroutines,
|
|
638
|
+
# Dedupe extracted nodes in memory
|
|
639
|
+
nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
|
|
640
|
+
self.clients, extracted_nodes_bulk, episode_context, entity_types
|
|
625
641
|
)
|
|
626
642
|
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
max_coroutines=self.max_coroutines,
|
|
631
|
-
)
|
|
643
|
+
episodic_edges: list[EpisodicEdge] = []
|
|
644
|
+
for episode_uuid, nodes in nodes_by_episode.items():
|
|
645
|
+
episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
|
|
632
646
|
|
|
633
647
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
|
638
|
-
episodic_edges, uuid_map
|
|
639
|
-
)
|
|
648
|
+
extracted_edges_bulk_updated: list[list[EntityEdge]] = [
|
|
649
|
+
resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
|
|
650
|
+
]
|
|
640
651
|
|
|
641
|
-
#
|
|
642
|
-
await
|
|
643
|
-
|
|
644
|
-
|
|
652
|
+
# Dedupe extracted edges in memory
|
|
653
|
+
edges_by_episode = await dedupe_edges_bulk(
|
|
654
|
+
self.clients,
|
|
655
|
+
extracted_edges_bulk_updated,
|
|
656
|
+
episode_context,
|
|
657
|
+
[],
|
|
658
|
+
edge_types or {},
|
|
659
|
+
edge_type_map or edge_type_map_default,
|
|
645
660
|
)
|
|
646
661
|
|
|
647
|
-
#
|
|
648
|
-
|
|
649
|
-
|
|
662
|
+
# Extract node attributes
|
|
663
|
+
nodes_by_uuid: dict[str, EntityNode] = {
|
|
664
|
+
node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
|
|
668
|
+
for node in nodes_by_uuid.values():
|
|
669
|
+
episode_uuids: list[str] = []
|
|
670
|
+
for episode_uuid, mentioned_nodes in nodes_by_episode.items():
|
|
671
|
+
for mentioned_node in mentioned_nodes:
|
|
672
|
+
if node.uuid == mentioned_node.uuid:
|
|
673
|
+
episode_uuids.append(episode_uuid)
|
|
674
|
+
break
|
|
675
|
+
|
|
676
|
+
episode_mentions: list[EpisodicNode] = [
|
|
677
|
+
episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
|
|
678
|
+
]
|
|
679
|
+
episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
|
|
680
|
+
|
|
681
|
+
extract_attributes_params.append((node, episode_mentions))
|
|
682
|
+
|
|
683
|
+
new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
|
|
684
|
+
*[
|
|
685
|
+
extract_attributes_from_nodes(
|
|
686
|
+
self.clients,
|
|
687
|
+
[params[0]],
|
|
688
|
+
params[1][0],
|
|
689
|
+
params[1][0:],
|
|
690
|
+
entity_types,
|
|
691
|
+
)
|
|
692
|
+
for params in extract_attributes_params
|
|
693
|
+
]
|
|
650
694
|
)
|
|
651
|
-
logger.debug(f'extracted edge length: {len(edges)}')
|
|
652
695
|
|
|
653
|
-
|
|
696
|
+
hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
|
|
654
697
|
|
|
655
|
-
#
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
698
|
+
# TODO: Resolve nodes and edges against the existing graph
|
|
699
|
+
edges_by_uuid: dict[str, EntityEdge] = {
|
|
700
|
+
edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
# save data to KG
|
|
704
|
+
await add_nodes_and_edges_bulk(
|
|
705
|
+
self.driver,
|
|
706
|
+
episodes,
|
|
707
|
+
episodic_edges,
|
|
708
|
+
hydrated_nodes,
|
|
709
|
+
list(edges_by_uuid.values()),
|
|
710
|
+
self.embedder,
|
|
659
711
|
)
|
|
660
712
|
|
|
661
713
|
end = time()
|
|
@@ -828,7 +880,7 @@ class Graphiti:
|
|
|
828
880
|
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
|
829
881
|
)[0]
|
|
830
882
|
|
|
831
|
-
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
|
883
|
+
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
|
832
884
|
self.llm_client,
|
|
833
885
|
updated_edge,
|
|
834
886
|
related_edges,
|
|
@@ -867,9 +919,7 @@ class Graphiti:
|
|
|
867
919
|
nodes_to_delete: list[EntityNode] = []
|
|
868
920
|
for node in nodes:
|
|
869
921
|
query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
|
|
870
|
-
records, _, _ = await self.driver.execute_query(
|
|
871
|
-
query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
|
872
|
-
)
|
|
922
|
+
records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')
|
|
873
923
|
|
|
874
924
|
for record in records:
|
|
875
925
|
if record['episode_count'] == 1:
|
graphiti_core/helpers.py
CHANGED
|
@@ -32,7 +32,6 @@ from graphiti_core.errors import GroupIdValidationError
|
|
|
32
32
|
|
|
33
33
|
load_dotenv()
|
|
34
34
|
|
|
35
|
-
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
|
|
36
35
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
37
36
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
38
37
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
|
@@ -76,6 +76,7 @@ class GeminiClient(LLMClient):
|
|
|
76
76
|
cache: bool = False,
|
|
77
77
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
78
78
|
thinking_config: types.ThinkingConfig | None = None,
|
|
79
|
+
client: 'genai.Client | None' = None,
|
|
79
80
|
):
|
|
80
81
|
"""
|
|
81
82
|
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
|
@@ -85,7 +86,7 @@ class GeminiClient(LLMClient):
|
|
|
85
86
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
86
87
|
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
|
87
88
|
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
|
88
|
-
|
|
89
|
+
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
|
89
90
|
"""
|
|
90
91
|
if config is None:
|
|
91
92
|
config = LLMConfig()
|
|
@@ -93,10 +94,12 @@ class GeminiClient(LLMClient):
|
|
|
93
94
|
super().__init__(config, cache)
|
|
94
95
|
|
|
95
96
|
self.model = config.model
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
api_key=config.api_key
|
|
99
|
-
|
|
97
|
+
|
|
98
|
+
if client is None:
|
|
99
|
+
self.client = genai.Client(api_key=config.api_key)
|
|
100
|
+
else:
|
|
101
|
+
self.client = client
|
|
102
|
+
|
|
100
103
|
self.max_tokens = max_tokens
|
|
101
104
|
self.thinking_config = thinking_config
|
|
102
105
|
|