graphiti-core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (37) hide show
  1. graphiti_core/__init__.py +3 -0
  2. graphiti_core/edges.py +232 -0
  3. graphiti_core/graphiti.py +618 -0
  4. graphiti_core/helpers.py +7 -0
  5. graphiti_core/llm_client/__init__.py +5 -0
  6. graphiti_core/llm_client/anthropic_client.py +63 -0
  7. graphiti_core/llm_client/client.py +96 -0
  8. graphiti_core/llm_client/config.py +58 -0
  9. graphiti_core/llm_client/groq_client.py +64 -0
  10. graphiti_core/llm_client/openai_client.py +65 -0
  11. graphiti_core/llm_client/utils.py +22 -0
  12. graphiti_core/nodes.py +250 -0
  13. graphiti_core/prompts/__init__.py +4 -0
  14. graphiti_core/prompts/dedupe_edges.py +154 -0
  15. graphiti_core/prompts/dedupe_nodes.py +151 -0
  16. graphiti_core/prompts/extract_edge_dates.py +60 -0
  17. graphiti_core/prompts/extract_edges.py +138 -0
  18. graphiti_core/prompts/extract_nodes.py +145 -0
  19. graphiti_core/prompts/invalidate_edges.py +74 -0
  20. graphiti_core/prompts/lib.py +122 -0
  21. graphiti_core/prompts/models.py +31 -0
  22. graphiti_core/search/__init__.py +0 -0
  23. graphiti_core/search/search.py +142 -0
  24. graphiti_core/search/search_utils.py +454 -0
  25. graphiti_core/utils/__init__.py +15 -0
  26. graphiti_core/utils/bulk_utils.py +227 -0
  27. graphiti_core/utils/maintenance/__init__.py +16 -0
  28. graphiti_core/utils/maintenance/edge_operations.py +170 -0
  29. graphiti_core/utils/maintenance/graph_data_operations.py +133 -0
  30. graphiti_core/utils/maintenance/node_operations.py +199 -0
  31. graphiti_core/utils/maintenance/temporal_operations.py +184 -0
  32. graphiti_core/utils/maintenance/utils.py +0 -0
  33. graphiti_core/utils/utils.py +39 -0
  34. graphiti_core-0.1.0.dist-info/LICENSE +201 -0
  35. graphiti_core-0.1.0.dist-info/METADATA +199 -0
  36. graphiti_core-0.1.0.dist-info/RECORD +37 -0
  37. graphiti_core-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,170 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from datetime import datetime
