graphiti-core 0.3.8__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/edges.py CHANGED
@@ -51,7 +51,7 @@ class Edge(BaseModel, ABC):
51
51
  uuid=self.uuid,
52
52
  )
53
53
 
54
- logger.info(f'Deleted Edge: {self.uuid}')
54
+ logger.debug(f'Deleted Edge: {self.uuid}')
55
55
 
56
56
  return result
57
57
 
@@ -83,7 +83,7 @@ class EpisodicEdge(Edge):
83
83
  created_at=self.created_at,
84
84
  )
85
85
 
86
- logger.info(f'Saved edge to neo4j: {self.uuid}')
86
+ logger.debug(f'Saved edge to neo4j: {self.uuid}')
87
87
 
88
88
  return result
89
89
 
@@ -178,7 +178,7 @@ class EntityEdge(Edge):
178
178
  self.fact_embedding = await embedder.create(input=[text])
179
179
 
180
180
  end = time()
181
- logger.info(f'embedded {text} in {end - start} ms')
181
+ logger.debug(f'embedded {text} in {end - start} ms')
182
182
 
183
183
  return self.fact_embedding
184
184
 
@@ -188,9 +188,9 @@ class EntityEdge(Edge):
188
188
  MATCH (source:Entity {uuid: $source_uuid})
189
189
  MATCH (target:Entity {uuid: $target_uuid})
190
190
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
191
- SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
192
- episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
193
- valid_at: $valid_at, invalid_at: $invalid_at}
191
+ SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
192
+ created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
193
+ WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
194
194
  RETURN r.uuid AS uuid""",
195
195
  source_uuid=self.source_node_uuid,
196
196
  target_uuid=self.target_node_uuid,
@@ -206,7 +206,7 @@ class EntityEdge(Edge):
206
206
  invalid_at=self.invalid_at,
207
207
  )
208
208
 
209
- logger.info(f'Saved edge to neo4j: {self.uuid}')
209
+ logger.debug(f'Saved edge to neo4j: {self.uuid}')
210
210
 
211
211
  return result
212
212
 
@@ -313,7 +313,7 @@ class CommunityEdge(Edge):
313
313
  created_at=self.created_at,
314
314
  )
315
315
 
316
- logger.info(f'Saved edge to neo4j: {self.uuid}')
316
+ logger.debug(f'Saved edge to neo4j: {self.uuid}')
317
317
 
318
318
  return result
319
319
 
graphiti_core/errors.py CHANGED
@@ -35,6 +35,14 @@ class GroupsEdgesNotFoundError(GraphitiError):
35
35
  super().__init__(self.message)
36
36
 
37
37
 
38
+ class GroupsNodesNotFoundError(GraphitiError):
39
+ """Raised when no nodes are found for a list of group ids."""
40
+
41
+ def __init__(self, group_ids: list[str]):
42
+ self.message = f'no nodes found for group ids {group_ids}'
43
+ super().__init__(self.message)
44
+
45
+
38
46
  class NodeNotFoundError(GraphitiError):
39
47
  """Raised when a node is not found."""
40
48
 
graphiti_core/graphiti.py CHANGED
@@ -21,11 +21,12 @@ from time import time
21
21
 
22
22
  from dotenv import load_dotenv
23
23
  from neo4j import AsyncGraphDatabase
24
+ from pydantic import BaseModel
24
25
 
25
26
  from graphiti_core.edges import EntityEdge, EpisodicEdge
26
27
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
27
28
  from graphiti_core.llm_client import LLMClient, OpenAIClient
28
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
29
+ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
29
30
  from graphiti_core.search.search import SearchConfig, search
30
31
  from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
31
32
  from graphiti_core.search.search_config_recipes import (
@@ -77,6 +78,12 @@ logger = logging.getLogger(__name__)
77
78
  load_dotenv()
78
79
 
79
80
 
81
+ class AddEpisodeResults(BaseModel):
82
+ episode: EpisodicNode
83
+ nodes: list[EntityNode]
84
+ edges: list[EntityEdge]
85
+
86
+
80
87
  class Graphiti:
81
88
  def __init__(
82
89
  self,
@@ -245,7 +252,7 @@ class Graphiti:
245
252
  group_id: str = '',
246
253
  uuid: str | None = None,
247
254
  update_communities: bool = False,
248
- ):
255
+ ) -> AddEpisodeResults:
249
256
  """
