graphiti-core 0.15.1__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -41,6 +41,10 @@ DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
41
41
 
42
42
 
43
43
  class GeminiRerankerClient(CrossEncoderClient):
44
+ """
45
+ Google Gemini Reranker Client
46
+ """
47
+
44
48
  def __init__(
45
49
  self,
46
50
  config: LLMConfig | None = None,
@@ -57,7 +61,6 @@ class GeminiRerankerClient(CrossEncoderClient):
57
61
  config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
58
62
  client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
59
63
  """
60
-
61
64
  if config is None:
62
65
  config = LLMConfig()
63
66
 
@@ -19,8 +19,6 @@ from abc import ABC, abstractmethod
19
19
  from collections.abc import Coroutine
20
20
  from typing import Any
21
21
 
22
- from graphiti_core.helpers import DEFAULT_DATABASE
23
-
24
22
  logger = logging.getLogger(__name__)
25
23
 
26
24
 
@@ -54,7 +52,7 @@ class GraphDriver(ABC):
54
52
  raise NotImplementedError()
55
53
 
56
54
  @abstractmethod
57
- def session(self, database: str) -> GraphDriverSession:
55
+ def session(self, database: str | None = None) -> GraphDriverSession:
58
56
  raise NotImplementedError()
59
57
 
60
58
  @abstractmethod
@@ -62,5 +60,5 @@ class GraphDriver(ABC):
62
60
  raise NotImplementedError()
63
61
 
64
62
  @abstractmethod
65
- def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
63
+ def delete_all_indexes(self, database_: str | None = None) -> Coroutine:
66
64
  raise NotImplementedError()
@@ -33,7 +33,6 @@ else:
33
33
  ) from None
34
34
 
35
35
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
36
- from graphiti_core.helpers import DEFAULT_DATABASE
37
36
 
38
37
  logger = logging.getLogger(__name__)
39
38
 
@@ -81,6 +80,7 @@ class FalkorDriver(GraphDriver):
81
80
  username: str | None = None,
82
81
  password: str | None = None,
83
82
  falkor_db: FalkorDB | None = None,
83
+ database: str = 'default_db',
84
84
  ):
85
85
  """
86
86
  Initialize the FalkorDB driver.
@@ -95,15 +95,16 @@ class FalkorDriver(GraphDriver):
95
95
  self.client = falkor_db
96
96
  else:
97
97
  self.client = FalkorDB(host=host, port=port, username=username, password=password)
98
+ self._database = database
98
99
 
99
100
  def _get_graph(self, graph_name: str | None) -> FalkorGraph:
100
- # FalkorDB requires a non-None database name for multi-tenant graphs; the default is DEFAULT_DATABASE
101
+ # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
101
102
  if graph_name is None:
102
- graph_name = DEFAULT_DATABASE
103
+ graph_name = self._database
103
104
  return self.client.select_graph(graph_name)
104
105
 
105
106
  async def execute_query(self, cypher_query_, **kwargs: Any):
106
- graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
107
+ graph_name = kwargs.pop('database_', self._database)
107
108
  graph = self._get_graph(graph_name)
108
109
 
109
110
  # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
@@ -136,7 +137,7 @@ class FalkorDriver(GraphDriver):
136
137
 
137
138
  return records, header, None
138
139
 
139
- def session(self, database: str | None) -> GraphDriverSession:
140
+ def session(self, database: str | None = None) -> GraphDriverSession:
140
141
  return FalkorDriverSession(self._get_graph(database))
141
142
 
142
143
  async def close(self) -> None:
@@ -148,10 +149,11 @@ class FalkorDriver(GraphDriver):
148
149
  elif hasattr(self.client.connection, 'close'):
149
150
  await self.client.connection.close()
150
151
 
151
- async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> None:
152
+ async def delete_all_indexes(self, database_: str | None = None) -> None:
153
+ database = database_ or self._database
152
154
  await self.execute_query(
153
155
  'CALL db.indexes() YIELD name DROP INDEX name',
154
- database_=database_,
156
+ database_=database,
155
157
  )
156
158
 
157
159
 
@@ -22,7 +22,6 @@ from neo4j import AsyncGraphDatabase, EagerResult
22
22
  from typing_extensions import LiteralString
23
23
 
24
24
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
25
- from graphiti_core.helpers import DEFAULT_DATABASE
26
25
 
27
26
  logger = logging.getLogger(__name__)
28
27
 
@@ -30,34 +29,36 @@ logger = logging.getLogger(__name__)
30
29
  class Neo4jDriver(GraphDriver):
31
30
  provider: str = 'neo4j'
32
31
 
33
- def __init__(
34
- self,
35
- uri: str,
36
- user: str | None,
37
- password: str | None,
38
- ):
32
+ def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
39
33
  super().__init__()
40
34
  self.client = AsyncGraphDatabase.driver(
41
35
  uri=uri,
42
36
  auth=(user or '', password or ''),
43
37
  )
38
+ self._database = database
44
39
 
45
40
  async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
41
+ # Check if database_ is provided in kwargs.
42
+ # If not populated, set the value to retain backwards compatibility
46
43
  params = kwargs.pop('params', None)
44
+ if params is None:
45
+ params = {}
46
+ params.setdefault('database_', self._database)
47
+
47
48
  result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
48
49
 
49
50
  return result
50
51
 
51
- def session(self, database: str) -> GraphDriverSession:
52
- return self.client.session(database=database) # type: ignore
52
+ def session(self, database: str | None = None) -> GraphDriverSession:
53
+ _database = database or self._database
54
+ return self.client.session(database=_database) # type: ignore
53
55
 
54
56
  async def close(self) -> None:
55
57
  return await self.client.close()
56
58
 
57
- def delete_all_indexes(
58
- self, database_: str = DEFAULT_DATABASE
59
- ) -> Coroutine[Any, Any, EagerResult]:
59
+ def delete_all_indexes(self, database_: str | None = None) -> Coroutine[Any, Any, EagerResult]:
60
+ database = database_ or self._database
60
61
  return self.client.execute_query(
61
62
  'CALL db.indexes() YIELD name DROP INDEX name',
62
- database_=database_,
63
+ database_=database,
63
64
  )
graphiti_core/edges.py CHANGED
@@ -27,7 +27,7 @@ from typing_extensions import LiteralString
27
27
  from graphiti_core.driver.driver import GraphDriver
28
28
  from graphiti_core.embedder import EmbedderClient
29
29
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
- from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
30
+ from graphiti_core.helpers import parse_db_date
31
31
  from graphiti_core.models.edges.edge_db_queries import (
32
32
  COMMUNITY_EDGE_SAVE,
33
33
  ENTITY_EDGE_SAVE,
@@ -71,7 +71,6 @@ class Edge(BaseModel, ABC):
71
71
  DELETE e
72
72
  """,
73
73
  uuid=self.uuid,
74
- database_=DEFAULT_DATABASE,
75
74
  )
76
75
 
77
76
  logger.debug(f'Deleted Edge: {self.uuid}')
@@ -99,7 +98,6 @@ class EpisodicEdge(Edge):
99
98
  uuid=self.uuid,
100
99
  group_id=self.group_id,
101
100
  created_at=self.created_at,
102
- database_=DEFAULT_DATABASE,
103
101
  )
104
102
 
105
103
  logger.debug(f'Saved edge to Graph: {self.uuid}')
@@ -119,7 +117,6 @@ class EpisodicEdge(Edge):
119
117
  e.created_at AS created_at
120
118
  """,
121
119
  uuid=uuid,
122
- database_=DEFAULT_DATABASE,
123
120
  routing_='r',
124
121
  )
125
122
 
@@ -143,7 +140,6 @@ class EpisodicEdge(Edge):
143
140
  e.created_at AS created_at
144
141
  """,
145
142
  uuids=uuids,
146
- database_=DEFAULT_DATABASE,
147
143
  routing_='r',
148
144
  )
149
145
 
@@ -183,7 +179,6 @@ class EpisodicEdge(Edge):
183
179
  group_ids=group_ids,
184
180
  uuid=uuid_cursor,
185
181
  limit=limit,
186
- database_=DEFAULT_DATABASE,
187
182
  routing_='r',
188
183
  )
189
184
 
@@ -231,9 +226,7 @@ class EntityEdge(Edge):
231
226
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
232
227
  RETURN e.fact_embedding AS fact_embedding
233
228
  """
234
- records, _, _ = await driver.execute_query(
235
- query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
236
- )
229
+ records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
237
230
 
238
231
  if len(records) == 0:
239
232
  raise EdgeNotFoundError(self.uuid)
@@ -261,7 +254,6 @@ class EntityEdge(Edge):
261
254
  result = await driver.execute_query(
262
255
  ENTITY_EDGE_SAVE,
263
256
  edge_data=edge_data,
264
- database_=DEFAULT_DATABASE,
265
257
  )
266
258
 
267
259
  logger.debug(f'Saved edge to Graph: {self.uuid}')
@@ -276,7 +268,6 @@ class EntityEdge(Edge):
276
268
  """
277
269
  + ENTITY_EDGE_RETURN,
278
270
  uuid=uuid,
279
- database_=DEFAULT_DATABASE,
280
271
  routing_='r',
281
272
  )
282
273
 
@@ -298,7 +289,6 @@ class EntityEdge(Edge):
298
289
  """
299
290
  + ENTITY_EDGE_RETURN,
300
291
  uuids=uuids,
301
- database_=DEFAULT_DATABASE,
302
292
  routing_='r',
303
293
  )
