graphiti-core 0.11.4__py3-none-any.whl → 0.11.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/edges.py CHANGED
@@ -46,7 +46,6 @@ ENTITY_EDGE_RETURN: LiteralString = """
46
46
  e.name AS name,
47
47
  e.group_id AS group_id,
48
48
  e.fact AS fact,
49
- e.fact_embedding AS fact_embedding,
50
49
  e.episodes AS episodes,
51
50
  e.expired_at AS expired_at,
52
51
  e.valid_at AS valid_at,
@@ -222,6 +221,20 @@ class EntityEdge(Edge):
222
221
 
223
222
  return self.fact_embedding
224
223
 
224
+ async def load_fact_embedding(self, driver: AsyncDriver):
225
+ query: LiteralString = """
226
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
227
+ RETURN e.fact_embedding AS fact_embedding
228
+ """
229
+ records, _, _ = await driver.execute_query(
230
+ query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
231
+ )
232
+
233
+ if len(records) == 0:
234
+ raise EdgeNotFoundError(self.uuid)
235
+
236
+ self.fact_embedding = records[0]['fact_embedding']
237
+
225
238
  async def save(self, driver: AsyncDriver):
226
239
  result = await driver.execute_query(
227
240
  ENTITY_EDGE_SAVE,
@@ -321,8 +334,8 @@ class EntityEdge(Edge):
321
334
  async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
322
335
  query: LiteralString = (
323
336
  """
324
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
325
- """
337
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
338
+ """
326
339
  + ENTITY_EDGE_RETURN
327
340
  )
328
341
  records, _, _ = await driver.execute_query(
@@ -452,7 +465,6 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
452
465
  name=record['name'],
453
466
  group_id=record['group_id'],
454
467
  episodes=record['episodes'],
455
- fact_embedding=record['fact_embedding'],
456
468
  created_at=record['created_at'].to_native(),
457
469
  expired_at=parse_db_date(record['expired_at']),
458
470
  valid_at=parse_db_date(record['valid_at']),
@@ -471,6 +483,8 @@ def get_community_edge_from_record(record: Any):
471
483
 
472
484
 
473
485
  async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
486
+ if len(edges) == 0:
487
+ return
474
488
  fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
475
489
  for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
476
490
  edge.fact_embedding = fact_embedding
graphiti_core/graphiti.py CHANGED
@@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
41
41
  from graphiti_core.search.search_filters import SearchFilters
42
42
  from graphiti_core.search.search_utils import (
43
43
  RELEVANT_SCHEMA_LIMIT,
44
+ get_edge_invalidation_candidates,
44
45
  get_mentioned_nodes,
45
46
  get_relevant_edges,
46
47
  )
@@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
62
63
  )
63
64
  from graphiti_core.utils.maintenance.edge_operations import (
64
65
  build_episodic_edges,
65
- dedupe_extracted_edge,
66
66
  extract_edges,
67
- resolve_edge_contradictions,
67
+ resolve_extracted_edge,
68
68
  resolve_extracted_edges,
69
69
  )
70
70
  from graphiti_core.utils.maintenance.graph_data_operations import (
@@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
77
77
  extract_nodes,
78
78
  resolve_extracted_nodes,
79
79
  )
80
- from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
81
80
  from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
82
81
 
83
82
  logger = logging.getLogger(__name__)
@@ -380,6 +379,7 @@ class Graphiti:
380
379
  resolve_extracted_edges(
381
380
  self.clients,
382
381
  edges,
382
+ episode,
383
383
  ),
