graphiti-core 0.3.8__tar.gz → 0.3.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (49) hide show
  1. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/PKG-INFO +2 -1
  2. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/edges.py +8 -8
  3. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/errors.py +8 -0
  4. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/graphiti.py +44 -24
  5. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/helpers.py +15 -1
  6. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/nodes.py +16 -8
  7. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/eval.py +28 -2
  8. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/extract_edge_dates.py +8 -9
  9. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/extract_edges.py +3 -2
  10. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/invalidate_edges.py +1 -1
  11. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/search.py +62 -46
  12. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/search_config.py +13 -3
  13. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/search_config_recipes.py +42 -1
  14. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/search_utils.py +53 -13
  15. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/__init__.py +0 -2
  16. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/community_operations.py +14 -26
  17. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/edge_operations.py +7 -13
  18. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/node_operations.py +5 -5
  19. graphiti_core-0.3.11/graphiti_core/utils/maintenance/temporal_operations.py +95 -0
  20. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/pyproject.toml +4 -3
  21. graphiti_core-0.3.8/graphiti_core/utils/maintenance/temporal_operations.py +0 -217
  22. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/LICENSE +0 -0
  23. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/README.md +0 -0
  24. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/__init__.py +0 -0
  25. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/embedder/__init__.py +0 -0
  26. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/embedder/client.py +0 -0
  27. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/embedder/openai.py +0 -0
  28. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/embedder/voyage.py +0 -0
  29. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/__init__.py +0 -0
  30. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/anthropic_client.py +0 -0
  31. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/client.py +0 -0
  32. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/config.py +0 -0
  33. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/errors.py +0 -0
  34. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/groq_client.py +0 -0
  35. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/openai_client.py +0 -0
  36. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/utils.py +0 -0
  37. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/__init__.py +0 -0
  38. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/dedupe_edges.py +0 -0
  39. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  40. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/extract_nodes.py +0 -0
  41. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/lib.py +0 -0
  42. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/models.py +0 -0
  43. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/summarize_nodes.py +0 -0
  44. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/py.typed +0 -0
  45. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/__init__.py +0 -0
  46. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/__init__.py +0 -0
  47. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/bulk_utils.py +0 -0
  48. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
  49. {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.8
3
+ Version: 0.3.11
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -17,6 +17,7 @@ Requires-Dist: numpy (>=1.0.0)
17
17
  Requires-Dist: openai (>=1.50.2,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
19
  Requires-Dist: tenacity (<9.0.0)
20
+ Requires-Dist: voyageai (>=0.2.3,<0.3.0)
20
21
  Description-Content-Type: text/markdown
21
22
 
22
23
  <div align="center">
@@ -51,7 +51,7 @@ class Edge(BaseModel, ABC):
51
51
  uuid=self.uuid,
52
52
  )
53
53
 
54
- logger.info(f'Deleted Edge: {self.uuid}')
54
+ logger.debug(f'Deleted Edge: {self.uuid}')
55
55
 
56
56
  return result
57
57
 
@@ -83,7 +83,7 @@ class EpisodicEdge(Edge):
83
83
  created_at=self.created_at,
84
84
  )
85
85
 
86
- logger.info(f'Saved edge to neo4j: {self.uuid}')
86
+ logger.debug(f'Saved edge to neo4j: {self.uuid}')
87
87
 
88
88
  return result
89
89
 
@@ -178,7 +178,7 @@ class EntityEdge(Edge):
178
178
  self.fact_embedding = await embedder.create(input=[text])
179
179
 
180
180
  end = time()
181
- logger.info(f'embedded {text} in {end - start} ms')
181
+ logger.debug(f'embedded {text} in {end - start} ms')
182
182
 
183
183
  return self.fact_embedding
184
184
 
@@ -188,9 +188,9 @@ class EntityEdge(Edge):
188
188
  MATCH (source:Entity {uuid: $source_uuid})
189
189
  MATCH (target:Entity {uuid: $target_uuid})