250
257
  Process an episode and update the graph.
251
258
 
@@ -312,13 +319,11 @@ class Graphiti:
312
319
  valid_at=reference_time,
313
320
  )
314
321
  episode.uuid = uuid if uuid is not None else episode.uuid
315
- if not self.store_raw_episode_content:
316
- episode.content = ''
317
322
 
318
323
  # Extract entities as nodes
319
324
 
320
325
  extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
321
- logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
326
+ logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
322
327
 
323
328
  # Calculate Embeddings
324
329
 
@@ -333,7 +338,7 @@ class Graphiti:
333
338
  )
334
339
  )
335
340
 
336
- logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
341
+ logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
337
342
 
338
343
  (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
339
344
  resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
@@ -341,7 +346,7 @@ class Graphiti:
341
346
  self.llm_client, episode, extracted_nodes, previous_episodes, group_id
342
347
  ),
343
348
  )
344
- logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
349
+ logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
345
350
  nodes = mentioned_nodes
346
351
 
347
352
  extracted_edges_with_resolved_pointers = resolve_edge_pointers(
@@ -371,10 +376,10 @@ class Graphiti:
371
376
  ]
372
377
  )
373
378
  )
374
- logger.info(
379
+ logger.debug(
375
380
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
376
381
  )
377
- logger.info(
382
+ logger.debug(
378
383
  f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
379
384
  )
380
385
 
@@ -426,15 +431,18 @@ class Graphiti:
426
431
 
427
432
  entity_edges.extend(resolved_edges + invalidated_edges)
428
433
 
429
- logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
434
+ logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
430
435
 
431
436
  episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
432
437
 
433
- logger.info(f'Built episodic edges: {episodic_edges}')
438
+ logger.debug(f'Built episodic edges: {episodic_edges}')
434
439
 
435
440
  episode.entity_edges = [edge.uuid for edge in entity_edges]
436
441
 
437
442
  # Future optimization would be using batch operations to save nodes and edges
443
+ if not self.store_raw_episode_content:
444
+ episode.content = ''
445
+
438
446
  await episode.save(self.driver)
439
447
  await asyncio.gather(*[node.save(self.driver) for node in nodes])
440
448
  await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
@@ -451,6 +459,8 @@ class Graphiti:
451
459
  end = time()
452
460
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
453
461
 
462
+ return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
463
+
454
464
  except Exception as e:
455
465
  raise e
456
466
 
@@ -554,7 +564,7 @@ class Graphiti:
554
564
  edges = await dedupe_edges_bulk(
555
565
  self.driver, self.llm_client, extracted_edges_with_resolved_pointers
556
566
  )
557
- logger.info(f'extracted edge length: {len(edges)}')
567
+ logger.debug(f'extracted edge length: {len(edges)}')
558
568
 
559
569
  # invalidate edges
560
570
 
@@ -567,11 +577,20 @@ class Graphiti:
567
577
  except Exception as e:
568
578
  raise e
569
579
 
570
- async def build_communities(self):
580
+ async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
581
+ """
582
+ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
583
+ the content of these communities.
584
+ ----------
585
+ query : list[str] | None
586
+ Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
587
+ """
571
588
  # Clear existing communities
572
589
  await remove_communities(self.driver)
573
590
 
574
- community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
591
+ community_nodes, community_edges = await build_communities(
592
+ self.driver, self.llm_client, group_ids
593
+ )
575
594
 
576
595
  await asyncio.gather(
577
596
  *[node.generate_name_embedding(self.embedder) for node in community_nodes]
@@ -580,6 +599,8 @@ class Graphiti:
580
599
  await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
581
600
  await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
582
601
 
602
+ return community_nodes
603
+
583
604
  async def search(
584
605
  self,
585
606
  query: str,
@@ -700,18 +721,17 @@ class Graphiti:
700
721
  ).nodes
701
722
  return nodes
702
723
 
724
+ async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
725
+ episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
703
726
 
704
- async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
705
- episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
706
-
707
- edges_list = await asyncio.gather(
708
- *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
709
- )
727
+ edges_list = await asyncio.gather(
728
+ *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
729
+ )
710
730
 
711
- edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
731
+ edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
712
732
 
713
- nodes = await get_mentioned_nodes(self.driver, episodes)
733
+ nodes = await get_mentioned_nodes(self.driver, episodes)
714
734
 
715
- communities = await get_communities_by_nodes(self.driver, nodes)
735
+ communities = await get_communities_by_nodes(self.driver, nodes)
716
736
 
717
- return SearchResults(edges=edges, nodes=nodes, communities=communities)
737
+ return SearchResults(edges=edges, nodes=nodes, communities=communities)
graphiti_core/helpers.py CHANGED
@@ -16,6 +16,7 @@ limitations under the License.
16
16
 
17
17
  from datetime import datetime
18
18
 
19
+ import numpy as np
19
20
  from neo4j import time as neo4j_time
20
21
 
21
22
 
@@ -25,7 +26,7 @@ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
25
26
 
26
27
  def lucene_sanitize(query: str) -> str:
27
28
  # Escape special characters from a query before passing into Lucene
28
- # + - && || ! ( ) { } [ ] ^ " ~ * ? : \
29
+ # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
29
30
  escape_map = str.maketrans(
30
31
  {
31
32
  '+': r'\+',
@@ -46,8 +47,21 @@ def lucene_sanitize(query: str) -> str:
46
47
  '?': r'\?',
47
48
  ':': r'\:',
48
49
  '\\': r'\\',
50
+ '/': r'\/',
49
51
  }
50
52
  )
51
53
 
52
54
  sanitized = query.translate(escape_map)
53
55
  return sanitized
56
+
57
+
58
+ def normalize_l2(embedding: list[float]) -> list[float]:
59
+ embedding_array = np.array(embedding)
60
+ if embedding_array.ndim == 1:
61
+ norm = np.linalg.norm(embedding_array)
62
+ if norm == 0:
63
+ return embedding_array.tolist()
64
+ return (embedding_array / norm).tolist()
65
+ else:
66
+ norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
67
+ return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
graphiti_core/nodes.py CHANGED
@@ -86,7 +86,7 @@ class Node(BaseModel, ABC):
86
86
  uuid=self.uuid,
87
87
  )
88
88
 
89
- logger.info(f'Deleted Node: {self.uuid}')
89
+ logger.debug(f'Deleted Node: {self.uuid}')
90
90
 
91
91
  return result
92
92
 
@@ -135,7 +135,7 @@ class EpisodicNode(Node):
135
135
  source=self.source.value,
136
136
  )
137
137
 
138
- logger.info(f'Saved Node to neo4j: {self.uuid}')
138
+ logger.debug(f'Saved Node to neo4j: {self.uuid}')
139
139
 
140
140
  return result
141
141
 
@@ -217,7 +217,7 @@ class EntityNode(Node):
217
217
  text = self.name.replace('\n', ' ')
218
218
  self.name_embedding = await embedder.create(input=[text])
219
219
  end = time()
220
- logger.info(f'embedded {text} in {end - start} ms')
220
+ logger.debug(f'embedded {text} in {end - start} ms')
221
221
 
222
222
  return self.name_embedding
223
223
 
@@ -225,7 +225,8 @@ class EntityNode(Node):
225
225
  result = await driver.execute_query(
226
226
  """
227
227
  MERGE (n:Entity {uuid: $uuid})
228
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
228
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
229
+ WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
229
230
  RETURN n.uuid AS uuid""",
230
231
  uuid=self.uuid,
231
232
  name=self.name,
@@ -235,7 +236,7 @@ class EntityNode(Node):
235
236
  created_at=self.created_at,
236
237
  )
237
238
 
238
- logger.info(f'Saved Node to neo4j: {self.uuid}')
239
+ logger.debug(f'Saved Node to neo4j: {self.uuid}')
239
240
 
240
241
  return result
241
242
 
@@ -257,6 +258,9 @@ class EntityNode(Node):
257
258
 
258
259
  nodes = [get_entity_node_from_record(record) for record in records]
259
260
 
261
+ if len(nodes) == 0:
262
+ raise NodeNotFoundError(uuid)
263
+
260
264
  return nodes[0]
261
265
 
262
266
  @classmethod
@@ -308,7 +312,8 @@ class CommunityNode(Node):
308
312
  result = await driver.execute_query(
309
313
  """
310
314
  MERGE (n:Community {uuid: $uuid})
311
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
315
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
316
+ WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
312
317
  RETURN n.uuid AS uuid""",
313
318
  uuid=self.uuid,
314
319
  name=self.name,
@@ -318,7 +323,7 @@ class CommunityNode(Node):
318
323
  created_at=self.created_at,
319
324
  )
320
325
 
321
- logger.info(f'Saved Node to neo4j: {self.uuid}')
326
+ logger.debug(f'Saved Node to neo4j: {self.uuid}')
322
327
 
323
328
  return result
324
329
 
@@ -327,7 +332,7 @@ class CommunityNode(Node):
327
332
  text = self.name.replace('\n', ' ')
328
333
  self.name_embedding = await embedder.create(input=[text])
329
334
  end = time()
330
- logger.info(f'embedded {text} in {end - start} ms')
335
+ logger.debug(f'embedded {text} in {end - start} ms')
331
336
 
332
337
  return self.name_embedding
333
338
 
@@ -349,6 +354,9 @@ class CommunityNode(Node):
349
354
 
350
355
  nodes = [get_community_node_from_record(record) for record in records]
351
356
 
357
+ if len(nodes) == 0:
358
+ raise NodeNotFoundError(uuid)
359
+
352
360
  return nodes[0]
353
361
 
354
362
  @classmethod
@@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
23
23
  class Prompt(Protocol):
24
24
  qa_prompt: PromptVersion
25
25
  eval_prompt: PromptVersion
26
+ query_expansion: PromptVersion
26
27
 
27
28
 
28
29
  class Versions(TypedDict):
29
30
  qa_prompt: PromptFunction
30
31
  eval_prompt: PromptFunction
32
+ query_expansion: PromptFunction
33
+
34
+
35
+ def query_expansion(context: dict[str, Any]) -> list[Message]:
36
+ sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
37
+
38
+ user_prompt = f"""
39
+ Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
40
+ that maintains the relevant context?
41
+ <QUESTION>
42
+ {json.dumps(context['query'])}
43
+ </QUESTION>
44
+ respond with a JSON object in the following format:
45
+ {{
46
+ "query": "query optimized for database search"
47
+ }}
48
+ """
49
+ return [
50
+ Message(role='system', content=sys_prompt),
51
+ Message(role='user', content=user_prompt),
52
+ ]
31
53
 
32
54
 
33
55
  def qa_prompt(context: dict[str, Any]) -> list[Message]:
@@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
38
60
  You are given the following entity summaries and facts to help you determine the answer to your question.
39
61
  <ENTITY_SUMMARIES>
40
62
  {json.dumps(context['entity_summaries'])}
41
- </ENTITY_SUMMARIES
63
+ </ENTITY_SUMMARIES>
42
64
  <FACTS>
43
65
  {json.dumps(context['facts'])}
44
66
  </FACTS>
@@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
87
109
  ]