384
384
  extract_attributes_from_nodes(
385
385
  self.clients, nodes, episode, previous_episodes, entity_types
@@ -396,7 +396,7 @@ class Graphiti:
396
396
  episode.content = ''
397
397
 
398
398
  await add_nodes_and_edges_bulk(
399
- self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges
399
+ self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
400
400
  )
401
401
 
402
402
  # Update any communities
@@ -680,15 +680,17 @@ class Graphiti:
680
680
 
681
681
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
682
682
 
683
- related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
684
-
685
- resolved_edge = await dedupe_extracted_edge(self.llm_client, updated_edge, related_edges[0])
683
+ related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
684
+ existing_edges = (
685
+ await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
686
+ )[0]
686
687
 
687
- contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
688
- invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
688
+ resolved_edge, invalidated_edges = await resolve_extracted_edge(
689
+ self.llm_client, updated_edge, related_edges, existing_edges
690
+ )
689
691
 
690
692
  await add_nodes_and_edges_bulk(
691
- self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
693
+ self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
692
694
  )
693
695
 
694
696
  async def remove_episode(self, episode_uuid: str):
graphiti_core/helpers.py CHANGED
@@ -22,6 +22,7 @@ from datetime import datetime
22
22
  import numpy as np
23
23
  from dotenv import load_dotenv
24
24
  from neo4j import time as neo4j_time
25
+ from numpy._typing import NDArray
25
26
  from typing_extensions import LiteralString
26
27
 
27
28
  load_dotenv()
@@ -78,20 +79,17 @@ def lucene_sanitize(query: str) -> str:
78
79
  return sanitized
79
80
 
80
81
 
81
- def normalize_l2(embedding: list[float]):
82
+ def normalize_l2(embedding: list[float]) -> NDArray:
82
83
  embedding_array = np.array(embedding)
83
- if embedding_array.ndim == 1:
84
- norm = np.linalg.norm(embedding_array)
85
- if norm == 0:
86
- return [0.0] * len(embedding)
87
- return (embedding_array / norm).tolist()
88
- else:
89
- norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
90
- return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
84
+ norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
85
+ return np.where(norm == 0, embedding_array, embedding_array / norm)
91
86
 
92
87
 
93
88
  # Use this instead of asyncio.gather() to bound coroutines
94
- async def semaphore_gather(*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT):
89
+ async def semaphore_gather(
90
+ *coroutines: Coroutine,
91
+ max_coroutines: int = SEMAPHORE_LIMIT,
92
+ ):
95
93
  semaphore = asyncio.Semaphore(max_coroutines)
96
94
 
97
95
  async def _wrap_coroutine(coroutine):
@@ -139,15 +139,11 @@ class AnthropicClient(LLMClient):
139
139
  A list containing a single tool definition for use with the Anthropic API.
140
140
  """
141
141
  if response_model is not None:
142
- # temporary debug log
143
- logger.info(f'Creating tool for response_model: {response_model}')
144
142
  # Use the response_model to define the tool
145
143
  model_schema = response_model.model_json_schema()
146
144
  tool_name = response_model.__name__
147
145
  description = model_schema.get('description', f'Extract {tool_name} information')
148
146
  else:
149
- # temporary debug log
150
- logger.info('Creating generic JSON output tool')
151
147
  # Create a generic JSON output tool
152
148
  tool_name = 'generic_json_output'
153
149
  description = 'Output data in JSON format'
@@ -205,8 +201,6 @@ class AnthropicClient(LLMClient):
205
201
  try:
206
202
  # Create the appropriate tool based on whether response_model is provided
207
203
  tools, tool_choice = self._create_tool(response_model)
208
- # temporary debug log
209
- logger.info(f'using model: {self.model} with max_tokens: {self.max_tokens}')
210
204
  result = await self.client.messages.create(
211
205
  system=system_message.content,
212
206
  max_tokens=max_creation_tokens,
@@ -227,13 +221,6 @@ class AnthropicClient(LLMClient):
227
221
  return tool_args
228
222
 
229
223
  # If we didn't get a proper tool_use response, try to extract from text
230
- # logger.debug(
231
- # f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}'
232
- # )
233
- # temporary debug log
234
- logger.info(
235
- f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}'
236
- )
237
224
  for content_item in result.content:
238
225
  if content_item.type == 'text':
239
226
  return self._extract_json_from_text(content_item.text)
graphiti_core/nodes.py CHANGED
@@ -42,7 +42,6 @@ ENTITY_NODE_RETURN: LiteralString = """
42
42
  RETURN
43
43
  n.uuid As uuid,
44
44
  n.name AS name,
45
- n.name_embedding AS name_embedding,
46
45
  n.group_id AS group_id,
47
46
  n.created_at AS created_at,
48
47
  n.summary AS summary,
@@ -305,6 +304,20 @@ class EntityNode(Node):
305
304
 
306
305
  return self.name_embedding
307
306
 
307
+ async def load_name_embedding(self, driver: AsyncDriver):
308
+ query: LiteralString = """
309
+ MATCH (n:Entity {uuid: $uuid})
310
+ RETURN n.name_embedding AS name_embedding
311
+ """
312
+ records, _, _ = await driver.execute_query(
313
+ query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
314
+ )
315
+
316
+ if len(records) == 0:
317
+ raise NodeNotFoundError(self.uuid)
318
+
319
+ self.name_embedding = records[0]['name_embedding']
320
+
308
321
  async def save(self, driver: AsyncDriver):
309
322
  entity_data: dict[str, Any] = {
310
323
  'uuid': self.uuid,
@@ -332,8 +345,8 @@ class EntityNode(Node):
332
345
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
333
346
  query = (
334
347
  """
335
- MATCH (n:Entity {uuid: $uuid})
336
- """
348
+ MATCH (n:Entity {uuid: $uuid})
349
+ """
337
350
  + ENTITY_NODE_RETURN
338
351
  )
339
352
  records, _, _ = await driver.execute_query(
@@ -428,6 +441,20 @@ class CommunityNode(Node):
428
441
 
429
442
  return self.name_embedding
430
443
 
444
+ async def load_name_embedding(self, driver: AsyncDriver):
445
+ query: LiteralString = """
446
+ MATCH (c:Community {uuid: $uuid})
447
+ RETURN c.name_embedding AS name_embedding
448
+ """
449
+ records, _, _ = await driver.execute_query(
450
+ query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
451
+ )
452
+
453
+ if len(records) == 0:
454
+ raise NodeNotFoundError(self.uuid)
455
+
456
+ self.name_embedding = records[0]['name_embedding']
457
+
431
458
  @classmethod
432
459
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
433
460
  records, _, _ = await driver.execute_query(
@@ -436,7 +463,6 @@ class CommunityNode(Node):
436
463
  RETURN
437
464
  n.uuid As uuid,
438
465
  n.name AS name,
439
- n.name_embedding AS name_embedding,
440
466
  n.group_id AS group_id,
441
467
  n.created_at AS created_at,
442
468
  n.summary AS summary
@@ -461,7 +487,6 @@ class CommunityNode(Node):
461
487
  RETURN
462
488
  n.uuid As uuid,
463
489
  n.name AS name,
464
- n.name_embedding AS name_embedding,
465
490
  n.group_id AS group_id,
466
491
  n.created_at AS created_at,
467
492
  n.summary AS summary
@@ -495,7 +520,6 @@ class CommunityNode(Node):
495
520
  RETURN
496
521
  n.uuid As uuid,
497
522
  n.name AS name,
498
- n.name_embedding AS name_embedding,
499
523
  n.group_id AS group_id,
500
524
  n.created_at AS created_at,
501
525
  n.summary AS summary
@@ -534,7 +558,6 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
534
558
  uuid=record['uuid'],
535
559
  name=record['name'],
536
560
  group_id=record['group_id'],
537
- name_embedding=record['name_embedding'],
538
561
  labels=record['labels'],
539
562
  created_at=record['created_at'].to_native(),
540
563
  summary=record['summary'],
@@ -27,6 +27,10 @@ class EdgeDuplicate(BaseModel):
27
27
  ...,
28
28
  description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
29
29
  )
30
+ contradicted_facts: list[int] = Field(
31
+ ...,
32
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
33
+ )
30
34
 
31
35
 
32
36
  class UniqueFact(BaseModel):
@@ -41,11 +45,13 @@ class UniqueFacts(BaseModel):
41
45
  class Prompt(Protocol):
42
46
  edge: PromptVersion
43
47
  edge_list: PromptVersion
48
+ resolve_edge: PromptVersion
44
49
 
45
50
 
46
51
  class Versions(TypedDict):
47
52
  edge: PromptFunction
48
53
  edge_list: PromptFunction
54
+ resolve_edge: PromptFunction
49
55
 
50
56
 
51
57
  def edge(context: dict[str, Any]) -> list[Message]:
@@ -106,4 +112,41 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
106
112
  ]
107
113
 
108
114
 
109
- versions: Versions = {'edge': edge, 'edge_list': edge_list}
115
+ def resolve_edge(context: dict[str, Any]) -> list[Message]:
116
+ return [
117
+ Message(
118
+ role='system',
119
+ content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
120
+ 'facts are contradicted by the new fact.',
121
+ ),
122
+ Message(
123
+ role='user',
124
+ content=f"""
125
+ <NEW FACT>
126
+ {context['new_edge']}
127
+ </NEW FACT>
128
+
129
+ <EXISTING FACTS>
130
+ {context['existing_edges']}
131
+ </EXISTING FACTS>
132
+ <FACT INVALIDATION CANDIDATES>
133
+ {context['edge_invalidation_candidates']}
134
+ </FACT INVALIDATION CANDIDATES>
135
+
136
+
137
+ Task:
138
+ If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
139
+ If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
140
+
141
+ Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
142
+ Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
143
+ If there are no contradicted facts, return an empty list.
144
+
145
+ Guidelines:
146
+ 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
147
+ """,
148
+ ),
149
+ ]
150
+
151
+
152
+ versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
@@ -23,28 +23,38 @@ from .models import Message, PromptFunction, PromptVersion
23
23
 
24
24
 
25
25
  class NodeDuplicate(BaseModel):
26
- duplicate_node_id: int = Field(
26
+ id: int = Field(..., description='integer id of the entity')
27
+ duplicate_idx: int = Field(
27
28
  ...,
28
- description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
29
+ description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
29
30
  )
30
- name: str = Field(..., description='Name of the entity.')
31
+ name: str = Field(
32
+ ...,
33
+ description='Name of the entity. Should be the most complete and descriptive name possible.',
34
+ )
35
+
36
+
37
+ class NodeResolutions(BaseModel):
38
+ entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
31
39
 
32
40
 
33
41
  class Prompt(Protocol):
34
42
  node: PromptVersion
35
43
  node_list: PromptVersion
44
+ nodes: PromptVersion
36
45
 
37
46
 
38
47
  class Versions(TypedDict):
39
48
  node: PromptFunction
40
49
  node_list: PromptFunction
50
+ nodes: PromptFunction
41
51
 
42
52
 
43
53
  def node(context: dict[str, Any]) -> list[Message]:
44
54
  return [
45
55
  Message(
46
56
  role='system',
47
- content='You are a helpful assistant that de-duplicates entities from entity lists.',
57
+ content='You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.',
48
58
  ),
49
59
  Message(
50
60
  role='user',
@@ -69,19 +79,82 @@ def node(context: dict[str, Any]) -> list[Message]:
69
79
  Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
70
80
  is a duplicate entity of one of the EXISTING ENTITIES.
71
81
 
72
- The ENTITY TYPE DESCRIPTION gives more insight into what the entity type means for the NEW ENTITY.
82
+ Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
83
+
84
+ Do NOT mark entities as duplicates if:
85
+ - They are related but distinct.
86
+ - They have similar names or purposes but refer to separate instances or concepts.
73
87
 
74
88
  Task:
75
89
  If the NEW ENTITY represents a duplicate entity of any entity in EXISTING ENTITIES, set duplicate_entity_id to the
76
- id of the EXISTING ENTITY that is the duplicate. If the NEW ENTITY is not a duplicate of any of the EXISTING ENTITIES,
90
+ id of the EXISTING ENTITY that is the duplicate.
91
+
92
+ If the NEW ENTITY is not a duplicate of any of the EXISTING ENTITIES,
77
93
  duplicate_entity_id should be set to -1.
78
94
 
79
- Also return the most complete name for the entity.
95
+ Also return the name that best describes the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
96
+ is a duplicate of, or a combination of the two).
97
+ """,
98
+ ),
99
+ ]
80
100
 
81
- Guidelines:
82
- 1. Entities with the same name should be considered duplicates
83
- 2. Duplicate entities may refer to the same real-world entity even if names differ. Use context clues from the MESSAGES
84
- to determine if the NEW ENTITY represents a duplicate entity of one of the EXISTING ENTITIES.
101
+
102
+ def nodes(context: dict[str, Any]) -> list[Message]:
103
+ return [
104
+ Message(
105
+ role='system',
106
+ content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
107
+ 'of existing entities.',
108
+ ),
109
+ Message(
110
+ role='user',
111
+ content=f"""
112
+ <PREVIOUS MESSAGES>
113
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
114
+ </PREVIOUS MESSAGES>
115
+ <CURRENT MESSAGE>
116
+ {context['episode_content']}
117
+ </CURRENT MESSAGE>
118
+
119
+
120
+ Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
121
+ Each entity in ENTITIES is represented as a JSON object with the following structure:
122
+ {{
123
+ id: integer id of the entity,
124
+ name: "name of the entity",
125
+ entity_type: "ontological classification of the entity",
126
+ entity_type_description: "Description of what the entity type represents",
127
+ duplication_candidates: [
128
+ {{
129
+ idx: integer index of the candidate entity,
130
+ name: "name of the candidate entity",
131
+ entity_type: "ontological classification of the candidate entity",
132
+ ...<additional attributes>
133
+ }}
134
+ ]
135
+ }}
136
+
137
+ <ENTITIES>
138
+ {json.dumps(context['extracted_nodes'], indent=2)}
139
+ </ENTITIES>
140
+
141
+ For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
142
+
143
+ Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
144
+
145
+ Do NOT mark entities as duplicates if:
146
+ - They are related but distinct.
147
+ - They have similar names or purposes but refer to separate instances or concepts.
148
+
149
+ Task:
150
+ Your response will be a list called entity_resolutions which contains one entry for each entity.
151
+
152
+ For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
153
+ as an integer.
154
+
155
+ - If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
156
+ duplicate of.
157
+ - If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
85
158
  """,
86
159
  ),
87
160
  ]
@@ -124,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
124
197
  ]