190
190
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
191
- SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
192
- episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
193
- valid_at: $valid_at, invalid_at: $invalid_at}
191
+ SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
192
+ created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
193
+ WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
194
194
  RETURN r.uuid AS uuid""",
195
195
  source_uuid=self.source_node_uuid,
196
196
  target_uuid=self.target_node_uuid,
@@ -206,7 +206,7 @@ class EntityEdge(Edge):
206
206
  invalid_at=self.invalid_at,
207
207
  )
208
208
 
209
- logger.info(f'Saved edge to neo4j: {self.uuid}')
209
+ logger.debug(f'Saved edge to neo4j: {self.uuid}')
210
210
 
211
211
  return result
212
212
 
@@ -313,7 +313,7 @@ class CommunityEdge(Edge):
313
313
  created_at=self.created_at,
314
314
  )
315
315
 
316
- logger.info(f'Saved edge to neo4j: {self.uuid}')
316
+ logger.debug(f'Saved edge to neo4j: {self.uuid}')
317
317
 
318
318
  return result
319
319
 
@@ -35,6 +35,14 @@ class GroupsEdgesNotFoundError(GraphitiError):
35
35
  super().__init__(self.message)
36
36
 
37
37
 
38
+ class GroupsNodesNotFoundError(GraphitiError):
39
+ """Raised when no nodes are found for a list of group ids."""
40
+
41
+ def __init__(self, group_ids: list[str]):
42
+ self.message = f'no nodes found for group ids {group_ids}'
43
+ super().__init__(self.message)
44
+
45
+
38
46
  class NodeNotFoundError(GraphitiError):
39
47
  """Raised when a node is not found."""
40
48
 
@@ -21,11 +21,12 @@ from time import time
21
21
 
22
22
  from dotenv import load_dotenv
23
23
  from neo4j import AsyncGraphDatabase
24
+ from pydantic import BaseModel
24
25
 
25
26
  from graphiti_core.edges import EntityEdge, EpisodicEdge
26
27
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
27
28
  from graphiti_core.llm_client import LLMClient, OpenAIClient
28
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
29
+ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
29
30
  from graphiti_core.search.search import SearchConfig, search
30
31
  from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
31
32
  from graphiti_core.search.search_config_recipes import (
@@ -77,6 +78,12 @@ logger = logging.getLogger(__name__)
77
78
  load_dotenv()
78
79
 
79
80
 
81
+ class AddEpisodeResults(BaseModel):
82
+ episode: EpisodicNode
83
+ nodes: list[EntityNode]
84
+ edges: list[EntityEdge]
85
+
86
+
80
87
  class Graphiti:
81
88
  def __init__(
82
89
  self,
@@ -245,7 +252,7 @@ class Graphiti:
245
252
  group_id: str = '',
246
253
  uuid: str | None = None,
247
254
  update_communities: bool = False,
248
- ):
255
+ ) -> AddEpisodeResults:
249
256
  """
250
257
  Process an episode and update the graph.
251
258
 
@@ -312,13 +319,11 @@ class Graphiti:
312
319
  valid_at=reference_time,
313
320
  )
314
321
  episode.uuid = uuid if uuid is not None else episode.uuid
315
- if not self.store_raw_episode_content:
316
- episode.content = ''
317
322
 
318
323
  # Extract entities as nodes
319
324
 
320
325
  extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
321
- logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
326
+ logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
322
327
 
323
328
  # Calculate Embeddings
324
329
 
@@ -333,7 +338,7 @@ class Graphiti:
333
338
  )
334
339
  )
335
340
 
336
- logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
341
+ logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
337
342
 
338
343
  (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
339
344
  resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
@@ -341,7 +346,7 @@ class Graphiti:
341
346
  self.llm_client, episode, extracted_nodes, previous_episodes, group_id
342
347
  ),
343
348
  )
344
- logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
349
+ logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
345
350
  nodes = mentioned_nodes
346
351
 
347
352
  extracted_edges_with_resolved_pointers = resolve_edge_pointers(
@@ -371,10 +376,10 @@ class Graphiti:
371
376
  ]
372
377
  )
