graphiti-core 0.3.8__tar.gz → 0.3.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (49) hide show
  1. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/PKG-INFO +2 -1
  2. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/edges.py +3 -3
  3. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/embedder/openai.py +1 -1
  4. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/embedder/voyage.py +1 -1
  5. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/graphiti.py +33 -14
  6. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/helpers.py +15 -1
  7. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/nodes.py +4 -2
  8. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/eval.py +28 -2
  9. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/extract_edge_dates.py +8 -9
  10. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/extract_edges.py +3 -2
  11. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/invalidate_edges.py +1 -1
  12. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/search/search.py +61 -45
  13. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/search/search_config.py +13 -3
  14. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/search/search_config_recipes.py +40 -0
  15. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/search/search_utils.py +98 -53
  16. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/__init__.py +0 -2
  17. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/community_operations.py +13 -25
  18. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/edge_operations.py +2 -8
  19. graphiti_core-0.3.9/graphiti_core/utils/maintenance/temporal_operations.py +95 -0
  20. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/pyproject.toml +4 -3
  21. graphiti_core-0.3.8/graphiti_core/utils/maintenance/temporal_operations.py +0 -217
  22. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/LICENSE +0 -0
  23. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/README.md +0 -0
  24. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/__init__.py +0 -0
  25. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/embedder/__init__.py +0 -0
  26. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/embedder/client.py +0 -0
  27. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/errors.py +0 -0
  28. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/llm_client/__init__.py +0 -0
  29. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/llm_client/anthropic_client.py +0 -0
  30. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/llm_client/client.py +0 -0
  31. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/llm_client/config.py +0 -0
  32. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/llm_client/errors.py +0 -0
  33. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/llm_client/groq_client.py +0 -0
  34. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/llm_client/openai_client.py +0 -0
  35. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/llm_client/utils.py +0 -0
  36. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/__init__.py +0 -0
  37. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/dedupe_edges.py +0 -0
  38. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  39. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/extract_nodes.py +0 -0
  40. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/lib.py +0 -0
  41. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/models.py +0 -0
  42. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/prompts/summarize_nodes.py +0 -0
  43. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/py.typed +0 -0
  44. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/search/__init__.py +0 -0
  45. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/utils/__init__.py +0 -0
  46. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/utils/bulk_utils.py +0 -0
  47. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  48. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/node_operations.py +0 -0
  49. {graphiti_core-0.3.8 → graphiti_core-0.3.9}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.8
3
+ Version: 0.3.9
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -17,6 +17,7 @@ Requires-Dist: numpy (>=1.0.0)
17
17
  Requires-Dist: openai (>=1.50.2,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
19
  Requires-Dist: tenacity (<9.0.0)
20
+ Requires-Dist: voyageai (>=0.2.3,<0.3.0)
20
21
  Description-Content-Type: text/markdown
21
22
 
22
23
  <div align="center">
@@ -188,9 +188,9 @@ class EntityEdge(Edge):
188
188
  MATCH (source:Entity {uuid: $source_uuid})
189
189
  MATCH (target:Entity {uuid: $target_uuid})
190
190
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
191
- SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
192
- episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
193
- valid_at: $valid_at, invalid_at: $invalid_at}
191
+ SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
192
+ created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
193
+ WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
194
194
  RETURN r.uuid AS uuid""",
195
195
  source_uuid=self.source_node_uuid,
196
196
  target_uuid=self.target_node_uuid,
@@ -42,7 +42,7 @@ class OpenAIEmbedder(EmbedderClient):
42
42
  self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
43
43
 
44
44
  async def create(
45
- self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
+ self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
46
46
  ) -> list[float]:
47
47
  result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
48
48
  return result.data[0].embedding[: self.config.embedding_dim]
@@ -41,7 +41,7 @@ class VoyageAIEmbedder(EmbedderClient):
41
41
  self.client = voyageai.AsyncClient(api_key=config.api_key)
42
42
 
43
43
  async def create(
44
- self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
44
+ self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
45
  ) -> list[float]:
46
46
  result = await self.client.embed(input, model=self.config.embedding_model)
47
47
  return result.embeddings[0][: self.config.embedding_dim]
@@ -21,11 +21,12 @@ from time import time
21
21
 
22
22
  from dotenv import load_dotenv
23
23
  from neo4j import AsyncGraphDatabase
24
+ from pydantic import BaseModel
24
25
 
25
26
  from graphiti_core.edges import EntityEdge, EpisodicEdge
26
27
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
27
28
  from graphiti_core.llm_client import LLMClient, OpenAIClient
28
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
29
+ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
29
30
  from graphiti_core.search.search import SearchConfig, search
30
31
  from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
31
32
  from graphiti_core.search.search_config_recipes import (
@@ -77,6 +78,12 @@ logger = logging.getLogger(__name__)
77
78
  load_dotenv()
78
79
 
79
80
 
81
+ class AddEpisodeResults(BaseModel):
82
+ episode: EpisodicNode
83
+ nodes: list[EntityNode]
84
+ edges: list[EntityEdge]
85
+
86
+
80
87
  class Graphiti:
81
88
  def __init__(
82
89
  self,
@@ -245,7 +252,7 @@ class Graphiti:
245
252
  group_id: str = '',
246
253
  uuid: str | None = None,
247
254
  update_communities: bool = False,
248
- ):
255
+ ) -> AddEpisodeResults:
249
256
  """
