graphiti-core 0.10.4__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (31) hide show
  1. graphiti_core/edges.py +32 -57
  2. graphiti_core/embedder/client.py +3 -0
  3. graphiti_core/embedder/gemini.py +10 -0
  4. graphiti_core/embedder/openai.py +6 -0
  5. graphiti_core/embedder/voyage.py +7 -0
  6. graphiti_core/graphiti.py +42 -138
  7. graphiti_core/graphiti_types.py +31 -0
  8. graphiti_core/helpers.py +6 -1
  9. graphiti_core/llm_client/anthropic_client.py +4 -1
  10. graphiti_core/llm_client/client.py +4 -1
  11. graphiti_core/llm_client/gemini_client.py +4 -1
  12. graphiti_core/llm_client/openai_client.py +4 -1
  13. graphiti_core/llm_client/openai_generic_client.py +4 -1
  14. graphiti_core/models/edges/edge_db_queries.py +1 -1
  15. graphiti_core/nodes.py +10 -10
  16. graphiti_core/prompts/dedupe_edges.py +5 -7
  17. graphiti_core/prompts/dedupe_nodes.py +8 -21
  18. graphiti_core/prompts/extract_edges.py +61 -26
  19. graphiti_core/prompts/extract_nodes.py +89 -18
  20. graphiti_core/prompts/invalidate_edges.py +11 -11
  21. graphiti_core/search/search.py +13 -5
  22. graphiti_core/search/search_utils.py +206 -98
  23. graphiti_core/utils/bulk_utils.py +10 -7
  24. graphiti_core/utils/maintenance/edge_operations.py +88 -40
  25. graphiti_core/utils/maintenance/graph_data_operations.py +20 -6
  26. graphiti_core/utils/maintenance/node_operations.py +216 -223
  27. graphiti_core/utils/maintenance/temporal_operations.py +4 -11
  28. {graphiti_core-0.10.4.dist-info → graphiti_core-0.11.0.dist-info}/METADATA +25 -11
  29. {graphiti_core-0.10.4.dist-info → graphiti_core-0.11.0.dist-info}/RECORD +31 -30
  30. {graphiti_core-0.10.4.dist-info → graphiti_core-0.11.0.dist-info}/LICENSE +0 -0
  31. {graphiti_core-0.10.4.dist-info → graphiti_core-0.11.0.dist-info}/WHEEL +0 -0
graphiti_core/edges.py CHANGED
@@ -37,6 +37,21 @@ from graphiti_core.nodes import Node
37
37
 
38
38
  logger = logging.getLogger(__name__)
39
39
 
40
+ ENTITY_EDGE_RETURN: LiteralString = """
41
+ RETURN
42
+ e.uuid AS uuid,
43
+ startNode(e).uuid AS source_node_uuid,
44
+ endNode(e).uuid AS target_node_uuid,
45
+ e.created_at AS created_at,
46
+ e.name AS name,
47
+ e.group_id AS group_id,
48
+ e.fact AS fact,
49
+ e.fact_embedding AS fact_embedding,
50
+ e.episodes AS episodes,
51
+ e.expired_at AS expired_at,
52
+ e.valid_at AS valid_at,
53
+ e.invalid_at AS invalid_at"""
54
+
40
55
 
41
56
  class Edge(BaseModel, ABC):
42
57
  uuid: str = Field(default_factory=lambda: str(uuid4()))