88
110
 
89
111
 
90
- versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
112
+ versions: Versions = {
113
+ 'qa_prompt': qa_prompt,
114
+ 'eval_prompt': eval_prompt,
115
+ 'query_expansion': query_expansion,
116
+ }
@@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
37
37
  role='user',
38
38
  content=f"""
39
39
  Edge:
40
- Edge Name: {context['edge_name']}
41
40
  Fact: {context['edge_fact']}
42
41
 
43
42
  Current Episode: {context['current_episode']}
@@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
56
55
  Guidelines:
57
56
  1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
58
57
  2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
59
- 3. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
60
- 4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
61
- 5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
62
- 6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
63
- 7. If only a year is mentioned, use January 1st of that year at 00:00:00.
58
+ 3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
59
+ 4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
60
+ 5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
61
+ 6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
62
+ 7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
63
+ 8. If only a year is mentioned, use January 1st of that year at 00:00:00.
64
64
  9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
65
65
  Respond with a JSON object:
66
66
  {{
67
- "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
68
- "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
69
- "explanation": "Brief explanation of why these dates were chosen or why they were set to null"
67
+ "valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
68
+ "invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
70
69
  }}
71
70
  """,
72
71
  ),
@@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
113
113
  2. Each edge should represent a clear relationship between two DISTINCT nodes.