304
294
 
@@ -331,7 +321,6 @@ class EntityEdge(Edge):
331
321
  group_ids=group_ids,
332
322
  uuid=uuid_cursor,
333
323
  limit=limit,
334
- database_=DEFAULT_DATABASE,
335
324
  routing_='r',
336
325
  )
337
326
 
@@ -349,9 +338,7 @@ class EntityEdge(Edge):
349
338
  """
350
339
  + ENTITY_EDGE_RETURN
351
340
  )
352
- records, _, _ = await driver.execute_query(
353
- query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
354
- )
341
+ records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
355
342
 
356
343
  edges = [get_entity_edge_from_record(record) for record in records]
357
344
 
@@ -367,7 +354,6 @@ class CommunityEdge(Edge):
367
354
  uuid=self.uuid,
368
355
  group_id=self.group_id,
369
356
  created_at=self.created_at,
370
- database_=DEFAULT_DATABASE,
371
357
  )
372
358
 
373
359
  logger.debug(f'Saved edge to Graph: {self.uuid}')
@@ -387,7 +373,6 @@ class CommunityEdge(Edge):
387
373
  e.created_at AS created_at
388
374
  """,
389
375
  uuid=uuid,
390
- database_=DEFAULT_DATABASE,
391
376
  routing_='r',