373
378
  )
374
- logger.info(
379
+ logger.debug(
375
380
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
376
381
  )
377
- logger.info(
382
+ logger.debug(
378
383
  f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
379
384
  )
380
385
 
@@ -426,15 +431,18 @@ class Graphiti:
426
431
 
427
432
  entity_edges.extend(resolved_edges + invalidated_edges)
428
433
 
429
- logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
434
+ logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
430
435
 
431
436
  episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
432
437
 
433
- logger.info(f'Built episodic edges: {episodic_edges}')
438
+ logger.debug(f'Built episodic edges: {episodic_edges}')
434
439
 
435
440
  episode.entity_edges = [edge.uuid for edge in entity_edges]
436
441
 
437
442
  # Future optimization would be using batch operations to save nodes and edges
443
+ if not self.store_raw_episode_content:
444
+ episode.content = ''
445
+
438
446
  await episode.save(self.driver)
439
447
  await asyncio.gather(*[node.save(self.driver) for node in nodes])
440
448
  await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
@@ -451,6 +459,8 @@ class Graphiti:
451
459
  end = time()
452
460
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
453
461
 
462
+ return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
463
+
454
464
  except Exception as e:
455
465
  raise e
456
466
 
@@ -554,7 +564,7 @@ class Graphiti:
554
564
  edges = await dedupe_edges_bulk(
555
565
  self.driver, self.llm_client, extracted_edges_with_resolved_pointers
556
566
  )
557
- logger.info(f'extracted edge length: {len(edges)}')
567
+ logger.debug(f'extracted edge length: {len(edges)}')
558
568
 
559
569
  # invalidate edges
560
570
 
@@ -567,11 +577,20 @@ class Graphiti:
567
577
  except Exception as e:
568
578
  raise e
569
579
 
570
- async def build_communities(self):
580
+ async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
581
+ """
582
+ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
583
+ the content of these communities.
584
+ ----------
585
+ query : list[str] | None
586
+ Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
587
+ """
571
588
  # Clear existing communities
572
589
  await remove_communities(self.driver)
573
590
 
574
- community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
591
+ community_nodes, community_edges = await build_communities(
592
+ self.driver, self.llm_client, group_ids
593
+ )
575
594
 
576
595
  await asyncio.gather(
577
596
  *[node.generate_name_embedding(self.embedder) for node in community_nodes]
@@ -580,6 +599,8 @@ class Graphiti:
580
599
  await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
581
600
  await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
582
601
 
602
+ return community_nodes
603
+
583
604
  async def search(
584
605
  self,
585
606
  query: str,
@@ -700,18 +721,17 @@ class Graphiti:
700
721
  ).nodes
701
722
  return nodes
702
723
 
724
+ async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
725
+ episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
703
726
 
704
- async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
705
- episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
706
-
707
- edges_list = await asyncio.gather(
708
- *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
709
- )
727
+ edges_list = await asyncio.gather(
728
+ *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
729
+ )
710
730
 
711
- edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
731
+ edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
712
732
 
713
- nodes = await get_mentioned_nodes(self.driver, episodes)
733
+ nodes = await get_mentioned_nodes(self.driver, episodes)
714
734
 
715
- communities = await get_communities_by_nodes(self.driver, nodes)
735
+ communities = await get_communities_by_nodes(self.driver, nodes)
716
736
 
717
- return SearchResults(edges=edges, nodes=nodes, communities=communities)
737
+ return SearchResults(edges=edges, nodes=nodes, communities=communities)
@@ -16,6 +16,7 @@ limitations under the License.
16
16
 
17
17
  from datetime import datetime
18
18
 
19
+ import numpy as np
19
20
  from neo4j import time as neo4j_time
20
21
 
21
22
 
@@ -25,7 +26,7 @@ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
25
26
 
26
27
  def lucene_sanitize(query: str) -> str:
27
28
  # Escape special characters from a query before passing into Lucene
28
- # + - && || ! ( ) { } [ ] ^ " ~ * ? : \
29
+ # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
29
30
  escape_map = str.maketrans(
30
31
  {
31
32
  '+': r'\+',
@@ -46,8 +47,21 @@ def lucene_sanitize(query: str) -> str:
46
47
  '?': r'\?',
47
48
  ':': r'\:',
48
49
  '\\': r'\\',
50
+ '/': r'\/',
49
51
  }
50
52
  )
51
53
 
52
54
  sanitized = query.translate(escape_map)
53
55
  return sanitized
56
+
57
+
58
+ def normalize_l2(embedding: list[float]) -> list[float]:
59
+ embedding_array = np.array(embedding)
60
+ if embedding_array.ndim == 1:
61
+ norm = np.linalg.norm(embedding_array)
62
+ if norm == 0:
63
+ return embedding_array.tolist()
64
+ return (embedding_array / norm).tolist()
65
+ else:
66
+ norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
67
+ return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
@@ -86,7 +86,7 @@ class Node(BaseModel, ABC):
86
86
  uuid=self.uuid,
87
87
  )
88
88
 
89
- logger.info(f'Deleted Node: {self.uuid}')
89
+ logger.debug(f'Deleted Node: {self.uuid}')
90
90
 
91
91
  return result
92
92
 
@@ -135,7 +135,7 @@ class EpisodicNode(Node):
135
135
  source=self.source.value,
136
136
  )
137
137
 
138
- logger.info(f'Saved Node to neo4j: {self.uuid}')
138
+ logger.debug(f'Saved Node to neo4j: {self.uuid}')
139
139
 
140
140
  return result
141
141
 
@@ -217,7 +217,7 @@ class EntityNode(Node):
217
217
  text = self.name.replace('\n', ' ')
218
218
  self.name_embedding = await embedder.create(input=[text])
219
219
  end = time()
220
- logger.info(f'embedded {text} in {end - start} ms')
220
+ logger.debug(f'embedded {text} in {end - start} ms')
221
221
 
222
222
  return self.name_embedding
223
223
 
@@ -225,7 +225,8 @@ class EntityNode(Node):
225
225
  result = await driver.execute_query(
226
226
  """
227
227
  MERGE (n:Entity {uuid: $uuid})
228
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
228
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
229
+ WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
229
230
  RETURN n.uuid AS uuid""",
230
231
  uuid=self.uuid,
231
232
  name=self.name,
@@ -235,7 +236,7 @@ class EntityNode(Node):
235
236
  created_at=self.created_at,
236
237
  )
237
238
 
238
- logger.info(f'Saved Node to neo4j: {self.uuid}')
239
+ logger.debug(f'Saved Node to neo4j: {self.uuid}')
239
240
 
240
241
  return result
241
242
 
@@ -257,6 +258,9 @@ class EntityNode(Node):
257
258
 
258
259
  nodes = [get_entity_node_from_record(record) for record in records]
259
260
 
261
+ if len(nodes) == 0:
262
+ raise NodeNotFoundError(uuid)
263
+
260
264
  return nodes[0]
261
265
 
262
266
  @classmethod
@@ -308,7 +312,8 @@ class CommunityNode(Node):
308
312
  result = await driver.execute_query(
309
313
  """
310
314
  MERGE (n:Community {uuid: $uuid})
311
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
315
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
316
+ WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
312
317
  RETURN n.uuid AS uuid""",
313
318
  uuid=self.uuid,
314
319
  name=self.name,
@@ -318,7 +323,7 @@ class CommunityNode(Node):
318
323
  created_at=self.created_at,
319
324
  )
320
325
 
321
- logger.info(f'Saved Node to neo4j: {self.uuid}')
326
+ logger.debug(f'Saved Node to neo4j: {self.uuid}')
322
327
 
323
328
  return result
324
329
 
@@ -327,7 +332,7 @@ class CommunityNode(Node):
327
332
  text = self.name.replace('\n', ' ')
328
333
  self.name_embedding = await embedder.create(input=[text])
329
334
  end = time()
330
- logger.info(f'embedded {text} in {end - start} ms')
335
+ logger.debug(f'embedded {text} in {end - start} ms')
331
336
 
332
337
  return self.name_embedding
333
338
 
@@ -349,6 +354,9 @@ class CommunityNode(Node):
349
354
 
350
355
  nodes = [get_community_node_from_record(record) for record in records]
351
356
 
357
+ if len(nodes) == 0:
358
+ raise NodeNotFoundError(uuid)
359
+
352
360
  return nodes[0]
353
361
 
354
362
  @classmethod
@@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
23
23
  class Prompt(Protocol):
24
24
  qa_prompt: PromptVersion
25
25
  eval_prompt: PromptVersion
26
+ query_expansion: PromptVersion
26
27
 
27
28
 
28
29
  class Versions(TypedDict):
29
30
  qa_prompt: PromptFunction
30
31
  eval_prompt: PromptFunction
32
+ query_expansion: PromptFunction
33
+
34
+
35
+ def query_expansion(context: dict[str, Any]) -> list[Message]:
36
+ sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
37
+
38
+ user_prompt = f"""
39
+ Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
40
+ that maintains the relevant context?
41
+ <QUESTION>
42
+ {json.dumps(context['query'])}
43
+ </QUESTION>
44
+ respond with a JSON object in the following format:
45
+ {{
46
+ "query": "query optimized for database search"
47
+ }}
48
+ """
49
+ return [
50
+ Message(role='system', content=sys_prompt),
51
+ Message(role='user', content=user_prompt),
52
+ ]
31
53
 
32
54
 
33
55
  def qa_prompt(context: dict[str, Any]) -> list[Message]:
@@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
38
60
  You are given the following entity summaries and facts to help you determine the answer to your question.
39
61
  <ENTITY_SUMMARIES>
40
62
  {json.dumps(context['entity_summaries'])}
41
- </ENTITY_SUMMARIES
63
+ </ENTITY_SUMMARIES>
42
64
  <FACTS>
43
65
  {json.dumps(context['facts'])}
44
66
  </FACTS>
@@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
87
109
  ]