@@ -234,20 +249,8 @@ class EntityEdge(Edge):
234
249
  records, _, _ = await driver.execute_query(
235
250
  """
236
251
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
237
- RETURN
238
- e.uuid AS uuid,
239
- n.uuid AS source_node_uuid,
240
- m.uuid AS target_node_uuid,
241
- e.created_at AS created_at,
242
- e.name AS name,
243
- e.group_id AS group_id,
244
- e.fact AS fact,
245
- e.fact_embedding AS fact_embedding,
246
- e.episodes AS episodes,
247
- e.expired_at AS expired_at,
248
- e.valid_at AS valid_at,
249
- e.invalid_at AS invalid_at
250
- """,
252
+ """
253
+ + ENTITY_EDGE_RETURN,
251
254
  uuid=uuid,
252
255
  database_=DEFAULT_DATABASE,
253
256
  routing_='r',
@@ -268,20 +271,8 @@ class EntityEdge(Edge):
268
271
  """
269
272
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
270
273
  WHERE e.uuid IN $uuids
271
- RETURN
272
- e.uuid AS uuid,
273
- n.uuid AS source_node_uuid,
274
- m.uuid AS target_node_uuid,
275
- e.created_at AS created_at,
276
- e.name AS name,
277
- e.group_id AS group_id,
278
- e.fact AS fact,
279
- e.fact_embedding AS fact_embedding,
280
- e.episodes AS episodes,
281
- e.expired_at AS expired_at,
282
- e.valid_at AS valid_at,
283
- e.invalid_at AS invalid_at
284
- """,
274
+ """
275
+ + ENTITY_EDGE_RETURN,
285
276
  uuids=uuids,
286
277
  database_=DEFAULT_DATABASE,
287
278
  routing_='r',
@@ -308,20 +299,8 @@ class EntityEdge(Edge):
308
299
  WHERE e.group_id IN $group_ids
309
300
  """
310
301
  + cursor_query
302
+ + ENTITY_EDGE_RETURN
311
303
  + """
312
- RETURN
313
- e.uuid AS uuid,
314
- n.uuid AS source_node_uuid,
315
- m.uuid AS target_node_uuid,
316
- e.created_at AS created_at,
317
- e.name AS name,
318
- e.group_id AS group_id,
319
- e.fact AS fact,
320
- e.fact_embedding AS fact_embedding,
321
- e.episodes AS episodes,
322
- e.expired_at AS expired_at,
323
- e.valid_at AS valid_at,
324
- e.invalid_at AS invalid_at
325
304
  ORDER BY e.uuid DESC
326
305
  """
327
306
  + limit_query,
@@ -340,22 +319,12 @@ class EntityEdge(Edge):
340
319
 
341
320
  @classmethod
342
321
  async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
343
- query: LiteralString = """
344
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
345
- RETURN DISTINCT
346
- e.uuid AS uuid,
347
- n.uuid AS source_node_uuid,
348
- m.uuid AS target_node_uuid,
349
- e.created_at AS created_at,
350
- e.name AS name,
351
- e.group_id AS group_id,
352
- e.fact AS fact,
353
- e.fact_embedding AS fact_embedding,
354
- e.episodes AS episodes,
355
- e.expired_at AS expired_at,
356
- e.valid_at AS valid_at,
357
- e.invalid_at AS invalid_at
358
- """
322
+ query: LiteralString = (
323
+ """
324
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
325
+ """
326
+ + ENTITY_EDGE_RETURN
327
+ )
359
328
  records, _, _ = await driver.execute_query(
360
329
  query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
361
330
  )
@@ -499,3 +468,9 @@ def get_community_edge_from_record(record: Any):
499
468
  target_node_uuid=record['target_node_uuid'],
500
469
  created_at=record['created_at'].to_native(),
501
470
  )
471
+
472
+
473
+ async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
474
+ fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
475
+ for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
476
+ edge.fact_embedding = fact_embedding
@@ -32,3 +32,6 @@ class EmbedderClient(ABC):
32
32
  self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
33
33
  ) -> list[float]:
34
34
  pass
35
+
36
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
37
+ raise NotImplementedError()
@@ -66,3 +66,13 @@ class GeminiEmbedder(EmbedderClient):
66
66
  )
67
67
 
68
68
  return result.embeddings[0].values
69
+
70
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
71
+ # Generate embeddings
72
+ result = await self.client.aio.models.embed_content(
73
+ model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
74
+ contents=input_data_list,
75
+ config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
76
+ )
77
+
78
+ return [embedding.values for embedding in result.embeddings]
@@ -58,3 +58,9 @@ class OpenAIEmbedder(EmbedderClient):
58
58
  input=input_data, model=self.config.embedding_model
59
59
  )
60
60
  return result.data[0].embedding[: self.config.embedding_dim]
61
+
62
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
63
+ result = await self.client.embeddings.create(
64
+ input=input_data_list, model=self.config.embedding_model
65
+ )
66
+ return [embedding.embedding[: self.config.embedding_dim] for embedding in result.data]
@@ -56,3 +56,10 @@ class VoyageAIEmbedder(EmbedderClient):
56
56
 
57
57
  result = await self.client.embed(input_list, model=self.config.embedding_model)
58
58
  return [float(x) for x in result.embeddings[0][: self.config.embedding_dim]]
59
+
60
+ async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
61
+ result = await self.client.embed(input_data_list, model=self.config.embedding_model)
62
+ return [
63
+ [float(x) for x in embedding[: self.config.embedding_dim]]
64
+ for embedding in result.embeddings
65
+ ]
graphiti_core/graphiti.py CHANGED
@@ -27,6 +27,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
27
27
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
28
28
  from graphiti_core.edges import EntityEdge, EpisodicEdge
29
29
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
+ from graphiti_core.graphiti_types import GraphitiClients
30
31
  from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
31
32
  from graphiti_core.llm_client import LLMClient, OpenAIClient
32
33
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
@@ -42,7 +43,6 @@ from graphiti_core.search.search_utils import (
42
43
  RELEVANT_SCHEMA_LIMIT,
43
44
  get_mentioned_nodes,
44
45
  get_relevant_edges,
45
- get_relevant_nodes,
46
46
  )
47
47
  from graphiti_core.utils.bulk_utils import (
48
48
  RawEpisode,
@@ -72,7 +72,11 @@ from graphiti_core.utils.maintenance.graph_data_operations import (
72
72
  build_indices_and_constraints,
73
73
  retrieve_episodes,
74
74
  )
75
- from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes
75
+ from graphiti_core.utils.maintenance.node_operations import (
76
+ extract_attributes_from_nodes,
77
+ extract_nodes,
78
+ resolve_extracted_nodes,
79
+ )
76
80
  from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
77
81
  from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
78
82
 
@@ -150,6 +154,13 @@ class Graphiti:
150
154
  else:
151
155
  self.cross_encoder = OpenAIRerankerClient()
152
156
 
157
+ self.clients = GraphitiClients(
158
+ driver=self.driver,
159
+ llm_client=self.llm_client,
160
+ embedder=self.embedder,
161
+ cross_encoder=self.cross_encoder,
162
+ )
163
+
153
164
  async def close(self):
154
165
  """
155
166
  Close the connection to the Neo4j database.
@@ -222,6 +233,7 @@ class Graphiti:
222
233
  reference_time: datetime,
223
234
  last_n: int = EPISODE_WINDOW_LEN,
224
235
  group_ids: list[str] | None = None,
236
+ source: EpisodeType | None = None,
225
237
  ) -> list[EpisodicNode]:
226
238
  """
227
239
  Retrieve the last n episodic nodes from the graph.
@@ -248,7 +260,7 @@ class Graphiti:
248
260
  The actual retrieval is performed by the `retrieve_episodes` function
249
261
  from the `graphiti_core.utils` module.
250
262
  """
251
- return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
263
+ return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
252
264
 
253
265
  async def add_episode(
254
266
  self,
@@ -314,15 +326,16 @@ class Graphiti:
314
326
  """
315
327
  try:
316
328
  start = time()
317
-
318
- entity_edges: list[EntityEdge] = []
319
329
  now = utc_now()
320
330
 
321
331
  validate_entity_types(entity_types)
322
332
 
323
333
  previous_episodes = (
324
334
  await self.retrieve_episodes(
325
- reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
335
+ reference_time,
336
+ last_n=RELEVANT_SCHEMA_LIMIT,
337
+ group_ids=[group_id],
338
+ source=source,
326
339
  )
327
340
  if previous_episode_uuids is None
328
341
  else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
@@ -346,132 +359,36 @@ class Graphiti:
346
359
  # Extract entities as nodes
347
360
 
348
361
  extracted_nodes = await extract_nodes(
349
- self.llm_client, episode, previous_episodes, entity_types
362
+ self.clients, episode, previous_episodes, entity_types
350
363
  )
351
- logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
352
-
353
- # Calculate Embeddings
354
364
 
355
- await semaphore_gather(
356
- *[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
357
- )
358
-
359
- # Find relevant nodes already in the graph
360
- existing_nodes_lists: list[list[EntityNode]] = list(
361
- await semaphore_gather(
362
- *[
363
- get_relevant_nodes(self.driver, SearchFilters(), [node])
364
- for node in extracted_nodes
365
- ]
366
- )
367
- )
368
-
369
- # Resolve extracted nodes with nodes already in the graph and extract facts
370
- logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
371
-
372
- (mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
365
+ # Extract edges and resolve nodes
366
+ (nodes, uuid_map), extracted_edges = await semaphore_gather(
373
367
  resolve_extracted_nodes(
374
- self.llm_client,
368
+ self.clients,
375
369
  extracted_nodes,
376
- existing_nodes_lists,
377
370
  episode,
378
371
  previous_episodes,
379
372
  entity_types,
380
373
  ),
381
- extract_edges(
382
- self.llm_client, episode, extracted_nodes, previous_episodes, group_id
383
- ),
384
- )
385
- logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
386
- nodes = mentioned_nodes
387
-
388
- extracted_edges_with_resolved_pointers = resolve_edge_pointers(
389
- extracted_edges, uuid_map
374
+ extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
390
375
  )
391
376
 
392
- # calculate embeddings
393
- await semaphore_gather(
394
- *[
395
- edge.generate_embedding(self.embedder)
396
- for edge in extracted_edges_with_resolved_pointers
397
- ]
398
- )
377
+ edges = resolve_edge_pointers(extracted_edges, uuid_map)
399
378
 
400
- # Resolve extracted edges with related edges already in the graph
401
- related_edges_list: list[list[EntityEdge]] = list(
402
- await semaphore_gather(
403
- *[
404
- get_relevant_edges(
405
- self.driver,
406
- [edge],
407
- edge.source_node_uuid,
408
- edge.target_node_uuid,
409
- RELEVANT_SCHEMA_LIMIT,
410
- )
411
- for edge in extracted_edges_with_resolved_pointers
412
- ]
413
- )
414
- )
415
- logger.debug(
416
- f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
417
- )
418
- logger.debug(
419
- f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
420
- )
421
-
422
- existing_source_edges_list: list[list[EntityEdge]] = list(
423
- await semaphore_gather(
424
- *[
425
- get_relevant_edges(
426
- self.driver,
427
- [edge],
428
- edge.source_node_uuid,
429
- None,
430
- RELEVANT_SCHEMA_LIMIT,
431
- )
432
- for edge in extracted_edges_with_resolved_pointers
433
- ]
434
- )
435
- )
436
-
437
- existing_target_edges_list: list[list[EntityEdge]] = list(
438
- await semaphore_gather(
439
- *[
440
- get_relevant_edges(
441
- self.driver,
442
- [edge],
443
- None,
444
- edge.target_node_uuid,
445
- RELEVANT_SCHEMA_LIMIT,
446
- )
447
- for edge in extracted_edges_with_resolved_pointers
448
- ]
449
- )
450
- )
451
-
452
- existing_edges_list: list[list[EntityEdge]] = [
453
- source_lst + target_lst
454
- for source_lst, target_lst in zip(
455
- existing_source_edges_list, existing_target_edges_list, strict=False
456
- )
457
- ]
458
-
459
- resolved_edges, invalidated_edges = await resolve_extracted_edges(
460
- self.llm_client,
461
- extracted_edges_with_resolved_pointers,
462
- related_edges_list,
463
- existing_edges_list,
464
- episode,
465
- previous_episodes,
379
+ (resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
380
+ resolve_extracted_edges(
381
+ self.clients,
382
+ edges,
383
+ ),
384
+ extract_attributes_from_nodes(
385
+ self.clients, nodes, episode, previous_episodes, entity_types
386
+ ),
466
387
  )
467
388
 
468
- entity_edges.extend(resolved_edges + invalidated_edges)
389
+ entity_edges = resolved_edges + invalidated_edges
469
390
 
470
- logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
471
-
472
- episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
473
-
474
- logger.debug(f'Built episodic edges: {episodic_edges}')
391
+ episodic_edges = build_episodic_edges(nodes, episode, now)
475
392
 
476
393
  episode.entity_edges = [edge.uuid for edge in entity_edges]
477
394
 
@@ -565,7 +482,7 @@ class Graphiti:
565
482
  extracted_nodes,
566
483
  extracted_edges,
567
484
  episodic_edges,
568
- ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
485
+ ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs)
569
486
 
570
487
  # Generate embeddings
571
488
  await semaphore_gather(
@@ -684,9 +601,7 @@ class Graphiti:
684
601
 
685
602
  edges = (
686
603
  await search(
687
- self.driver,
688
- self.embedder,
689
- self.cross_encoder,
604
+ self.clients,
690
605
  query,
691
606
  group_ids,
692
607
  search_config,
@@ -728,9 +643,7 @@ class Graphiti:
728
643
  """
729
644
 
730
645
  return await search(
731
- self.driver,
732
- self.embedder,
733
- self.cross_encoder,
646
+ self.clients,
734
647
  query,
735
648
  group_ids,
736
649
  config,
@@ -761,26 +674,17 @@ class Graphiti:
761
674
  await edge.generate_embedding(self.embedder)
762
675
 
763
676
  resolved_nodes, uuid_map = await resolve_extracted_nodes(
764
- self.llm_client,
677
+ self.clients,
765
678
  [source_node, target_node],
766
- [
767
- await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
768
- await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
769
- ],
770
679
  )
771
680
 
772
681
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
773
682
 
774
- related_edges = await get_relevant_edges(
775
- self.driver,
776
- [updated_edge],
777
- source_node_uuid=resolved_nodes[0].uuid,
778
- target_node_uuid=resolved_nodes[1].uuid,
779
- )
683
+ related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
780
684
 
781
- resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges)
685
+ resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0])
782
686
 
