graphiti-core 0.17.11__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +20 -2
- graphiti_core/driver/falkordb_driver.py +16 -9
- graphiti_core/driver/neo4j_driver.py +8 -6
- graphiti_core/edges.py +73 -99
- graphiti_core/graph_queries.py +51 -97
- graphiti_core/graphiti.py +24 -9
- graphiti_core/helpers.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +106 -32
- graphiti_core/models/nodes/node_db_queries.py +101 -20
- graphiti_core/nodes.py +113 -128
- graphiti_core/prompts/dedupe_nodes.py +1 -1
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +12 -10
- graphiti_core/search/search.py +44 -32
- graphiti_core/search/search_config.py +8 -4
- graphiti_core/search/search_filters.py +5 -5
- graphiti_core/search/search_utils.py +154 -189
- graphiti_core/utils/bulk_utils.py +3 -5
- graphiti_core/utils/maintenance/community_operations.py +11 -7
- graphiti_core/utils/maintenance/edge_operations.py +19 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +14 -29
- graphiti_core/utils/maintenance/node_operations.py +11 -55
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/METADATA +11 -3
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/RECORD +26 -26
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/licenses/LICENSE +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -25,29 +25,22 @@ from uuid import uuid4
|
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
from typing_extensions import LiteralString
|
|
27
27
|
|
|
28
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
30
|
from graphiti_core.errors import NodeNotFoundError
|
|
31
31
|
from graphiti_core.helpers import parse_db_date
|
|
32
32
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
COMMUNITY_NODE_RETURN,
|
|
34
|
+
ENTITY_NODE_RETURN,
|
|
35
|
+
EPISODIC_NODE_RETURN,
|
|
35
36
|
EPISODIC_NODE_SAVE,
|
|
37
|
+
get_community_node_save_query,
|
|
38
|
+
get_entity_node_save_query,
|
|
36
39
|
)
|
|
37
40
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
38
41
|
|
|
39
42
|
logger = logging.getLogger(__name__)
|
|
40
43
|
|
|
41
|
-
ENTITY_NODE_RETURN: LiteralString = """
|
|
42
|
-
RETURN
|
|
43
|
-
n.uuid As uuid,
|
|
44
|
-
n.name AS name,
|
|
45
|
-
n.group_id AS group_id,
|
|
46
|
-
n.created_at AS created_at,
|
|
47
|
-
n.summary AS summary,
|
|
48
|
-
labels(n) AS labels,
|
|
49
|
-
properties(n) AS attributes"""
|
|
50
|
-
|
|
51
44
|
|
|
52
45
|
class EpisodeType(Enum):
|
|
53
46
|
"""
|
|
@@ -96,18 +89,26 @@ class Node(BaseModel, ABC):
|
|
|
96
89
|
async def save(self, driver: GraphDriver): ...
|
|
97
90
|
|
|
98
91
|
async def delete(self, driver: GraphDriver):
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
92
|
+
if driver.provider == GraphProvider.FALKORDB:
|
|
93
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
94
|
+
await driver.execute_query(
|
|
95
|
+
f"""
|
|
96
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
97
|
+
DETACH DELETE n
|
|
98
|
+
""",
|
|
99
|
+
uuid=self.uuid,
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
await driver.execute_query(
|
|
103
|
+
"""
|
|
104
|
+
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
|
105
|
+
DETACH DELETE n
|
|
106
|
+
""",
|
|
107
|
+
uuid=self.uuid,
|
|
108
|
+
)
|
|
106
109
|
|
|
107
110
|
logger.debug(f'Deleted Node: {self.uuid}')
|
|
108
111
|
|
|
109
|
-
return result
|
|
110
|
-
|
|
111
112
|
def __hash__(self):
|
|
112
113
|
return hash(self.uuid)
|
|
113
114
|
|
|
@@ -118,15 +119,23 @@ class Node(BaseModel, ABC):
|
|
|
118
119
|
|
|
119
120
|
@classmethod
|
|
120
121
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
122
|
+
if driver.provider == GraphProvider.FALKORDB:
|
|
123
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
124
|
+
await driver.execute_query(
|
|
125
|
+
f"""
|
|
126
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
127
|
+
DETACH DELETE n
|
|
128
|
+
""",
|
|
129
|
+
group_id=group_id,
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
await driver.execute_query(
|
|
133
|
+
"""
|
|
134
|
+
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
135
|
+
DETACH DELETE n
|
|
136
|
+
""",
|
|
137
|
+
group_id=group_id,
|
|
138
|
+
)
|
|
130
139
|
|
|
131
140
|
@classmethod
|
|
132
141
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
@@ -169,17 +178,10 @@ class EpisodicNode(Node):
|
|
|
169
178
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
170
179
|
records, _, _ = await driver.execute_query(
|
|
171
180
|
"""
|
|
172
|
-
|
|
173
|
-
RETURN
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
e.uuid AS uuid,
|
|
177
|
-
e.name AS name,
|
|
178
|
-
e.group_id AS group_id,
|
|
179
|
-
e.source_description AS source_description,
|
|
180
|
-
e.source AS source,
|
|
181
|
-
e.entity_edges AS entity_edges
|
|
182
|
-
""",
|
|
181
|
+
MATCH (e:Episodic {uuid: $uuid})
|
|
182
|
+
RETURN
|
|
183
|
+
"""
|
|
184
|
+
+ EPISODIC_NODE_RETURN,
|
|
183
185
|
uuid=uuid,
|
|
184
186
|
routing_='r',
|
|
185
187
|
)
|
|
@@ -195,18 +197,11 @@ class EpisodicNode(Node):
|
|
|
195
197
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
196
198
|
records, _, _ = await driver.execute_query(
|
|
197
199
|
"""
|
|
198
|
-
|
|
200
|
+
MATCH (e:Episodic)
|
|
201
|
+
WHERE e.uuid IN $uuids
|
|
199
202
|
RETURN DISTINCT
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
e.valid_at AS valid_at,
|
|
203
|
-
e.uuid AS uuid,
|
|
204
|
-
e.name AS name,
|
|
205
|
-
e.group_id AS group_id,
|
|
206
|
-
e.source_description AS source_description,
|
|
207
|
-
e.source AS source,
|
|
208
|
-
e.entity_edges AS entity_edges
|
|
209
|
-
""",
|
|
203
|
+
"""
|
|
204
|
+
+ EPISODIC_NODE_RETURN,
|
|
210
205
|
uuids=uuids,
|
|
211
206
|
routing_='r',
|
|
212
207
|
)
|
|
@@ -228,22 +223,17 @@ class EpisodicNode(Node):
|
|
|
228
223
|
|
|
229
224
|
records, _, _ = await driver.execute_query(
|
|
230
225
|
"""
|
|
231
|
-
|
|
232
|
-
|
|
226
|
+
MATCH (e:Episodic)
|
|
227
|
+
WHERE e.group_id IN $group_ids
|
|
228
|
+
"""
|
|
233
229
|
+ cursor_query
|
|
234
230
|
+ """
|
|
235
231
|
RETURN DISTINCT
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
e.group_id AS group_id,
|
|
242
|
-
e.source_description AS source_description,
|
|
243
|
-
e.source AS source,
|
|
244
|
-
e.entity_edges AS entity_edges
|
|
245
|
-
ORDER BY e.uuid DESC
|
|
246
|
-
"""
|
|
232
|
+
"""
|
|
233
|
+
+ EPISODIC_NODE_RETURN
|
|
234
|
+
+ """
|
|
235
|
+
ORDER BY uuid DESC
|
|
236
|
+
"""
|
|
247
237
|
+ limit_query,
|
|
248
238
|
group_ids=group_ids,
|
|
249
239
|
uuid=uuid_cursor,
|
|
@@ -259,18 +249,10 @@ class EpisodicNode(Node):
|
|
|
259
249
|
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
|
260
250
|
records, _, _ = await driver.execute_query(
|
|
261
251
|
"""
|
|
262
|
-
|
|
252
|
+
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
263
253
|
RETURN DISTINCT
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
e.valid_at AS valid_at,
|
|
267
|
-
e.uuid AS uuid,
|
|
268
|
-
e.name AS name,
|
|
269
|
-
e.group_id AS group_id,
|
|
270
|
-
e.source_description AS source_description,
|
|
271
|
-
e.source AS source,
|
|
272
|
-
e.entity_edges AS entity_edges
|
|
273
|
-
""",
|
|
254
|
+
"""
|
|
255
|
+
+ EPISODIC_NODE_RETURN,
|
|
274
256
|
entity_node_uuid=entity_node_uuid,
|
|
275
257
|
routing_='r',
|
|
276
258
|
)
|
|
@@ -297,11 +279,14 @@ class EntityNode(Node):
|
|
|
297
279
|
return self.name_embedding
|
|
298
280
|
|
|
299
281
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
300
|
-
|
|
282
|
+
records, _, _ = await driver.execute_query(
|
|
283
|
+
"""
|
|
301
284
|
MATCH (n:Entity {uuid: $uuid})
|
|
302
285
|
RETURN n.name_embedding AS name_embedding
|
|
303
|
-
|
|
304
|
-
|
|
286
|
+
""",
|
|
287
|
+
uuid=self.uuid,
|
|
288
|
+
routing_='r',
|
|
289
|
+
)
|
|
305
290
|
|
|
306
291
|
if len(records) == 0:
|
|
307
292
|
raise NodeNotFoundError(self.uuid)
|
|
@@ -317,12 +302,12 @@ class EntityNode(Node):
|
|
|
317
302
|
'summary': self.summary,
|
|
318
303
|
'created_at': self.created_at,
|
|
319
304
|
}
|
|
320
|
-
|
|
321
305
|
entity_data.update(self.attributes or {})
|
|
322
306
|
|
|
307
|
+
labels = ':'.join(self.labels + ['Entity'])
|
|
308
|
+
|
|
323
309
|
result = await driver.execute_query(
|
|
324
|
-
|
|
325
|
-
labels=self.labels + ['Entity'],
|
|
310
|
+
get_entity_node_save_query(driver.provider, labels),
|
|
326
311
|
entity_data=entity_data,
|
|
327
312
|
)
|
|
328
313
|
|
|
@@ -332,14 +317,12 @@ class EntityNode(Node):
|
|
|
332
317
|
|
|
333
318
|
@classmethod
|
|
334
319
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
335
|
-
query = (
|
|
336
|
-
"""
|
|
337
|
-
MATCH (n:Entity {uuid: $uuid})
|
|
338
|
-
"""
|
|
339
|
-
+ ENTITY_NODE_RETURN
|
|
340
|
-
)
|
|
341
320
|
records, _, _ = await driver.execute_query(
|
|
342
|
-
|
|
321
|
+
"""
|
|
322
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
323
|
+
RETURN
|
|
324
|
+
"""
|
|
325
|
+
+ ENTITY_NODE_RETURN,
|
|
343
326
|
uuid=uuid,
|
|
344
327
|
routing_='r',
|
|
345
328
|
)
|
|
@@ -355,8 +338,10 @@ class EntityNode(Node):
|
|
|
355
338
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
356
339
|
records, _, _ = await driver.execute_query(
|
|
357
340
|
"""
|
|
358
|
-
|
|
359
|
-
|
|
341
|
+
MATCH (n:Entity)
|
|
342
|
+
WHERE n.uuid IN $uuids
|
|
343
|
+
RETURN
|
|
344
|
+
"""
|
|
360
345
|
+ ENTITY_NODE_RETURN,
|
|
361
346
|
uuids=uuids,
|
|
362
347
|
routing_='r',
|
|
@@ -379,22 +364,26 @@ class EntityNode(Node):
|
|
|
379
364
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
380
365
|
with_embeddings_query: LiteralString = (
|
|
381
366
|
""",
|
|
382
|
-
|
|
383
|
-
|
|
367
|
+
n.name_embedding AS name_embedding
|
|
368
|
+
"""
|
|
384
369
|
if with_embeddings
|
|
385
370
|
else ''
|
|
386
371
|
)
|
|
387
372
|
|
|
388
373
|
records, _, _ = await driver.execute_query(
|
|
389
374
|
"""
|
|
390
|
-
|
|
391
|
-
|
|
375
|
+
MATCH (n:Entity)
|
|
376
|
+
WHERE n.group_id IN $group_ids
|
|
377
|
+
"""
|
|
392
378
|
+ cursor_query
|
|
379
|
+
+ """
|
|
380
|
+
RETURN
|
|
381
|
+
"""
|
|
393
382
|
+ ENTITY_NODE_RETURN
|
|
394
383
|
+ with_embeddings_query
|
|
395
384
|
+ """
|
|
396
|
-
|
|
397
|
-
|
|
385
|
+
ORDER BY n.uuid DESC
|
|
386
|
+
"""
|
|
398
387
|
+ limit_query,
|
|
399
388
|
group_ids=group_ids,
|
|
400
389
|
uuid=uuid_cursor,
|
|
@@ -413,7 +402,7 @@ class CommunityNode(Node):
|
|
|
413
402
|
|
|
414
403
|
async def save(self, driver: GraphDriver):
|
|
415
404
|
result = await driver.execute_query(
|
|
416
|
-
|
|
405
|
+
get_community_node_save_query(driver.provider),
|
|
417
406
|
uuid=self.uuid,
|
|
418
407
|
name=self.name,
|
|
419
408
|
group_id=self.group_id,
|
|
@@ -436,11 +425,14 @@ class CommunityNode(Node):
|
|
|
436
425
|
return self.name_embedding
|
|
437
426
|
|
|
438
427
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
439
|
-
|
|
428
|
+
records, _, _ = await driver.execute_query(
|
|
429
|
+
"""
|
|
440
430
|
MATCH (c:Community {uuid: $uuid})
|
|
441
431
|
RETURN c.name_embedding AS name_embedding
|
|
442
|
-
|
|
443
|
-
|
|
432
|
+
""",
|
|
433
|
+
uuid=self.uuid,
|
|
434
|
+
routing_='r',
|
|
435
|
+
)
|
|
444
436
|
|
|
445
437
|
if len(records) == 0:
|
|
446
438
|
raise NodeNotFoundError(self.uuid)
|
|
@@ -451,14 +443,10 @@ class CommunityNode(Node):
|
|
|
451
443
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
452
444
|
records, _, _ = await driver.execute_query(
|
|
453
445
|
"""
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
n.group_id AS group_id,
|
|
459
|
-
n.created_at AS created_at,
|
|
460
|
-
n.summary AS summary
|
|
461
|
-
""",
|
|
446
|
+
MATCH (n:Community {uuid: $uuid})
|
|
447
|
+
RETURN
|
|
448
|
+
"""
|
|
449
|
+
+ COMMUNITY_NODE_RETURN,
|
|
462
450
|
uuid=uuid,
|
|
463
451
|
routing_='r',
|
|
464
452
|
)
|
|
@@ -474,14 +462,11 @@ class CommunityNode(Node):
|
|
|
474
462
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
475
463
|
records, _, _ = await driver.execute_query(
|
|
476
464
|
"""
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
n.created_at AS created_at,
|
|
483
|
-
n.summary AS summary
|
|
484
|
-
""",
|
|
465
|
+
MATCH (n:Community)
|
|
466
|
+
WHERE n.uuid IN $uuids
|
|
467
|
+
RETURN
|
|
468
|
+
"""
|
|
469
|
+
+ COMMUNITY_NODE_RETURN,
|
|
485
470
|
uuids=uuids,
|
|
486
471
|
routing_='r',
|
|
487
472
|
)
|
|
@@ -503,18 +488,17 @@ class CommunityNode(Node):
|
|
|
503
488
|
|
|
504
489
|
records, _, _ = await driver.execute_query(
|
|
505
490
|
"""
|
|
506
|
-
|
|
507
|
-
|
|
491
|
+
MATCH (n:Community)
|
|
492
|
+
WHERE n.group_id IN $group_ids
|
|
493
|
+
"""
|
|
508
494
|
+ cursor_query
|
|
509
495
|
+ """
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
n.
|
|
515
|
-
|
|
516
|
-
ORDER BY n.uuid DESC
|
|
517
|
-
"""
|
|
496
|
+
RETURN
|
|
497
|
+
"""
|
|
498
|
+
+ COMMUNITY_NODE_RETURN
|
|
499
|
+
+ """
|
|
500
|
+
ORDER BY n.uuid DESC
|
|
501
|
+
"""
|
|
518
502
|
+ limit_query,
|
|
519
503
|
group_ids=group_ids,
|
|
520
504
|
uuid=uuid_cursor,
|
|
@@ -586,6 +570,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
|
586
570
|
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
|
587
571
|
if not nodes: # Handle empty list case
|
|
588
572
|
return
|
|
573
|
+
|
|
589
574
|
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
|
590
575
|
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
|
591
576
|
node.name_embedding = name_embedding
|
|
@@ -68,6 +68,10 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
68
68
|
Message(
|
|
69
69
|
role='user',
|
|
70
70
|
content=f"""
|
|
71
|
+
<FACT TYPES>
|
|
72
|
+
{context['edge_types']}
|
|
73
|
+
</FACT TYPES>
|
|
74
|
+
|
|
71
75
|
<PREVIOUS_MESSAGES>
|
|
72
76
|
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
|
73
77
|
</PREVIOUS_MESSAGES>
|
|
@@ -84,10 +88,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
84
88
|
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
|
85
89
|
</REFERENCE_TIME>
|
|
86
90
|
|
|
87
|
-
<FACT TYPES>
|
|
88
|
-
{context['edge_types']}
|
|
89
|
-
</FACT TYPES>
|
|
90
|
-
|
|
91
91
|
# TASK
|
|
92
92
|
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
|
93
93
|
Only extract facts that:
|
|
@@ -75,6 +75,10 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
|
|
75
75
|
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
|
|
76
76
|
|
|
77
77
|
user_prompt = f"""
|
|
78
|
+
<ENTITY TYPES>
|
|
79
|
+
{context['entity_types']}
|
|
80
|
+
</ENTITY TYPES>
|
|
81
|
+
|
|
78
82
|
<PREVIOUS MESSAGES>
|
|
79
83
|
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
|
80
84
|
</PREVIOUS MESSAGES>
|
|
@@ -83,10 +87,6 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
|
|
83
87
|
{context['episode_content']}
|
|
84
88
|
</CURRENT MESSAGE>
|
|
85
89
|
|
|
86
|
-
<ENTITY TYPES>
|
|
87
|
-
{context['entity_types']}
|
|
88
|
-
</ENTITY TYPES>
|
|
89
|
-
|
|
90
90
|
Instructions:
|
|
91
91
|
|
|
92
92
|
You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
|
|
@@ -124,15 +124,16 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
|
|
|
124
124
|
Your primary task is to extract and classify relevant entities from JSON files"""
|
|
125
125
|
|
|
126
126
|
user_prompt = f"""
|
|
127
|
+
<ENTITY TYPES>
|
|
128
|
+
{context['entity_types']}
|
|
129
|
+
</ENTITY TYPES>
|
|
130
|
+
|
|
127
131
|
<SOURCE DESCRIPTION>:
|
|
128
132
|
{context['source_description']}
|
|
129
133
|
</SOURCE DESCRIPTION>
|
|
130
134
|
<JSON>
|
|
131
135
|
{context['episode_content']}
|
|
132
136
|
</JSON>
|
|
133
|
-
<ENTITY TYPES>
|
|
134
|
-
{context['entity_types']}
|
|
135
|
-
</ENTITY TYPES>
|
|
136
137
|
|
|
137
138
|
{context['custom_prompt']}
|
|
138
139
|
|
|
@@ -155,13 +156,14 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
|
|
|
155
156
|
Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
|
|
156
157
|
|
|
157
158
|
user_prompt = f"""
|
|
158
|
-
<TEXT>
|
|
159
|
-
{context['episode_content']}
|
|
160
|
-
</TEXT>
|
|
161
159
|
<ENTITY TYPES>
|
|
162
160
|
{context['entity_types']}
|
|
163
161
|
</ENTITY TYPES>
|
|
164
162
|
|
|
163
|
+
<TEXT>
|
|
164
|
+
{context['episode_content']}
|
|
165
|
+
</TEXT>
|
|
166
|
+
|
|
165
167
|
Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
|
|
166
168
|
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
|
|
167
169
|
Indicate the classified entity type by providing its entity_type_id.
|