graphiti-core 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/graphiti.py CHANGED
@@ -262,6 +262,7 @@ class Graphiti:
262
262
  group_id: str = '',
263
263
  uuid: str | None = None,
264
264
  update_communities: bool = False,
265
+ entity_types: dict[str, BaseModel] | None = None,
265
266
  ) -> AddEpisodeResults:
266
267
  """
267
268
  Process an episode and update the graph.
@@ -336,7 +337,9 @@ class Graphiti:
336
337
 
337
338
  # Extract entities as nodes
338
339
 
339
- extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
340
+ extracted_nodes = await extract_nodes(
341
+ self.llm_client, episode, previous_episodes, entity_types
342
+ )
340
343
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
341
344
 
342
345
  # Calculate Embeddings
@@ -348,7 +351,10 @@ class Graphiti:
348
351
  # Find relevant nodes already in the graph
349
352
  existing_nodes_lists: list[list[EntityNode]] = list(
350
353
  await semaphore_gather(
351
- *[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
354
+ *[
355
+ get_relevant_nodes(self.driver, SearchFilters(), [node])
356
+ for node in extracted_nodes
357
+ ]
352
358
  )
353
359
  )
354
360
 
@@ -362,6 +368,7 @@ class Graphiti:
362
368
  existing_nodes_lists,
363
369
  episode,
364
370
  previous_episodes,
371
+ entity_types,
365
372
  ),
366
373
  extract_edges(
367
374
  self.llm_client, episode, extracted_nodes, previous_episodes, group_id
@@ -728,8 +735,8 @@ class Graphiti:
728
735
  self.llm_client,
729
736
  [source_node, target_node],
730
737
  [
731
- await get_relevant_nodes(self.driver, [source_node]),
732
- await get_relevant_nodes(self.driver, [target_node]),
738
+ await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
739
+ await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
733
740
  ],
734
741
  )
735
742
 
@@ -31,14 +31,16 @@ EPISODIC_NODE_SAVE_BULK = """
31
31
 
32
32
  ENTITY_NODE_SAVE = """
33
33
  MERGE (n:Entity {uuid: $uuid})
34
- SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
34
+ SET n:$($labels)
35
+ SET n = $entity_data
35
36
  WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
36
37
  RETURN n.uuid AS uuid"""
37
38
 
38
39
  ENTITY_NODE_SAVE_BULK = """
39
40
  UNWIND $nodes AS node
40
41
  MERGE (n:Entity {uuid: node.uuid})
41
- SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
42
+ SET n:$(node.labels)
43
+ SET n = node
42
44
  WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
43
45
  RETURN n.uuid AS uuid
44
46
  """
graphiti_core/nodes.py CHANGED
@@ -255,6 +255,9 @@ class EpisodicNode(Node):
255
255
  class EntityNode(Node):
256
256
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
257
257
  summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
258
+ attributes: dict[str, Any] = Field(
259
+ default={}, description='Additional attributes of the node. Dependent on node labels'
260
+ )
258
261
 
259
262
  async def generate_name_embedding(self, embedder: EmbedderClient):
260
263
  start = time()
@@ -266,14 +269,21 @@ class EntityNode(Node):
266
269
  return self.name_embedding
267
270
 
268
271
  async def save(self, driver: AsyncDriver):
272
+ entity_data: dict[str, Any] = {
273
+ 'uuid': self.uuid,
274
+ 'name': self.name,
275
+ 'name_embedding': self.name_embedding,
276
+ 'group_id': self.group_id,
277
+ 'summary': self.summary,
278
+ 'created_at': self.created_at,
279
+ }
280
+
281
+ entity_data.update(self.attributes or {})
282
+
269
283
  result = await driver.execute_query(
270
284
  ENTITY_NODE_SAVE,
271
- uuid=self.uuid,
272
- name=self.name,
273
- group_id=self.group_id,
274
- summary=self.summary,
275
- name_embedding=self.name_embedding,
276
- created_at=self.created_at,
285
+ labels=self.labels + ['Entity'],
286
+ entity_data=entity_data,
277
287
  database_=DEFAULT_DATABASE,
278
288
  )
279
289
 
@@ -292,7 +302,9 @@ class EntityNode(Node):
292
302
  n.name_embedding AS name_embedding,
293
303
  n.group_id AS group_id,
294
304
  n.created_at AS created_at,
295
- n.summary AS summary
305
+ n.summary AS summary,
306
+ labels(n) AS labels,
307
+ properties(n) AS attributes
296
308
  """,
297
309
  uuid=uuid,
298
310
  database_=DEFAULT_DATABASE,
@@ -317,7 +329,9 @@ class EntityNode(Node):
317
329
  n.name_embedding AS name_embedding,
318
330
  n.group_id AS group_id,
319
331
  n.created_at AS created_at,
320
- n.summary AS summary
332
+ n.summary AS summary,
333
+ labels(n) AS labels,
334
+ properties(n) AS attributes
321
335
  """,
322
336
  uuids=uuids,
323
337
  database_=DEFAULT_DATABASE,