783
- contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
687
+ contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
784
688
  invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
785
689
 
786
690
  await add_nodes_and_edges_bulk(
@@ -0,0 +1,31 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from neo4j import AsyncDriver
18
+ from pydantic import BaseModel, ConfigDict
19
+
20
+ from graphiti_core.cross_encoder import CrossEncoderClient
21
+ from graphiti_core.embedder import EmbedderClient
22
+ from graphiti_core.llm_client import LLMClient
23
+
24
+
25
+ class GraphitiClients(BaseModel):
26
+ driver: AsyncDriver
27
+ llm_client: LLMClient
28
+ embedder: EmbedderClient
29
+ cross_encoder: CrossEncoderClient
30
+
31
+ model_config = ConfigDict(arbitrary_types_allowed=True)
graphiti_core/helpers.py CHANGED
@@ -22,15 +22,20 @@ from datetime import datetime
22
22
  import numpy as np
23
23
  from dotenv import load_dotenv
24
24
  from neo4j import time as neo4j_time
25
+ from typing_extensions import LiteralString
25
26
 
26
27
  load_dotenv()
27
28
 
28
29
  DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
29
30
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
30
31
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
31
- MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 2))
32
+ MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
32
33
  DEFAULT_PAGE_LIMIT = 20
33
34
 