125
198
 
126
199
 
127
- versions: Versions = {'node': node, 'node_list': node_list}
200
+ versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
@@ -256,7 +256,7 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
256
256
  1. Do not hallucinate entity property values if they cannot be found in the current context.
257
257
  2. Only use the provided MESSAGES and ENTITY to set attribute values.
258
258
  3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
259
- Summaries must be no longer than 500 words.
259
+ Summaries must be no longer than 250 words.
260
260
 
261
261
  <ENTITY>
262
262
  {context['node']}
@@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
24
24
  class InvalidatedEdges(BaseModel):
25
25
  contradicted_facts: list[int] = Field(
26
26
  ...,
27
- description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
27
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
28
28
  )
29
29
 
30
30
 
@@ -25,7 +25,7 @@ from .models import Message, PromptFunction, PromptVersion
25
25
  class Summary(BaseModel):
26
26
  summary: str = Field(
27
27
  ...,
28
- description='Summary containing the important information about the entity. Under 500 words',
28
+ description='Summary containing the important information about the entity. Under 250 words',
29
29
  )
30
30
 
31
31
 
@@ -56,7 +56,7 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
56
56
  content=f"""
57
57
  Synthesize the information from the following two summaries into a single succinct summary.
58
58
 
59
- Summaries must be under 500 words.
59
+ Summaries must be under 250 words.
60
60
 
61
61
  Summaries:
62
62
  {json.dumps(context['node_summaries'], indent=2)}
@@ -82,7 +82,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
82
82
 
83
83
  Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
84
84
  information from the provided MESSAGES. Your summary should also only contain information relevant to the
85
- provided ENTITY. Summaries must be under 500 words.
85
+ provided ENTITY. Summaries must be under 250 words.
86
86
 
87
87
  In addition, extract any values for the provided entity properties based on their descriptions.
88
88
  If the value of the entity property cannot be found in the current context, set the value of the property to the Python value None.
@@ -117,7 +117,7 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
117
117
  role='user',
118
118
  content=f"""