@@ -351,7 +365,9 @@ class EntityNode(Node):
351
365
  n.name_embedding AS name_embedding,
352
366
  n.group_id AS group_id,
353
367
  n.created_at AS created_at,
354
- n.summary AS summary
368
+ n.summary AS summary,
369
+ labels(n) AS labels,
370
+ properties(n) AS attributes
355
371
  ORDER BY n.uuid DESC
356
372
  """
357
373
  + limit_query,
@@ -503,9 +519,10 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
503
519
  name=record['name'],
504
520
  group_id=record['group_id'],
505
521
  name_embedding=record['name_embedding'],
506
- labels=['Entity'],
522
+ labels=record['labels'],
507
523
  created_at=record['created_at'].to_native(),
508
524
  summary=record['summary'],
525
+ attributes=record['attributes'],
509
526
  )
510
527
 
511
528
 
@@ -30,11 +30,19 @@ class MissedEntities(BaseModel):
30
30
  missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
31
31
 
32
32
 
33
+ class EntityClassification(BaseModel):
34
+ entity_classification: str = Field(
35
+ ...,
36
+ description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
37
+ )
38
+
39
+
33
40
  class Prompt(Protocol):
34
41
  extract_message: PromptVersion
35
42
  extract_json: PromptVersion
36
43
  extract_text: PromptVersion
37
44
  reflexion: PromptVersion
45
+ classify_nodes: PromptVersion
38
46
 
39
47
 
40
48
  class Versions(TypedDict):
@@ -42,6 +50,7 @@ class Versions(TypedDict):
42
50
  extract_json: PromptFunction
43
51
  extract_text: PromptFunction
44
52
  reflexion: PromptFunction
53
+ classify_nodes: PromptFunction
45
54
 
46
55
 
47
56
  def extract_message(context: dict[str, Any]) -> list[Message]:
@@ -66,6 +75,7 @@ Guidelines:
66
75
  4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
67
76
  5. Be as explicit as possible in your node names, using full names.
68
77
  6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
78
+ 7. Extract preferences as their own nodes
69
79
  """
70
80
  return [
71
81
  Message(role='system', content=sys_prompt),
@@ -109,7 +119,7 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
109
119
 
110
120
  {context['custom_prompt']}
111
121
 
112
- Given the following text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
122
+ Given the above text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
113
123
 
114
124
  Guidelines:
115
125
  1. Extract significant entities, concepts, or actors mentioned in the conversation.
@@ -147,9 +157,41 @@ extracted.
147
157
  ]
148
158
 
149
159
 
160
+ def classify_nodes(context: dict[str, Any]) -> list[Message]:
161
+ sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted"""
162
+
163
+ user_prompt = f"""
164
+ <PREVIOUS MESSAGES>
165
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
166
+ </PREVIOUS MESSAGES>
167
+ <CURRENT MESSAGE>
168
+ {context["episode_content"]}
169
+ </CURRENT MESSAGE>
170
+
171
+ <EXTRACTED ENTITIES>
172
+ {context['extracted_entities']}
173
+ </EXTRACTED ENTITIES>
174
+
175
+ <ENTITY TYPES>
176
+ {context['entity_types']}
177
+ </ENTITY TYPES>
178
+
179
+ Given the above conversation, extracted entities, and provided entity types, classify the extracted entities.
180
+
181
+ Guidelines:
182
+ 1. Each entity must have exactly one type
183
+ 2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
184
+ """
185
+ return [
186
+ Message(role='system', content=sys_prompt),
187
+ Message(role='user', content=user_prompt),
188
+ ]
189
+
190
+
150
191
  versions: Versions = {
151
192
  'extract_message': extract_message,
152
193
  'extract_json': extract_json,
153
194
  'extract_text': extract_text,
154
195
  'reflexion': reflexion,
196
+ 'classify_nodes': classify_nodes,
155
197
  }
@@ -24,7 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
24
24
 
25
25
  class Summary(BaseModel):
26
26
  summary: str = Field(
27
- ..., description='Summary containing the important information from both summaries'
27
+ ...,
28
+ description='Summary containing the important information about the entity. Under 500 words',
28
29
  )
29
30
 
30
31
 
@@ -68,7 +69,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
68
69
  return [
69
70
  Message(
70
71
  role='system',
71
- content='You are a helpful assistant that combines summaries with new conversation context.',
72
+ content='You are a helpful assistant that extracts entity properties from the provided text.',
72
73
  ),
73
74
  Message(
74
75
  role='user',
@@ -79,18 +80,23 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
79
80
  {json.dumps(context['episode_content'], indent=2)}
80
81
  </MESSAGES>
81
82
 
82
- Given the above MESSAGES and the following ENTITY name and ENTITY CONTEXT, create a summary for the ENTITY. Your summary must only use
83
- information from the provided MESSAGES and from the ENTITY CONTEXT. Your summary should also only contain information relevant to the
84
- provided ENTITY.
83
+ Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
84
+ information from the provided MESSAGES. Your summary should also only contain information relevant to the
85
+ provided ENTITY. Summaries must be under 500 words.
85
86
 
86
- Summaries must be under 500 words.
87
+ In addition, extract any values for the provided entity properties based on their descriptions.
87
88
 
88
89
  <ENTITY>
89
90
  {context['node_name']}
90
91
  </ENTITY>
92
+
91
93
  <ENTITY CONTEXT>
92
94
  {context['node_summary']}
93
95
  </ENTITY CONTEXT>
96
+
97
+ <ATTRIBUTES>
98
+ {json.dumps(context['attributes'], indent=2)}
99
+ </ATTRIBUTES>
94
100
  """,