35
+ RUNTIME_QUERY: LiteralString = (
36
+ 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
37
+ )
38
+
34
39
 
35
40
  def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
36
41
  return neo_date.to_native() if neo_date else None
@@ -262,7 +262,7 @@ class AnthropicClient(LLMClient):
262
262
  self,
263
263
  messages: list[Message],
264
264
  response_model: type[BaseModel] | None = None,
265
- max_tokens: int = DEFAULT_MAX_TOKENS,
265
+ max_tokens: int | None = None,
266
266
  ) -> dict[str, typing.Any]:
267
267
  """
268
268
  Generate a response from the LLM.
@@ -280,6 +280,9 @@ class AnthropicClient(LLMClient):
280
280
  RefusalError: If the LLM refuses to respond.
281
281
  Exception: If an error occurs during the generation process.
282
282
  """
283
+ if max_tokens is None:
284
+ max_tokens = self.max_tokens
285
+
283
286
  retry_count = 0
284
287
  max_retries = 2
285
288
  last_error: Exception | None = None
@@ -127,8 +127,11 @@ class LLMClient(ABC):
127
127
  self,
128
128
  messages: list[Message],
129
129
  response_model: type[BaseModel] | None = None,
130
- max_tokens: int = DEFAULT_MAX_TOKENS,
130
+ max_tokens: int | None = None,
131
131
  ) -> dict[str, typing.Any]:
132
+ if max_tokens is None:
133
+ max_tokens = self.max_tokens
134
+
132
135
  if response_model is not None:
133
136
  serialized_model = json.dumps(response_model.model_json_schema())
134
137
  messages[
@@ -166,7 +166,7 @@ class GeminiClient(LLMClient):
166
166
  self,
167
167
  messages: list[Message],
168
168
  response_model: type[BaseModel] | None = None,
169
- max_tokens: int = DEFAULT_MAX_TOKENS,
169
+ max_tokens: int | None = None,
170
170
  ) -> dict[str, typing.Any]:
171
171
  """
172
172
  Generate a response from the Gemini language model.
@@ -180,6 +180,9 @@ class GeminiClient(LLMClient):
180
180
  Returns:
181
181
  dict[str, typing.Any]: The response from the language model.
182
182
  """
183
+ if max_tokens is None:
184
+ max_tokens = self.max_tokens
185
+
183
186
  # Call the internal _generate_response method
184
187
  return await self._generate_response(
185
188
  messages=messages, response_model=response_model, max_tokens=max_tokens
@@ -131,8 +131,11 @@ class OpenAIClient(LLMClient):
131
131
  self,
132
132
  messages: list[Message],
133
133
  response_model: type[BaseModel] | None = None,
134
- max_tokens: int = DEFAULT_MAX_TOKENS,
134
+ max_tokens: int | None = None,
135
135
  ) -> dict[str, typing.Any]:
136
+ if max_tokens is None:
137
+ max_tokens = self.max_tokens
138
+
136
139
  retry_count = 0
137
140
  last_error = None
138
141
 
@@ -117,8 +117,11 @@ class OpenAIGenericClient(LLMClient):
117
117
  self,
118
118
  messages: list[Message],
119
119
  response_model: type[BaseModel] | None = None,
120
- max_tokens: int = DEFAULT_MAX_TOKENS,
120
+ max_tokens: int | None = None,
121
121
  ) -> dict[str, typing.Any]:
122
+ if max_tokens is None:
123
+ max_tokens = self.max_tokens
124
+
122
125
  retry_count = 0
123
126
  last_error = None
124
127
 
@@ -47,7 +47,7 @@ ENTITY_EDGE_SAVE_BULK = """
47
47
  SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
48
48
  created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
49
49
  WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
50
- RETURN r.uuid AS uuid
50
+ RETURN edge.uuid AS uuid
51
51
  """
52
52
 
53
53
  COMMUNITY_EDGE_SAVE = """
graphiti_core/nodes.py CHANGED
@@ -39,8 +39,6 @@ from graphiti_core.utils.datetime_utils import utc_now
39
39
  logger = logging.getLogger(__name__)
40
40
 
41
41
  ENTITY_NODE_RETURN: LiteralString = """
42
- OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
43
- WITH n, collect(e.uuid) AS episodes
44
42
  RETURN
45
43
  n.uuid As uuid,
46
44
  n.name AS name,
@@ -49,8 +47,8 @@ ENTITY_NODE_RETURN: LiteralString = """
49
47
  n.created_at AS created_at,
50
48
  n.summary AS summary,
51
49
  labels(n) AS labels,
52
- properties(n) AS attributes,
53
- episodes"""
50
+ properties(n) AS attributes
51
+ """
54
52
 
55
53
 
56
54
  class EpisodeType(Enum):
@@ -294,9 +292,6 @@ class EpisodicNode(Node):
294
292
  class EntityNode(Node):
295
293
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
296
294
  summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
297
- episodes: list[str] | None = Field(
298
- default=None, description='List of episode uuids that mention this node.'
299
- )
300
295
  attributes: dict[str, Any] = Field(
301
296
  default={}, description='Additional attributes of the node. Dependent on node labels'
302
297
  )
@@ -337,8 +332,8 @@ class EntityNode(Node):
337
332
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
338
333
  query = (
339
334
  """
340
- MATCH (n:Entity {uuid: $uuid})
341
- """
335
+ MATCH (n:Entity {uuid: $uuid})
336
+ """
342
337
  + ENTITY_NODE_RETURN
343
338
  )
344
339
  records, _, _ = await driver.execute_query(
@@ -544,7 +539,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
544
539
  created_at=record['created_at'].to_native(),
545
540
  summary=record['summary'],
546
541
  attributes=record['attributes'],
547
- episodes=record['episodes'],
548
542
  )
549
543
 
550
544
  entity_node.attributes.pop('uuid', None)
@@ -566,3 +560,9 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
566
560
  created_at=record['created_at'].to_native(),
567
561
  summary=record['summary'],
568
562
  )
563
+
564
+
565
+ async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
566
+ name_embeddings = await embedder.create_batch([node.name for node in nodes])
567
+ for node, name_embedding in zip(nodes, name_embeddings, strict=True):
568
+ node.name_embedding = name_embedding