88
110
 
89
111
 
90
- versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
112
+ versions: Versions = {
113
+ 'qa_prompt': qa_prompt,
114
+ 'eval_prompt': eval_prompt,
115
+ 'query_expansion': query_expansion,
116
+ }
@@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
37
37
  role='user',
38
38
  content=f"""
39
39
  Edge:
40
- Edge Name: {context['edge_name']}
41
40
  Fact: {context['edge_fact']}
42
41
 
43
42
  Current Episode: {context['current_episode']}
@@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
56
55
  Guidelines:
57
56
  1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
58
57
  2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
59
- 3. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
60
- 4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
61
- 5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
62
- 6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
63
- 7. If only a year is mentioned, use January 1st of that year at 00:00:00.
58
+ 3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
59
+ 4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
60
+ 5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
61
+ 6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
62
+ 7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
63
+ 8. If only a year is mentioned, use January 1st of that year at 00:00:00.
64
64
  9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
65
65
  Respond with a JSON object:
66
66
  {{
67
- "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
68
- "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
69
- "explanation": "Brief explanation of why these dates were chosen or why they were set to null"
67
+ "valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
68
+ "invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
70
69
  }}
71
70
  """,
72
71
  ),
@@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
113
113
  2. Each edge should represent a clear relationship between two DISTINCT nodes.
114
114
  3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
115
115
  4. Provide a more detailed fact describing the relationship.
116
- 5. Consider temporal aspects of relationships when relevant.
117
- 6. Avoid using the same node as the source and target of a relationship
116
+ 5. The fact should include any specific relevant information, including numeric information
117
+ 6. Consider temporal aspects of relationships when relevant.
118
+ 7. Avoid using the same node as the source and target of a relationship
118
119
 
119
120
  Respond with a JSON object in the following format:
120
121
  {{
@@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
82
82
  Message(
83
83
  role='user',
84
84
  content=f"""
85
- Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
85
+ Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
86
86
 
87
87
  Existing Edges:
88
88
  {context['existing_edges']}