392
377
  )
393
378
 
@@ -409,7 +394,6 @@ class CommunityEdge(Edge):
409
394
  e.created_at AS created_at
410
395
  """,
411
396
  uuids=uuids,
412
- database_=DEFAULT_DATABASE,
413
397
  routing_='r',
414
398
  )
415
399
 
@@ -447,7 +431,6 @@ class CommunityEdge(Edge):
447
431
  group_ids=group_ids,
448
432
  uuid=uuid_cursor,
449
433
  limit=limit,
450
- database_=DEFAULT_DATABASE,
451
434
  routing_='r',
452
435
  )
453
436
 
@@ -47,15 +47,27 @@ class GeminiEmbedder(EmbedderClient):
47
47
  Google Gemini Embedder Client
48
48
  """
49
49
 
50
- def __init__(self, config: GeminiEmbedderConfig | None = None):
50
+ def __init__(
51
+ self,
52
+ config: GeminiEmbedderConfig | None = None,
53
+ client: 'genai.Client | None' = None,
54
+ ):
55
+ """
56
+ Initialize the GeminiEmbedder with the provided configuration and client.
57
+
58
+ Args:
59
+ config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
60
+ client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
61
+ """
51
62
  if config is None:
52
63
  config = GeminiEmbedderConfig()
64
+
53
65
  self.config = config
54
66
 
55
- # Configure the Gemini API
56
- self.client = genai.Client(
57
- api_key=config.api_key,
58
- )
67
+ if client is None:
68
+ self.client = genai.Client(api_key=config.api_key)
69
+ else:
70
+ self.client = client
59
71
 
