graphiti-core 0.10.5__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -18,16 +18,23 @@ import logging
18
18
  from datetime import datetime
19
19
  from time import time
20
20
 
21
- from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
21
+ from graphiti_core.edges import (
22
+ CommunityEdge,
23
+ EntityEdge,
24
+ EpisodicEdge,
25
+ create_entity_edge_embeddings,
26
+ )
27
+ from graphiti_core.graphiti_types import GraphitiClients
22
28
  from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
23
29
  from graphiti_core.llm_client import LLMClient
24
30
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
25
31
  from graphiti_core.prompts import prompt_library
26
32
  from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
27
33
  from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
28
- from graphiti_core.utils.datetime_utils import utc_now
34
+ from graphiti_core.search.search_filters import SearchFilters
35
+ from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
36
+ from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
29
37
  from graphiti_core.utils.maintenance.temporal_operations import (
30
- extract_edge_dates,
31
38
  get_edge_contradictions,
32
39
  )
33
40
 
@@ -39,7 +46,7 @@ def build_episodic_edges(
39
46
  episode: EpisodicNode,
40
47
  created_at: datetime,
41
48
  ) -> list[EpisodicEdge]:
42
- edges: list[EpisodicEdge] = [
49
+ episodic_edges: list[EpisodicEdge] = [
43
50
  EpisodicEdge(
44
51
  source_node_uuid=episode.uuid,
45
52
  target_node_uuid=node.uuid,
@@ -49,7 +56,9 @@ def build_episodic_edges(
49
56
  for node in entity_nodes
50
57
  ]
51
58
 
52
- return edges
59
+ logger.debug(f'Built episodic edges: {episodic_edges}')
60
+
61
+ return episodic_edges
53
62
 
54
63
 
55
64
  def build_community_edges(
@@ -71,7 +80,7 @@ def build_community_edges(
71
80
 
72
81
 
73
82
  async def extract_edges(
74
- llm_client: LLMClient,
83
+ clients: GraphitiClients,
75
84
  episode: EpisodicNode,
76
85
  nodes: list[EntityNode],
77
86
  previous_episodes: list[EpisodicNode],
@@ -79,7 +88,9 @@ async def extract_edges(
79
88
  ) -> list[EntityEdge]:
80
89
  start = time()
81
90
 
82
- EXTRACT_EDGES_MAX_TOKENS = 16384
91
+ extract_edges_max_tokens = 16384
92
+ llm_client = clients.llm_client
93
+ embedder = clients.embedder
83
94
 
84
95
  node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
85
96
 
@@ -88,16 +99,17 @@ async def extract_edges(
88
99
  'episode_content': episode.content,
89
100
  'nodes': [node.name for node in nodes],
90
101
  'previous_episodes': [ep.content for ep in previous_episodes],
102
+ 'reference_time': episode.valid_at,
91
103
  'custom_prompt': '',
92
104
  }
93
105
 
94
106
  facts_missed = True
95
107
  reflexion_iterations = 0
96
- while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
108
+ while facts_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
97
109
  llm_response = await llm_client.generate_response(
98
110
  prompt_library.extract_edges.edge(context),
99
111
  response_model=ExtractedEdges,
100
- max_tokens=EXTRACT_EDGES_MAX_TOKENS,
112
+ max_tokens=extract_edges_max_tokens,
101
113
  )
102
114
  edges_data = llm_response.get('edges', [])
103
115
 
@@ -106,7 +118,9 @@ async def extract_edges(
106
118
  reflexion_iterations += 1
107
119
  if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
108
120
  reflexion_response = await llm_client.generate_response(
109
- prompt_library.extract_edges.reflexion(context), response_model=MissingFacts
121
+ prompt_library.extract_edges.reflexion(context),
122
+ response_model=MissingFacts,
123
+ max_tokens=extract_edges_max_tokens,
110
124
  )
111
125
 
112
126
  missing_facts = reflexion_response.get('missing_facts', [])
@@ -122,9 +136,33 @@ async def extract_edges(
122
136
  end = time()
123
137
  logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
124
138
 
139
+ if len(edges_data) == 0:
140
+ return []
141
+
125
142
  # Convert the extracted data into EntityEdge objects
126
143
  edges = []
127
144
  for edge_data in edges_data:
145
+ # Validate Edge Date information
146
+ valid_at = edge_data.get('valid_at', None)
147
+ invalid_at = edge_data.get('invalid_at', None)
148
+ valid_at_datetime = None
149
+ invalid_at_datetime = None
150
+
151
+ if valid_at:
152
+ try:
153
+ valid_at_datetime = ensure_utc(
154
+ datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
155
+ )
156
+ except ValueError as e:
157
+ logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')
158
+
159
+ if invalid_at:
160
+ try:
161
+ invalid_at_datetime = ensure_utc(
162
+ datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
163
+ )
164
+ except ValueError as e:
165
+ logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
128
166
  edge = EntityEdge(
129
167
  source_node_uuid=node_uuids_by_name_map.get(
130
168
  edge_data.get('source_entity_name', ''), ''
@@ -137,14 +175,18 @@ async def extract_edges(
137
175
  fact=edge_data.get('fact', ''),
138
176
  episodes=[episode.uuid],
139
177
  created_at=utc_now(),
140
- valid_at=None,
141
- invalid_at=None,
178
+ valid_at=valid_at_datetime,
179
+ invalid_at=invalid_at_datetime,
142
180
  )
143
181
  edges.append(edge)
144
182
  logger.debug(
145
183
  f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
146
184
  )
147
185
 
186
+ await create_entity_edge_embeddings(embedder, edges)
187
+
188
+ logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
189
+
148
190
  return edges
149
191
 
150
192
 
@@ -193,14 +235,24 @@ async def dedupe_extracted_edges(
193
235
 
194
236
 
195
237
  async def resolve_extracted_edges(
196
- llm_client: LLMClient,
238
+ clients: GraphitiClients,
197
239
  extracted_edges: list[EntityEdge],
198
- related_edges_lists: list[list[EntityEdge]],
199
- existing_edges_lists: list[list[EntityEdge]],
200
- current_episode: EpisodicNode,
201
- previous_episodes: list[EpisodicNode],
202
240
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
203
- # resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
241
+ driver = clients.driver
242
+ llm_client = clients.llm_client
243
+
244
+ search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
245
+ get_relevant_edges(driver, extracted_edges, SearchFilters()),
246
+ get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
247
+ )
248
+
249
+ related_edges_lists, edge_invalidation_candidates = search_results
250
+
251
+ logger.debug(
252
+ f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
253
+ )
254
+
255
+ # resolve edges with related edges in the graph and find invalidation candidates
204
256
  results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
205
257
  await semaphore_gather(
206
258
  *[
@@ -209,11 +261,9 @@ async def resolve_extracted_edges(
209
261
  extracted_edge,
210
262
  related_edges,
211
263
  existing_edges,
212
- current_episode,
213
- previous_episodes,
214
264
  )
215
265
  for extracted_edge, related_edges, existing_edges in zip(
216
- extracted_edges, related_edges_lists, existing_edges_lists, strict=False
266
+ extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
217
267
  )
218
268
  ]
219
269
  )
@@ -228,12 +278,17 @@ async def resolve_extracted_edges(
228
278
  resolved_edges.append(resolved_edge)
229
279
  invalidated_edges.extend(invalidated_edge_chunk)
230
280
 
281
+ logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
282
+
231
283
  return resolved_edges, invalidated_edges
232
284
 
233
285
 
234
286
  def resolve_edge_contradictions(
235
287
  resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
236
288
  ) -> list[EntityEdge]:
289
+ if len(invalidation_candidates) == 0:
290
+ return []
291
+
237
292
  # Determine which contradictory edges need to be expired
238
293
  invalidated_edges: list[EntityEdge] = []
239
294
  for edge in invalidation_candidates:
@@ -266,21 +321,15 @@ async def resolve_extracted_edge(
266
321
  extracted_edge: EntityEdge,
267
322
  related_edges: list[EntityEdge],
268
323
  existing_edges: list[EntityEdge],
269
- current_episode: EpisodicNode,
270
- previous_episodes: list[EpisodicNode],
271
324
  ) -> tuple[EntityEdge, list[EntityEdge]]:
272
- resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather(
325
+ resolved_edge, invalidation_candidates = await semaphore_gather(
273
326
  dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
274
- extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
275
327
  get_edge_contradictions(llm_client, extracted_edge, existing_edges),
276
328
  )
277
329
 
278
330
  now = utc_now()
279
331
 
280
- resolved_edge.valid_at = valid_at if valid_at else resolved_edge.valid_at
281
- resolved_edge.invalid_at = invalid_at if invalid_at else resolved_edge.invalid_at
282
-
283
- if invalid_at and not resolved_edge.expired_at:
332
+ if resolved_edge.invalid_at and not resolved_edge.expired_at:
284
333
  resolved_edge.expired_at = now
285
334
 
286
335
  # Determine if the new_edge needs to be expired
@@ -308,16 +357,17 @@ async def resolve_extracted_edge(
308
357
  async def dedupe_extracted_edge(
309
358
  llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
310
359
  ) -> EntityEdge:
360
+ if len(related_edges) == 0:
361
+ return extracted_edge
362
+
311
363
  start = time()
312
364
 
313
365
  # Prepare context for LLM
314
366
  related_edges_context = [
315
- {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges
367
+ {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
316
368
  ]
317
369
 
318
370
  extracted_edge_context = {
319
- 'uuid': extracted_edge.uuid,
320
- 'name': extracted_edge.name,
321
371
  'fact': extracted_edge.fact,
322
372
  }
323
373
 
@@ -330,15 +380,13 @@ async def dedupe_extracted_edge(
330
380
  prompt_library.dedupe_edges.edge(context), response_model=EdgeDuplicate
331
381
  )
332
382
 
333
- is_duplicate: bool = llm_response.get('is_duplicate', False)
334
- uuid: str | None = llm_response.get('uuid', None)
383
+ duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
335
384
 
336
- edge = extracted_edge
337
- if is_duplicate:
338
- for existing_edge in related_edges:
339
- if existing_edge.uuid != uuid:
340
- continue
341
- edge = existing_edge
385
+ edge = (
386
+ related_edges[duplicate_fact_id]
387
+ if 0 <= duplicate_fact_id < len(related_edges)
388
+ else extracted_edge
389
+ )
342
390
 
343
391
  end = time()
344
392
  logger.debug(
@@ -117,6 +117,7 @@ async def retrieve_episodes(
117
117
  reference_time: datetime,
118
118
  last_n: int = EPISODE_WINDOW_LEN,
119
119
  group_ids: list[str] | None = None,
120
+ source: EpisodeType | None = None,
120
121
  ) -> list[EpisodicNode]:
121
122
  """
122
123
  Retrieve the last n episodic nodes from the graph.
@@ -132,13 +133,17 @@ async def retrieve_episodes(
132
133
  Returns:
133
134
  list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
134
135
  """
135
- group_id_filter: LiteralString = 'AND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
136
+ group_id_filter: LiteralString = (
137
+ '\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
138
+ )
139
+ source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
136
140
 
137
141
  query: LiteralString = (
138
142
  """
139
- MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
140
- """
143
+ MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
144
+ """
141
145
  + group_id_filter
146
+ + source_filter
142
147
  + """
143
148
  RETURN e.content AS content,
144
149
  e.created_at AS created_at,
@@ -156,6 +161,7 @@ async def retrieve_episodes(
156
161
  result = await driver.execute_query(
157
162
  query,
158
163
  reference_time=reference_time,
164
+ source=source.name if source is not None else None,
159
165
  num_episodes=last_n,
160
166
  group_ids=group_ids,
161
167
  database_=DEFAULT_DATABASE,