250
257
  Process an episode and update the graph.
251
258
 
@@ -451,6 +458,8 @@ class Graphiti:
451
458
  end = time()
452
459
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
453
460
 
461
+ return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
462
+
454
463
  except Exception as e:
455
464
  raise e
456
465
 
@@ -567,11 +576,20 @@ class Graphiti:
567
576
  except Exception as e:
568
577
  raise e
569
578
 
570
- async def build_communities(self):
579
+ async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
580
+ """
581
+ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
582
+ the content of these communities.
583
+ ----------
584
+ query : list[str] | None
585
+ Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
586
+ """
571
587
  # Clear existing communities
572
588
  await remove_communities(self.driver)
573
589
 
574
- community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
590
+ community_nodes, community_edges = await build_communities(
591
+ self.driver, self.llm_client, group_ids
592
+ )
575
593
 
576
594
  await asyncio.gather(
577
595
  *[node.generate_name_embedding(self.embedder) for node in community_nodes]
@@ -580,6 +598,8 @@ class Graphiti:
580
598
  await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
581
599
  await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
582
600
 
601
+ return community_nodes
602
+
583
603
  async def search(
584
604
  self,
585
605
  query: str,
@@ -700,18 +720,17 @@ class Graphiti:
700
720
  ).nodes
701
721
  return nodes
702
722
 
723
+ async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
724
+ episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
703
725
 
704
- async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
705
- episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
706
-
707
- edges_list = await asyncio.gather(
708
- *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
709
- )
726
+ edges_list = await asyncio.gather(
727
+ *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
728
+ )
710
729
 
711
- edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
730
+ edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
712
731
 
713
- nodes = await get_mentioned_nodes(self.driver, episodes)
732
+ nodes = await get_mentioned_nodes(self.driver, episodes)
714
733
 
715
- communities = await get_communities_by_nodes(self.driver, nodes)
734
+ communities = await get_communities_by_nodes(self.driver, nodes)
716
735
 
717
- return SearchResults(edges=edges, nodes=nodes, communities=communities)
736
+ return SearchResults(edges=edges, nodes=nodes, communities=communities)
@@ -16,6 +16,7 @@ limitations under the License.
16
16
 
17
17
  from datetime import datetime
18
18
 
19
+ import numpy as np
19
20
  from neo4j import time as neo4j_time
20
21
 
21
22
 
@@ -25,7 +26,7 @@ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
25
26
 
26
27
  def lucene_sanitize(query: str) -> str:
27
28
  # Escape special characters from a query before passing into Lucene
28
- # + - && || ! ( ) { } [ ] ^ " ~ * ? : \
29
+ # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
29
30
  escape_map = str.maketrans(
30
31
  {
31
32
  '+': r'\+',
@@ -46,8 +47,21 @@ def lucene_sanitize(query: str) -> str:
46
47
  '?': r'\?',
47
48
  ':': r'\:',
48
49
  '\\': r'\\',
50
+ '/': r'\/',
49
51
  }
50
52
  )
51
53
 
52
54
  sanitized = query.translate(escape_map)
53
55
  return sanitized
56
+
57
+
58
+ def normalize_l2(embedding: list[float]) -> list[float]:
59
+ embedding_array = np.array(embedding)
60
+ if embedding_array.ndim == 1:
61
+ norm = np.linalg.norm(embedding_array)
62
+ if norm == 0:
63
+ return embedding_array.tolist()
64
+ return (embedding_array / norm).tolist()
65
+ else:
66
+ norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
67
+ return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
@@ -225,7 +225,8 @@ class EntityNode(Node):
225
225
  result = await driver.execute_query(
226
226
  """
227
227
  MERGE (n:Entity {uuid: $uuid})
228
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
228
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
229
+ WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
229
230
  RETURN n.uuid AS uuid""",
230
231
  uuid=self.uuid,
231
232
  name=self.name,
@@ -308,7 +309,8 @@ class CommunityNode(Node):
308
309
  result = await driver.execute_query(
309
310
  """
310
311
  MERGE (n:Community {uuid: $uuid})
311
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
312
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
313
+ WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
312
314
  RETURN n.uuid AS uuid""",
313
315
  uuid=self.uuid,
314
316
  name=self.name,
@@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
23
23
  class Prompt(Protocol):
24
24
  qa_prompt: PromptVersion
25
25
  eval_prompt: PromptVersion
26
+ query_expansion: PromptVersion
26
27
 
27
28
 
28
29
  class Versions(TypedDict):
29
30
  qa_prompt: PromptFunction
30
31
  eval_prompt: PromptFunction
32
+ query_expansion: PromptFunction
33
+
34
+
35
+ def query_expansion(context: dict[str, Any]) -> list[Message]:
36
+ sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
37
+
38
+ user_prompt = f"""
39
+ Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
40
+ that maintains the relevant context?
41
+ <QUESTION>
42
+ {json.dumps(context['query'])}
43
+ </QUESTION>
44
+ respond with a JSON object in the following format:
45
+ {{
46
+ "query": "query optimized for database search"
47
+ }}
48
+ """
49
+ return [
50
+ Message(role='system', content=sys_prompt),
51
+ Message(role='user', content=user_prompt),
52
+ ]
31
53
 
32
54
 
33
55
  def qa_prompt(context: dict[str, Any]) -> list[Message]:
@@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
38
60
  You are given the following entity summaries and facts to help you determine the answer to your question.
39
61
  <ENTITY_SUMMARIES>
40
62
  {json.dumps(context['entity_summaries'])}
41
- </ENTITY_SUMMARIES
63
+ </ENTITY_SUMMARIES>
42
64
  <FACTS>
43
65
  {json.dumps(context['facts'])}
44
66
  </FACTS>
@@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
87
109
  ]
88
110
 
89
111
 
90
- versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
112
+ versions: Versions = {
113
+ 'qa_prompt': qa_prompt,
114
+ 'eval_prompt': eval_prompt,
115
+ 'query_expansion': query_expansion,
116
+ }
@@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
37
37
  role='user',
38
38
  content=f"""
39
39
  Edge:
40
- Edge Name: {context['edge_name']}
41
40
  Fact: {context['edge_fact']}
42
41
 
43
42
  Current Episode: {context['current_episode']}
@@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
56
55
  Guidelines:
57
56
  1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
58
57
  2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
59
- 3. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
60
- 4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
61
- 5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
62
- 6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
63
- 7. If only a year is mentioned, use January 1st of that year at 00:00:00.
58
+ 3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
59
+ 4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
60
+ 5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
61
+ 6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
62
+ 7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
63
+ 8. If only a year is mentioned, use January 1st of that year at 00:00:00.
64
64
  9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
65
65
  Respond with a JSON object:
66
66
  {{
67
- "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
68
- "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
69
- "explanation": "Brief explanation of why these dates were chosen or why they were set to null"
67
+ "valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
68
+ "invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
70
69
  }}
71
70
  """,
72
71
  ),
@@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
113
113
  2. Each edge should represent a clear relationship between two DISTINCT nodes.
114
114
  3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
115
115
  4. Provide a more detailed fact describing the relationship.
116
- 5. Consider temporal aspects of relationships when relevant.
117
- 6. Avoid using the same node as the source and target of a relationship
116
+ 5. The fact should include any specific relevant information, including numeric information
117
+ 6. Consider temporal aspects of relationships when relevant.
118
+ 7. Avoid using the same node as the source and target of a relationship
118
119
 
119
120
  Respond with a JSON object in the following format:
120
121
  {{
@@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
82
82
  Message(
83
83
  role='user',
84
84
  content=f"""
85
- Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
85
+ Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
86
86
 
87
87
  Existing Edges:
88
88
  {context['existing_edges']}
@@ -29,13 +29,10 @@ from graphiti_core.search.search_config import (
29
29
  DEFAULT_SEARCH_LIMIT,
30
30
  CommunityReranker,
31
31
  CommunitySearchConfig,
32
- CommunitySearchMethod,
33
32
  EdgeReranker,
34
33
  EdgeSearchConfig,
35
- EdgeSearchMethod,
36
34
  NodeReranker,
37
35
  NodeSearchConfig,
38
- NodeSearchMethod,
39
36
  SearchConfig,
40
37
  SearchResults,
41
38
  )
@@ -45,6 +42,7 @@ from graphiti_core.search.search_utils import (
45
42
  edge_fulltext_search,
46
43
  edge_similarity_search,
47
44
  episode_mentions_reranker,
45
+ maximal_marginal_relevance,
48
46
  node_distance_reranker,
49
47
  node_fulltext_search,
50
48
  node_similarity_search,
@@ -120,22 +118,18 @@ async def edge_search(
120
118
  if config is None:
121
119
  return []
122
120
 
123
- search_results: list[list[EntityEdge]] = []
121
+ query_vector = await embedder.create(input=[query])
124
122
 
125
- if EdgeSearchMethod.bm25 in config.search_methods:
126
- text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
127
- search_results.append(text_search)
128
-
129
- if EdgeSearchMethod.cosine_similarity in config.search_methods:
130
- search_vector = await embedder.create(input=[query])
131
-
132
- similarity_search = await edge_similarity_search(
133
- driver, search_vector, None, None, group_ids, 2 * limit
123
+ search_results: list[list[EntityEdge]] = list(
124
+ await asyncio.gather(
125
+ *[
126
+ edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
127
+ edge_similarity_search(
128
+ driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
129
+ ),
130
+ ]
134
131
  )
135
- search_results.append(similarity_search)
136
-
137
- if len(search_results) > 1 and config.reranker is None:
138
- raise SearchRerankerError('Multiple edge searches enabled without a reranker')
132
+ )
139
133
 
140
134
  edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
141
135
 
@@ -144,6 +138,15 @@ async def edge_search(
144
138
  search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
145
139
 
146
140
  reranked_uuids = rrf(search_result_uuids)
141
+ elif config.reranker == EdgeReranker.mmr:
142
+ search_result_uuids_and_vectors = [
143
+ (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
144
+ for result in search_results
145
+ for edge in result
146
+ ]
147
+ reranked_uuids = maximal_marginal_relevance(
148
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
149
+ )
147
150
  elif config.reranker == EdgeReranker.node_distance:
148
151
  if center_node_uuid is None:
149
152
  raise SearchRerankerError('No center node provided for Node Distance reranker')
@@ -184,22 +187,18 @@ async def node_search(
184
187
  if config is None:
185
188
  return []
186
189
 
187
- search_results: list[list[EntityNode]] = []
188
-
189
- if NodeSearchMethod.bm25 in config.search_methods:
190
- text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
191
- search_results.append(text_search)
190
+ query_vector = await embedder.create(input=[query])
192
191
 
193
- if NodeSearchMethod.cosine_similarity in config.search_methods:
194
- search_vector = await embedder.create(input=[query])
195
-
196
- similarity_search = await node_similarity_search(
197
- driver, search_vector, group_ids, 2 * limit
192
+ search_results: list[list[EntityNode]] = list(
193
+ await asyncio.gather(
194
+ *[
195
+ node_fulltext_search(driver, query, group_ids, 2 * limit),
196
+ node_similarity_search(
197
+ driver, query_vector, group_ids, 2 * limit, config.sim_min_score
198
+ ),
199
+ ]
198
200
  )
199
- search_results.append(similarity_search)
200
-
201
- if len(search_results) > 1 and config.reranker is None:
202
- raise SearchRerankerError('Multiple node searches enabled without a reranker')
201
+ )
203
202
 
204
203
  search_result_uuids = [[node.uuid for node in result] for result in search_results]
205
204
  node_uuid_map = {node.uuid: node for result in search_results for node in result}
@@ -207,6 +206,15 @@ async def node_search(
207
206
  reranked_uuids: list[str] = []
208
207
  if config.reranker == NodeReranker.rrf:
209
208
  reranked_uuids = rrf(search_result_uuids)
209
+ elif config.reranker == NodeReranker.mmr:
210
+ search_result_uuids_and_vectors = [
211
+ (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
212
+ for result in search_results
213
+ for node in result
214
+ ]
215
+ reranked_uuids = maximal_marginal_relevance(
216
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
217
+ )
210
218
  elif config.reranker == NodeReranker.episode_mentions:
211
219
  reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
212
220
  elif config.reranker == NodeReranker.node_distance:
@@ -232,22 +240,18 @@ async def community_search(
232
240
  if config is None:
233
241
  return []
234
242
 
235
- search_results: list[list[CommunityNode]] = []
236
-
237
- if CommunitySearchMethod.bm25 in config.search_methods:
238
- text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
239
- search_results.append(text_search)
240
-
241
- if CommunitySearchMethod.cosine_similarity in config.search_methods:
242
- search_vector = await embedder.create(input=[query])
243
+ query_vector = await embedder.create(input=[query])
243
244
 
244
- similarity_search = await community_similarity_search(
245
- driver, search_vector, group_ids, 2 * limit
245
+ search_results: list[list[CommunityNode]] = list(
246
+ await asyncio.gather(
247
+ *[
248
+ community_fulltext_search(driver, query, group_ids, 2 * limit),
249
+ community_similarity_search(
250
+ driver, query_vector, group_ids, 2 * limit, config.sim_min_score
251
+ ),
252
+ ]
246
253
  )
247
- search_results.append(similarity_search)
248
-
249
- if len(search_results) > 1 and config.reranker is None:
250
- raise SearchRerankerError('Multiple node searches enabled without a reranker')
254
+ )
251
255
 
252
256
  search_result_uuids = [[community.uuid for community in result] for result in search_results]
253
257
  community_uuid_map = {
@@ -257,6 +261,18 @@ async def community_search(
257
261
  reranked_uuids: list[str] = []
258
262
  if config.reranker == CommunityReranker.rrf:
259
263
  reranked_uuids = rrf(search_result_uuids)
264
+ elif config.reranker == CommunityReranker.mmr:
265
+ search_result_uuids_and_vectors = [
266
+ (
267
+ community.uuid,
268
+ community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
269
+ )
270
+ for result in search_results
271
+ for community in result
272
+ ]
273
+ reranked_uuids = maximal_marginal_relevance(
274
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda
275
+ )
260
276
 
261
277
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
262
278
 
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
20
20
 
21
21
  from graphiti_core.edges import EntityEdge
22
22
  from graphiti_core.nodes import CommunityNode, EntityNode
23
+ from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
23
24
 
24
25
  DEFAULT_SEARCH_LIMIT = 10
25
26
 
@@ -43,31 +44,40 @@ class EdgeReranker(Enum):
43
44
  rrf = 'reciprocal_rank_fusion'
44
45
  node_distance = 'node_distance'
45
46
  episode_mentions = 'episode_mentions'
47
+ mmr = 'mmr'
46
48
 
47
49
 
48
50
  class NodeReranker(Enum):
49
51
  rrf = 'reciprocal_rank_fusion'
50
52
  node_distance = 'node_distance'
51
53
  episode_mentions = 'episode_mentions'
54
+ mmr = 'mmr'
52
55
 
53
56
 
54
57
  class CommunityReranker(Enum):
55
58
  rrf = 'reciprocal_rank_fusion'
59
+ mmr = 'mmr'
56
60
 
57
61
 
58
62
  class EdgeSearchConfig(BaseModel):
59
63
  search_methods: list[EdgeSearchMethod]
60
- reranker: EdgeReranker | None
64
+ reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
65
+ sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
66
+ mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
61
67
 
62
68
 
63
69
  class NodeSearchConfig(BaseModel):
64
70
  search_methods: list[NodeSearchMethod]
65
- reranker: NodeReranker | None
71
+ reranker: NodeReranker = Field(default=NodeReranker.rrf)
72
+ sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
73
+ mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
66
74
 
67
75
 
68
76
  class CommunitySearchConfig(BaseModel):
69
77
  search_methods: list[CommunitySearchMethod]
70
- reranker: CommunityReranker | None
78
+ reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
79
+ sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
80
+ mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
71
81
 
72
82
 
73
83
  class SearchConfig(BaseModel):
@@ -43,6 +43,22 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
43
43
  ),
44
44
  )
45
45
 
46
+ # Performs a hybrid search with mmr reranking over edges, nodes, and communities
47
+ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
48
+ edge_config=EdgeSearchConfig(
49
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
50
+ reranker=EdgeReranker.mmr,
51
+ ),
52
+ node_config=NodeSearchConfig(
53
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
54
+ reranker=NodeReranker.mmr,
55
+ ),
56
+ community_config=CommunitySearchConfig(
57
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
58
+ reranker=CommunityReranker.mmr,
59
+ ),
60
+ )
61
+
46
62
  # performs a hybrid search over edges with rrf reranking
47
63
  EDGE_HYBRID_SEARCH_RRF = SearchConfig(
48
64
  edge_config=EdgeSearchConfig(
@@ -51,6 +67,14 @@ EDGE_HYBRID_SEARCH_RRF = SearchConfig(
51
67
  )
52
68
  )
53
69
 
70
+ # performs a hybrid search over edges with mmr reranking
71
+ EDGE_HYBRID_SEARCH_mmr = SearchConfig(
72
+ edge_config=EdgeSearchConfig(
73
+ search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
74
+ reranker=EdgeReranker.mmr,
75
+ )
76
+ )
77
+
54
78
  # performs a hybrid search over edges with node distance reranking
55
79
  EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
56
80
  edge_config=EdgeSearchConfig(
@@ -75,6 +99,14 @@ NODE_HYBRID_SEARCH_RRF = SearchConfig(
75
99
  )
76
100
  )
77
101
 
102
+ # performs a hybrid search over nodes with mmr reranking
103
+ NODE_HYBRID_SEARCH_MMR = SearchConfig(
104
+ node_config=NodeSearchConfig(
105
+ search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
106
+ reranker=NodeReranker.mmr,
107
+ )
108
+ )
109
+
78
110
  # performs a hybrid search over nodes with node distance reranking
79
111
  NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
80
112
  node_config=NodeSearchConfig(
@@ -98,3 +130,11 @@ COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
98
130
  reranker=CommunityReranker.rrf,
99
131
  )
100
132
  )
133
+
134
+ # performs a hybrid search over communities with mmr reranking
135
+ COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
136
+ community_config=CommunitySearchConfig(
137
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
138
+ reranker=CommunityReranker.mmr,
139
+ )
140
+ )