119
119
  Create a short one sentence description of the summary that explains what kind of information is summarized.
120
- Summaries must be under 500 words.
120
+ Summaries must be under 250 words.
121
121
 
122
122
  Summary:
123
123
  {json.dumps(context['summary'], indent=2)}
@@ -50,6 +50,9 @@ from graphiti_core.search.search_utils import (
50
50
  edge_similarity_search,
51
51
  episode_fulltext_search,
52
52
  episode_mentions_reranker,
53
+ get_embeddings_for_communities,
54
+ get_embeddings_for_edges,
55
+ get_embeddings_for_nodes,
53
56
  maximal_marginal_relevance,
54
57
  node_bfs_search,
55
58
  node_distance_reranker,
@@ -209,23 +212,17 @@ async def edge_search(
209
212
 
210
213
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
211
214
  elif config.reranker == EdgeReranker.mmr:
212
- search_result_uuids_and_vectors = [
213
- (edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
214
- for result in search_results
215
- for edge in result
216
- ]
215
+ search_result_uuids_and_vectors = await get_embeddings_for_edges(
216
+ driver, list(edge_uuid_map.values())
217
+ )
217
218
  reranked_uuids = maximal_marginal_relevance(
218
219
  query_vector,
219
220
  search_result_uuids_and_vectors,
220
221
  config.mmr_lambda,
222
+ reranker_min_score,
221
223
  )
222
224
  elif config.reranker == EdgeReranker.cross_encoder:
223
- search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
224
-
225
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
226
- rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
227
-
228
- fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
225
+ fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
229
226
  reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
230
227
  reranked_uuids = [
231
228
  fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
@@ -308,27 +305,23 @@ async def node_search(
308
305
  if config.reranker == NodeReranker.rrf:
309
306
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
310
307
  elif config.reranker == NodeReranker.mmr:
311
- search_result_uuids_and_vectors = [
312
- (node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
313
- for result in search_results
314
- for node in result
315
- ]
308
+ search_result_uuids_and_vectors = await get_embeddings_for_nodes(
309
+ driver, list(node_uuid_map.values())
310
+ )
311
+
316
312
  reranked_uuids = maximal_marginal_relevance(
317
313
  query_vector,
318
314
  search_result_uuids_and_vectors,
319
315
  config.mmr_lambda,
316
+ reranker_min_score,
320
317
  )
321
318
  elif config.reranker == NodeReranker.cross_encoder:
322
- # use rrf as a preliminary reranker
323
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
324
- rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
325
-
326
- summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
319
+ name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
327
320
 
328
- reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
321
+ reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
329
322
  reranked_uuids = [
330
- summary_to_uuid_map[fact]
331
- for fact, score in reranked_summaries
323
+ name_to_uuid_map[name]
324
+ for name, score in reranked_node_names
332
325
  if score >= reranker_min_score
333
326
  ]
334
327
  elif config.reranker == NodeReranker.episode_mentions:
@@ -431,28 +424,18 @@ async def community_search(
431
424
  if config.reranker == CommunityReranker.rrf:
432
425
  reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
433
426
  elif config.reranker == CommunityReranker.mmr:
434
- search_result_uuids_and_vectors = [
435
- (
436
- community.uuid,
437
- community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
438
- )
439
- for result in search_results
440
- for community in result
441
- ]
427
+ search_result_uuids_and_vectors = await get_embeddings_for_communities(
428
+ driver, list(community_uuid_map.values())
429
+ )
430
+
442
431
  reranked_uuids = maximal_marginal_relevance(
443
- query_vector,
444
- search_result_uuids_and_vectors,
445
- config.mmr_lambda,
432
+ query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
446
433
  )
447
434
  elif config.reranker == CommunityReranker.cross_encoder:
448
- summary_to_uuid_map = {
449
- node.summary: node.uuid for result in search_results for node in result
450
- }
451
- reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
435
+ name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
436
+ reranked_nodes = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
452
437
  reranked_uuids = [
453
- summary_to_uuid_map[fact]
454
- for fact, score in reranked_summaries
455
- if score >= reranker_min_score
438
+ name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
456
439
  ]
457
440
 
458
441
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]