19
+ from time import time
20
+ from typing import List
21
+
22
+ from graphiti_core.edges import EntityEdge, EpisodicEdge
23
+ from graphiti_core.llm_client import LLMClient
24
+ from graphiti_core.nodes import EntityNode, EpisodicNode
25
+ from graphiti_core.prompts import prompt_library
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def build_episodic_edges(
31
+ entity_nodes: List[EntityNode],
32
+ episode: EpisodicNode,
33
+ created_at: datetime,
34
+ ) -> List[EpisodicEdge]:
35
+ edges: List[EpisodicEdge] = []
36
+
37
+ for node in entity_nodes:
38
+ edge = EpisodicEdge(
39
+ source_node_uuid=episode.uuid,
40
+ target_node_uuid=node.uuid,
41
+ created_at=created_at,
42
+ )
43
+ edges.append(edge)
44
+
45
+ return edges
46
+
47
+
48
+ async def extract_edges(
49
+ llm_client: LLMClient,
50
+ episode: EpisodicNode,
51
+ nodes: list[EntityNode],
52
+ previous_episodes: list[EpisodicNode],
53
+ ) -> list[EntityEdge]:
54
+ start = time()
55
+
56
+ # Prepare context for LLM
57
+ context = {
58
+ 'episode_content': episode.content,
59
+ 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
60
+ 'nodes': [
61
+ {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes
62
+ ],
63
+ 'previous_episodes': [
64
+ {
65
+ 'content': ep.content,
66
+ 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
67
+ }
68
+ for ep in previous_episodes
69
+ ],
70
+ }
71
+
72
+ llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context))
73
+ print(llm_response)
74
+ edges_data = llm_response.get('edges', [])
75
+
76
+ end = time()
77
+ logger.info(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
78
+
79
+ # Convert the extracted data into EntityEdge objects
80
+ edges = []
81
+ for edge_data in edges_data:
82
+ if edge_data['target_node_uuid'] and edge_data['source_node_uuid']:
83
+ edge = EntityEdge(
84
+ source_node_uuid=edge_data['source_node_uuid'],
85
+ target_node_uuid=edge_data['target_node_uuid'],
86
+ name=edge_data['relation_type'],
87
+ fact=edge_data['fact'],
88
+ episodes=[episode.uuid],
89
+ created_at=datetime.now(),
90
+ valid_at=None,
91
+ invalid_at=None,
92
+ )
93
+ edges.append(edge)
94
+ logger.info(
95
+ f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
96
+ )
97
+
98
+ return edges
99
+
100
+
101
+ def create_edge_identifier(
102
+ source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
103
+ ) -> str:
104
+ return f'{source_node.name}-{edge.name}-{target_node.name}'
105
+
106
+
107
+ async def dedupe_extracted_edges(
108
+ llm_client: LLMClient,
109
+ extracted_edges: list[EntityEdge],
110
+ existing_edges: list[EntityEdge],
111
+ ) -> list[EntityEdge]:
112
+ # Create edge map
113
+ edge_map = {}
114
+ for edge in extracted_edges:
115
+ edge_map[edge.uuid] = edge
116
+
117
+ # Prepare context for LLM
118
+ context = {
119
+ 'extracted_edges': [
120
+ {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
121
+ ],
122
+ 'existing_edges': [
123
+ {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
124
+ ],
125
+ }
126
+
127
+ llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
128
+ unique_edge_data = llm_response.get('unique_facts', [])
129
+ logger.info(f'Extracted unique edges: {unique_edge_data}')
130
+
131
+ # Get full edge data
132
+ edges = []
133
+ for unique_edge in unique_edge_data:
134
+ edge = edge_map[unique_edge['uuid']]
135
+ edges.append(edge)
136
+
137
+ return edges
138
+
139
+
140
+ async def dedupe_edge_list(
141
+ llm_client: LLMClient,
142
+ edges: list[EntityEdge],
143
+ ) -> list[EntityEdge]:
144
+ start = time()
145
+
146
+ # Create edge map
147
+ edge_map = {}
148
+ for edge in edges:
149
+ edge_map[edge.uuid] = edge
150
+
151
+ # Prepare context for LLM
152
+ context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
153
+
154
+ llm_response = await llm_client.generate_response(
155
+ prompt_library.dedupe_edges.edge_list(context)
156
+ )
157
+ unique_edges_data = llm_response.get('unique_facts', [])
158
+
159
+ end = time()
160
+ logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
161
+
162
+ # Get full edge data
163
+ unique_edges = []
164
+ for edge_data in unique_edges_data:
165
+ uuid = edge_data['uuid']
166
+ edge = edge_map[uuid]
167
+ edge.fact = edge_data['fact']
168
+ unique_edges.append(edge)
169
+
170
+ return unique_edges
@@ -0,0 +1,133 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ from datetime import datetime, timezone
20
+
21
+ from neo4j import AsyncDriver
22
+ from typing_extensions import LiteralString
23
+
24
+ from graphiti_core.nodes import EpisodeType, EpisodicNode
25
+
26
+ EPISODE_WINDOW_LEN = 3
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ async def build_indices_and_constraints(driver: AsyncDriver):
32
+ range_indices: list[LiteralString] = [
33
+ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
34
+ 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
35
+ 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
36
+ 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
37
+ 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
38
+ 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
39
+ 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
40
+ 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
41
+ 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
42
+ 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
43
+ 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
44
+ 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
45
+ 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
46
+ ]
47
+
48
+ fulltext_indices: list[LiteralString] = [
49
+ 'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
50
+ 'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
51
+ ]
52
+
53
+ vector_indices: list[LiteralString] = [
54
+ """
55
+ CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
56
+ FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
57
+ OPTIONS {indexConfig: {
58
+ `vector.dimensions`: 1024,
59
+ `vector.similarity_function`: 'cosine'
60
+ }}
61
+ """,
62
+ """
63
+ CREATE VECTOR INDEX name_embedding IF NOT EXISTS
64
+ FOR (n:Entity) ON (n.name_embedding)
65
+ OPTIONS {indexConfig: {
66
+ `vector.dimensions`: 1024,
67
+ `vector.similarity_function`: 'cosine'
68
+ }}
69
+ """,
70
+ ]
71
+ index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
72
+
73
+ await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
74
+
75
+
76
+ async def clear_data(driver: AsyncDriver):
77
+ async with driver.session() as session:
78
+
79
+ async def delete_all(tx):
80
+ await tx.run('MATCH (n) DETACH DELETE n')
81
+
82
+ await session.execute_write(delete_all)
83
+
84
+
85
+ async def retrieve_episodes(
86
+ driver: AsyncDriver,
87
+ reference_time: datetime,
88
+ last_n: int = EPISODE_WINDOW_LEN,
89
+ ) -> list[EpisodicNode]:
90
+ """
91
+ Retrieve the last n episodic nodes from the graph.
92
+
93
+ Args:
94
+ driver (AsyncDriver): The Neo4j driver instance.
95
+ reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp
96
+ less than or equal to this reference_time will be retrieved. This allows for
97
+ querying the graph's state at a specific point in time.
98
+ last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time.
99
+
100
+ Returns:
101
+ list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
102
+ """
103
+ result = await driver.execute_query(
104
+ """
105
+ MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
106
+ RETURN e.content as content,
107
+ e.created_at as created_at,
108
+ e.valid_at as valid_at,
109
+ e.uuid as uuid,
110
+ e.name as name,
111
+ e.source_description as source_description,
112
+ e.source as source
113
+ ORDER BY e.created_at DESC
114
+ LIMIT $num_episodes
115
+ """,
116
+ reference_time=reference_time,
117
+ num_episodes=last_n,
118
+ )
119
+ episodes = [
120
+ EpisodicNode(
121
+ content=record['content'],
122
+ created_at=datetime.fromtimestamp(
123
+ record['created_at'].to_native().timestamp(), timezone.utc
124
+ ),
125
+ valid_at=(record['valid_at'].to_native()),
126
+ uuid=record['uuid'],
127
+ source=EpisodeType.from_str(record['source']),
128
+ name=record['name'],
129
+ source_description=record['source_description'],
130
+ )
131
+ for record in result.records
132
+ ]
133
+ return list(reversed(episodes)) # Return in chronological order
@@ -0,0 +1,199 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from datetime import datetime
19
+ from time import time
20
+ from typing import Any
21
+
22
+ from graphiti_core.llm_client import LLMClient
23
+ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
24
+ from graphiti_core.prompts import prompt_library
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ async def extract_message_nodes(
30
+ llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
31
+ ) -> list[dict[str, Any]]:
32
+ # Prepare context for LLM
33
+ context = {
34
+ 'episode_content': episode.content,
35
+ 'episode_timestamp': episode.valid_at.isoformat(),
36
+ 'previous_episodes': [
37
+ {
38
+ 'content': ep.content,
39
+ 'timestamp': ep.valid_at.isoformat(),
40
+ }
41
+ for ep in previous_episodes
42
+ ],
43
+ }
44
+
45
+ llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v2(context))
46
+ extracted_node_data = llm_response.get('extracted_nodes', [])
47
+ return extracted_node_data
48
+
49
+
50
+ async def extract_json_nodes(
51
+ llm_client: LLMClient,
52
+ episode: EpisodicNode,
53
+ ) -> list[dict[str, Any]]:
54
+ # Prepare context for LLM
55
+ context = {
56
+ 'episode_content': episode.content,
57
+ 'episode_timestamp': episode.valid_at.isoformat(),
58
+ 'source_description': episode.source_description,
59
+ }
60
+
61
+ llm_response = await llm_client.generate_response(
62
+ prompt_library.extract_nodes.extract_json(context)
63
+ )
64
+ extracted_node_data = llm_response.get('extracted_nodes', [])
65
+ return extracted_node_data
66
+
67
+
68
+ async def extract_nodes(
69
+ llm_client: LLMClient,
70
+ episode: EpisodicNode,
71
+ previous_episodes: list[EpisodicNode],
72
+ ) -> list[EntityNode]:
73
+ start = time()
74
+ extracted_node_data: list[dict[str, Any]] = []
75
+ if episode.source in [EpisodeType.message, EpisodeType.text]:
76
+ extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
77
+ elif episode.source == EpisodeType.json:
78
+ extracted_node_data = await extract_json_nodes(llm_client, episode)
79
+
80
+ end = time()
81
+ logger.info(f'Extracted new nodes: {extracted_node_data} in {(end - start) * 1000} ms')
82
+ # Convert the extracted data into EntityNode objects
83
+ new_nodes = []
84
+ for node_data in extracted_node_data:
85
+ new_node = EntityNode(
86
+ name=node_data['name'],
87
+ labels=node_data['labels'],
88
+ summary=node_data['summary'],
89
+ created_at=datetime.now(),
90
+ )
91
+ new_nodes.append(new_node)
92
+ logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
93
+
94
+ return new_nodes
95
+
96
+
97
+ async def dedupe_extracted_nodes(
98
+ llm_client: LLMClient,
99
+ extracted_nodes: list[EntityNode],
100
+ existing_nodes: list[EntityNode],
101
+ ) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]:
102
+ start = time()
103
+
104
+ # build existing node map
105
+ node_map: dict[str, EntityNode] = {}
106
+ for node in existing_nodes:
107
+ node_map[node.name] = node
108
+
109
+ # Temp hack
110
+ new_nodes_map: dict[str, EntityNode] = {}
111
+ for node in extracted_nodes:
112
+ new_nodes_map[node.name] = node
113
+
114
+ # Prepare context for LLM
115
+ existing_nodes_context = [
116
+ {'name': node.name, 'summary': node.summary} for node in existing_nodes
117
+ ]
118
+
119
+ extracted_nodes_context = [
120
+ {'name': node.name, 'summary': node.summary} for node in extracted_nodes
121
+ ]
122
+
123
+ context = {
124
+ 'existing_nodes': existing_nodes_context,
125
+ 'extracted_nodes': extracted_nodes_context,
126
+ }
127
+
128
+ llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v2(context))
129
+
130
+ duplicate_data = llm_response.get('duplicates', [])
131
+
132
+ end = time()
133
+ logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
134
+
135
+ uuid_map: dict[str, str] = {}
136
+ for duplicate in duplicate_data:
137
+ uuid = new_nodes_map[duplicate['name']].uuid
138
+ uuid_value = node_map[duplicate['duplicate_of']].uuid
139
+ uuid_map[uuid] = uuid_value
140
+
141
+ nodes: list[EntityNode] = []
142
+ brand_new_nodes: list[EntityNode] = []
143
+ for node in extracted_nodes:
144
+ if node.uuid in uuid_map:
145
+ existing_uuid = uuid_map[node.uuid]
146
+ # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
147
+ # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
148
+ # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
149
+ existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
150
+ if existing_node:
151
+ nodes.append(existing_node)
152
+
153
+ continue
154
+ brand_new_nodes.append(node)
155
+ nodes.append(node)
156
+
157
+ return nodes, uuid_map, brand_new_nodes
158
+
159
+
160
+ async def dedupe_node_list(
161
+ llm_client: LLMClient,
162
+ nodes: list[EntityNode],
163
+ ) -> tuple[list[EntityNode], dict[str, str]]:
164
+ start = time()
165
+
166
+ # build node map
167
+ node_map = {}
168
+ for node in nodes:
169
+ node_map[node.name] = node
170
+
171
+ # Prepare context for LLM
172
+ nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes]
173
+
174
+ context = {
175
+ 'nodes': nodes_context,
176
+ }
177
+
178
+ llm_response = await llm_client.generate_response(
179
+ prompt_library.dedupe_nodes.node_list(context)
180
+ )
181
+
182
+ nodes_data = llm_response.get('nodes', [])
183
+
184
+ end = time()
185
+ logger.info(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms')
186
+
187
+ # Get full node data
188
+ unique_nodes = []
189
+ uuid_map: dict[str, str] = {}
190
+ for node_data in nodes_data:
191
+ node = node_map[node_data['names'][0]]
192
+ unique_nodes.append(node)
193
+
194
+ for name in node_data['names'][1:]:
195
+ uuid = node_map[name].uuid
196
+ uuid_value = node_map[node_data['names'][0]].uuid
197
+ uuid_map[uuid] = uuid_value
198
+
199
+ return unique_nodes, uuid_map
@@ -0,0 +1,184 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from datetime import datetime
19
+ from typing import List
20
+
21
+ from graphiti_core.edges import EntityEdge
22
+ from graphiti_core.llm_client import LLMClient
23
+ from graphiti_core.nodes import EntityNode, EpisodicNode
24
+ from graphiti_core.prompts import prompt_library
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
29
+
30
+
31
+ def extract_node_and_edge_triplets(
32
+ edges: list[EntityEdge], nodes: list[EntityNode]
33
+ ) -> list[NodeEdgeNodeTriplet]:
34
+ return [extract_node_edge_node_triplet(edge, nodes) for edge in edges]
35
+
36
+
37
+ def extract_node_edge_node_triplet(
38
+ edge: EntityEdge, nodes: list[EntityNode]
39
+ ) -> NodeEdgeNodeTriplet:
40
+ source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
41
+ target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
42
+ if not source_node or not target_node:
43
+ raise ValueError(f'Source or target node not found for edge {edge.uuid}')
44
+ return (source_node, edge, target_node)
45
+
46
+
47
+ def prepare_edges_for_invalidation(
48
+ existing_edges: list[EntityEdge],
49
+ new_edges: list[EntityEdge],
50
+ nodes: list[EntityNode],
51
+ ) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
52
+ existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
53
+ new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
54
+
55
+ for edge_list, result_list in [
56
+ (existing_edges, existing_edges_pending_invalidation),
57
+ (new_edges, new_edges_with_nodes),
58
+ ]:
59
+ for edge in edge_list:
60
+ source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
61
+ target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
62
+
63
+ if source_node and target_node:
64
+ result_list.append((source_node, edge, target_node))
65
+
66
+ return existing_edges_pending_invalidation, new_edges_with_nodes
67
+
68
+
69
+ async def invalidate_edges(
70
+ llm_client: LLMClient,
71
+ existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet],
72
+ new_edges: list[NodeEdgeNodeTriplet],
73
+ current_episode: EpisodicNode,
74
+ previous_episodes: list[EpisodicNode],
75
+ ) -> list[EntityEdge]:
76
+ invalidated_edges = [] # TODO: this is not yet used?
77
+
78
+ context = prepare_invalidation_context(
79
+ existing_edges_pending_invalidation,
80
+ new_edges,
81
+ current_episode,
82
+ previous_episodes,
83
+ )
84
+ llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context))
85
+
86
+ edges_to_invalidate = llm_response.get('invalidated_edges', [])
87
+ invalidated_edges = process_edge_invalidation_llm_response(
88
+ edges_to_invalidate, existing_edges_pending_invalidation
89
+ )
90
+
91
+ return invalidated_edges
92
+
93
+
94
+ def extract_date_strings_from_edge(edge: EntityEdge) -> str:
95
+ start = edge.valid_at
96
+ end = edge.invalid_at
97
+ date_string = f'Start Date: {start.isoformat()}' if start else ''
98
+ if end:
99
+ date_string += f' (End Date: {end.isoformat()})'
100
+ return date_string
101
+
102
+
103
+ def prepare_invalidation_context(
104
+ existing_edges: list[NodeEdgeNodeTriplet],
105
+ new_edges: list[NodeEdgeNodeTriplet],
106
+ current_episode: EpisodicNode,
107
+ previous_episodes: list[EpisodicNode],
108
+ ) -> dict:
109
+ return {
110
+ 'existing_edges': [
111
+ f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
112
+ for source_node, edge, target_node in sorted(
113
+ existing_edges, key=lambda x: (x[1].created_at), reverse=True
114
+ )
115
+ ],
116
+ 'new_edges': [
117
+ f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
118
+ for source_node, edge, target_node in sorted(
119
+ new_edges, key=lambda x: (x[1].created_at), reverse=True
120
+ )
121
+ ],
122
+ 'current_episode': current_episode.content,
123
+ 'previous_episodes': [episode.content for episode in previous_episodes],
124
+ }
125
+
126
+
127
+ def process_edge_invalidation_llm_response(
128
+ edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
129
+ ) -> List[EntityEdge]:
130
+ invalidated_edges = []
131
+ for edge_to_invalidate in edges_to_invalidate:
132
+ edge_uuid = edge_to_invalidate['edge_uuid']
133
+ edge_to_update = next(
134
+ (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
135
+ None,
136
+ )
137
+ if edge_to_update:
138
+ edge_to_update.expired_at = datetime.now()
139
+ edge_to_update.fact = edge_to_invalidate['fact']
140
+ invalidated_edges.append(edge_to_update)
141
+ logger.info(
142
+ f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
143
+ )
144
+ return invalidated_edges
145
+
146
+
147
+ async def extract_edge_dates(
148
+ llm_client: LLMClient,
149
+ edge: EntityEdge,
150
+ reference_time: datetime,
151
+ current_episode: EpisodicNode,
152
+ previous_episodes: List[EpisodicNode],
153
+ ) -> tuple[datetime | None, datetime | None, str]:
154
+ context = {
155
+ 'edge_name': edge.name,
156
+ 'edge_fact': edge.fact,
157
+ 'current_episode': current_episode.content,
158
+ 'previous_episodes': [ep.content for ep in previous_episodes],
159
+ 'reference_timestamp': reference_time.isoformat(),
160
+ }
161
+ llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))
162
+
163
+ valid_at = llm_response.get('valid_at')
164
+ invalid_at = llm_response.get('invalid_at')
165
+ explanation = llm_response.get('explanation', '')
166
+
167
+ valid_at_datetime = None
168
+ invalid_at_datetime = None
169
+
170
+ if valid_at and valid_at != '':
171
+ try:
172
+ valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
173
+ except ValueError as e:
174
+ logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
175
+
176
+ if invalid_at and invalid_at != '':
177
+ try:
178
+ invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
179
+ except ValueError as e:
180
+ logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
181
+
182
+ logger.info(f'Edge date extraction explanation: {explanation}')
183
+
184
+ return valid_at_datetime, invalid_at_datetime, explanation
File without changes