llm-ie 1.3.0__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py
CHANGED
|
@@ -6,8 +6,7 @@ import warnings
|
|
|
6
6
|
import itertools
|
|
7
7
|
import asyncio
|
|
8
8
|
import nest_asyncio
|
|
9
|
-
from
|
|
10
|
-
from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
|
|
9
|
+
from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional
|
|
11
10
|
from llm_ie.utils import extract_json, apply_prompt_template
|
|
12
11
|
from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
13
12
|
from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
|
|
@@ -206,7 +205,6 @@ class StructExtractor(Extractor):
|
|
|
206
205
|
gen_text = self.inference_engine.chat(
|
|
207
206
|
messages=messages,
|
|
208
207
|
verbose=verbose,
|
|
209
|
-
stream=False,
|
|
210
208
|
messages_logger=messages_logger
|
|
211
209
|
)
|
|
212
210
|
|
|
@@ -290,9 +288,8 @@ class StructExtractor(Extractor):
|
|
|
290
288
|
|
|
291
289
|
current_gen_text = ""
|
|
292
290
|
|
|
293
|
-
response_stream = self.inference_engine.
|
|
294
|
-
messages=messages
|
|
295
|
-
stream=True
|
|
291
|
+
response_stream = self.inference_engine.chat_stream(
|
|
292
|
+
messages=messages
|
|
296
293
|
)
|
|
297
294
|
for chunk in response_stream:
|
|
298
295
|
yield chunk
|
|
@@ -306,7 +303,7 @@ class StructExtractor(Extractor):
|
|
|
306
303
|
yield {"type": "info", "data": "All units processed by LLM."}
|
|
307
304
|
return units
|
|
308
305
|
|
|
309
|
-
async def
|
|
306
|
+
async def _extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
310
307
|
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
311
308
|
"""
|
|
312
309
|
This is the asynchronous version of the extract() method.
|
|
@@ -422,6 +419,28 @@ class StructExtractor(Extractor):
|
|
|
422
419
|
for struct in structs:
|
|
423
420
|
aggregated_struct.update(struct)
|
|
424
421
|
return aggregated_struct
|
|
422
|
+
|
|
423
|
+
def _post_process_struct(self, units: List[FrameExtractionUnit]) -> Dict[str, Any]:
|
|
424
|
+
"""
|
|
425
|
+
Helper method to post-process units into a structured dictionary.
|
|
426
|
+
Shared by extract_struct and extract_struct_async.
|
|
427
|
+
"""
|
|
428
|
+
struct_json = []
|
|
429
|
+
for unit in units:
|
|
430
|
+
if unit.status != "success":
|
|
431
|
+
continue
|
|
432
|
+
try:
|
|
433
|
+
unit_struct_json = extract_json(unit.get_generated_text())
|
|
434
|
+
struct_json.extend(unit_struct_json)
|
|
435
|
+
except Exception as e:
|
|
436
|
+
unit.set_status("fail")
|
|
437
|
+
warnings.warn(f"Struct extraction failed for unit ({unit.start}, {unit.end}): {e}", RuntimeWarning)
|
|
438
|
+
|
|
439
|
+
if self.aggregation_func is None:
|
|
440
|
+
struct = self._default_struct_aggregate(struct_json)
|
|
441
|
+
else:
|
|
442
|
+
struct = self.aggregation_func(struct_json)
|
|
443
|
+
return struct
|
|
425
444
|
|
|
426
445
|
|
|
427
446
|
def extract_struct(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
@@ -457,7 +476,7 @@ class StructExtractor(Extractor):
|
|
|
457
476
|
warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
|
|
458
477
|
|
|
459
478
|
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
460
|
-
extraction_results = asyncio.run(self.
|
|
479
|
+
extraction_results = asyncio.run(self._extract_async(text_content=text_content,
|
|
461
480
|
document_key=document_key,
|
|
462
481
|
concurrent_batch_size=concurrent_batch_size,
|
|
463
482
|
return_messages_log=return_messages_log)
|
|
@@ -470,26 +489,29 @@ class StructExtractor(Extractor):
|
|
|
470
489
|
|
|
471
490
|
units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
472
491
|
|
|
473
|
-
|
|
474
|
-
for unit in units:
|
|
475
|
-
if unit.status != "success":
|
|
476
|
-
continue
|
|
477
|
-
try:
|
|
478
|
-
unit_struct_json = extract_json(unit.get_generated_text())
|
|
479
|
-
struct_json.extend(unit_struct_json)
|
|
480
|
-
except Exception as e:
|
|
481
|
-
unit.set_status("fail")
|
|
482
|
-
warnings.warn(f"Struct extraction failed for unit ({unit.start}, {unit.end}): {e}", RuntimeWarning)
|
|
483
|
-
|
|
484
|
-
if self.aggregation_func is None:
|
|
485
|
-
struct = self._default_struct_aggregate(struct_json)
|
|
486
|
-
else:
|
|
487
|
-
struct = self.aggregation_func(struct_json)
|
|
492
|
+
struct = self._post_process_struct(units)
|
|
488
493
|
|
|
489
494
|
if return_messages_log:
|
|
490
495
|
return struct, messages_log
|
|
491
496
|
return struct
|
|
492
497
|
|
|
498
|
+
async def extract_struct_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
499
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False) -> Dict[str, Any]:
|
|
500
|
+
"""
|
|
501
|
+
This is the async version of extract_struct.
|
|
502
|
+
"""
|
|
503
|
+
extraction_results = await self._extract_async(text_content=text_content,
|
|
504
|
+
document_key=document_key,
|
|
505
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
506
|
+
return_messages_log=return_messages_log)
|
|
507
|
+
|
|
508
|
+
units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
509
|
+
struct = self._post_process_struct(units)
|
|
510
|
+
|
|
511
|
+
if return_messages_log:
|
|
512
|
+
return struct, messages_log
|
|
513
|
+
return struct
|
|
514
|
+
|
|
493
515
|
|
|
494
516
|
class BasicStructExtractor(StructExtractor):
|
|
495
517
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
|
|
@@ -725,6 +747,14 @@ class FrameExtractor(Extractor):
|
|
|
725
747
|
"""
|
|
726
748
|
return NotImplemented
|
|
727
749
|
|
|
750
|
+
@abc.abstractmethod
|
|
751
|
+
async def extract_frames_async(self, text_content:Union[str, Dict[str,str]], entity_key:str,
|
|
752
|
+
document_key:str=None, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
|
|
753
|
+
"""
|
|
754
|
+
This is the async version of extract_frames.
|
|
755
|
+
"""
|
|
756
|
+
return NotImplemented
|
|
757
|
+
|
|
728
758
|
|
|
729
759
|
class DirectFrameExtractor(FrameExtractor):
|
|
730
760
|
def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
|
|
@@ -833,7 +863,6 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
833
863
|
gen_text = self.inference_engine.chat(
|
|
834
864
|
messages=messages,
|
|
835
865
|
verbose=verbose,
|
|
836
|
-
stream=False,
|
|
837
866
|
messages_logger=messages_logger
|
|
838
867
|
)
|
|
839
868
|
|
|
@@ -917,9 +946,8 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
917
946
|
|
|
918
947
|
current_gen_text = ""
|
|
919
948
|
|
|
920
|
-
response_stream = self.inference_engine.
|
|
921
|
-
messages=messages
|
|
922
|
-
stream=True
|
|
949
|
+
response_stream = self.inference_engine.chat_stream(
|
|
950
|
+
messages=messages
|
|
923
951
|
)
|
|
924
952
|
for chunk in response_stream:
|
|
925
953
|
yield chunk
|
|
@@ -933,7 +961,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
933
961
|
yield {"type": "info", "data": "All units processed by LLM."}
|
|
934
962
|
return units
|
|
935
963
|
|
|
936
|
-
async def
|
|
964
|
+
async def _extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
937
965
|
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
938
966
|
"""
|
|
939
967
|
This is the asynchronous version of the extract() method.
|
|
@@ -1040,6 +1068,45 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
1040
1068
|
else:
|
|
1041
1069
|
return units
|
|
1042
1070
|
|
|
1071
|
+
def _post_process_units_to_frames(self, units, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
|
|
1072
|
+
ENTITY_KEY = "entity_text"
|
|
1073
|
+
frame_list = []
|
|
1074
|
+
for unit in units:
|
|
1075
|
+
entity_json = []
|
|
1076
|
+
if unit.status != "success":
|
|
1077
|
+
warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
|
|
1078
|
+
continue
|
|
1079
|
+
for entity in extract_json(gen_text=unit.gen_text):
|
|
1080
|
+
if ENTITY_KEY in entity:
|
|
1081
|
+
entity_json.append(entity)
|
|
1082
|
+
else:
|
|
1083
|
+
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
1084
|
+
|
|
1085
|
+
spans = self._find_entity_spans(text=unit.text,
|
|
1086
|
+
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
1087
|
+
case_sensitive=case_sensitive,
|
|
1088
|
+
fuzzy_match=fuzzy_match,
|
|
1089
|
+
fuzzy_buffer_size=fuzzy_buffer_size,
|
|
1090
|
+
fuzzy_score_cutoff=fuzzy_score_cutoff,
|
|
1091
|
+
allow_overlap_entities=allow_overlap_entities)
|
|
1092
|
+
for ent, span in zip(entity_json, spans):
|
|
1093
|
+
if span is not None:
|
|
1094
|
+
start, end = span
|
|
1095
|
+
entity_text = unit.text[start:end]
|
|
1096
|
+
start += unit.start
|
|
1097
|
+
end += unit.start
|
|
1098
|
+
attr = {}
|
|
1099
|
+
if "attr" in ent and ent["attr"] is not None:
|
|
1100
|
+
attr = ent["attr"]
|
|
1101
|
+
|
|
1102
|
+
frame = LLMInformationExtractionFrame(frame_id=f"{len(frame_list)}",
|
|
1103
|
+
start=start,
|
|
1104
|
+
end=end,
|
|
1105
|
+
entity_text=entity_text,
|
|
1106
|
+
attr=attr)
|
|
1107
|
+
frame_list.append(frame)
|
|
1108
|
+
return frame_list
|
|
1109
|
+
|
|
1043
1110
|
|
|
1044
1111
|
def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
1045
1112
|
verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32,
|
|
@@ -1088,7 +1155,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
1088
1155
|
warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1089
1156
|
|
|
1090
1157
|
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
1091
|
-
extraction_results = asyncio.run(self.
|
|
1158
|
+
extraction_results = asyncio.run(self._extract_async(text_content=text_content,
|
|
1092
1159
|
document_key=document_key,
|
|
1093
1160
|
concurrent_batch_size=concurrent_batch_size,
|
|
1094
1161
|
return_messages_log=return_messages_log)
|
|
@@ -1101,248 +1168,31 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
1101
1168
|
|
|
1102
1169
|
units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
1103
1170
|
|
|
1104
|
-
frame_list =
|
|
1105
|
-
for unit in units:
|
|
1106
|
-
entity_json = []
|
|
1107
|
-
if unit.status != "success":
|
|
1108
|
-
warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
|
|
1109
|
-
continue
|
|
1110
|
-
for entity in extract_json(gen_text=unit.gen_text):
|
|
1111
|
-
if ENTITY_KEY in entity:
|
|
1112
|
-
entity_json.append(entity)
|
|
1113
|
-
else:
|
|
1114
|
-
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
1115
|
-
|
|
1116
|
-
spans = self._find_entity_spans(text=unit.text,
|
|
1117
|
-
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
1118
|
-
case_sensitive=case_sensitive,
|
|
1119
|
-
fuzzy_match=fuzzy_match,
|
|
1120
|
-
fuzzy_buffer_size=fuzzy_buffer_size,
|
|
1121
|
-
fuzzy_score_cutoff=fuzzy_score_cutoff,
|
|
1122
|
-
allow_overlap_entities=allow_overlap_entities)
|
|
1123
|
-
for ent, span in zip(entity_json, spans):
|
|
1124
|
-
if span is not None:
|
|
1125
|
-
start, end = span
|
|
1126
|
-
entity_text = unit.text[start:end]
|
|
1127
|
-
start += unit.start
|
|
1128
|
-
end += unit.start
|
|
1129
|
-
attr = {}
|
|
1130
|
-
if "attr" in ent and ent["attr"] is not None:
|
|
1131
|
-
attr = ent["attr"]
|
|
1132
|
-
|
|
1133
|
-
frame = LLMInformationExtractionFrame(frame_id=f"{len(frame_list)}",
|
|
1134
|
-
start=start,
|
|
1135
|
-
end=end,
|
|
1136
|
-
entity_text=entity_text,
|
|
1137
|
-
attr=attr)
|
|
1138
|
-
frame_list.append(frame)
|
|
1171
|
+
frame_list = self._post_process_units_to_frames(units, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
|
|
1139
1172
|
|
|
1140
1173
|
if return_messages_log:
|
|
1141
1174
|
return frame_list, messages_log
|
|
1142
1175
|
return frame_list
|
|
1143
|
-
|
|
1144
1176
|
|
|
1145
|
-
async def
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1177
|
+
async def extract_frames_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
1178
|
+
concurrent_batch_size:int=32, case_sensitive:bool=False,
|
|
1179
|
+
fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
1180
|
+
allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
|
|
1149
1181
|
"""
|
|
1150
|
-
This
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
The key in the `text_contents` dictionaries that holds the document text.
|
|
1160
|
-
cpu_concurrency: int, optional
|
|
1161
|
-
The number of parallel threads to use for CPU-bound tasks like chunking.
|
|
1162
|
-
llm_concurrency: int, optional
|
|
1163
|
-
The number of concurrent requests to make to the LLM.
|
|
1164
|
-
case_sensitive : bool, Optional
|
|
1165
|
-
if True, entity text matching will be case-sensitive.
|
|
1166
|
-
fuzzy_match : bool, Optional
|
|
1167
|
-
if True, fuzzy matching will be applied to find entity text.
|
|
1168
|
-
fuzzy_buffer_size : float, Optional
|
|
1169
|
-
the buffer size for fuzzy matching. Default is 20% of entity text length.
|
|
1170
|
-
fuzzy_score_cutoff : float, Optional
|
|
1171
|
-
the Jaccard score cutoff for fuzzy matching.
|
|
1172
|
-
Matched entity text must have a score higher than this value or a None will be returned.
|
|
1173
|
-
allow_overlap_entities : bool, Optional
|
|
1174
|
-
if True, entities can overlap in the text.
|
|
1175
|
-
return_messages_log : bool, Optional
|
|
1176
|
-
if True, a list of messages will be returned.
|
|
1177
|
-
|
|
1178
|
-
Yields:
|
|
1179
|
-
-------
|
|
1180
|
-
AsyncGenerator[Dict[str, any], None]
|
|
1181
|
-
A dictionary for each completed document, containing its 'idx' and extracted 'frames'.
|
|
1182
|
-
"""
|
|
1183
|
-
# Validate text_contents must be a list of str or dict, and not both
|
|
1184
|
-
if not isinstance(text_contents, list):
|
|
1185
|
-
raise ValueError("text_contents must be a list of strings or dictionaries.")
|
|
1186
|
-
if all(isinstance(doc, str) for doc in text_contents):
|
|
1187
|
-
pass
|
|
1188
|
-
elif all(isinstance(doc, dict) for doc in text_contents):
|
|
1189
|
-
pass
|
|
1190
|
-
# Set CPU executor and queues
|
|
1191
|
-
cpu_executor = ThreadPoolExecutor(max_workers=cpu_concurrency)
|
|
1192
|
-
tasks_queue = asyncio.Queue(maxsize=llm_concurrency * 2)
|
|
1193
|
-
# Store to track units and pending counts
|
|
1194
|
-
results_store = {
|
|
1195
|
-
idx: {'pending': 0, 'units': [], 'text': doc if isinstance(doc, str) else doc.get(document_key, "")}
|
|
1196
|
-
for idx, doc in enumerate(text_contents)
|
|
1197
|
-
}
|
|
1198
|
-
|
|
1199
|
-
output_queue = asyncio.Queue()
|
|
1200
|
-
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1201
|
-
|
|
1202
|
-
async def producer():
|
|
1203
|
-
try:
|
|
1204
|
-
for idx, text_content in enumerate(text_contents):
|
|
1205
|
-
text = text_content if isinstance(text_content, str) else text_content.get(document_key, "")
|
|
1206
|
-
if not text:
|
|
1207
|
-
warnings.warn(f"Document at index {idx} is empty or missing the document key '{document_key}'.")
|
|
1208
|
-
# signal that this document is done
|
|
1209
|
-
await output_queue.put({'idx': idx, 'frames': []})
|
|
1210
|
-
continue
|
|
1211
|
-
|
|
1212
|
-
units = await self.unit_chunker.chunk_async(text, cpu_executor)
|
|
1213
|
-
await self.context_chunker.fit_async(text, units, cpu_executor)
|
|
1214
|
-
results_store[idx]['pending'] = len(units)
|
|
1215
|
-
|
|
1216
|
-
# Handle cases where a document yields no units
|
|
1217
|
-
if not units:
|
|
1218
|
-
# signal that this document is done
|
|
1219
|
-
await output_queue.put({'idx': idx, 'frames': []})
|
|
1220
|
-
continue
|
|
1221
|
-
|
|
1222
|
-
# Iterate through units
|
|
1223
|
-
for unit in units:
|
|
1224
|
-
context = await self.context_chunker.chunk_async(unit, cpu_executor)
|
|
1225
|
-
messages = []
|
|
1226
|
-
if self.system_prompt:
|
|
1227
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1228
|
-
|
|
1229
|
-
if not context:
|
|
1230
|
-
if isinstance(text_content, str):
|
|
1231
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
1232
|
-
else:
|
|
1233
|
-
unit_content = text_content.copy()
|
|
1234
|
-
unit_content[document_key] = unit.text
|
|
1235
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
1236
|
-
else:
|
|
1237
|
-
# insert context to user prompt
|
|
1238
|
-
if isinstance(text_content, str):
|
|
1239
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
1240
|
-
else:
|
|
1241
|
-
context_content = text_content.copy()
|
|
1242
|
-
context_content[document_key] = context
|
|
1243
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
1244
|
-
# simulate conversation where assistant confirms
|
|
1245
|
-
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
1246
|
-
# place unit of interest
|
|
1247
|
-
messages.append({'role': 'user', 'content': unit.text})
|
|
1248
|
-
|
|
1249
|
-
await tasks_queue.put({'idx': idx, 'unit': unit, 'messages': messages})
|
|
1250
|
-
finally:
|
|
1251
|
-
for _ in range(llm_concurrency):
|
|
1252
|
-
await tasks_queue.put(None)
|
|
1253
|
-
|
|
1254
|
-
async def worker():
|
|
1255
|
-
while True:
|
|
1256
|
-
task_item = await tasks_queue.get()
|
|
1257
|
-
if task_item is None:
|
|
1258
|
-
tasks_queue.task_done()
|
|
1259
|
-
break
|
|
1260
|
-
|
|
1261
|
-
idx = task_item['idx']
|
|
1262
|
-
unit = task_item['unit']
|
|
1263
|
-
doc_results = results_store[idx]
|
|
1264
|
-
|
|
1265
|
-
try:
|
|
1266
|
-
gen_text = await self.inference_engine.chat_async(
|
|
1267
|
-
messages=task_item['messages'], messages_logger=messages_logger
|
|
1268
|
-
)
|
|
1269
|
-
unit.set_generated_text(gen_text["response"])
|
|
1270
|
-
unit.set_status("success")
|
|
1271
|
-
doc_results['units'].append(unit)
|
|
1272
|
-
except Exception as e:
|
|
1273
|
-
warnings.warn(f"Error processing unit for doc idx {idx}: {e}")
|
|
1274
|
-
finally:
|
|
1275
|
-
doc_results['pending'] -= 1
|
|
1276
|
-
if doc_results['pending'] <= 0:
|
|
1277
|
-
final_frames = self._post_process_and_create_frames(doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
|
|
1278
|
-
output_payload = {'idx': idx, 'frames': final_frames}
|
|
1279
|
-
if return_messages_log:
|
|
1280
|
-
output_payload['messages_log'] = messages_logger.get_messages_log()
|
|
1281
|
-
await output_queue.put(output_payload)
|
|
1282
|
-
|
|
1283
|
-
tasks_queue.task_done()
|
|
1284
|
-
|
|
1285
|
-
# Start producer and workers
|
|
1286
|
-
producer_task = asyncio.create_task(producer())
|
|
1287
|
-
worker_tasks = [asyncio.create_task(worker()) for _ in range(llm_concurrency)]
|
|
1288
|
-
|
|
1289
|
-
# Main loop to gather results
|
|
1290
|
-
docs_completed = 0
|
|
1291
|
-
while docs_completed < len(text_contents):
|
|
1292
|
-
result = await output_queue.get()
|
|
1293
|
-
yield result
|
|
1294
|
-
docs_completed += 1
|
|
1295
|
-
|
|
1296
|
-
# Final cleanup
|
|
1297
|
-
await producer_task
|
|
1298
|
-
await tasks_queue.join()
|
|
1299
|
-
|
|
1300
|
-
# Cancel any lingering worker tasks
|
|
1301
|
-
for task in worker_tasks:
|
|
1302
|
-
task.cancel()
|
|
1303
|
-
await asyncio.gather(*worker_tasks, return_exceptions=True)
|
|
1304
|
-
|
|
1305
|
-
cpu_executor.shutdown(wait=False)
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
def _post_process_and_create_frames(self, doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
|
|
1309
|
-
"""Helper function to run post-processing logic for a completed document."""
|
|
1310
|
-
ENTITY_KEY = "entity_text"
|
|
1311
|
-
frame_list = []
|
|
1312
|
-
for res in sorted(doc_results['units'], key=lambda r: r.start):
|
|
1313
|
-
entity_json = []
|
|
1314
|
-
for entity in extract_json(gen_text=res.gen_text):
|
|
1315
|
-
if ENTITY_KEY in entity:
|
|
1316
|
-
entity_json.append(entity)
|
|
1317
|
-
else:
|
|
1318
|
-
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
1182
|
+
This is the async version of extract_frames.
|
|
1183
|
+
"""
|
|
1184
|
+
extraction_results = await self._extract_async(text_content=text_content,
|
|
1185
|
+
document_key=document_key,
|
|
1186
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
1187
|
+
return_messages_log=return_messages_log)
|
|
1188
|
+
|
|
1189
|
+
units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
1190
|
+
frame_list = self._post_process_units_to_frames(units, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
|
|
1319
1191
|
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
1323
|
-
case_sensitive=case_sensitive,
|
|
1324
|
-
fuzzy_match=fuzzy_match,
|
|
1325
|
-
fuzzy_buffer_size=fuzzy_buffer_size,
|
|
1326
|
-
fuzzy_score_cutoff=fuzzy_score_cutoff,
|
|
1327
|
-
allow_overlap_entities=allow_overlap_entities
|
|
1328
|
-
)
|
|
1329
|
-
for ent, span in zip(entity_json, spans):
|
|
1330
|
-
if span is not None:
|
|
1331
|
-
start, end = span
|
|
1332
|
-
entity_text = res.text[start:end]
|
|
1333
|
-
start += res.start
|
|
1334
|
-
end += res.start
|
|
1335
|
-
attr = ent.get("attr", {}) or {}
|
|
1336
|
-
frame = LLMInformationExtractionFrame(
|
|
1337
|
-
frame_id=f"{len(frame_list)}",
|
|
1338
|
-
start=start,
|
|
1339
|
-
end=end,
|
|
1340
|
-
entity_text=entity_text,
|
|
1341
|
-
attr=attr
|
|
1342
|
-
)
|
|
1343
|
-
frame_list.append(frame)
|
|
1192
|
+
if return_messages_log:
|
|
1193
|
+
return frame_list, messages_log
|
|
1344
1194
|
return frame_list
|
|
1345
|
-
|
|
1195
|
+
|
|
1346
1196
|
|
|
1347
1197
|
class ReviewFrameExtractor(DirectFrameExtractor):
|
|
1348
1198
|
def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
|
|
@@ -1494,7 +1344,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1494
1344
|
initial = self.inference_engine.chat(
|
|
1495
1345
|
messages=messages,
|
|
1496
1346
|
verbose=verbose,
|
|
1497
|
-
stream=False,
|
|
1498
1347
|
messages_logger=messages_logger
|
|
1499
1348
|
)
|
|
1500
1349
|
|
|
@@ -1508,7 +1357,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1508
1357
|
review = self.inference_engine.chat(
|
|
1509
1358
|
messages=messages,
|
|
1510
1359
|
verbose=verbose,
|
|
1511
|
-
stream=False,
|
|
1512
1360
|
messages_logger=messages_logger
|
|
1513
1361
|
)
|
|
1514
1362
|
|
|
@@ -1596,9 +1444,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1596
1444
|
|
|
1597
1445
|
yield f"{Fore.BLUE}Extraction:{Style.RESET_ALL}\n"
|
|
1598
1446
|
|
|
1599
|
-
response_stream = self.inference_engine.
|
|
1600
|
-
messages=messages
|
|
1601
|
-
stream=True
|
|
1447
|
+
response_stream = self.inference_engine.chat_stream(
|
|
1448
|
+
messages=messages
|
|
1602
1449
|
)
|
|
1603
1450
|
|
|
1604
1451
|
initial = ""
|
|
@@ -1612,15 +1459,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1612
1459
|
messages.append({'role': 'assistant', 'content': initial})
|
|
1613
1460
|
messages.append({'role': 'user', 'content': self.review_prompt})
|
|
1614
1461
|
|
|
1615
|
-
response_stream = self.inference_engine.
|
|
1616
|
-
messages=messages
|
|
1617
|
-
stream=True
|
|
1462
|
+
response_stream = self.inference_engine.chat_stream(
|
|
1463
|
+
messages=messages
|
|
1618
1464
|
)
|
|
1619
1465
|
|
|
1620
1466
|
for chunk in response_stream:
|
|
1621
1467
|
yield chunk
|
|
1622
1468
|
|
|
1623
|
-
async def
|
|
1469
|
+
async def _extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
1624
1470
|
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnit]:
|
|
1625
1471
|
"""
|
|
1626
1472
|
This is the asynchronous version of the extract() method with the review step.
|
|
@@ -2056,7 +1902,6 @@ class AttributeExtractor(Extractor):
|
|
|
2056
1902
|
gen_text = self.inference_engine.chat(
|
|
2057
1903
|
messages=messages,
|
|
2058
1904
|
verbose=verbose,
|
|
2059
|
-
stream=False,
|
|
2060
1905
|
messages_logger=messages_logger
|
|
2061
1906
|
)
|
|
2062
1907
|
|
|
@@ -2123,7 +1968,7 @@ class AttributeExtractor(Extractor):
|
|
|
2123
1968
|
return (new_frames, messages_log) if return_messages_log else new_frames
|
|
2124
1969
|
|
|
2125
1970
|
|
|
2126
|
-
async def
|
|
1971
|
+
async def _extract_async(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
2127
1972
|
concurrent_batch_size:int=32, inplace:bool=True, return_messages_log:bool=False) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
2128
1973
|
"""
|
|
2129
1974
|
This method extracts attributes from the document asynchronously.
|
|
@@ -2195,6 +2040,16 @@ class AttributeExtractor(Extractor):
|
|
|
2195
2040
|
else:
|
|
2196
2041
|
return (new_frames, messages_logger.get_messages_log()) if return_messages_log else new_frames
|
|
2197
2042
|
|
|
2043
|
+
async def extract_attributes_async(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
2044
|
+
concurrent_batch_size:int=32, inplace:bool=True,
|
|
2045
|
+
return_messages_log:bool=False) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
2046
|
+
"""
|
|
2047
|
+
This is the async version of extract_attributes.
|
|
2048
|
+
"""
|
|
2049
|
+
return await self._extract_async(frames=frames, text=text, context_size=context_size,
|
|
2050
|
+
concurrent_batch_size=concurrent_batch_size, inplace=inplace, return_messages_log=return_messages_log)
|
|
2051
|
+
|
|
2052
|
+
|
|
2198
2053
|
def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
2199
2054
|
concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
|
|
2200
2055
|
return_messages_log:bool=False, inplace:bool=True) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
@@ -2230,7 +2085,7 @@ class AttributeExtractor(Extractor):
|
|
|
2230
2085
|
|
|
2231
2086
|
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
2232
2087
|
|
|
2233
|
-
return asyncio.run(self.
|
|
2088
|
+
return asyncio.run(self._extract_async(frames=frames, text=text, context_size=context_size,
|
|
2234
2089
|
concurrent_batch_size=concurrent_batch_size,
|
|
2235
2090
|
inplace=inplace, return_messages_log=return_messages_log))
|
|
2236
2091
|
else:
|
|
@@ -2375,6 +2230,17 @@ class RelationExtractor(Extractor):
|
|
|
2375
2230
|
return asyncio.run(self._extract_async(doc, buffer_size, concurrent_batch_size, return_messages_log))
|
|
2376
2231
|
else:
|
|
2377
2232
|
return self._extract(doc, buffer_size, verbose, return_messages_log)
|
|
2233
|
+
|
|
2234
|
+
async def extract_relations_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
2235
|
+
"""
|
|
2236
|
+
This is the async version of extract_relations.
|
|
2237
|
+
"""
|
|
2238
|
+
if not doc.has_frame():
|
|
2239
|
+
raise ValueError("Input document must have frames.")
|
|
2240
|
+
if doc.has_duplicate_frame_ids():
|
|
2241
|
+
raise ValueError("All frame_ids in the input document must be unique.")
|
|
2242
|
+
|
|
2243
|
+
return await self._extract_async(doc, buffer_size, concurrent_batch_size, return_messages_log)
|
|
2378
2244
|
|
|
2379
2245
|
|
|
2380
2246
|
class BinaryRelationExtractor(RelationExtractor):
|
llm_ie/prompt_editor.py
CHANGED
|
@@ -270,5 +270,5 @@ class PromptEditor:
|
|
|
270
270
|
|
|
271
271
|
messages = [{"role": "system", "content": self.system_prompt + guideline}] + messages
|
|
272
272
|
|
|
273
|
-
stream_generator = self.inference_engine.
|
|
273
|
+
stream_generator = self.inference_engine.chat_stream(messages)
|
|
274
274
|
yield from stream_generator
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: llm-ie
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.1
|
|
4
4
|
Summary: A comprehensive toolkit that provides building blocks for LLM-based named entity recognition, attribute extraction, and relation extraction pipelines.
|
|
5
5
|
License: MIT
|
|
6
6
|
Author: Enshuo (David) Hsu
|
|
@@ -10,7 +10,8 @@ Classifier: Programming Language :: Python :: 3
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
12
|
Requires-Dist: colorama (>=0.4.6,<0.5.0)
|
|
13
|
-
Requires-Dist: json_repair (>=0.30
|
|
13
|
+
Requires-Dist: json_repair (>=0.30)
|
|
14
|
+
Requires-Dist: llm-inference-engine (>=0.1.5)
|
|
14
15
|
Requires-Dist: nest_asyncio (>=1.6.0,<2.0.0)
|
|
15
16
|
Requires-Dist: nltk (>=3.8,<4.0)
|
|
16
17
|
Description-Content-Type: text/markdown
|
|
@@ -22,10 +22,10 @@ llm_ie/asset/prompt_guide/SentenceReviewFrameExtractor_prompt_guide.txt,sha256=9
|
|
|
22
22
|
llm_ie/asset/prompt_guide/StructExtractor_prompt_guide.txt,sha256=x8L4n_LVl6ofQu6cDE9YP4SB2FSQ4GrTee8y1XKwwwc,1922
|
|
23
23
|
llm_ie/chunkers.py,sha256=b4APRwaLMU40QXVEhOK8m1DZi_jr-VCHAFwbMjqVBgA,11308
|
|
24
24
|
llm_ie/data_types.py,sha256=iG_jdqhpBi33xnsfFQYayCXNBK-2N-8u1xIhoKfJzRI,18294
|
|
25
|
-
llm_ie/engines.py,sha256=
|
|
26
|
-
llm_ie/extractors.py,sha256=
|
|
27
|
-
llm_ie/prompt_editor.py,sha256=
|
|
25
|
+
llm_ie/engines.py,sha256=Lxzj0gfbUjaU8TpWWM7MqS71Vmpqdq_mIHoLiXqOmXs,1089
|
|
26
|
+
llm_ie/extractors.py,sha256=hA7VoWZU2z6aWXfg5rTwFAmK5L0weILbCGWUaUNJU9w,114859
|
|
27
|
+
llm_ie/prompt_editor.py,sha256=ZAr6A9HRbqKWumVa5kRgcnH2rXdHSmPhYP1Hdp3Ic2o,12049
|
|
28
28
|
llm_ie/utils.py,sha256=k6M4l8GsKOMcmO6UwONQ353Zk-TeoBj6HXGjlAn-JE0,3679
|
|
29
|
-
llm_ie-1.
|
|
30
|
-
llm_ie-1.
|
|
31
|
-
llm_ie-1.
|
|
29
|
+
llm_ie-1.4.1.dist-info/METADATA,sha256=Tfp40uGnIbgQa-1IxtPYJpiViOaeXbOKGYi08yCCUzg,768
|
|
30
|
+
llm_ie-1.4.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
31
|
+
llm_ie-1.4.1.dist-info/RECORD,,
|
|
File without changes
|