graphiti-core 0.3.8__tar.gz → 0.3.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/PKG-INFO +2 -1
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/edges.py +8 -8
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/errors.py +8 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/graphiti.py +44 -24
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/helpers.py +15 -1
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/nodes.py +16 -8
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/eval.py +28 -2
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/extract_edge_dates.py +8 -9
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/extract_edges.py +3 -2
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/invalidate_edges.py +1 -1
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/search.py +62 -46
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/search_config.py +13 -3
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/search_config_recipes.py +42 -1
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/search_utils.py +53 -13
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/__init__.py +0 -2
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/community_operations.py +14 -26
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/edge_operations.py +7 -13
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/node_operations.py +5 -5
- graphiti_core-0.3.11/graphiti_core/utils/maintenance/temporal_operations.py +95 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/pyproject.toml +4 -3
- graphiti_core-0.3.8/graphiti_core/utils/maintenance/temporal_operations.py +0 -217
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/LICENSE +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/README.md +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/embedder/openai.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/embedder/voyage.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
- {graphiti_core-0.3.8 → graphiti_core-0.3.11}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.11
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -17,6 +17,7 @@ Requires-Dist: numpy (>=1.0.0)
|
|
|
17
17
|
Requires-Dist: openai (>=1.50.2,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
20
|
+
Requires-Dist: voyageai (>=0.2.3,<0.3.0)
|
|
20
21
|
Description-Content-Type: text/markdown
|
|
21
22
|
|
|
22
23
|
<div align="center">
|
|
@@ -51,7 +51,7 @@ class Edge(BaseModel, ABC):
|
|
|
51
51
|
uuid=self.uuid,
|
|
52
52
|
)
|
|
53
53
|
|
|
54
|
-
logger.
|
|
54
|
+
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
55
55
|
|
|
56
56
|
return result
|
|
57
57
|
|
|
@@ -83,7 +83,7 @@ class EpisodicEdge(Edge):
|
|
|
83
83
|
created_at=self.created_at,
|
|
84
84
|
)
|
|
85
85
|
|
|
86
|
-
logger.
|
|
86
|
+
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
|
87
87
|
|
|
88
88
|
return result
|
|
89
89
|
|
|
@@ -178,7 +178,7 @@ class EntityEdge(Edge):
|
|
|
178
178
|
self.fact_embedding = await embedder.create(input=[text])
|
|
179
179
|
|
|
180
180
|
end = time()
|
|
181
|
-
logger.
|
|
181
|
+
logger.debug(f'embedded {text} in {end - start} ms')
|
|
182
182
|
|
|
183
183
|
return self.fact_embedding
|
|
184
184
|
|
|
@@ -188,9 +188,9 @@ class EntityEdge(Edge):
|
|
|
188
188
|
MATCH (source:Entity {uuid: $source_uuid})
|
|
189
189
|
MATCH (target:Entity {uuid: $target_uuid})
|
|
190
190
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
191
|
-
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact,
|
|
192
|
-
|
|
193
|
-
|
|
191
|
+
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
|
192
|
+
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
|
193
|
+
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
|
194
194
|
RETURN r.uuid AS uuid""",
|
|
195
195
|
source_uuid=self.source_node_uuid,
|
|
196
196
|
target_uuid=self.target_node_uuid,
|
|
@@ -206,7 +206,7 @@ class EntityEdge(Edge):
|
|
|
206
206
|
invalid_at=self.invalid_at,
|
|
207
207
|
)
|
|
208
208
|
|
|
209
|
-
logger.
|
|
209
|
+
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
|
210
210
|
|
|
211
211
|
return result
|
|
212
212
|
|
|
@@ -313,7 +313,7 @@ class CommunityEdge(Edge):
|
|
|
313
313
|
created_at=self.created_at,
|
|
314
314
|
)
|
|
315
315
|
|
|
316
|
-
logger.
|
|
316
|
+
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
|
317
317
|
|
|
318
318
|
return result
|
|
319
319
|
|
|
@@ -35,6 +35,14 @@ class GroupsEdgesNotFoundError(GraphitiError):
|
|
|
35
35
|
super().__init__(self.message)
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
class GroupsNodesNotFoundError(GraphitiError):
|
|
39
|
+
"""Raised when no nodes are found for a list of group ids."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, group_ids: list[str]):
|
|
42
|
+
self.message = f'no nodes found for group ids {group_ids}'
|
|
43
|
+
super().__init__(self.message)
|
|
44
|
+
|
|
45
|
+
|
|
38
46
|
class NodeNotFoundError(GraphitiError):
|
|
39
47
|
"""Raised when a node is not found."""
|
|
40
48
|
|
|
@@ -21,11 +21,12 @@ from time import time
|
|
|
21
21
|
|
|
22
22
|
from dotenv import load_dotenv
|
|
23
23
|
from neo4j import AsyncGraphDatabase
|
|
24
|
+
from pydantic import BaseModel
|
|
24
25
|
|
|
25
26
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
26
27
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
27
28
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
28
|
-
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
29
|
+
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
29
30
|
from graphiti_core.search.search import SearchConfig, search
|
|
30
31
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
|
31
32
|
from graphiti_core.search.search_config_recipes import (
|
|
@@ -77,6 +78,12 @@ logger = logging.getLogger(__name__)
|
|
|
77
78
|
load_dotenv()
|
|
78
79
|
|
|
79
80
|
|
|
81
|
+
class AddEpisodeResults(BaseModel):
|
|
82
|
+
episode: EpisodicNode
|
|
83
|
+
nodes: list[EntityNode]
|
|
84
|
+
edges: list[EntityEdge]
|
|
85
|
+
|
|
86
|
+
|
|
80
87
|
class Graphiti:
|
|
81
88
|
def __init__(
|
|
82
89
|
self,
|
|
@@ -245,7 +252,7 @@ class Graphiti:
|
|
|
245
252
|
group_id: str = '',
|
|
246
253
|
uuid: str | None = None,
|
|
247
254
|
update_communities: bool = False,
|
|
248
|
-
):
|
|
255
|
+
) -> AddEpisodeResults:
|
|
249
256
|
"""
|
|
250
257
|
Process an episode and update the graph.
|
|
251
258
|
|
|
@@ -312,13 +319,11 @@ class Graphiti:
|
|
|
312
319
|
valid_at=reference_time,
|
|
313
320
|
)
|
|
314
321
|
episode.uuid = uuid if uuid is not None else episode.uuid
|
|
315
|
-
if not self.store_raw_episode_content:
|
|
316
|
-
episode.content = ''
|
|
317
322
|
|
|
318
323
|
# Extract entities as nodes
|
|
319
324
|
|
|
320
325
|
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
|
321
|
-
logger.
|
|
326
|
+
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
322
327
|
|
|
323
328
|
# Calculate Embeddings
|
|
324
329
|
|
|
@@ -333,7 +338,7 @@ class Graphiti:
|
|
|
333
338
|
)
|
|
334
339
|
)
|
|
335
340
|
|
|
336
|
-
logger.
|
|
341
|
+
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
337
342
|
|
|
338
343
|
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
|
339
344
|
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
|
@@ -341,7 +346,7 @@ class Graphiti:
|
|
|
341
346
|
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
342
347
|
),
|
|
343
348
|
)
|
|
344
|
-
logger.
|
|
349
|
+
logger.debug(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
|
345
350
|
nodes = mentioned_nodes
|
|
346
351
|
|
|
347
352
|
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
|
@@ -371,10 +376,10 @@ class Graphiti:
|
|
|
371
376
|
]
|
|
372
377
|
)
|
|
373
378
|
)
|
|
374
|
-
logger.
|
|
379
|
+
logger.debug(
|
|
375
380
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
|
376
381
|
)
|
|
377
|
-
logger.
|
|
382
|
+
logger.debug(
|
|
378
383
|
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
|
379
384
|
)
|
|
380
385
|
|
|
@@ -426,15 +431,18 @@ class Graphiti:
|
|
|
426
431
|
|
|
427
432
|
entity_edges.extend(resolved_edges + invalidated_edges)
|
|
428
433
|
|
|
429
|
-
logger.
|
|
434
|
+
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
|
430
435
|
|
|
431
436
|
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
|
432
437
|
|
|
433
|
-
logger.
|
|
438
|
+
logger.debug(f'Built episodic edges: {episodic_edges}')
|
|
434
439
|
|
|
435
440
|
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
|
436
441
|
|
|
437
442
|
# Future optimization would be using batch operations to save nodes and edges
|
|
443
|
+
if not self.store_raw_episode_content:
|
|
444
|
+
episode.content = ''
|
|
445
|
+
|
|
438
446
|
await episode.save(self.driver)
|
|
439
447
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
|
440
448
|
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
|
@@ -451,6 +459,8 @@ class Graphiti:
|
|
|
451
459
|
end = time()
|
|
452
460
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
453
461
|
|
|
462
|
+
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
|
|
463
|
+
|
|
454
464
|
except Exception as e:
|
|
455
465
|
raise e
|
|
456
466
|
|
|
@@ -554,7 +564,7 @@ class Graphiti:
|
|
|
554
564
|
edges = await dedupe_edges_bulk(
|
|
555
565
|
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
|
556
566
|
)
|
|
557
|
-
logger.
|
|
567
|
+
logger.debug(f'extracted edge length: {len(edges)}')
|
|
558
568
|
|
|
559
569
|
# invalidate edges
|
|
560
570
|
|
|
@@ -567,11 +577,20 @@ class Graphiti:
|
|
|
567
577
|
except Exception as e:
|
|
568
578
|
raise e
|
|
569
579
|
|
|
570
|
-
async def build_communities(self):
|
|
580
|
+
async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
|
|
581
|
+
"""
|
|
582
|
+
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
|
583
|
+
the content of these communities.
|
|
584
|
+
----------
|
|
585
|
+
query : list[str] | None
|
|
586
|
+
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
|
587
|
+
"""
|
|
571
588
|
# Clear existing communities
|
|
572
589
|
await remove_communities(self.driver)
|
|
573
590
|
|
|
574
|
-
community_nodes, community_edges = await build_communities(
|
|
591
|
+
community_nodes, community_edges = await build_communities(
|
|
592
|
+
self.driver, self.llm_client, group_ids
|
|
593
|
+
)
|
|
575
594
|
|
|
576
595
|
await asyncio.gather(
|
|
577
596
|
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
|
@@ -580,6 +599,8 @@ class Graphiti:
|
|
|
580
599
|
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
|
581
600
|
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
|
582
601
|
|
|
602
|
+
return community_nodes
|
|
603
|
+
|
|
583
604
|
async def search(
|
|
584
605
|
self,
|
|
585
606
|
query: str,
|
|
@@ -700,18 +721,17 @@ class Graphiti:
|
|
|
700
721
|
).nodes
|
|
701
722
|
return nodes
|
|
702
723
|
|
|
724
|
+
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
|
725
|
+
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
|
703
726
|
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
edges_list = await asyncio.gather(
|
|
708
|
-
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
709
|
-
)
|
|
727
|
+
edges_list = await asyncio.gather(
|
|
728
|
+
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
729
|
+
)
|
|
710
730
|
|
|
711
|
-
|
|
731
|
+
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
|
712
732
|
|
|
713
|
-
|
|
733
|
+
nodes = await get_mentioned_nodes(self.driver, episodes)
|
|
714
734
|
|
|
715
|
-
|
|
735
|
+
communities = await get_communities_by_nodes(self.driver, nodes)
|
|
716
736
|
|
|
717
|
-
|
|
737
|
+
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
|
@@ -16,6 +16,7 @@ limitations under the License.
|
|
|
16
16
|
|
|
17
17
|
from datetime import datetime
|
|
18
18
|
|
|
19
|
+
import numpy as np
|
|
19
20
|
from neo4j import time as neo4j_time
|
|
20
21
|
|
|
21
22
|
|
|
@@ -25,7 +26,7 @@ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
|
25
26
|
|
|
26
27
|
def lucene_sanitize(query: str) -> str:
|
|
27
28
|
# Escape special characters from a query before passing into Lucene
|
|
28
|
-
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \
|
|
29
|
+
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
|
29
30
|
escape_map = str.maketrans(
|
|
30
31
|
{
|
|
31
32
|
'+': r'\+',
|
|
@@ -46,8 +47,21 @@ def lucene_sanitize(query: str) -> str:
|
|
|
46
47
|
'?': r'\?',
|
|
47
48
|
':': r'\:',
|
|
48
49
|
'\\': r'\\',
|
|
50
|
+
'/': r'\/',
|
|
49
51
|
}
|
|
50
52
|
)
|
|
51
53
|
|
|
52
54
|
sanitized = query.translate(escape_map)
|
|
53
55
|
return sanitized
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def normalize_l2(embedding: list[float]) -> list[float]:
|
|
59
|
+
embedding_array = np.array(embedding)
|
|
60
|
+
if embedding_array.ndim == 1:
|
|
61
|
+
norm = np.linalg.norm(embedding_array)
|
|
62
|
+
if norm == 0:
|
|
63
|
+
return embedding_array.tolist()
|
|
64
|
+
return (embedding_array / norm).tolist()
|
|
65
|
+
else:
|
|
66
|
+
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
|
67
|
+
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
|
@@ -86,7 +86,7 @@ class Node(BaseModel, ABC):
|
|
|
86
86
|
uuid=self.uuid,
|
|
87
87
|
)
|
|
88
88
|
|
|
89
|
-
logger.
|
|
89
|
+
logger.debug(f'Deleted Node: {self.uuid}')
|
|
90
90
|
|
|
91
91
|
return result
|
|
92
92
|
|
|
@@ -135,7 +135,7 @@ class EpisodicNode(Node):
|
|
|
135
135
|
source=self.source.value,
|
|
136
136
|
)
|
|
137
137
|
|
|
138
|
-
logger.
|
|
138
|
+
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
|
139
139
|
|
|
140
140
|
return result
|
|
141
141
|
|
|
@@ -217,7 +217,7 @@ class EntityNode(Node):
|
|
|
217
217
|
text = self.name.replace('\n', ' ')
|
|
218
218
|
self.name_embedding = await embedder.create(input=[text])
|
|
219
219
|
end = time()
|
|
220
|
-
logger.
|
|
220
|
+
logger.debug(f'embedded {text} in {end - start} ms')
|
|
221
221
|
|
|
222
222
|
return self.name_embedding
|
|
223
223
|
|
|
@@ -225,7 +225,8 @@ class EntityNode(Node):
|
|
|
225
225
|
result = await driver.execute_query(
|
|
226
226
|
"""
|
|
227
227
|
MERGE (n:Entity {uuid: $uuid})
|
|
228
|
-
SET n = {uuid: $uuid, name: $name,
|
|
228
|
+
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
229
|
+
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
|
229
230
|
RETURN n.uuid AS uuid""",
|
|
230
231
|
uuid=self.uuid,
|
|
231
232
|
name=self.name,
|
|
@@ -235,7 +236,7 @@ class EntityNode(Node):
|
|
|
235
236
|
created_at=self.created_at,
|
|
236
237
|
)
|
|
237
238
|
|
|
238
|
-
logger.
|
|
239
|
+
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
|
239
240
|
|
|
240
241
|
return result
|
|
241
242
|
|
|
@@ -257,6 +258,9 @@ class EntityNode(Node):
|
|
|
257
258
|
|
|
258
259
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
259
260
|
|
|
261
|
+
if len(nodes) == 0:
|
|
262
|
+
raise NodeNotFoundError(uuid)
|
|
263
|
+
|
|
260
264
|
return nodes[0]
|
|
261
265
|
|
|
262
266
|
@classmethod
|
|
@@ -308,7 +312,8 @@ class CommunityNode(Node):
|
|
|
308
312
|
result = await driver.execute_query(
|
|
309
313
|
"""
|
|
310
314
|
MERGE (n:Community {uuid: $uuid})
|
|
311
|
-
SET n = {uuid: $uuid, name: $name,
|
|
315
|
+
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
316
|
+
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
|
312
317
|
RETURN n.uuid AS uuid""",
|
|
313
318
|
uuid=self.uuid,
|
|
314
319
|
name=self.name,
|
|
@@ -318,7 +323,7 @@ class CommunityNode(Node):
|
|
|
318
323
|
created_at=self.created_at,
|
|
319
324
|
)
|
|
320
325
|
|
|
321
|
-
logger.
|
|
326
|
+
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
|
322
327
|
|
|
323
328
|
return result
|
|
324
329
|
|
|
@@ -327,7 +332,7 @@ class CommunityNode(Node):
|
|
|
327
332
|
text = self.name.replace('\n', ' ')
|
|
328
333
|
self.name_embedding = await embedder.create(input=[text])
|
|
329
334
|
end = time()
|
|
330
|
-
logger.
|
|
335
|
+
logger.debug(f'embedded {text} in {end - start} ms')
|
|
331
336
|
|
|
332
337
|
return self.name_embedding
|
|
333
338
|
|
|
@@ -349,6 +354,9 @@ class CommunityNode(Node):
|
|
|
349
354
|
|
|
350
355
|
nodes = [get_community_node_from_record(record) for record in records]
|
|
351
356
|
|
|
357
|
+
if len(nodes) == 0:
|
|
358
|
+
raise NodeNotFoundError(uuid)
|
|
359
|
+
|
|
352
360
|
return nodes[0]
|
|
353
361
|
|
|
354
362
|
@classmethod
|
|
@@ -23,11 +23,33 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
23
23
|
class Prompt(Protocol):
|
|
24
24
|
qa_prompt: PromptVersion
|
|
25
25
|
eval_prompt: PromptVersion
|
|
26
|
+
query_expansion: PromptVersion
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class Versions(TypedDict):
|
|
29
30
|
qa_prompt: PromptFunction
|
|
30
31
|
eval_prompt: PromptFunction
|
|
32
|
+
query_expansion: PromptFunction
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
|
36
|
+
sys_prompt = """You are an expert at rephrasing questions into queries used in a database retrieval system"""
|
|
37
|
+
|
|
38
|
+
user_prompt = f"""
|
|
39
|
+
Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
|
|
40
|
+
that maintains the relevant context?
|
|
41
|
+
<QUESTION>
|
|
42
|
+
{json.dumps(context['query'])}
|
|
43
|
+
</QUESTION>
|
|
44
|
+
respond with a JSON object in the following format:
|
|
45
|
+
{{
|
|
46
|
+
"query": "query optimized for database search"
|
|
47
|
+
}}
|
|
48
|
+
"""
|
|
49
|
+
return [
|
|
50
|
+
Message(role='system', content=sys_prompt),
|
|
51
|
+
Message(role='user', content=user_prompt),
|
|
52
|
+
]
|
|
31
53
|
|
|
32
54
|
|
|
33
55
|
def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
|
@@ -38,7 +60,7 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
|
|
38
60
|
You are given the following entity summaries and facts to help you determine the answer to your question.
|
|
39
61
|
<ENTITY_SUMMARIES>
|
|
40
62
|
{json.dumps(context['entity_summaries'])}
|
|
41
|
-
</ENTITY_SUMMARIES
|
|
63
|
+
</ENTITY_SUMMARIES>
|
|
42
64
|
<FACTS>
|
|
43
65
|
{json.dumps(context['facts'])}
|
|
44
66
|
</FACTS>
|
|
@@ -87,4 +109,8 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
|
|
87
109
|
]
|
|
88
110
|
|
|
89
111
|
|
|
90
|
-
versions: Versions = {
|
|
112
|
+
versions: Versions = {
|
|
113
|
+
'qa_prompt': qa_prompt,
|
|
114
|
+
'eval_prompt': eval_prompt,
|
|
115
|
+
'query_expansion': query_expansion,
|
|
116
|
+
}
|
|
@@ -37,7 +37,6 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|
|
37
37
|
role='user',
|
|
38
38
|
content=f"""
|
|
39
39
|
Edge:
|
|
40
|
-
Edge Name: {context['edge_name']}
|
|
41
40
|
Fact: {context['edge_fact']}
|
|
42
41
|
|
|
43
42
|
Current Episode: {context['current_episode']}
|
|
@@ -56,17 +55,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|
|
56
55
|
Guidelines:
|
|
57
56
|
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
|
|
58
57
|
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
|
|
59
|
-
3. If
|
|
60
|
-
4.
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
7. If only a
|
|
58
|
+
3. If the fact is written in the present tense, use the Reference Timestamp for the valid_at date
|
|
59
|
+
4. If no temporal information is found that establishes or changes the relationship, leave the fields as null.
|
|
60
|
+
5. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
|
|
61
|
+
6. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
|
|
62
|
+
7. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
|
|
63
|
+
8. If only a year is mentioned, use January 1st of that year at 00:00:00.
|
|
64
64
|
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
|
|
65
65
|
Respond with a JSON object:
|
|
66
66
|
{{
|
|
67
|
-
"valid_at": "YYYY-MM-DDTHH:MM:
|
|
68
|
-
"invalid_at": "YYYY-MM-DDTHH:MM:
|
|
69
|
-
"explanation": "Brief explanation of why these dates were chosen or why they were set to null"
|
|
67
|
+
"valid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
|
68
|
+
"invalid_at": "YYYY-MM-DDTHH:MM:SS.SSSSSSZ or null",
|
|
70
69
|
}}
|
|
71
70
|
""",
|
|
72
71
|
),
|
|
@@ -113,8 +113,9 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|
|
113
113
|
2. Each edge should represent a clear relationship between two DISTINCT nodes.
|
|
114
114
|
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
|
|
115
115
|
4. Provide a more detailed fact describing the relationship.
|
|
116
|
-
5.
|
|
117
|
-
6.
|
|
116
|
+
5. The fact should include any specific relevant information, including numeric information
|
|
117
|
+
6. Consider temporal aspects of relationships when relevant.
|
|
118
|
+
7. Avoid using the same node as the source and target of a relationship
|
|
118
119
|
|
|
119
120
|
Respond with a JSON object in the following format:
|
|
120
121
|
{{
|
|
@@ -82,7 +82,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|
|
82
82
|
Message(
|
|
83
83
|
role='user',
|
|
84
84
|
content=f"""
|
|
85
|
-
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to
|
|
85
|
+
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
|
|
86
86
|
|
|
87
87
|
Existing Edges:
|
|
88
88
|
{context['existing_edges']}
|