graphiti-core 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/edges.py CHANGED
@@ -188,9 +188,9 @@ class EntityEdge(Edge):
188
188
  MATCH (source:Entity {uuid: $source_uuid})
189
189
  MATCH (target:Entity {uuid: $target_uuid})
190
190
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
191
- SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
192
- episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
193
- valid_at: $valid_at, invalid_at: $invalid_at}
191
+ SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
192
+ created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
193
+ WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
194
194
  RETURN r.uuid AS uuid""",
195
195
  source_uuid=self.source_node_uuid,
196
196
  target_uuid=self.target_node_uuid,
@@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
42
42
  self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
43
43
 
44
44
  async def create(
45
- self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
+ self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
46
46
  ) -> list[float]:
47
47
  result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
48
48
  return result.data[0].embedding[: self.config.embedding_dim]
@@ -41,7 +41,7 @@ class VoyageAIEmbedder(EmbedderClient):
41
41
  self.client = voyageai.AsyncClient(api_key=config.api_key)
42
42
 
43
43
  async def create(
44
- self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
44
+ self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
45
  ) -> list[float]:
46
46
  result = await self.client.embed(input, model=self.config.embedding_model)
47
47
  return result.embeddings[0][: self.config.embedding_dim]
graphiti_core/graphiti.py CHANGED
@@ -21,11 +21,12 @@ from time import time
21
21
 
22
22
  from dotenv import load_dotenv
23
23
  from neo4j import AsyncGraphDatabase
24
+ from pydantic import BaseModel
24
25
 
25
26
  from graphiti_core.edges import EntityEdge, EpisodicEdge
26
27
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
27
28
  from graphiti_core.llm_client import LLMClient, OpenAIClient
28
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
29
+ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
29
30
  from graphiti_core.search.search import SearchConfig, search
30
31
  from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
31
32
  from graphiti_core.search.search_config_recipes import (
@@ -77,6 +78,12 @@ logger = logging.getLogger(__name__)
77
78
  load_dotenv()
78
79
 
79
80
 
81
+ class AddEpisodeResults(BaseModel):
82
+ episode: EpisodicNode
83
+ nodes: list[EntityNode]
84
+ edges: list[EntityEdge]
85
+
86
+
80
87
  class Graphiti:
81
88
  def __init__(
82
89
  self,
@@ -245,7 +252,7 @@ class Graphiti:
245
252
  group_id: str = '',
246
253
  uuid: str | None = None,
247
254
  update_communities: bool = False,
248
- ):
255
+ ) -> AddEpisodeResults:
249
256
  """
250
257
  Process an episode and update the graph.
251
258
 
@@ -451,6 +458,8 @@ class Graphiti:
451
458
  end = time()
452
459
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
453
460
 
461
+ return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
462
+
454
463
  except Exception as e:
455
464
  raise e
456
465
 
@@ -567,11 +576,20 @@ class Graphiti:
567
576
  except Exception as e:
568
577
  raise e
569
578
 
570
- async def build_communities(self):
579
+ async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
580
+ """
581
+ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
582
+ the content of these communities.
583
+ ----------
584
+ query : list[str] | None
585
+ Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
586
+ """
571
587
  # Clear existing communities
572
588
  await remove_communities(self.driver)
573
589
 
574
- community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
590
+ community_nodes, community_edges = await build_communities(
591
+ self.driver, self.llm_client, group_ids
592
+ )
575
593
 