60
72
  async def create(
61
73
  self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
graphiti_core/graphiti.py CHANGED
@@ -30,7 +30,6 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
30
30
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
31
31
  from graphiti_core.graphiti_types import GraphitiClients
32
32
  from graphiti_core.helpers import (
33
- DEFAULT_DATABASE,
34
33
  semaphore_gather,
35
34
  validate_excluded_entity_types,
36
35
  validate_group_id,
@@ -57,7 +56,6 @@ from graphiti_core.utils.bulk_utils import (
57
56
  add_nodes_and_edges_bulk,
58
57
  dedupe_edges_bulk,
59
58
  dedupe_nodes_bulk,
60
- extract_edge_dates_bulk,
61
59
  extract_nodes_and_edges_bulk,
62
60
  resolve_edge_pointers,
63
61
  retrieve_previous_episodes_bulk,
@@ -169,7 +167,6 @@ class Graphiti:
169
167
  raise ValueError('uri must be provided when graph_driver is None')
170
168
  self.driver = Neo4jDriver(uri, user, password)
171
169
 
172
- self.database = DEFAULT_DATABASE
173
170
  self.store_raw_episode_content = store_raw_episode_content
174
171
  self.max_coroutines = max_coroutines
175
172
  if llm_client:
@@ -508,7 +505,7 @@ class Graphiti:
508
505
 
509
506
  entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
510
507
 
511
- episodic_edges = build_episodic_edges(nodes, episode, now)
508
+ episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
512
509
 
513
510
  episode.entity_edges = [edge.uuid for edge in entity_edges]
514
511
 
@@ -536,8 +533,16 @@ class Graphiti:
536
533
  except Exception as e:
537
534
  raise e
538
535
 
539
- #### WIP: USE AT YOUR OWN RISK ####
540
- async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
536
+ ##### EXPERIMENTAL #####
537
+ async def add_episode_bulk(
538
+ self,
539
+ bulk_episodes: list[RawEpisode],
540
+ group_id: str = '',
541
+ entity_types: dict[str, BaseModel] | None = None,
542
+ excluded_entity_types: list[str] | None = None,
543
+ edge_types: dict[str, BaseModel] | None = None,
544
+ edge_type_map: dict[tuple[str, str], list[str]] | None = None,
545
+ ):
541
546
  """
542
547
  Process multiple episodes in bulk and update the graph.
543
548
 
@@ -580,8 +585,17 @@ class Graphiti:
580
585
 
581
586
  validate_group_id(group_id)
582
587
 
588
+ # Create default edge type map
589
+ edge_type_map_default = (
590
+ {('Entity', 'Entity'): list(edge_types.keys())}
591
+ if edge_types is not None
592
+ else {('Entity', 'Entity'): []}
593
+ )
594
+
583
595
  episodes = [
584
- EpisodicNode(
596
+ await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
597
+ if episode.uuid is not None
598
+ else EpisodicNode(
585
599
  name=episode.name,
586
600
  labels=[],
587
601
  source=episode.source,
@@ -594,68 +608,106 @@ class Graphiti:
594
608
  for episode in bulk_episodes
595
609
  ]
596
610
 
597
- # Save all the episodes
598
- await semaphore_gather(
599
- *[episode.save(self.driver) for episode in episodes],
600
- max_coroutines=self.max_coroutines,
611
+ episodes_by_uuid: dict[str, EpisodicNode] = {
612
+ episode.uuid: episode for episode in episodes
613
+ }
614
+
615
+ # Save all episodes
616
+ await add_nodes_and_edges_bulk(
617
+ driver=self.driver,
618
+ episodic_nodes=episodes,
619
+ episodic_edges=[],
620
+ entity_nodes=[],
621
+ entity_edges=[],
622
+ embedder=self.embedder,
601
623
  )
602
624
 
603
625
  # Get previous episode context for each episode
604
- episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
626
+ episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
605
627
 
606
- # Extract all nodes and edges
607
- (
608
- extracted_nodes,
609
- extracted_edges,
610
- episodic_edges,
611
- ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None)
612
-
613
- # Generate embeddings
614
- await semaphore_gather(
615
- *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
616
- *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
617
- max_coroutines=self.max_coroutines,
628
+ # Extract all nodes and edges for each episode
629
+ extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
630
+ self.clients,
631
+ episode_context,
632
+ edge_type_map=edge_type_map or edge_type_map_default,
633
+ edge_types=edge_types,
634
+ entity_types=entity_types,
635
+ excluded_entity_types=excluded_entity_types,
618
636
  )
619
637
 
620
- # Dedupe extracted nodes, compress extracted edges
621
- (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
622
- dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
623
- extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
624
- max_coroutines=self.max_coroutines,
638
+ # Dedupe extracted nodes in memory
639
+ nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
640
+ self.clients, extracted_nodes_bulk, episode_context, entity_types
625
641
  )
626
642
 
627
- # save nodes to KG
628
- await semaphore_gather(
629
- *[node.save(self.driver) for node in nodes],
630
- max_coroutines=self.max_coroutines,
631
- )
643
+ episodic_edges: list[EpisodicEdge] = []
644
+ for episode_uuid, nodes in nodes_by_episode.items():
645
+ episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
632
646
 
633
647
  # re-map edge pointers so that they don't point to discard dupe nodes
634
- extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
635
- extracted_edges_timestamped, uuid_map
636
- )
637
- episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
638
- episodic_edges, uuid_map
639
- )
648
+ extracted_edges_bulk_updated: list[list[EntityEdge]] = [
649
+ resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
650
+ ]
640
651
 
641
- # save episodic edges to KG
642
- await semaphore_gather(
643
- *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
644
- max_coroutines=self.max_coroutines,
652
+ # Dedupe extracted edges in memory
653
+ edges_by_episode = await dedupe_edges_bulk(
654
+ self.clients,
655
+ extracted_edges_bulk_updated,
656
+ episode_context,
657
+ [],
658
+ edge_types or {},
659
+ edge_type_map or edge_type_map_default,
645
660
  )
646
661
 
647
- # Dedupe extracted edges
648
- edges = await dedupe_edges_bulk(
649
- self.driver, self.llm_client, extracted_edges_with_resolved_pointers
662
+ # Extract node attributes
663
+ nodes_by_uuid: dict[str, EntityNode] = {
664
+ node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
665
+ }
666
+
667
+ extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
668
+ for node in nodes_by_uuid.values():
669
+ episode_uuids: list[str] = []
670
+ for episode_uuid, mentioned_nodes in nodes_by_episode.items():
671
+ for mentioned_node in mentioned_nodes:
672
+ if node.uuid == mentioned_node.uuid:
673
+ episode_uuids.append(episode_uuid)
674
+ break
675
+
676
+ episode_mentions: list[EpisodicNode] = [
677
+ episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
678
+ ]
679
+ episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
680
+
681
+ extract_attributes_params.append((node, episode_mentions))
682
+
683
+ new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
684
+ *[
685
+ extract_attributes_from_nodes(
686
+ self.clients,
687
+ [params[0]],
688
+ params[1][0],
689
+ params[1][0:],
690
+ entity_types,
691
+ )
692
+ for params in extract_attributes_params
693
+ ]
650
694
  )
651
- logger.debug(f'extracted edge length: {len(edges)}')
652
695
 
653
- # invalidate edges
696
+ hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
654
697
 
655
- # save edges to KG
656
- await semaphore_gather(
657
- *[edge.save(self.driver) for edge in edges],
658
- max_coroutines=self.max_coroutines,
698
+ # TODO: Resolve nodes and edges against the existing graph
699
+ edges_by_uuid: dict[str, EntityEdge] = {
700
+ edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
701
+ }
702
+
703
+ # save data to KG
704
+ await add_nodes_and_edges_bulk(
705
+ self.driver,
706
+ episodes,
707
+ episodic_edges,
708
+ hydrated_nodes,
709
+ list(edges_by_uuid.values()),
710
+ self.embedder,
659
711
  )
660
712
 
661
713
  end = time()
@@ -828,7 +880,7 @@ class Graphiti:
828
880
  await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
829
881
  )[0]
830
882
 
831
- resolved_edge, invalidated_edges = await resolve_extracted_edge(
883
+ resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
832
884
  self.llm_client,
833
885
  updated_edge,
834
886
  related_edges,
@@ -867,9 +919,7 @@ class Graphiti:
867
919
  nodes_to_delete: list[EntityNode] = []
868
920
  for node in nodes:
869
921
  query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
870
- records, _, _ = await self.driver.execute_query(
871
- query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
872
- )
922
+ records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')
873
923
 
874
924
  for record in records:
875
925
  if record['episode_count'] == 1:
graphiti_core/helpers.py CHANGED
@@ -32,7 +32,6 @@ from graphiti_core.errors import GroupIdValidationError
32
32
 
33
33
  load_dotenv()
34
34
 
35
- DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
36
35
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
37
36
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
38
37
  MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
@@ -76,6 +76,7 @@ class GeminiClient(LLMClient):
76
76
  cache: bool = False,
77
77
  max_tokens: int = DEFAULT_MAX_TOKENS,
78
78
  thinking_config: types.ThinkingConfig | None = None,
79
+ client: 'genai.Client | None' = None,
79
80
  ):
80
81
  """
81
82
  Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
@@ -85,7 +86,7 @@ class GeminiClient(LLMClient):
85
86
  cache (bool): Whether to use caching for responses. Defaults to False.
86
87
  thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
87
88
  Only use with models that support thinking (gemini-2.5+). Defaults to None.
88
-
89
+ client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
89
90
  """
90
91
  if config is None:
91
92
  config = LLMConfig()
@@ -93,10 +94,12 @@ class GeminiClient(LLMClient):
93
94
  super().__init__(config, cache)
94
95
 
95
96
  self.model = config.model
96
- # Configure the Gemini API
97
- self.client = genai.Client(
98
- api_key=config.api_key,
99
- )
97
+
98
+ if client is None:
99
+ self.client = genai.Client(api_key=config.api_key)
100
+ else:
101
+ self.client = client
102
+
100
103
  self.max_tokens = max_tokens
101
104
  self.thinking_config = thinking_config
102
105