114
114
  3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
115
115
  4. Provide a more detailed fact describing the relationship.
116
- 5. Consider temporal aspects of relationships when relevant.
117
- 6. Avoid using the same node as the source and target of a relationship
116
+ 5. The fact should include any specific relevant information, including numeric information
117
+ 6. Consider temporal aspects of relationships when relevant.
118
+ 7. Avoid using the same node as the source and target of a relationship
118
119
 
119
120
  Respond with a JSON object in the following format:
120
121
  {{
@@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
82
82
  Message(
83
83
  role='user',
84
84
  content=f"""
85
- Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
85
+ Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
86
86
 
87
87
  Existing Edges:
88
88
  {context['existing_edges']}
@@ -29,13 +29,10 @@ from graphiti_core.search.search_config import (
29
29
  DEFAULT_SEARCH_LIMIT,
30
30
  CommunityReranker,
31
31
  CommunitySearchConfig,
32
- CommunitySearchMethod,
33
32
  EdgeReranker,
34
33
  EdgeSearchConfig,
35
- EdgeSearchMethod,
36
34
  NodeReranker,
37
35
  NodeSearchConfig,
38
- NodeSearchMethod,
39
36
  SearchConfig,
40
37
  SearchResults,
41
38
  )
@@ -45,6 +42,7 @@ from graphiti_core.search.search_utils import (
45
42
  edge_fulltext_search,
46
43
  edge_similarity_search,
47
44
  episode_mentions_reranker,
45
+ maximal_marginal_relevance,
48
46
  node_distance_reranker,
49
47
  node_fulltext_search,
50
48
  node_similarity_search,
@@ -120,22 +118,18 @@ async def edge_search(
120
118
  if config is None:
121
119
  return []
122
120
 
123
- search_results: list[list[EntityEdge]] = []
121
+ query_vector = await embedder.create(input=[query])
124
122
 
125
- if EdgeSearchMethod.bm25 in config.search_methods:
126
- text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
127
- search_results.append(text_search)
128
-
129
- if EdgeSearchMethod.cosine_similarity in config.search_methods:
130
- search_vector = await embedder.create(input=[query])
131
-
132
- similarity_search = await edge_similarity_search(
133
- driver, search_vector, None, None, group_ids, 2 * limit
123
+ search_results: list[list[EntityEdge]] = list(
124
+ await asyncio.gather(
125
+ *[
126
+ edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
127
+ edge_similarity_search(
128
+ driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
129
+ ),
130
+ ]
134
131
  )
135
- search_results.append(similarity_search)
136
-
137
- if len(search_results) > 1 and config.reranker is None:
138
- raise SearchRerankerError('Multiple edge searches enabled without a reranker')
132
+ )
139
133
 
140
134
  edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
141
135
 
@@ -144,6 +138,15 @@ async def edge_search(
144
138
  search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
145
139
 
146
140
  reranked_uuids = rrf(search_result_uuids)
141
+ elif config.reranker == EdgeReranker.mmr:
142
+ search_result_uuids_and_vectors = [
143
+ (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
144
+ for result in search_results
145
+ for edge in result
146
+ ]
147
+ reranked_uuids = maximal_marginal_relevance(
148
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
149
+ )
147
150
  elif config.reranker == EdgeReranker.node_distance:
148
151
  if center_node_uuid is None:
149
152
  raise SearchRerankerError('No center node provided for Node Distance reranker')
@@ -157,7 +160,7 @@ async def edge_search(
157
160
  for edge in sorted_results:
158
161
  source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
159
162
 
160
- source_uuids = [edge.source_node_uuid for edge in sorted_results]
163
+ source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
161
164
 
162
165
  reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
163
166
 
@@ -184,22 +187,18 @@ async def node_search(
184
187
  if config is None:
185
188
  return []
186
189
 
187
- search_results: list[list[EntityNode]] = []
188
-
189
- if NodeSearchMethod.bm25 in config.search_methods:
190
- text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
191
- search_results.append(text_search)
190
+ query_vector = await embedder.create(input=[query])
192
191
 
193
- if NodeSearchMethod.cosine_similarity in config.search_methods:
194
- search_vector = await embedder.create(input=[query])
195
-
196
- similarity_search = await node_similarity_search(
197
- driver, search_vector, group_ids, 2 * limit
192
+ search_results: list[list[EntityNode]] = list(
193
+ await asyncio.gather(
194
+ *[
195
+ node_fulltext_search(driver, query, group_ids, 2 * limit),
196
+ node_similarity_search(
197
+ driver, query_vector, group_ids, 2 * limit, config.sim_min_score
198
+ ),
199
+ ]
198
200
  )
199
- search_results.append(similarity_search)
200
-
201
- if len(search_results) > 1 and config.reranker is None:
202
- raise SearchRerankerError('Multiple node searches enabled without a reranker')
201
+ )
203
202
 
204
203
  search_result_uuids = [[node.uuid for node in result] for result in search_results]
205
204
  node_uuid_map = {node.uuid: node for result in search_results for node in result}
@@ -207,6 +206,15 @@ async def node_search(
207
206
  reranked_uuids: list[str] = []
208
207
  if config.reranker == NodeReranker.rrf:
209
208
  reranked_uuids = rrf(search_result_uuids)
209
+ elif config.reranker == NodeReranker.mmr:
210
+ search_result_uuids_and_vectors = [
211
+ (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
212
+ for result in search_results
213
+ for node in result
214
+ ]
215
+ reranked_uuids = maximal_marginal_relevance(
216
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
217
+ )
210
218
  elif config.reranker == NodeReranker.episode_mentions:
211
219
  reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
212
220
  elif config.reranker == NodeReranker.node_distance:
@@ -232,22 +240,18 @@ async def community_search(
232
240
  if config is None:
233
241
  return []
234
242
 
235
- search_results: list[list[CommunityNode]] = []
236
-
237
- if CommunitySearchMethod.bm25 in config.search_methods:
238
- text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
239
- search_results.append(text_search)
240
-
241
- if CommunitySearchMethod.cosine_similarity in config.search_methods:
242
- search_vector = await embedder.create(input=[query])
243
+ query_vector = await embedder.create(input=[query])
243
244
 
244
- similarity_search = await community_similarity_search(
245
- driver, search_vector, group_ids, 2 * limit
245
+ search_results: list[list[CommunityNode]] = list(
246
+ await asyncio.gather(
247
+ *[
248
+ community_fulltext_search(driver, query, group_ids, 2 * limit),
249
+ community_similarity_search(
250
+ driver, query_vector, group_ids, 2 * limit, config.sim_min_score
251
+ ),
252
+ ]
246
253
  )
247
- search_results.append(similarity_search)
248
-
249
- if len(search_results) > 1 and config.reranker is None:
250
- raise SearchRerankerError('Multiple node searches enabled without a reranker')
254
+ )
251
255
 
252
256
  search_result_uuids = [[community.uuid for community in result] for result in search_results]
253
257
  community_uuid_map = {
@@ -257,6 +261,18 @@ async def community_search(
257
261
  reranked_uuids: list[str] = []
258
262
  if config.reranker == CommunityReranker.rrf:
259
263
  reranked_uuids = rrf(search_result_uuids)
264
+ elif config.reranker == CommunityReranker.mmr:
265
+ search_result_uuids_and_vectors = [
266
+ (
267
+ community.uuid,
268
+ community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
269
+ )
270
+ for result in search_results
271
+ for community in result
272
+ ]
273
+ reranked_uuids = maximal_marginal_relevance(
274
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
275
+ )
260
276
 
261
277
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
262
278