576
594
  await asyncio.gather(
577
595
  *[node.generate_name_embedding(self.embedder) for node in community_nodes]
@@ -580,6 +598,8 @@ class Graphiti:
580
598
  await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
581
599
  await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
582
600
 
601
+ return community_nodes
602
+
583
603
  async def search(
584
604
  self,
585
605
  query: str,
@@ -700,18 +720,17 @@ class Graphiti:
700
720
  ).nodes
701
721
  return nodes
702
722
 
723
+ async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
724
+ episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
703
725
 
704
- async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
705
- episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
706
-
707
- edges_list = await asyncio.gather(
708
- *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
709
- )
726
+ edges_list = await asyncio.gather(
727
+ *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
728
+ )
710
729
 
711
- edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
730
+ edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
712
731
 
713
- nodes = await get_mentioned_nodes(self.driver, episodes)
732
+ nodes = await get_mentioned_nodes(self.driver, episodes)
714
733
 
715
- communities = await get_communities_by_nodes(self.driver, nodes)
734
+ communities = await get_communities_by_nodes(self.driver, nodes)
716
735
 
717
- return SearchResults(edges=edges, nodes=nodes, communities=communities)
736
+ return SearchResults(edges=edges, nodes=nodes, communities=communities)
graphiti_core/helpers.py CHANGED
@@ -16,6 +16,7 @@ limitations under the License.
16
16
 
17
17
  from datetime import datetime
18
18
 
19
+ import numpy as np
19
20
  from neo4j import time as neo4j_time
20
21
 
21
22
 
@@ -25,7 +26,7 @@ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
25
26
 
26
27
  def lucene_sanitize(query: str) -> str:
27
28
  # Escape special characters from a query before passing into Lucene
28
- # + - && || ! ( ) { } [ ] ^ " ~ * ? : \
29
+ # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
29
30
  escape_map = str.maketrans(
30
31
  {
31
32
  '+': r'\+',
@@ -46,8 +47,21 @@ def lucene_sanitize(query: str) -> str:
46
47
  '?': r'\?',
47
48
  ':': r'\:',
48
49
  '\\': r'\\',
50
+ '/': r'\/',
49
51
  }
50
52
  )
51
53
 
52
54
  sanitized = query.translate(escape_map)
53
55
  return sanitized
56
+
57
+
58
+ def normalize_l2(embedding: list[float]) -> list[float]:
59
+ embedding_array = np.array(embedding)
60
+ if embedding_array.ndim == 1:
61
+ norm = np.linalg.norm(embedding_array)
62
+ if norm == 0:
63
+ return embedding_array.tolist()
64
+ return (embedding_array / norm).tolist()
65
+ else:
66
+ norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
67
+ return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
graphiti_core/nodes.py CHANGED
@@ -225,7 +225,8 @@ class EntityNode(Node):
225
225
  result = await driver.execute_query(
226
226
  """
227
227
  MERGE (n:Entity {uuid: $uuid})
228
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
228
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
229
+ WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
229
230
  RETURN n.uuid AS uuid""",
230
231
  uuid=self.uuid,
231
232
  name=self.name,
@@ -308,7 +309,8 @@ class CommunityNode(Node):
308
309
  result = await driver.execute_query(
309
310
  """
310
311
  MERGE (n:Community {uuid: $uuid})
311
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
312
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
313
+ WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
312
314
  RETURN n.uuid AS uuid""",
313
315
  uuid=self.uuid,
314
316
  name=self.name,
@@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
23
23
  class Prompt(Protocol):
24
24
  qa_prompt: PromptVersion
25
25
  eval_prompt: PromptVersion
26
+ query_expansion: PromptVersion
26
27
 
27
28
 
28
29
  class Versions(TypedDict):
29
30
  qa_prompt: PromptFunction
30
31
  eval_prompt: PromptFunction
32
+ query_expansion: PromptFunction
33
+
34
+
35
+ def query_expansion(context: dict[str, Any]) -> list[Message]:
36
+ sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
37
+
38
+ user_prompt = f"""
39
+ Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
40
+ that maintains the relevant context?
41
+ <QUESTION>
42
+ {json.dumps(context['query'])}
43
+ </QUESTION>
44
+ respond with a JSON object in the following format:
45
+ {{
46
+ "query": "query optimized for database search"
47
+ }}
48
+ """
49
+ return [
50
+ Message(role='system', content=sys_prompt),
51
+ Message(role='user', content=user_prompt),
52
+ ]
31
53
 
32
54
 
33
55
  def qa_prompt(context: dict[str, Any]) -> list[Message]:
@@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
38
60
  You are given the following entity summaries and facts to help you determine the answer to your question.
39
61
  <ENTITY_SUMMARIES>
40
62
  {json.dumps(context['entity_summaries'])}
41
- </ENTITY_SUMMARIES
63
+ </ENTITY_SUMMARIES>
42
64
  <FACTS>
43
65
  {json.dumps(context['facts'])}
44
66
  </FACTS>
@@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
87
109
  ]
88
110
 
89
111
 
90
- versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
112
+ versions: Versions = {
113
+ 'qa_prompt': qa_prompt,
114
+ 'eval_prompt': eval_prompt,
115
+ 'query_expansion': query_expansion,
116
+ }
@@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
37
37
  role='user',
38
38
  content=f"""
39
39
  Edge:
40
- Edge Name: {context['edge_name']}
41
40
  Fact: {context['edge_fact']}
42
41
 
43
42
  Current Episode: {context['current_episode']}
@@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
56
55
  Guidelines:
57
56
  1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
58
57
  2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
59
- 3. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
60
- 4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
61
- 5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
62
- 6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
63
- 7. If only a year is mentioned, use January 1st of that year at 00:00:00.
58
+ 3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
59
+ 4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
60
+ 5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
61
+ 6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
62
+ 7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
63
+ 8. If only a year is mentioned, use January 1st of that year at 00:00:00.
64
64
  9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
65
65
  Respond with a JSON object:
66
66
  {{
67
- "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
68
- "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
69
- "explanation": "Brief explanation of why these dates were chosen or why they were set to null"
67
+ "valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
68
+ "invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
70
69
  }}
71
70
  """,
72
71
  ),
@@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
113
113
  2. Each edge should represent a clear relationship between two DISTINCT nodes.
114
114
  3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
115
115
  4. Provide a more detailed fact describing the relationship.
116
- 5. Consider temporal aspects of relationships when relevant.
117
- 6. Avoid using the same node as the source and target of a relationship
116
+ 5. The fact should include any specific relevant information, including numeric information
117
+ 6. Consider temporal aspects of relationships when relevant.
118
+ 7. Avoid using the same node as the source and target of a relationship
118
119
 
119
120
  Respond with a JSON object in the following format:
120
121
  {{
@@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
82
82
  Message(
83
83
  role='user',
84
84
  content=f"""
85
- Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
85
+ Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
86
86
 
87
87
  Existing Edges:
88
88
  {context['existing_edges']}
@@ -29,13 +29,10 @@ from graphiti_core.search.search_config import (
29
29
  DEFAULT_SEARCH_LIMIT,
30
30
  CommunityReranker,
31
31
  CommunitySearchConfig,
32
- CommunitySearchMethod,
33
32
  EdgeReranker,
34
33
  EdgeSearchConfig,
35
- EdgeSearchMethod,
36
34
  NodeReranker,
37
35
  NodeSearchConfig,
38
- NodeSearchMethod,
39
36
  SearchConfig,
40
37
  SearchResults,
41
38
  )
@@ -45,6 +42,7 @@ from graphiti_core.search.search_utils import (
45
42
  edge_fulltext_search,
46
43
  edge_similarity_search,
47
44
  episode_mentions_reranker,
45
+ maximal_marginal_relevance,
48
46
  node_distance_reranker,
49
47
  node_fulltext_search,
50
48
  node_similarity_search,
@@ -120,22 +118,18 @@ async def edge_search(
120
118
  if config is None:
121
119
  return []
122
120
 
123
- search_results: list[list[EntityEdge]] = []
121
+ query_vector = await embedder.create(input=[query])
124
122
 
125
- if EdgeSearchMethod.bm25 in config.search_methods:
126
- text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
127
- search_results.append(text_search)
128
-
129
- if EdgeSearchMethod.cosine_similarity in config.search_methods:
130
- search_vector = await embedder.create(input=[query])
131
-
132
- similarity_search = await edge_similarity_search(
133
- driver, search_vector, None, None, group_ids, 2 * limit
123
+ search_results: list[list[EntityEdge]] = list(
124
+ await asyncio.gather(
125
+ *[
126
+ edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
127
+ edge_similarity_search(
128
+ driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
129
+ ),
130
+ ]
134
131
  )
135
- search_results.append(similarity_search)
136
-
137
- if len(search_results) > 1 and config.reranker is None:
138
- raise SearchRerankerError('Multiple edge searches enabled without a reranker')
132
+ )
139
133
 
140
134
  edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
141
135
 
@@ -144,6 +138,15 @@ async def edge_search(
144
138
  search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
145
139
 
146
140
  reranked_uuids = rrf(search_result_uuids)
141
+ elif config.reranker == EdgeReranker.mmr:
142
+ search_result_uuids_and_vectors = [
143
+ (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
144
+ for result in search_results
145
+ for edge in result
146
+ ]
147
+ reranked_uuids = maximal_marginal_relevance(
148
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
149
+ )
147
150
  elif config.reranker == EdgeReranker.node_distance:
148
151
  if center_node_uuid is None:
149
152
  raise SearchRerankerError('No center node provided for Node Distance reranker')
@@ -184,22 +187,18 @@ async def node_search(
184
187
  if config is None:
185
188
  return []
186
189
 
187
- search_results: list[list[EntityNode]] = []
188
-
189
- if NodeSearchMethod.bm25 in config.search_methods:
190
- text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
191
- search_results.append(text_search)
190
+ query_vector = await embedder.create(input=[query])
192
191
 
193
- if NodeSearchMethod.cosine_similarity in config.search_methods:
194
- search_vector = await embedder.create(input=[query])
195
-
196
- similarity_search = await node_similarity_search(
197
- driver, search_vector, group_ids, 2 * limit
192
+ search_results: list[list[EntityNode]] = list(
193
+ await asyncio.gather(
194
+ *[
195
+ node_fulltext_search(driver, query, group_ids, 2 * limit),
196
+ node_similarity_search(
197
+ driver, query_vector, group_ids, 2 * limit, config.sim_min_score
198
+ ),
199
+ ]
198
200
  )
199
- search_results.append(similarity_search)
200
-
201
- if len(search_results) > 1 and config.reranker is None:
202
- raise SearchRerankerError('Multiple node searches enabled without a reranker')
201
+ )
203
202
 
204
203
  search_result_uuids = [[node.uuid for node in result] for result in search_results]
205
204
  node_uuid_map = {node.uuid: node for result in search_results for node in result}
@@ -207,6 +206,15 @@ async def node_search(
207
206
  reranked_uuids: list[str] = []
208
207
  if config.reranker == NodeReranker.rrf:
209
208
  reranked_uuids = rrf(search_result_uuids)
209
+ elif config.reranker == NodeReranker.mmr:
210
+ search_result_uuids_and_vectors = [
211
+ (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
212
+ for result in search_results
213
+ for node in result
214
+ ]
215
+ reranked_uuids = maximal_marginal_relevance(
216
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
217
+ )
210
218
  elif config.reranker == NodeReranker.episode_mentions:
211
219
  reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
212
220
  elif config.reranker == NodeReranker.node_distance:
@@ -232,22 +240,18 @@ async def community_search(
232
240
  if config is None:
233
241
  return []
234
242
 
235
- search_results: list[list[CommunityNode]] = []
236
-
237
- if CommunitySearchMethod.bm25 in config.search_methods:
238
- text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
239
- search_results.append(text_search)
240
-
241
- if CommunitySearchMethod.cosine_similarity in config.search_methods:
242
- search_vector = await embedder.create(input=[query])
243
+ query_vector = await embedder.create(input=[query])
243
244
 
244
- similarity_search = await community_similarity_search(
245
- driver, search_vector, group_ids, 2 * limit
245
+ search_results: list[list[CommunityNode]] = list(
246
+ await asyncio.gather(
247
+ *[
248
+ community_fulltext_search(driver, query, group_ids, 2 * limit),
249
+ community_similarity_search(
250
+ driver, query_vector, group_ids, 2 * limit, config.sim_min_score
251
+ ),
252
+ ]
246
253
  )
247
- search_results.append(similarity_search)
248
-
249
- if len(search_results) > 1 and config.reranker is None:
250
- raise SearchRerankerError('Multiple node searches enabled without a reranker')
254
+ )
251
255
 
252
256
  search_result_uuids = [[community.uuid for community in result] for result in search_results]
253
257
  community_uuid_map = {
@@ -257,6 +261,18 @@ async def community_search(
257
261
  reranked_uuids: list[str] = []
258
262
  if config.reranker == CommunityReranker.rrf:
259
263
  reranked_uuids = rrf(search_result_uuids)
264
+ elif config.reranker == CommunityReranker.mmr:
265
+ search_result_uuids_and_vectors = [
266
+ (
267
+ community.uuid,
268
+ community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
269
+ )
270
+ for result in search_results
271
+ for community in result
272
+ ]
273
+ reranked_uuids = maximal_marginal_relevance(
274
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
275
+ )
260
276
 
261
277
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
262
278
 
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
20
20
 
21
21
  from graphiti_core.edges import EntityEdge
22
22
  from graphiti_core.nodes import CommunityNode, EntityNode
23
+ from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
23
24
 
24
25
  DEFAULT_SEARCH_LIMIT = 10
25
26
 
@@ -43,31 +44,40 @@ class EdgeReranker(Enum):
43
44
  rrf = 'reciprocal_rank_fusion'
44
45
  node_distance = 'node_distance'
45
46
  episode_mentions = 'episode_mentions'
47
+ mmr = 'mmr'
46
48
 
47
49
 
48
50
  class NodeReranker(Enum):
49
51
  rrf = 'reciprocal_rank_fusion'
50
52
  node_distance = 'node_distance'
51
53
  episode_mentions = 'episode_mentions'
54
+ mmr = 'mmr'
52
55
 
53
56
 
54
57
  class CommunityReranker(Enum):
55
58
  rrf = 'reciprocal_rank_fusion'
59
+ mmr = 'mmr'
56
60
 
57
61
 
58
62
  class EdgeSearchConfig(BaseModel):
59
63
  search_methods: list[EdgeSearchMethod]
60
- reranker: EdgeReranker | None
64
+ reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
65
+ sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
66
+ mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
61
67
 
62
68
 
63
69
  class NodeSearchConfig(BaseModel):
64
70
  search_methods: list[NodeSearchMethod]
65
- reranker: NodeReranker | None
71
+ reranker: NodeReranker = Field(default=NodeReranker.rrf)
72
+ sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
73
+ mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
66
74
 
67
75
 
68
76
  class CommunitySearchConfig(BaseModel):
69
77
  search_methods: list[CommunitySearchMethod]
70
- reranker: CommunityReranker | None
78
+ reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
79
+ sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
80
+ mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
71
81
 
72
82
 
73
83
  class SearchConfig(BaseModel):
@@ -43,6 +43,22 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
43
43
  ),
44
44
  )
45
45
 
46
+ # Performs a hybrid search with mmr reranking over edges, nodes, and communities
47
+ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
48
+ edge_config=EdgeSearchConfig(
49
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
50
+ reranker=EdgeReranker.mmr,
51
+ ),
52
+ node_config=NodeSearchConfig(
53
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
54
+ reranker=NodeReranker.mmr,
55
+ ),
56
+ community_config=CommunitySearchConfig(
57
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
58
+ reranker=CommunityReranker.mmr,
59
+ ),
60
+ )
61
+
46
62
  # performs a hybrid search over edges with rrf reranking
47
63
  EDGE_HYBRID_SEARCH_RRF = SearchConfig(
48
64
  edge_config=EdgeSearchConfig(
@@ -51,6 +67,14 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
51
67
  )
52
68
  )
53
69
 
70
+ # performs a hybrid search over edges with mmr reranking
71
+ EDGE_HYBRID_SEARCH_mmr = SearchConfig(
72
+ edge_config=EdgeSearchConfig(
73
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
74
+ reranker=EdgeReranker.mmr,
75
+ )
76
+ )
77
+
54
78
  # performs a hybrid search over edges with node distance reranking
55
79
  EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
56
80
  edge_config=EdgeSearchConfig(
@@ -75,6 +99,14 @@ NODE_HYBRID_SEARCH_RRF = SearchConfig(
75
99
  )
76
100
  )
77
101
 
102
+ # performs a hybrid search over nodes with mmr reranking
103
+ NODE_HYBRID_SEARCH_MMR = SearchConfig(
104
+ node_config=NodeSearchConfig(
105
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
106
+ reranker=NodeReranker.mmr,
107
+ )
108
+ )
109
+
78
110
  # performs a hybrid search over nodes with node distance reranking
79
111
  NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
80
112
  node_config=NodeSearchConfig(
@@ -98,3 +130,11 @@ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
98
130
  reranker=CommunityReranker.rrf,
99
131
  )
100
132
  )
133
+
134
+ # performs a hybrid search over communities with mmr reranking
135
+ COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
136
+ community_config=CommunitySearchConfig(
137
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
138
+ reranker=CommunityReranker.mmr,
139
+ )
140
+ )
@@ -19,10 +19,11 @@ import logging
19
19
  from collections import defaultdict
20
20
  from time import time
21
21
 
22
+ import numpy as np
22
23
  from neo4j import AsyncDriver, Query
23
24
 
24
25
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
25
- from graphiti_core.helpers import lucene_sanitize
26
+ from graphiti_core.helpers import lucene_sanitize, normalize_l2
26
27
  from graphiti_core.nodes import (
27
28
  CommunityNode,
28
29
  EntityNode,
@@ -34,6 +35,8 @@ from graphiti_core.nodes import (
34
35
  logger = logging.getLogger(__name__)
35
36
 
36
37
  RELEVANT_SCHEMA_LIMIT = 3
38
+ DEFAULT_MIN_SCORE = 0.6
39
+ DEFAULT_MMR_LAMBDA = 0.5
37
40
 
38
41
 
39
42
  def fulltext_query(query: str, group_ids: list[str] | None = None):
@@ -52,8 +55,23 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
52
55
  return full_query
53
56
 
54
57
 
58
+ async def get_episodes_by_mentions(
59
+ driver: AsyncDriver,
60
+ nodes: list[EntityNode],
61
+ edges: list[EntityEdge],
62
+ limit: int = RELEVANT_SCHEMA_LIMIT,
63
+ ) -> list[EpisodicNode]:
64
+ episode_uuids: list[str] = []
65
+ for edge in edges:
66
+ episode_uuids.extend(edge.episodes)
67
+
68
+ episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
69
+
70
+ return episodes
71
+
72
+
55
73
  async def get_mentioned_nodes(
56
- driver: AsyncDriver, episodes: list[EpisodicNode]
74
+ driver: AsyncDriver, episodes: list[EpisodicNode]
57
75
  ) -> list[EntityNode]:
58
76
  episode_uuids = [episode.uuid for episode in episodes]
59
77
  records, _, _ = await driver.execute_query(
@@ -76,7 +94,7 @@ async def get_mentioned_nodes(
76
94
 
77
95
 
78
96
  async def get_communities_by_nodes(
79
- driver: AsyncDriver, nodes: list[EntityNode]
97
+ driver: AsyncDriver, nodes: list[EntityNode]
80
98
  ) -> list[CommunityNode]:
81
99
  node_uuids = [node.uuid for node in nodes]
82
100
  records, _, _ = await driver.execute_query(
@@ -99,12 +117,12 @@ async def get_communities_by_nodes(
99
117
 
100
118
 
101
119
  async def edge_fulltext_search(
102
- driver: AsyncDriver,
103
- query: str,
104
- source_node_uuid: str | None,
105
- target_node_uuid: str | None,
106
- group_ids: list[str] | None = None,
107
- limit=RELEVANT_SCHEMA_LIMIT,
120
+ driver: AsyncDriver,
121
+ query: str,
122
+ source_node_uuid: str | None,
123
+ target_node_uuid: str | None,
124
+ group_ids: list[str] | None = None,
125
+ limit=RELEVANT_SCHEMA_LIMIT,
108
126
  ) -> list[EntityEdge]:
109
127
  # fulltext search over facts
110
128
  fuzzy_query = fulltext_query(query, group_ids)
@@ -113,9 +131,6 @@ async def edge_fulltext_search(
113
131
  CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
114
132
  YIELD relationship AS rel, score
115
133
  MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
116
- WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
117
- AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
118
- AND ($group_ids IS NULL OR n.group_id IN $group_ids)
119
134
  RETURN
120
135
  r.uuid AS uuid,
121
136
  r.group_id AS group_id,
@@ -147,21 +162,24 @@ async def edge_fulltext_search(
147
162
 
148
163
 
149
164
  async def edge_similarity_search(
150
- driver: AsyncDriver,
151
- search_vector: list[float],
152
- source_node_uuid: str | None,
153
- target_node_uuid: str | None,
154
- group_ids: list[str] | None = None,
155
- limit: int = RELEVANT_SCHEMA_LIMIT,
165
+ driver: AsyncDriver,
166
+ search_vector: list[float],
167
+ source_node_uuid: str | None,
168
+ target_node_uuid: str | None,
169
+ group_ids: list[str] | None = None,
170
+ limit: int = RELEVANT_SCHEMA_LIMIT,
171
+ min_score: float = DEFAULT_MIN_SCORE,
156
172
  ) -> list[EntityEdge]:
157
173
  # vector similarity search over embedded facts
158
174
  query = Query("""
175
+ CYPHER runtime = parallel parallelRuntimeSupport=all
159
176
  MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
160
177
  WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
161
178
  AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
162
179
  AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
180
+ WITH n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
181
+ WHERE score > $min_score
163
182
  RETURN
164
- vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
165
183
  r.uuid AS uuid,
166
184
  r.group_id AS group_id,
167
185
  n.uuid AS source_node_uuid,
@@ -185,6 +203,7 @@ async def edge_similarity_search(
185
203
  target_uuid=target_node_uuid,
186
204
  group_ids=group_ids,
187
205
  limit=limit,
206
+ min_score=min_score,
188
207
  )
189
208
 
190
209
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -193,10 +212,10 @@ async def edge_similarity_search(
193
212
 
194
213
 
195
214
  async def node_fulltext_search(
196
- driver: AsyncDriver,
197
- query: str,
198
- group_ids: list[str] | None = None,
199
- limit=RELEVANT_SCHEMA_LIMIT,
215
+ driver: AsyncDriver,
216
+ query: str,
217
+ group_ids: list[str] | None = None,
218
+ limit=RELEVANT_SCHEMA_LIMIT,
200
219
  ) -> list[EntityNode]:
201
220
  # BM25 search to get top nodes
202
221
  fuzzy_query = fulltext_query(query, group_ids)
@@ -205,7 +224,6 @@ async def node_fulltext_search(
205
224
  """
206
225
  CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
207
226
  YIELD node AS n, score
208
- WHERE $group_ids IS NULL OR n.group_id IN $group_ids
209
227
  RETURN
210
228
  n.uuid AS uuid,
211
229
  n.group_id AS group_id,
@@ -226,18 +244,21 @@ async def node_fulltext_search(
226
244
 
227
245
 
228
246
  async def node_similarity_search(
229
- driver: AsyncDriver,
230
- search_vector: list[float],
231
- group_ids: list[str] | None = None,
232
- limit=RELEVANT_SCHEMA_LIMIT,
247
+ driver: AsyncDriver,
248
+ search_vector: list[float],
249
+ group_ids: list[str] | None = None,
250
+ limit=RELEVANT_SCHEMA_LIMIT,
251
+ min_score: float = DEFAULT_MIN_SCORE,
233
252
  ) -> list[EntityNode]:
234
253
  # vector similarity search over entity names
235
254
  records, _, _ = await driver.execute_query(
236
255
  """
256
+ CYPHER runtime = parallel parallelRuntimeSupport=all
237
257
  MATCH (n:Entity)
238
258
  WHERE $group_ids IS NULL OR n.group_id IN $group_ids
259
+ WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
260
+ WHERE score > $min_score
239
261
  RETURN
240
- vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
241
262
  n.uuid As uuid,
242
263
  n.group_id AS group_id,
243
264
  n.name AS name,
@@ -250,6 +271,7 @@ async def node_similarity_search(
250
271
  search_vector=search_vector,
251
272
  group_ids=group_ids,
252
273
  limit=limit,
274
+ min_score=min_score,
253
275
  )
254
276
  nodes = [get_entity_node_from_record(record) for record in records]
255
277
 
@@ -257,10 +279,10 @@ async def node_similarity_search(
257
279
 
258
280
 
259
281
  async def community_fulltext_search(
260
- driver: AsyncDriver,
261
- query: str,
262
- group_ids: list[str] | None = None,
263
- limit=RELEVANT_SCHEMA_LIMIT,
282
+ driver: AsyncDriver,
283
+ query: str,
284
+ group_ids: list[str] | None = None,
285
+ limit=RELEVANT_SCHEMA_LIMIT,
264
286
  ) -> list[CommunityNode]:
265
287
  # BM25 search to get top communities
266
288
  fuzzy_query = fulltext_query(query, group_ids)
@@ -269,8 +291,6 @@ async def community_fulltext_search(
269
291
  """
270
292
  CALL db.index.fulltext.queryNodes("community_name", $query)
271
293
  YIELD node AS comm, score
272
- MATCH (comm:Community)
273
- WHERE $group_ids IS NULL OR comm.group_id in $group_ids
274
294
  RETURN
275
295
  comm.uuid AS uuid,
276
296
  comm.group_id AS group_id,
@@ -291,18 +311,21 @@ async def community_fulltext_search(
291
311
 
292
312
 
293
313
  async def community_similarity_search(
294
- driver: AsyncDriver,
295
- search_vector: list[float],
296
- group_ids: list[str] | None = None,
297
- limit=RELEVANT_SCHEMA_LIMIT,
314
+ driver: AsyncDriver,
315
+ search_vector: list[float],
316
+ group_ids: list[str] | None = None,
317
+ limit=RELEVANT_SCHEMA_LIMIT,
318
+ min_score=DEFAULT_MIN_SCORE,
298
319
  ) -> list[CommunityNode]:
299
320
  # vector similarity search over entity names
300
321
  records, _, _ = await driver.execute_query(
301
322
  """
323
+ CYPHER runtime = parallel parallelRuntimeSupport=all
302
324
  MATCH (comm:Community)
303
325
  WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
326
+ WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
327
+ WHERE score > $min_score
304
328
  RETURN
305
- vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
306
329
  comm.uuid As uuid,
307
330
  comm.group_id AS group_id,
308
331
  comm.name AS name,
@@ -315,6 +338,7 @@ async def community_similarity_search(
315
338
  search_vector=search_vector,
316
339
  group_ids=group_ids,
317
340
  limit=limit,
341
+ min_score=min_score,
318
342
  )
319
343
  communities = [get_community_node_from_record(record) for record in records]
320
344
 
@@ -322,11 +346,11 @@ async def community_similarity_search(
322
346
 
323
347
 
324
348
  async def hybrid_node_search(
325
- queries: list[str],
326
- embeddings: list[list[float]],
327
- driver: AsyncDriver,
328
- group_ids: list[str] | None = None,
329
- limit: int = RELEVANT_SCHEMA_LIMIT,
349
+ queries: list[str],
350
+ embeddings: list[list[float]],
351
+ driver: AsyncDriver,
352
+ group_ids: list[str] | None = None,
353
+ limit: int = RELEVANT_SCHEMA_LIMIT,
330
354
  ) -> list[EntityNode]:
331
355
  """
332
356
  Perform a hybrid search for nodes using both text queries and embeddings.
@@ -389,8 +413,8 @@ async def hybrid_node_search(
389
413
 
390
414
 
391
415
  async def get_relevant_nodes(
392
- nodes: list[EntityNode],
393
- driver: AsyncDriver,
416
+ nodes: list[EntityNode],
417
+ driver: AsyncDriver,
394
418
  ) -> list[EntityNode]:
395
419
  """
396
420
  Retrieve relevant nodes based on the provided list of EntityNodes.
@@ -427,11 +451,11 @@ async def get_relevant_nodes(
427
451
 
428
452
 
429
453
  async def get_relevant_edges(
430
- driver: AsyncDriver,
431
- edges: list[EntityEdge],
432
- source_node_uuid: str | None,
433
- target_node_uuid: str | None,
434
- limit: int = RELEVANT_SCHEMA_LIMIT,
454
+ driver: AsyncDriver,
455
+ edges: list[EntityEdge],
456
+ source_node_uuid: str | None,
457
+ target_node_uuid: str | None,
458
+ limit: int = RELEVANT_SCHEMA_LIMIT,
435
459
  ) -> list[EntityEdge]:
436
460
  start = time()
437
461
  relevant_edges: list[EntityEdge] = []
@@ -488,7 +512,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
488
512
 
489
513
 
490
514
  async def node_distance_reranker(
491
- driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
515
+ driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
492
516
  ) -> list[str]:
493
517
  # filter out node_uuid center node node uuid
494
518
  filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
@@ -555,3 +579,24 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
555
579
  sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
556
580
 
557
581
  return sorted_uuids
582
+
583
+
584
+ def maximal_marginal_relevance(
585
+ query_vector: list[float],
586
+ candidates: list[tuple[str, list[float]]],
587
+ mmr_lambda: float = DEFAULT_MMR_LAMBDA,
588
+ ):
589
+ candidates_with_mmr: list[tuple[str, float]] = []
590
+ for candidate in candidates:
591
+ max_sim = max(
592
+ [
593
+ np.dot(normalize_l2(candidate[1]), normalize_l2(c[1]))
594
+ for c in candidates
595
+ ]
596
+ )
597
+ mmr = mmr_lambda * np.dot(candidate[1], query_vector) + (1 - mmr_lambda) * max_sim
598
+ candidates_with_mmr.append((candidate[0], mmr))
599
+
600
+ candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
601
+
602
+ return [candidate[0] for candidate in candidates_with_mmr]
@@ -4,7 +4,6 @@ from .graph_data_operations import (
4
4
  retrieve_episodes,
5
5
  )
6
6
  from .node_operations import extract_nodes
7
- from .temporal_operations import invalidate_edges
8
7
 
9
8
  __all__ = [
10
9
  'extract_edges',
@@ -12,5 +11,4 @@ __all__ = [
12
11
  'extract_nodes',
13
12
  'clear_data',
14
13
  'retrieve_episodes',
15
- 'invalidate_edges',
16
14
  ]
@@ -15,7 +15,6 @@ from graphiti_core.utils.maintenance.edge_operations import build_community_edge
15
15
 
16
16
  MAX_COMMUNITY_BUILD_CONCURRENCY = 10
17
17
 
18
-
19
18
  logger = logging.getLogger(__name__)
20
19
 
21
20
 
@@ -24,31 +23,20 @@ class Neighbor(BaseModel):
24
23
  edge_count: int
25
24
 
26
25
 
27
- async def build_community_projection(driver: AsyncDriver) -> str:
28
- records, _, _ = await driver.execute_query("""
29
- CALL gds.graph.project("communities", "Entity",
30
- {RELATES_TO: {
31
- type: "RELATES_TO",
32
- orientation: "UNDIRECTED",
33
- properties: {weight: {property: "*", aggregation: "COUNT"}}
34
- }}
35
- )
36
- YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
37
- """)
38
-
39
- return records[0]['graph']
40
-
41
-
42
- async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
26
+ async def get_community_clusters(
27
+ driver: AsyncDriver, group_ids: list[str] | None
28
+ ) -> list[list[EntityNode]]:
43
29
  community_clusters: list[list[EntityNode]] = []
44
30
 
45
- group_id_values, _, _ = await driver.execute_query("""
46
- MATCH (n:Entity WHERE n.group_id IS NOT NULL)
47
- RETURN
48
- collect(DISTINCT n.group_id) AS group_ids
49
- """)
31
+ if group_ids is None:
32
+ group_id_values, _, _ = await driver.execute_query("""
33
+ MATCH (n:Entity WHERE n.group_id IS NOT NULL)
34
+ RETURN
35
+ collect(DISTINCT n.group_id) AS group_ids
36
+ """)
37
+
38
+ group_ids = group_id_values[0]['group_ids']
50
39
 
51
- group_ids = group_id_values[0]['group_ids']
52
40
  for group_id in group_ids:
53
41
  projection: dict[str, list[Neighbor]] = {}
54
42
  nodes = await EntityNode.get_by_group_ids(driver, [group_id])
@@ -197,9 +185,9 @@ async def build_community(
197
185
 
198
186
 
199
187
  async def build_communities(
200
- driver: AsyncDriver, llm_client: LLMClient
188
+ driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None
201
189
  ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
202
- community_clusters = await get_community_clusters(driver)
190
+ community_clusters = await get_community_clusters(driver, group_ids)
203
191
 
204
192
  semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
205
193
 
@@ -122,12 +122,6 @@ async def extract_edges(
122
122
  return edges
123
123
 
124
124
 
125
- def create_edge_identifier(
126
- source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
127
- ) -> str:
128
- return f'{source_node.name}-{edge.name}-{target_node.name}'
129
-
130
-
131
125
  async def dedupe_extracted_edges(
132
126
  llm_client: LLMClient,
133
127
  extracted_edges: list[EntityEdge],
@@ -251,11 +245,11 @@ async def resolve_extracted_edge(
251
245
  if (
252
246
  edge.invalid_at is not None
253
247
  and resolved_edge.valid_at is not None
254
- and edge.invalid_at < resolved_edge.valid_at
248
+ and edge.invalid_at <= resolved_edge.valid_at
255
249
  ) or (
256
250
  edge.valid_at is not None
257
251
  and resolved_edge.invalid_at is not None
258
- and resolved_edge.invalid_at < edge.valid_at
252
+ and resolved_edge.invalid_at <= edge.valid_at
259
253
  ):
260
254
  continue
261
255
  # New edge invalidates edge
@@ -21,129 +21,11 @@ from typing import List
21
21
 
22
22
  from graphiti_core.edges import EntityEdge
23
23
  from graphiti_core.llm_client import LLMClient
24
- from graphiti_core.nodes import EntityNode, EpisodicNode
24
+ from graphiti_core.nodes import EpisodicNode
25
25
  from graphiti_core.prompts import prompt_library
26
26
 
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
- NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
30
-
31
-
32
- def extract_node_and_edge_triplets(
33
- edges: list[EntityEdge], nodes: list[EntityNode]
34
- ) -> list[NodeEdgeNodeTriplet]:
35
- return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
36
-
37
-
38
- def extract_node_edge_node_triplet(
39
- edge: EntityEdge, nodes: list[EntityNode]
40
- ) -> NodeEdgeNodeTriplet:
41
- source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
42
- target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
43
- if not source_node or not target_node:
44
- raise ValueError(f'Source or target node not found for edge {edge.uuid}')
45
- return (source_node, edge, target_node)
46
-
47
-
48
- def prepare_edges_for_invalidation(
49
- existing_edges: list[EntityEdge],
50
- new_edges: list[EntityEdge],
51
- nodes: list[EntityNode],
52
- ) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
53
- existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
54
- new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
55
-
56
- for edge_list, result_list in [
57
- (existing_edges, existing_edges_pending_invalidation),
58
- (new_edges, new_edges_with_nodes),
59
- ]:
60
- for edge in edge_list:
61
- source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
62
- target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
63
-
64
- if source_node and target_node:
65
- result_list.append((source_node, edge, target_node))
66
-
67
- return existing_edges_pending_invalidation, new_edges_with_nodes
68
-
69
-
70
- async def invalidate_edges(
71
- llm_client: LLMClient,
72
- existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
73
- new_edges: list[NodeEdgeNodeTriplet],
74
- current_episode: EpisodicNode,
75
- previous_episodes: list[EpisodicNode],
76
- ) -> list[EntityEdge]:
77
- invalidated_edges = [] # TODO: this is not yet used?
78
-
79
- context = prepare_invalidation_context(
80
- existing_edges_pending_invalidation,
81
- new_edges,
82
- current_episode,
83
- previous_episodes,
84
- )
85
- llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
86
-
87
- edges_to_invalidate = llm_response.get('invalidated_edges', [])
88
- invalidated_edges = process_edge_invalidation_llm_response(
89
- edges_to_invalidate, existing_edges_pending_invalidation
90
- )
91
-
92
- return invalidated_edges
93
-
94
-
95
- def extract_date_strings_from_edge(edge: EntityEdge) -> str:
96
- start = edge.valid_at
97
- end = edge.invalid_at
98
- date_string = f'Start Date: {start.isoformat()}' if start else ''
99
- if end:
100
- date_string += f' (End Date: {end.isoformat()})'
101
- return date_string
102
-
103
-
104
- def prepare_invalidation_context(
105
- existing_edges: list[NodeEdgeNodeTriplet],
106
- new_edges: list[NodeEdgeNodeTriplet],
107
- current_episode: EpisodicNode,
108
- previous_episodes: list[EpisodicNode],
109
- ) -> dict:
110
- return {
111
- 'existing_edges': [
112
- f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
113
- for source_node, edge, target_node in sorted(
114
- existing_edges, key=lambda x: (x[1].created_at), reverse=True
115
- )
116
- ],
117
- 'new_edges': [
118
- f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
119
- for source_node, edge, target_node in sorted(
120
- new_edges, key=lambda x: (x[1].created_at), reverse=True
121
- )
122
- ],
123
- 'current_episode': current_episode.content,
124
- 'previous_episodes': [episode.content for episode in previous_episodes],
125
- }
126
-
127
-
128
- def process_edge_invalidation_llm_response(
129
- edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
130
- ) -> List[EntityEdge]:
131
- invalidated_edges = []
132
- for edge_to_invalidate in edges_to_invalidate:
133
- edge_uuid = edge_to_invalidate['edge_uuid']
134
- edge_to_update = next(
135
- (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
136
- None,
137
- )
138
- if edge_to_update:
139
- edge_to_update.expired_at = datetime.now()
140
- edge_to_update.fact = edge_to_invalidate['fact']
141
- invalidated_edges.append(edge_to_update)
142
- logger.info(
143
- f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
144
- )
145
- return invalidated_edges
146
-
147
29
 
148
30
  async def extract_edge_dates(
149
31
  llm_client: LLMClient,
@@ -152,7 +34,6 @@ async def extract_edge_dates(
152
34
  previous_episodes: List[EpisodicNode],
153
35
  ) -> tuple[datetime | None, datetime | None]:
154
36
  context = {
155
- 'edge_name': edge.name,
156
37
  'edge_fact': edge.fact,
157
38
  'current_episode': current_episode.content,
158
39
  'previous_episodes': [ep.content for ep in previous_episodes],
@@ -162,25 +43,22 @@ async def extract_edge_dates(
162
43
 
163
44
  valid_at = llm_response.get('valid_at')
164
45
  invalid_at = llm_response.get('invalid_at')
165
- explanation = llm_response.get('explanation', '')
166
46
 
167
47
  valid_at_datetime = None
168
48
  invalid_at_datetime = None
169
49
 
170
- if valid_at and valid_at != '':
50
+ if valid_at:
171
51
  try:
172
52
  valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
173
53
  except ValueError as e:
174
54
  logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
175
55
 
176
- if invalid_at and invalid_at != '':
56
+ if invalid_at:
177
57
  try:
178
58
  invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
179
59
  except ValueError as e:
180
60
  logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
181
61
 
182
- logger.info(f'Edge date extraction explanation: {explanation}')
183
-
184
62
  return valid_at_datetime, invalid_at_datetime
185
63
 
186
64
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.7
3
+ Version: 0.3.9
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -14,9 +14,10 @@ Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: diskcache (>=5.6.3,<6.0.0)
15
15
  Requires-Dist: neo4j (>=5.23.0,<6.0.0)
16
16
  Requires-Dist: numpy (>=1.0.0)
17
- Requires-Dist: openai (>=1.38.0,<2.0.0)
17
+ Requires-Dist: openai (>=1.50.2,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
19
  Requires-Dist: tenacity (<9.0.0)
20
+ Requires-Dist: voyageai (>=0.2.3,<0.3.0)
20
21
  Description-Content-Type: text/markdown
21
22
 
22
23
  <div align="center">
@@ -1,12 +1,12 @@
1
1
  graphiti_core/__init__.py,sha256=e5SWFkRiaUwfprYIeIgVIh7JDedNiloZvd3roU-0aDY,55
2
- graphiti_core/edges.py,sha256=lLuRKjSHTk1GvTS06OUw2lSMiDAB4TQSXgnLq1fU3n8,13378
2
+ graphiti_core/edges.py,sha256=IKWe6nRxg749RD7o5AgbbH6blCPBtNREBRSU_oMt4tM,13434
3
3
  graphiti_core/embedder/__init__.py,sha256=eWd-0sPxflnYXLoWNT9sxwCIFun5JNO9Fk4E-ZXXf8Y,164
4
4
  graphiti_core/embedder/client.py,sha256=Sd9CyYXaqRazdOH8opKackrTx-y9y-T54M78XTVMzxs,1006
5
- graphiti_core/embedder/openai.py,sha256=28cl4qQCQeu6EGxVVPw3lPesA-Z_Cpvuhozyc1jdqVg,1586
6
- graphiti_core/embedder/voyage.py,sha256=pGrSquGnSiYl4nXGnutbdWchtYgZb0Fi_yW3c90dPlI,1497
5
+ graphiti_core/embedder/openai.py,sha256=_FVpmdgEBgbeXGQjivhiA7qxEUuDNwCxI_l-2k_95QA,1590
6
+ graphiti_core/embedder/voyage.py,sha256=jLf43hIzeAnSZSy0P4jitVacWLYiKn3o8qZ9w10-r6E,1501
7
7
  graphiti_core/errors.py,sha256=iJrkk5sTgc2z16ABS6TziPylEabdBJcpk0x9KyBUmxs,1527
8
- graphiti_core/graphiti.py,sha256=5E2UbYlbl65D3MZyagEUPgoPrb_kVYDIqIw7KVlU_NM,26162
9
- graphiti_core/helpers.py,sha256=_wTSDcYmeXT3u0AwX15iSLuTRa_SR4jJdT10rxfl1_E,1484
8
+ graphiti_core/graphiti.py,sha256=a4ECdZ9-Zx-KxCctIKLnrLBnpgdwTu6jm0DlyDJLnpk,26936
9
+ graphiti_core/helpers.py,sha256=gS0BU5OOL1S6ByV2ogFlGpBiryyBHyM3ZnLSukbl6_4,1996
10
10
  graphiti_core/llm_client/__init__.py,sha256=PA80TSMeX-sUXITXEAxMDEt3gtfZgcJrGJUcyds1mSo,207
11
11
  graphiti_core/llm_client/anthropic_client.py,sha256=4l2PbCjIoeRr7UJ2DUh2grYLTtE2vNaWlo72IIRQDeI,2405
12
12
  graphiti_core/llm_client/client.py,sha256=WAnX0e4EuCFHXdFHeq_O1HZsW1STSByvDCFUHMAHEFU,3394
@@ -15,34 +15,34 @@ graphiti_core/llm_client/errors.py,sha256=-qlWwv1X-UjfsFIiNl-7yJIYvPwi7z8srVRfX4
15
15
  graphiti_core/llm_client/groq_client.py,sha256=5uGWeQ903EuNxuRiaeH-_J1U2Le_b7Q1UGV_K8bQAiw,2329
16
16
  graphiti_core/llm_client/openai_client.py,sha256=xLkbpusRVFRK0zPr3kOqY31HK_XCXrpO5rqUSpcEqEU,3825
17
17
  graphiti_core/llm_client/utils.py,sha256=Ms-QhA5X9rps7NBdJeQZUgQLD3vaZRWPiTlhJa6BjXM,995
18
- graphiti_core/nodes.py,sha256=wIYeRspoRErcX0vvesk_fxhdXKCYn4rpgjgm3PdwSkI,13669
18
+ graphiti_core/nodes.py,sha256=Jcn9LFr22NHVvyh0eSqUk_zTSY0dU3192MRKPUPSR4c,13783
19
19
  graphiti_core/prompts/__init__.py,sha256=EA-x9xUki9l8wnu2l8ek_oNf75-do5tq5hVq7Zbv8Kw,101
20
20
  graphiti_core/prompts/dedupe_edges.py,sha256=DUNHdIudj50FAjkla4nc68tSFSD2yjmYHBw-Bb7ph20,6529
21
21
  graphiti_core/prompts/dedupe_nodes.py,sha256=BZ9S-PB9SSGjc5Oo8ivdgA6rZx3OGOFhKtwrBlQ0bm0,7269
22
- graphiti_core/prompts/eval.py,sha256=fYLY2nKwgE9dB7mtYMNKyn1tQXM8B-tOeYmSzB5Bxk8,2844
23
- graphiti_core/prompts/extract_edge_dates.py,sha256=oOCR8mC_3gI1bumrmIjUbkNO-WTuLTXXAalPDYnDXeM,3655
24
- graphiti_core/prompts/extract_edges.py,sha256=AQ8xYbAv_RKXAT6WMwXs1_GvUdLtM_lhLNbt3SkOAmk,5348
22
+ graphiti_core/prompts/eval.py,sha256=9gavc4SKAPdsrhpN8NEUTc632erkaifyOf0hevmdeKY,3657
23
+ graphiti_core/prompts/extract_edge_dates.py,sha256=pb5Oe5WTZ468REmWNR2NAEHHYMt5GpiJVUAqpVI3aBI,3622
24
+ graphiti_core/prompts/extract_edges.py,sha256=pGmYcl1zKIuu-HmHUkbkThJ5QKED3efMtDlKoT0wBRM,5448
25
25
  graphiti_core/prompts/extract_nodes.py,sha256=VIr0Nh0mSiodI3iGOQFszh7DOni4mufOKJDuGkMysl8,6889
26
- graphiti_core/prompts/invalidate_edges.py,sha256=8SHt3iPTdmqk8A52LxgdMtI39w4USKqVDMOS2i6lRQ4,4342
26
+ graphiti_core/prompts/invalidate_edges.py,sha256=2vhi9TsL9poAHqApfk_Us0VveG0-T8cZymfBwOgA8tc,4341
27
27
  graphiti_core/prompts/lib.py,sha256=ZOE6nNoI_wQ12Sufx7rQkQtkIm_eTAL7pCiYGU2hcMI,4054
28
28
  graphiti_core/prompts/models.py,sha256=cvx_Bv5RMFUD_5IUawYrbpOKLPHogai7_bm7YXrSz84,867
29
29
  graphiti_core/prompts/summarize_nodes.py,sha256=FLuZpGTABgcxuIDkx_IKH115nHEw0rIaFhcGlWveAMc,2357
30
30
  graphiti_core/py.typed,sha256=vlmmzQOt7bmeQl9L3XJP4W6Ry0iiELepnOrinKz5KQg,79
31
31
  graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
- graphiti_core/search/search.py,sha256=odxpm6MJw5ihEDjbBQ2Icvtr5Mf2oG8Yj6LpNqO3gFw,8620
33
- graphiti_core/search/search_config.py,sha256=d8w9RDO55G2bwbjYQBaD6gXqEWK1-NsDANrNibYB6t8,2165
34
- graphiti_core/search/search_config_recipes.py,sha256=_VJqvYB70e8Jke3hsbeQF3Bdogn2MubpYeAQe15M2Jo,3450
35
- graphiti_core/search/search_utils.py,sha256=WE-iVPI92AWR13aM3JQxtHaYoiPzDMtOOo8rEob8QEI,17844
32
+ graphiti_core/search/search.py,sha256=bFCHscRU4V_blPlvuoM4ugRUdeZ6smGAnTMaQulvcjU,9024
33
+ graphiti_core/search/search_config.py,sha256=dWcanEmMoL42RHF-jcZO9C2G9BdqjkI9w-5xe9Wd2Xg,2737
34
+ graphiti_core/search/search_config_recipes.py,sha256=8kr3oeXQG4L_j1IrceOVeE7IGNtUSLTSe3p89-NGwWM,4892
35
+ graphiti_core/search/search_utils.py,sha256=dJ5vYC0U7JyjlritDSNPY3bbFyDqwuV0fDnOZ7H37hk,19421
36
36
  graphiti_core/utils/__init__.py,sha256=cJAcMnBZdHBQmWrZdU1PQ1YmaL75bhVUkyVpIPuOyns,260
37
37
  graphiti_core/utils/bulk_utils.py,sha256=JtoYTZPCigPa3n2E43Oe7QhFZRTA_QKNGy1jVgklHag,12614
38
- graphiti_core/utils/maintenance/__init__.py,sha256=4b9sfxqyFZMLwxxS2lnQ6_wBr3xrJRIqfAWOidK8EK0,388
39
- graphiti_core/utils/maintenance/community_operations.py,sha256=Z2lVrTmUh42sEPqSDZq4fXbcj507BuZrHZKV1vJk6tU,9875
40
- graphiti_core/utils/maintenance/edge_operations.py,sha256=lSeesSnWQ3vpeD2dIY0tSiHEHRMK6fiirEhNNT-s5os,11438
38
+ graphiti_core/utils/maintenance/__init__.py,sha256=TRY3wWWu5kn3Oahk_KKhltrWnh0NACw0FskjqF6OtlA,314
39
+ graphiti_core/utils/maintenance/community_operations.py,sha256=BiL2LTuGSbyZNg65FmgeZ3HSfAl3OuWgbmMlsSoQgk4,9505
40
+ graphiti_core/utils/maintenance/edge_operations.py,sha256=rlB88mQ5WFr1gcSefdUuACLP_mwRbnFohbdJwnh03uo,11265
41
41
  graphiti_core/utils/maintenance/graph_data_operations.py,sha256=RgdqYSau9Mr-f7IUSD1sSPztxlyO0C80C3MPPmPBRi0,6100
42
42
  graphiti_core/utils/maintenance/node_operations.py,sha256=QAg4KQkSAOXx9QRaUp7t6DCaztZlzeOBC3__57FCs_o,9025
43
- graphiti_core/utils/maintenance/temporal_operations.py,sha256=BzfGDm96w4HcUEsaWTHUBt5S8dNmDQL1eX6AuBL-XFM,8135
43
+ graphiti_core/utils/maintenance/temporal_operations.py,sha256=wWLSWqcB3AQWs0YFiVH6avP7RC6Zy_Bua7dBLeUX_V4,3366
44
44
  graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
- graphiti_core-0.3.7.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
46
- graphiti_core-0.3.7.dist-info/METADATA,sha256=6NPJcK3qV8rcVDjopZyTqyc8WlcFXKYDVEOdOzco1KI,9395
47
- graphiti_core-0.3.7.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
48
- graphiti_core-0.3.7.dist-info/RECORD,,
45
+ graphiti_core-0.3.9.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
46
+ graphiti_core-0.3.9.dist-info/METADATA,sha256=leK_2yXwo_vRcQeTVS8F3awiQUbaIQJ0m4TH9RG656o,9436
47
+ graphiti_core-0.3.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
48
+ graphiti_core-0.3.9.dist-info/RECORD,,