95
101
  ),
96
102
  ]
@@ -100,6 +100,7 @@ async def search(
100
100
  query_vector,
101
101
  group_ids,
102
102
  config.node_config,
103
+ search_filter,
103
104
  center_node_uuid,
104
105
  bfs_origin_node_uuids,
105
106
  config.limit,
@@ -233,6 +234,7 @@ async def node_search(
233
234
  query_vector: list[float],
234
235
  group_ids: list[str] | None,
235
236
  config: NodeSearchConfig | None,
237
+ search_filter: SearchFilters,
236
238
  center_node_uuid: str | None = None,
237
239
  bfs_origin_node_uuids: list[str] | None = None,
238
240
  limit=DEFAULT_SEARCH_LIMIT,
@@ -243,11 +245,13 @@ async def node_search(
243
245
  search_results: list[list[EntityNode]] = list(
244
246
  await semaphore_gather(
245
247
  *[
246
- node_fulltext_search(driver, query, group_ids, 2 * limit),
248
+ node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
247
249
  node_similarity_search(
248
- driver, query_vector, group_ids, 2 * limit, config.sim_min_score
250
+ driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
251
+ ),
252
+ node_bfs_search(
253
+ driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
249
254
  ),
250
- node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
251
255
  ]
252
256
  )
253
257
  )
@@ -255,7 +259,9 @@ async def node_search(
255
259
  if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
256
260
  origin_node_uuids = [node.uuid for result in search_results for node in result]
257
261
  search_results.append(
258
- await node_bfs_search(driver, origin_node_uuids, config.bfs_max_depth, 2 * limit)
262
+ await node_bfs_search(
263
+ driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
264
+ )
259
265
  )
260
266
 
261
267
  search_result_uuids = [[node.uuid for node in result] for result in search_results]
@@ -39,18 +39,37 @@ class DateFilter(BaseModel):
39
39
 
40
40
 
41
41
  class SearchFilters(BaseModel):
42
+ node_labels: list[str] | None = Field(
43
+ default=None, description='List of node labels to filter on'
44
+ )
42
45
  valid_at: list[list[DateFilter]] | None = Field(default=None)
43
46
  invalid_at: list[list[DateFilter]] | None = Field(default=None)
44
47
  created_at: list[list[DateFilter]] | None = Field(default=None)
45
48
  expired_at: list[list[DateFilter]] | None = Field(default=None)
46
49
 
47
50
 
48
- def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
51
+ def node_search_filter_query_constructor(
52
+ filters: SearchFilters,
53
+ ) -> tuple[LiteralString, dict[str, Any]]:
54
+ filter_query: LiteralString = ''
55
+ filter_params: dict[str, Any] = {}
56
+
57
+ if filters.node_labels is not None:
58
+ node_labels = ':'.join(filters.node_labels)
59
+ node_label_filter = ' AND n:' + node_labels
60
+ filter_query += node_label_filter
61
+
62
+ return filter_query, filter_params
63
+
64
+
65
+ def edge_search_filter_query_constructor(
66
+ filters: SearchFilters,
67
+ ) -> tuple[LiteralString, dict[str, Any]]:
49
68
  filter_query: LiteralString = ''
50
69
  filter_params: dict[str, Any] = {}
51
70
 
52
71
  if filters.valid_at is not None:
53
- valid_at_filter = 'AND ('
72
+ valid_at_filter = ' AND ('
54
73
  for i, or_list in enumerate(filters.valid_at):
55
74
  for j, date_filter in enumerate(or_list):
56
75
  filter_params['valid_at_' + str(j)] = date_filter.date
@@ -75,7 +94,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
75
94
  filter_query += valid_at_filter
76
95
 
77
96
  if filters.invalid_at is not None:
78
- invalid_at_filter = 'AND ('
97
+ invalid_at_filter = ' AND ('
79
98
  for i, or_list in enumerate(filters.invalid_at):
80
99
  for j, date_filter in enumerate(or_list):
81
100
  filter_params['invalid_at_' + str(j)] = date_filter.date
@@ -100,7 +119,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
100
119
  filter_query += invalid_at_filter
101
120
 
102
121
  if filters.created_at is not None:
103
- created_at_filter = 'AND ('
122
+ created_at_filter = ' AND ('
104
123
  for i, or_list in enumerate(filters.created_at):
105
124
  for j, date_filter in enumerate(or_list):
106
125
  filter_params['created_at_' + str(j)] = date_filter.date
@@ -38,7 +38,11 @@ from graphiti_core.nodes import (
38
38
  get_community_node_from_record,
39
39
  get_entity_node_from_record,
40
40
  )
41
- from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor
41
+ from graphiti_core.search.search_filters import (
42
+ SearchFilters,
43
+ edge_search_filter_query_constructor,
44
+ node_search_filter_query_constructor,
45
+ )
42
46
 
43
47
  logger = logging.getLogger(__name__)
44
48
 
@@ -97,7 +101,9 @@ async def get_mentioned_nodes(
97
101
  n.name AS name,
98
102
  n.name_embedding AS name_embedding,
99
103
  n.created_at AS created_at,
100
- n.summary AS summary
104
+ n.summary AS summary,
105
+ labels(n) AS labels,
106
+ properties(n) AS attributes
101
107
  """,
102
108
  uuids=episode_uuids,
103
109
  database_=DEFAULT_DATABASE,
@@ -146,7 +152,7 @@ async def edge_fulltext_search(
146
152
  if fuzzy_query == '':
147
153
  return []
148
154
 
149
- filter_query, filter_params = search_filter_query_constructor(search_filter)
155
+ filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
150
156
 
151
157
  cypher_query = Query(
152
158
  """
@@ -205,7 +211,7 @@ async def edge_similarity_search(
205
211
 
206
212
  query_params: dict[str, Any] = {}
207
213
 
208
- filter_query, filter_params = search_filter_query_constructor(search_filter)
214
+ filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
209
215
  query_params.update(filter_params)
210
216
 
211
217
  group_filter_query: LiteralString = ''
@@ -223,8 +229,8 @@ async def edge_similarity_search(
223
229
 
224
230
  query: LiteralString = (
225
231
  """
226
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
227
- """
232
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
233
+ """
228
234
  + group_filter_query
229
235
  + filter_query
230
236
  + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@@ -276,7 +282,7 @@ async def edge_bfs_search(
276
282
  if bfs_origin_node_uuids is None:
277
283
  return []
278
284
 
279
- filter_query, filter_params = search_filter_query_constructor(search_filter)
285
+ filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
280
286
 
281
287
  query = Query(
282
288
  """
@@ -323,6 +329,7 @@ async def edge_bfs_search(
323
329
  async def node_fulltext_search(
324
330
  driver: AsyncDriver,
325
331
  query: str,
332
+ search_filter: SearchFilters,
326
333
  group_ids: list[str] | None = None,
327
334
  limit=RELEVANT_SCHEMA_LIMIT,
328
335
  ) -> list[EntityNode]:
@@ -331,20 +338,30 @@ async def node_fulltext_search(
331
338
  if fuzzy_query == '':
332
339
  return []
333
340
 
341
+ filter_query, filter_params = node_search_filter_query_constructor(search_filter)
342
+
334
343
  records, _, _ = await driver.execute_query(
335
344
  """
336
345
  CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
337
- YIELD node AS n, score
346
+ YIELD node AS node, score
347
+ MATCH (n:Entity)
348
+ WHERE n.uuid = node.uuid
349
+ """
350
+ + filter_query
351
+ + """
338
352
  RETURN
339
353
  n.uuid AS uuid,
340
354
  n.group_id AS group_id,
341
355
  n.name AS name,
342
356
  n.name_embedding AS name_embedding,
343
357
  n.created_at AS created_at,
344
- n.summary AS summary
358
+ n.summary AS summary,
359
+ labels(n) AS labels,
360
+ properties(n) AS attributes
345
361
  ORDER BY score DESC
346
362
  LIMIT $limit
347
363
  """,
364
+ filter_params,
348
365
  query=fuzzy_query,
349
366
  group_ids=group_ids,
350
367
  limit=limit,
@@ -359,6 +376,7 @@ async def node_fulltext_search(
359
376
  async def node_similarity_search(
360
377
  driver: AsyncDriver,
361
378
  search_vector: list[float],
379
+ search_filter: SearchFilters,
362
380
  group_ids: list[str] | None = None,
363
381
  limit=RELEVANT_SCHEMA_LIMIT,
364
382
  min_score: float = DEFAULT_MIN_SCORE,
@@ -375,12 +393,16 @@ async def node_similarity_search(
375
393
  group_filter_query += 'WHERE n.group_id IN $group_ids'
376
394
  query_params['group_ids'] = group_ids
377
395
 
396
+ filter_query, filter_params = node_search_filter_query_constructor(search_filter)
397
+ query_params.update(filter_params)
398
+
378
399
  records, _, _ = await driver.execute_query(
379
400
  runtime_query
380
401
  + """
381
402
  MATCH (n:Entity)
382
403
  """
383
404
  + group_filter_query
405
+ + filter_query
384
406
  + """
385
407
  WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
386
408
  WHERE score > $min_score
@@ -390,7 +412,9 @@ async def node_similarity_search(
390
412
  n.name AS name,
391
413
  n.name_embedding AS name_embedding,
392
414
  n.created_at AS created_at,
393
- n.summary AS summary
415
+ n.summary AS summary,
416
+ labels(n) AS labels,
417
+ properties(n) AS attributes
394
418
  ORDER BY score DESC
395
419
  LIMIT $limit
396
420
  """,
@@ -410,6 +434,7 @@ async def node_similarity_search(
410
434
  async def node_bfs_search(
411
435
  driver: AsyncDriver,
412
436
  bfs_origin_node_uuids: list[str] | None,
437
+ search_filter: SearchFilters,
413
438
  bfs_max_depth: int,
414
439
  limit: int,
415
440
  ) -> list[EntityNode]:
@@ -417,19 +442,28 @@ async def node_bfs_search(
417
442
  if bfs_origin_node_uuids is None:
418
443
  return []
419
444
 
445
+ filter_query, filter_params = node_search_filter_query_constructor(search_filter)
446
+
420
447
  records, _, _ = await driver.execute_query(
421
448
  """
422
449
  UNWIND $bfs_origin_node_uuids AS origin_uuid
423
450
  MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
424
- RETURN DISTINCT
425
- n.uuid As uuid,
426
- n.group_id AS group_id,
427
- n.name AS name,
428
- n.name_embedding AS name_embedding,
429
- n.created_at AS created_at,
430
- n.summary AS summary
431
- LIMIT $limit
432
- """,
451
+ WHERE n.group_id = origin.group_id
452
+ """
453
+ + filter_query
454
+ + """
455
+ RETURN DISTINCT
456
+ n.uuid As uuid,
457
+ n.group_id AS group_id,
458
+ n.name AS name,
459
+ n.name_embedding AS name_embedding,
460
+ n.created_at AS created_at,
461
+ n.summary AS summary,
462
+ labels(n) AS labels,
463
+ properties(n) AS attributes
464
+ LIMIT $limit
465
+ """,
466
+ filter_params,
433
467
  bfs_origin_node_uuids=bfs_origin_node_uuids,
434
468
  depth=bfs_max_depth,
435
469
  limit=limit,
@@ -531,6 +565,7 @@ async def hybrid_node_search(
531
565
  queries: list[str],
532
566
  embeddings: list[list[float]],
533
567
  driver: AsyncDriver,
568
+ search_filter: SearchFilters,
534
569
  group_ids: list[str] | None = None,
535
570
  limit: int = RELEVANT_SCHEMA_LIMIT,
536
571
  ) -> list[EntityNode]:
@@ -575,8 +610,14 @@ async def hybrid_node_search(
575
610
  start = time()
576
611
  results: list[list[EntityNode]] = list(
577
612
  await semaphore_gather(
578
- *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
579
- *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
613
+ *[
614
+ node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
615
+ for q in queries
616
+ ],
617
+ *[
618
+ node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
619
+ for e in embeddings
620
+ ],
580
621
  )
581
622
  )
582
623
 
@@ -596,6 +637,7 @@ async def hybrid_node_search(
596
637
 
597
638
  async def get_relevant_nodes(
598
639
  driver: AsyncDriver,
640
+ search_filter: SearchFilters,
599
641
  nodes: list[EntityNode],
600
642
  ) -> list[EntityNode]:
601
643
  """
@@ -627,6 +669,7 @@ async def get_relevant_nodes(
627
669
  [node.name for node in nodes],
628
670
  [node.name_embedding for node in nodes if node.name_embedding is not None],
629
671
  driver,
672
+ search_filter,
630
673
  [node.group_id for node in nodes],
631
674
  )
632
675
 
@@ -23,6 +23,7 @@ from math import ceil
23
23
  from neo4j import AsyncDriver, AsyncManagedTransaction
24
24
  from numpy import dot, sqrt
25
25
  from pydantic import BaseModel
26
+ from typing_extensions import Any
26
27
 
27
28
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
28
29
  from graphiti_core.helpers import semaphore_gather
@@ -36,6 +37,7 @@ from graphiti_core.models.nodes.node_db_queries import (
36
37
  EPISODIC_NODE_SAVE_BULK,
37
38
  )
38
39
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
40
+ from graphiti_core.search.search_filters import SearchFilters
39
41
  from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
40
42
  from graphiti_core.utils.datetime_utils import utc_now
41
43
  from graphiti_core.utils.maintenance.edge_operations import (
@@ -109,8 +111,23 @@ async def add_nodes_and_edges_bulk_tx(
109
111
  episodes = [dict(episode) for episode in episodic_nodes]
110
112
  for episode in episodes:
111
113
  episode['source'] = str(episode['source'].value)
114
+ nodes: list[dict[str, Any]] = []
115
+ for node in entity_nodes:
116
+ entity_data: dict[str, Any] = {
117
+ 'uuid': node.uuid,
118
+ 'name': node.name,
119
+ 'name_embedding': node.name_embedding,
120
+ 'group_id': node.group_id,
121
+ 'summary': node.summary,
122
+ 'created_at': node.created_at,
123
+ }
124
+
125
+ entity_data.update(node.attributes or {})
126
+ entity_data['labels'] = list(set(node.labels + ['Entity']))
127
+ nodes.append(entity_data)
128
+
112
129
  await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
113
- await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
130
+ await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
114
131
  await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
115
132
  await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
116
133
 
@@ -172,7 +189,7 @@ async def dedupe_nodes_bulk(
172
189
 
173
190
  existing_nodes_chunks: list[list[EntityNode]] = list(
174
191
  await semaphore_gather(
175
- *[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
192
+ *[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
176
193
  )
177
194
  )
178
195
 
@@ -14,15 +14,19 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import ast
17
18
  import logging
18
19
  from time import time
19
20
 
21
+ import pydantic
22
+ from pydantic import BaseModel
23
+
20
24
  from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
21
25
  from graphiti_core.llm_client import LLMClient
22
26
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
23
27
  from graphiti_core.prompts import prompt_library
24
28
  from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
25
- from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
29
+ from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
26
30
  from graphiti_core.prompts.summarize_nodes import Summary
27
31
  from graphiti_core.utils.datetime_utils import utc_now
28
32
 
@@ -114,6 +118,7 @@ async def extract_nodes(
114
118
  llm_client: LLMClient,
115
119
  episode: EpisodicNode,
116
120
  previous_episodes: list[EpisodicNode],
121
+ entity_types: dict[str, BaseModel] | None = None,
117
122
  ) -> list[EntityNode]:
118
123
  start = time()
119
124
  extracted_node_names: list[str] = []
@@ -144,15 +149,35 @@ async def extract_nodes(
144
149
  for entity in missing_entities:
145
150
  custom_prompt += f'\n{entity},'
146
151
 
152
+ node_classification_context = {
153
+ 'episode_content': episode.content,
154
+ 'previous_episodes': [ep.content for ep in previous_episodes],
155
+ 'extracted_entities': extracted_node_names,
156
+ 'entity_types': entity_types.keys() if entity_types is not None else [],
157
+ }
158
+
159
+ node_classifications: dict[str, str | None] = {}
160
+
161
+ if entity_types is not None:
162
+ llm_response = await llm_client.generate_response(
163
+ prompt_library.extract_nodes.classify_nodes(node_classification_context),
164
+ response_model=EntityClassification,
165
+ )
166
+ response_string = llm_response.get('entity_classification', '{}')
167
+ node_classifications.update(ast.literal_eval(response_string))
168
+
147
169
  end = time()
148
170
  logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
149
171
  # Convert the extracted data into EntityNode objects
150
172
  new_nodes = []
151
173
  for name in extracted_node_names:
174
+ entity_type = node_classifications.get(name)
175
+ labels = ['Entity'] if entity_type is None else ['Entity', entity_type]
176
+
152
177
  new_node = EntityNode(
153
178
  name=name,
154
179
  group_id=episode.group_id,
155
- labels=['Entity'],
180
+ labels=labels,
156
181
  summary='',
157
182
  created_at=utc_now(),
158
183
  )
@@ -218,6 +243,7 @@ async def resolve_extracted_nodes(
218
243
  existing_nodes_lists: list[list[EntityNode]],
219
244
  episode: EpisodicNode | None = None,
220
245
  previous_episodes: list[EpisodicNode] | None = None,
246
+ entity_types: dict[str, BaseModel] | None = None,
221
247
  ) -> tuple[list[EntityNode], dict[str, str]]:
222
248
  uuid_map: dict[str, str] = {}
223
249
  resolved_nodes: list[EntityNode] = []
@@ -225,7 +251,12 @@ async def resolve_extracted_nodes(
225
251
  await semaphore_gather(
226
252
  *[
227
253
  resolve_extracted_node(
228
- llm_client, extracted_node, existing_nodes, episode, previous_episodes
254
+ llm_client,
255
+ extracted_node,
256
+ existing_nodes,
257
+ episode,
258
+ previous_episodes,
259
+ entity_types,
229
260
  )
230
261
  for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
231
262
  ]
@@ -245,6 +276,7 @@ async def resolve_extracted_node(
245
276
  existing_nodes: list[EntityNode],
246
277
  episode: EpisodicNode | None = None,
247
278
  previous_episodes: list[EpisodicNode] | None = None,
279
+ entity_types: dict[str, BaseModel] | None = None,
248
280
  ) -> tuple[EntityNode, dict[str, str]]:
249
281
  start = time()
250
282
 
@@ -273,19 +305,39 @@ async def resolve_extracted_node(
273
305
  'previous_episodes': [ep.content for ep in previous_episodes]
274
306
  if previous_episodes is not None
275
307
  else [],
308
+ 'attributes': [],
276
309
  }
277
310
 
278
- llm_response, node_summary_response = await semaphore_gather(
311
+ entity_type_classes: tuple[BaseModel, ...] = tuple()
312
+ if entity_types is not None: # type: ignore
313
+ entity_type_classes = entity_type_classes + tuple(
314
+ filter(
315
+ lambda x: x is not None, # type: ignore
316
+ [entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore
317
+ )
318
+ )
319
+
320
+ for entity_type in entity_type_classes:
321
+ for field_name in entity_type.model_fields:
322
+ summary_context.get('attributes', []).append(field_name) # type: ignore
323
+
324
+ entity_attributes_model = pydantic.create_model( # type: ignore
325
+ 'EntityAttributes',
326
+ __base__=entity_type_classes + (Summary,), # type: ignore
327
+ )
328
+
329
+ llm_response, node_attributes_response = await semaphore_gather(
279
330
  llm_client.generate_response(
280
331
  prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
281
332
  ),
282
333
  llm_client.generate_response(
283
334
  prompt_library.summarize_nodes.summarize_context(summary_context),
284
- response_model=Summary,
335
+ response_model=entity_attributes_model,
285
336
  ),
286
337
  )
287
338
 
288
- extracted_node.summary = node_summary_response.get('summary', '')
339
+ extracted_node.summary = node_attributes_response.get('summary', '')
340
+ extracted_node.attributes.update(node_attributes_response)
289
341
 
290
342
  is_duplicate: bool = llm_response.get('is_duplicate', False)
291
343
  uuid: str | None = llm_response.get('uuid', None)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.6.1
3
+ Version: 0.7.1
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -9,7 +9,7 @@ graphiti_core/embedder/client.py,sha256=HKIlpPLnzFT81jurPkry6z8F8nxfZVfejdcfxHVU
9
9
  graphiti_core/embedder/openai.py,sha256=FzEM9rtSDK1wTb4iYKjNjjdFf8BEBTDxG2vM_E-5W-8,1621
10
10
  graphiti_core/embedder/voyage.py,sha256=7kqrLG75J3Q6cdA2Nlx1JSYtpk2141ckdl3OtDDw0vU,1882
11
11
  graphiti_core/errors.py,sha256=ddHrHGQxhwkVAtSph4AV84UoOlgwZufMczXPwB7uqPo,1795
12
- graphiti_core/graphiti.py,sha256=IaQ2xUM3Z1BG7ByJpznRAdg3FWtcOuIOq9YkY_JfiLE,28974
12
+ graphiti_core/graphiti.py,sha256=BfsR_JF89_bX0D9PJ2Q2IHQrnph9hd4I7-ayGvvZxpU,29231
13
13
  graphiti_core/helpers.py,sha256=z7ApOgrm_J7hk5FN_XPAwkKyopEY943BgHjDJbSXr2s,2869
14
14
  graphiti_core/llm_client/__init__.py,sha256=PA80TSMeX-sUXITXEAxMDEt3gtfZgcJrGJUcyds1mSo,207
15
15
  graphiti_core/llm_client/anthropic_client.py,sha256=RlD6e49XvMJsTKU0krpq46gPSFm6-hfLkkq4Sfx27BE,2574
@@ -24,38 +24,38 @@ graphiti_core/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
24
24
  graphiti_core/models/edges/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  graphiti_core/models/edges/edge_db_queries.py,sha256=2UoLkmazO-FJYqjc3g0LuL-pyjekzQxxed_XHVv_HZE,2671
26
26
  graphiti_core/models/nodes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
- graphiti_core/models/nodes/node_db_queries.py,sha256=I0top_N23FN0U5ZbypaS5IXvtfx2zgJmKUCT_7mpUdo,2257
28
- graphiti_core/nodes.py,sha256=_ExaTj2HU-xDczbls4aFcLdpc8zwPZUZ8JgVOrBiEdw,16098
27
+ graphiti_core/models/nodes/node_db_queries.py,sha256=f4_UT6XL8UDt4_CO9YIHeI8pvpw_vrutA9SYrgi6QCU,2121
28
+ graphiti_core/nodes.py,sha256=dKllAYBvNy6uCDxvacvNoVHiEm-wJm_cIK3KKTahVkM,16709
29
29
  graphiti_core/prompts/__init__.py,sha256=EA-x9xUki9l8wnu2l8ek_oNf75-do5tq5hVq7Zbv8Kw,101
30
30
  graphiti_core/prompts/dedupe_edges.py,sha256=EuX8ngeItBzrlMBOgeHrpExzxIFHD2aoDyaX1ZniF6I,3556
31
31
  graphiti_core/prompts/dedupe_nodes.py,sha256=mqvNATL-4Vo33vaxUEZfOq6hXXOiL-ftY0zcx2G-82I,4624
32
32
  graphiti_core/prompts/eval.py,sha256=csW494kKBMvWSm2SYLIRuGgNghhwNR3YwGn3veo3g_Y,3691
33
33
  graphiti_core/prompts/extract_edge_dates.py,sha256=td2yx2wnX-nLioMa0mtla3WcRyO71_wSjemT79YZGQ0,4096
34
34
  graphiti_core/prompts/extract_edges.py,sha256=vyEdW7JAPOT_eLWUi6nRmxbvucyVoyoYX2SxXfknRUg,3467
35
- graphiti_core/prompts/extract_nodes.py,sha256=JXLHeL1VcFo0auGf2roVnoWu1CyZJDWxBCu6BXE9fUQ,5289
35
+ graphiti_core/prompts/extract_nodes.py,sha256=-01MpcVd9drtmMDIpQkkzZe8YwVhedmdbZq7UNGfo24,6651
36
36
  graphiti_core/prompts/invalidate_edges.py,sha256=DV2mEyIhhjc0hdKEMFLQMeG0FiUCkv_X0ctCliYjQ2c,3577
37
37
  graphiti_core/prompts/lib.py,sha256=oxhlpGEgV15VOLEZiwirxmIJBIdfzfiyL58iyzFDskE,4254
38
38
  graphiti_core/prompts/models.py,sha256=cvx_Bv5RMFUD_5IUawYrbpOKLPHogai7_bm7YXrSz84,867
39
39
  graphiti_core/prompts/prompt_helpers.py,sha256=-9TABwIcIQUVHcNANx6wIZd-FT2DgYKyGTfx4IGYq2I,64
40
- graphiti_core/prompts/summarize_nodes.py,sha256=XOJykwT7LtzWk2bRquFgv4tRAU3JOkkNkWBg-mkYOKc,3593
40
+ graphiti_core/prompts/summarize_nodes.py,sha256=ONDZdkvC7-RPaKx2geWSVjNaJAsHxRisV8tiU2ukw4k,3781
41
41
  graphiti_core/py.typed,sha256=vlmmzQOt7bmeQl9L3XJP4W6Ry0iiELepnOrinKz5KQg,79
42
42
  graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
- graphiti_core/search/search.py,sha256=4DaeP5aRT7ZOByDO3H5UK0edxfwQ4mzAOdFjnjwaDJs,12454
43
+ graphiti_core/search/search.py,sha256=DX-tcIa0SiKI2HX-b_WdjGE74A8RLWQor4p90dJluUA,12643
44
44
  graphiti_core/search/search_config.py,sha256=UZN8jFA4pBlw2O5N1cuhVRBdTwMLR9N3Oyo6sQ4MDVw,3117
45
45
  graphiti_core/search/search_config_recipes.py,sha256=yUqiLnn9vFg39M8eVwjVKfBCL_ptGrfDMQ47m_Blb0g,6885
46
- graphiti_core/search/search_filters.py,sha256=_E_Od3hUoZm6H2UVCcxhfS34AqGF2lNx0NJPCw0gAQs,5333
47
- graphiti_core/search/search_utils.py,sha256=GwF7tsvjKgVXtv6q4lXA1tZn1_0izy6rHNwL8d0cYU4,24348
46
+ graphiti_core/search/search_filters.py,sha256=4MJmCXD-blMc71xB4F9K4a72qidDCigADQ_ztdG15kc,5884
47
+ graphiti_core/search/search_utils.py,sha256=i9qTBOZOiwnuiUNlIw9OoTYIrooBrM2unPwylGVNVq8,25657
48
48
  graphiti_core/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
- graphiti_core/utils/bulk_utils.py,sha256=FYal4tSspGVohNsnDoyW_YjMiscySuYPuQLPSwVCy24,14110
49
+ graphiti_core/utils/bulk_utils.py,sha256=P4LKO46Yle4tBdNcQ3hDHcSQFaR8UBLfoL-z1M2Wua0,14690
50
50
  graphiti_core/utils/datetime_utils.py,sha256=Ti-2tnrDFRzBsbfblzsHybsM3jaDLP4-VT2t0VhpIzU,1357
51
51
  graphiti_core/utils/maintenance/__init__.py,sha256=TRY3wWWu5kn3Oahk_KKhltrWnh0NACw0FskjqF6OtlA,314
52
52
  graphiti_core/utils/maintenance/community_operations.py,sha256=gIw1M5HGgc2c3TXag5ygPPpAv5WsG-yoC8Lhmfr6FMs,10011
53
53
  graphiti_core/utils/maintenance/edge_operations.py,sha256=tNw56vN586JYZMgie6RLRTiHZ680-kWzDIxW8ucL6SU,12780
54
54
  graphiti_core/utils/maintenance/graph_data_operations.py,sha256=qds9ALk9PhpQs1CNZTZGpi70mqJ93Y2KhIh9X2r8MUI,6533
55
- graphiti_core/utils/maintenance/node_operations.py,sha256=lrlp27clVhWrxy2BxofTjIISZpwqNG12evHO5wNwOY8,12084
55
+ graphiti_core/utils/maintenance/node_operations.py,sha256=gihbPEBH6StLQCSd9wSu582d4Owaw3l5JLR1IBDrnVs,14137
56
56
  graphiti_core/utils/maintenance/temporal_operations.py,sha256=RdNtubCyYhOVrvcOIq2WppHls1Q-BEjtsN8r38l-Rtc,3691
57
57
  graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
- graphiti_core-0.6.1.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
59
- graphiti_core-0.6.1.dist-info/METADATA,sha256=T7rqCclsf8c92WTRWiYXFzWpQR36gy3whh_w-uXWjvA,10242
60
- graphiti_core-0.6.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
61
- graphiti_core-0.6.1.dist-info/RECORD,,
58
+ graphiti_core-0.7.1.dist-info/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
59
+ graphiti_core-0.7.1.dist-info/METADATA,sha256=7jGgBXFuCT17KdyQVeSWAN1R1KQrBSd5Up92tqR30-c,10242
60
+ graphiti_core-0.7.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
61
+ graphiti_core-0.7.1.dist-info/RECORD,,