llm-ie 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +5 -4
- llm_ie/asset/default_prompts/LLMUnitChunker_user_prompt.txt +129 -0
- llm_ie/chunkers.py +145 -6
- llm_ie/data_types.py +23 -37
- llm_ie/engines.py +621 -61
- llm_ie/extractors.py +341 -297
- llm_ie/prompt_editor.py +9 -32
- llm_ie/utils.py +95 -0
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.4.dist-info}/METADATA +1 -1
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.4.dist-info}/RECORD +11 -9
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.4.dist-info}/WHEEL +0 -0
llm_ie/extractors.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import re
|
|
3
|
-
import json
|
|
4
|
-
import json_repair
|
|
5
3
|
import inspect
|
|
6
4
|
import importlib.resources
|
|
7
5
|
import warnings
|
|
8
6
|
import itertools
|
|
9
7
|
import asyncio
|
|
10
8
|
import nest_asyncio
|
|
11
|
-
from
|
|
12
|
-
from
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
|
|
11
|
+
from llm_ie.utils import extract_json, apply_prompt_template
|
|
12
|
+
from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
13
13
|
from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
|
|
14
14
|
from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
|
|
15
|
-
from llm_ie.engines import InferenceEngine
|
|
15
|
+
from llm_ie.engines import InferenceEngine, MessagesLogger
|
|
16
16
|
from colorama import Fore, Style
|
|
17
17
|
|
|
18
18
|
|
|
@@ -95,79 +95,8 @@ class Extractor:
|
|
|
95
95
|
Returns : str
|
|
96
96
|
a user prompt.
|
|
97
97
|
"""
|
|
98
|
-
|
|
99
|
-
if isinstance(text_content, str):
|
|
100
|
-
matches = pattern.findall(self.prompt_template)
|
|
101
|
-
if len(matches) != 1:
|
|
102
|
-
raise ValueError("When text_content is str, the prompt template must has exactly 1 placeholder {{<placeholder name>}}.")
|
|
103
|
-
text = re.sub(r'\\', r'\\\\', text_content)
|
|
104
|
-
prompt = pattern.sub(text, self.prompt_template)
|
|
105
|
-
|
|
106
|
-
elif isinstance(text_content, dict):
|
|
107
|
-
# Check if all values are str
|
|
108
|
-
if not all([isinstance(v, str) for v in text_content.values()]):
|
|
109
|
-
raise ValueError("All values in text_content must be str.")
|
|
110
|
-
# Check if all keys are in the prompt template
|
|
111
|
-
placeholders = pattern.findall(self.prompt_template)
|
|
112
|
-
if len(placeholders) != len(text_content):
|
|
113
|
-
raise ValueError(f"Expect text_content ({len(text_content)}) and prompt template placeholder ({len(placeholders)}) to have equal size.")
|
|
114
|
-
if not all([k in placeholders for k, _ in text_content.items()]):
|
|
115
|
-
raise ValueError(f"All keys in text_content ({text_content.keys()}) must match placeholders in prompt template ({placeholders}).")
|
|
116
|
-
|
|
117
|
-
prompt = pattern.sub(lambda match: re.sub(r'\\', r'\\\\', text_content[match.group(1)]), self.prompt_template)
|
|
118
|
-
|
|
119
|
-
return prompt
|
|
120
|
-
|
|
121
|
-
def _find_dict_strings(self, text: str) -> List[str]:
|
|
122
|
-
"""
|
|
123
|
-
Extracts balanced JSON-like dictionaries from a string, even if nested.
|
|
98
|
+
return apply_prompt_template(self.prompt_template, text_content)
|
|
124
99
|
|
|
125
|
-
Parameters:
|
|
126
|
-
-----------
|
|
127
|
-
text : str
|
|
128
|
-
the input text containing JSON-like structures.
|
|
129
|
-
|
|
130
|
-
Returns : List[str]
|
|
131
|
-
A list of valid JSON-like strings representing dictionaries.
|
|
132
|
-
"""
|
|
133
|
-
open_brace = 0
|
|
134
|
-
start = -1
|
|
135
|
-
json_objects = []
|
|
136
|
-
|
|
137
|
-
for i, char in enumerate(text):
|
|
138
|
-
if char == '{':
|
|
139
|
-
if open_brace == 0:
|
|
140
|
-
# start of a new JSON object
|
|
141
|
-
start = i
|
|
142
|
-
open_brace += 1
|
|
143
|
-
elif char == '}':
|
|
144
|
-
open_brace -= 1
|
|
145
|
-
if open_brace == 0 and start != -1:
|
|
146
|
-
json_objects.append(text[start:i + 1])
|
|
147
|
-
start = -1
|
|
148
|
-
|
|
149
|
-
return json_objects
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
def _extract_json(self, gen_text:str) -> List[Dict[str, str]]:
|
|
153
|
-
"""
|
|
154
|
-
This method inputs a generated text and output a JSON of information tuples
|
|
155
|
-
"""
|
|
156
|
-
out = []
|
|
157
|
-
dict_str_list = self._find_dict_strings(gen_text)
|
|
158
|
-
for dict_str in dict_str_list:
|
|
159
|
-
try:
|
|
160
|
-
dict_obj = json.loads(dict_str)
|
|
161
|
-
out.append(dict_obj)
|
|
162
|
-
except json.JSONDecodeError:
|
|
163
|
-
dict_obj = json_repair.repair_json(dict_str, skip_json_loads=True, return_objects=True)
|
|
164
|
-
if dict_obj:
|
|
165
|
-
warnings.warn(f'JSONDecodeError detected, fixed with repair_json:\n{dict_str}', RuntimeWarning)
|
|
166
|
-
out.append(dict_obj)
|
|
167
|
-
else:
|
|
168
|
-
warnings.warn(f'JSONDecodeError could not be fixed:\n{dict_str}', RuntimeWarning)
|
|
169
|
-
return out
|
|
170
|
-
|
|
171
100
|
|
|
172
101
|
class FrameExtractor(Extractor):
|
|
173
102
|
from nltk.tokenize import RegexpTokenizer
|
|
@@ -405,7 +334,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
405
334
|
|
|
406
335
|
|
|
407
336
|
def extract(self, text_content:Union[str, Dict[str,str]],
|
|
408
|
-
document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[
|
|
337
|
+
document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
409
338
|
"""
|
|
410
339
|
This method inputs a text and outputs a list of outputs per unit.
|
|
411
340
|
|
|
@@ -423,11 +352,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
423
352
|
return_messages_log : bool, Optional
|
|
424
353
|
if True, a list of messages will be returned.
|
|
425
354
|
|
|
426
|
-
Return : List[
|
|
355
|
+
Return : List[FrameExtractionUnit]
|
|
427
356
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
428
357
|
"""
|
|
429
|
-
# define output
|
|
430
|
-
output = []
|
|
431
358
|
# unit chunking
|
|
432
359
|
if isinstance(text_content, str):
|
|
433
360
|
doc_text = text_content
|
|
@@ -440,76 +367,70 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
440
367
|
units = self.unit_chunker.chunk(doc_text)
|
|
441
368
|
# context chunker init
|
|
442
369
|
self.context_chunker.fit(doc_text, units)
|
|
370
|
+
|
|
443
371
|
# messages log
|
|
444
|
-
if return_messages_log
|
|
445
|
-
messages_log = []
|
|
372
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
446
373
|
|
|
447
374
|
# generate unit by unit
|
|
448
375
|
for i, unit in enumerate(units):
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
context = self.context_chunker.chunk(unit)
|
|
455
|
-
|
|
456
|
-
if context == "":
|
|
457
|
-
# no context, just place unit in user prompt
|
|
458
|
-
if isinstance(text_content, str):
|
|
459
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
460
|
-
else:
|
|
461
|
-
unit_content = text_content.copy()
|
|
462
|
-
unit_content[document_key] = unit.text
|
|
463
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
464
|
-
else:
|
|
465
|
-
# insert context to user prompt
|
|
466
|
-
if isinstance(text_content, str):
|
|
467
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
468
|
-
else:
|
|
469
|
-
context_content = text_content.copy()
|
|
470
|
-
context_content[document_key] = context
|
|
471
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
472
|
-
# simulate conversation where assistant confirms
|
|
473
|
-
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
474
|
-
# place unit of interest
|
|
475
|
-
messages.append({'role': 'user', 'content': unit.text})
|
|
376
|
+
try:
|
|
377
|
+
# construct chat messages
|
|
378
|
+
messages = []
|
|
379
|
+
if self.system_prompt:
|
|
380
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
476
381
|
|
|
477
|
-
|
|
478
|
-
print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
479
|
-
if context != "":
|
|
480
|
-
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
382
|
+
context = self.context_chunker.chunk(unit)
|
|
481
383
|
|
|
482
|
-
|
|
384
|
+
if context == "":
|
|
385
|
+
# no context, just place unit in user prompt
|
|
386
|
+
if isinstance(text_content, str):
|
|
387
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
388
|
+
else:
|
|
389
|
+
unit_content = text_content.copy()
|
|
390
|
+
unit_content[document_key] = unit.text
|
|
391
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
392
|
+
else:
|
|
393
|
+
# insert context to user prompt
|
|
394
|
+
if isinstance(text_content, str):
|
|
395
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
396
|
+
else:
|
|
397
|
+
context_content = text_content.copy()
|
|
398
|
+
context_content[document_key] = context
|
|
399
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
400
|
+
# simulate conversation where assistant confirms
|
|
401
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
402
|
+
# place unit of interest
|
|
403
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
483
404
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
405
|
+
if verbose:
|
|
406
|
+
print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
407
|
+
if context != "":
|
|
408
|
+
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
409
|
+
|
|
410
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
490
411
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
412
|
+
|
|
413
|
+
gen_text = self.inference_engine.chat(
|
|
414
|
+
messages=messages,
|
|
415
|
+
verbose=verbose,
|
|
416
|
+
stream=False,
|
|
417
|
+
messages_logger=messages_logger
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# add generated text to unit
|
|
421
|
+
unit.set_generated_text(gen_text["response"])
|
|
422
|
+
unit.set_status("success")
|
|
423
|
+
except Exception as e:
|
|
424
|
+
unit.set_status("fail")
|
|
425
|
+
warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
|
|
505
426
|
|
|
506
427
|
if return_messages_log:
|
|
507
|
-
return
|
|
428
|
+
return units, messages_logger.get_messages_log()
|
|
508
429
|
|
|
509
|
-
return
|
|
430
|
+
return units
|
|
510
431
|
|
|
511
432
|
def stream(self, text_content: Union[str, Dict[str, str]],
|
|
512
|
-
document_key: str = None) -> Generator[Dict[str, Any], None, List[
|
|
433
|
+
document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
|
|
513
434
|
"""
|
|
514
435
|
Streams LLM responses per unit with structured event types,
|
|
515
436
|
and returns collected data for post-processing.
|
|
@@ -525,12 +446,10 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
525
446
|
|
|
526
447
|
Returns:
|
|
527
448
|
--------
|
|
528
|
-
List[
|
|
529
|
-
A list of
|
|
449
|
+
List[FrameExtractionUnit]:
|
|
450
|
+
A list of FrameExtractionUnit objects, each containing the
|
|
530
451
|
original unit details and the fully accumulated 'gen_text' from the LLM.
|
|
531
452
|
"""
|
|
532
|
-
collected_results: List[FrameExtractionUnitResult] = []
|
|
533
|
-
|
|
534
453
|
if isinstance(text_content, str):
|
|
535
454
|
doc_text = text_content
|
|
536
455
|
elif isinstance(text_content, dict):
|
|
@@ -588,19 +507,14 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
588
507
|
current_gen_text += chunk["data"]
|
|
589
508
|
|
|
590
509
|
# Store the result for this unit
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
end=unit.end,
|
|
594
|
-
text=unit.text,
|
|
595
|
-
gen_text=current_gen_text
|
|
596
|
-
)
|
|
597
|
-
collected_results.append(result_for_unit)
|
|
510
|
+
unit.set_generated_text(current_gen_text)
|
|
511
|
+
unit.set_status("success")
|
|
598
512
|
|
|
599
513
|
yield {"type": "info", "data": "All units processed by LLM."}
|
|
600
|
-
return
|
|
514
|
+
return units
|
|
601
515
|
|
|
602
516
|
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
603
|
-
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[
|
|
517
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
604
518
|
"""
|
|
605
519
|
This is the asynchronous version of the extract() method.
|
|
606
520
|
|
|
@@ -618,7 +532,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
618
532
|
return_messages_log : bool, Optional
|
|
619
533
|
if True, a list of messages will be returned.
|
|
620
534
|
|
|
621
|
-
Return : List[
|
|
535
|
+
Return : List[FrameExtractionUnit]
|
|
622
536
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
623
537
|
"""
|
|
624
538
|
if isinstance(text_content, str):
|
|
@@ -637,6 +551,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
637
551
|
# context chunker init
|
|
638
552
|
self.context_chunker.fit(doc_text, units)
|
|
639
553
|
|
|
554
|
+
# messages logger init
|
|
555
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
556
|
+
|
|
640
557
|
# Prepare inputs for all units first
|
|
641
558
|
tasks_input = []
|
|
642
559
|
for i, unit in enumerate(units):
|
|
@@ -677,17 +594,15 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
677
594
|
async def semaphore_helper(task_data: Dict, **kwrs):
|
|
678
595
|
unit = task_data["unit"]
|
|
679
596
|
messages = task_data["messages"]
|
|
680
|
-
original_index = task_data["original_index"]
|
|
681
597
|
|
|
682
598
|
async with semaphore:
|
|
683
599
|
gen_text = await self.inference_engine.chat_async(
|
|
684
|
-
messages=messages
|
|
600
|
+
messages=messages,
|
|
601
|
+
messages_logger=messages_logger
|
|
685
602
|
)
|
|
686
603
|
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
out["reasoning"] = gen_text["reasoning"]
|
|
690
|
-
return out
|
|
604
|
+
unit.set_generated_text(gen_text["response"])
|
|
605
|
+
unit.set_status("success")
|
|
691
606
|
|
|
692
607
|
# Create and gather tasks
|
|
693
608
|
tasks = []
|
|
@@ -697,40 +612,13 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
697
612
|
))
|
|
698
613
|
tasks.append(task)
|
|
699
614
|
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
# Sort results back into original order using the index stored
|
|
703
|
-
results_raw.sort(key=lambda x: x["original_index"])
|
|
704
|
-
|
|
705
|
-
# Restructure the results
|
|
706
|
-
output: List[FrameExtractionUnitResult] = []
|
|
707
|
-
messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
|
|
708
|
-
|
|
709
|
-
for result_data in results_raw:
|
|
710
|
-
unit = result_data["unit"]
|
|
711
|
-
gen_text = result_data["gen_text"]
|
|
712
|
-
|
|
713
|
-
# Create result object
|
|
714
|
-
result = FrameExtractionUnitResult(
|
|
715
|
-
start=unit.start,
|
|
716
|
-
end=unit.end,
|
|
717
|
-
text=unit.text,
|
|
718
|
-
gen_text=gen_text
|
|
719
|
-
)
|
|
720
|
-
output.append(result)
|
|
721
|
-
|
|
722
|
-
# Append to messages log if requested
|
|
723
|
-
if return_messages_log:
|
|
724
|
-
message = {"role": "assistant", "content": gen_text}
|
|
725
|
-
if "reasoning" in result_data:
|
|
726
|
-
message["reasoning"] = result_data["reasoning"]
|
|
727
|
-
final_messages = result_data["messages"] + [message]
|
|
728
|
-
messages_log.append(final_messages)
|
|
615
|
+
await asyncio.gather(*tasks)
|
|
729
616
|
|
|
617
|
+
# Return units
|
|
730
618
|
if return_messages_log:
|
|
731
|
-
return
|
|
619
|
+
return units, messages_logger.get_messages_log()
|
|
732
620
|
else:
|
|
733
|
-
return
|
|
621
|
+
return units
|
|
734
622
|
|
|
735
623
|
|
|
736
624
|
def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
@@ -738,7 +626,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
738
626
|
case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
739
627
|
allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
|
|
740
628
|
"""
|
|
741
|
-
This method inputs a text and outputs a list of LLMInformationExtractionFrame
|
|
629
|
+
This method inputs a document text and outputs a list of LLMInformationExtractionFrame
|
|
742
630
|
It use the extract() method and post-process outputs into frames.
|
|
743
631
|
|
|
744
632
|
Parameters:
|
|
@@ -791,18 +679,21 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
791
679
|
verbose=verbose,
|
|
792
680
|
return_messages_log=return_messages_log)
|
|
793
681
|
|
|
794
|
-
|
|
682
|
+
units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
795
683
|
|
|
796
684
|
frame_list = []
|
|
797
|
-
for
|
|
685
|
+
for unit in units:
|
|
798
686
|
entity_json = []
|
|
799
|
-
|
|
687
|
+
if unit.status != "success":
|
|
688
|
+
warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
|
|
689
|
+
continue
|
|
690
|
+
for entity in extract_json(gen_text=unit.gen_text):
|
|
800
691
|
if ENTITY_KEY in entity:
|
|
801
692
|
entity_json.append(entity)
|
|
802
693
|
else:
|
|
803
694
|
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
804
695
|
|
|
805
|
-
spans = self._find_entity_spans(text=
|
|
696
|
+
spans = self._find_entity_spans(text=unit.text,
|
|
806
697
|
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
807
698
|
case_sensitive=case_sensitive,
|
|
808
699
|
fuzzy_match=fuzzy_match,
|
|
@@ -812,9 +703,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
812
703
|
for ent, span in zip(entity_json, spans):
|
|
813
704
|
if span is not None:
|
|
814
705
|
start, end = span
|
|
815
|
-
entity_text =
|
|
816
|
-
start +=
|
|
817
|
-
end +=
|
|
706
|
+
entity_text = unit.text[start:end]
|
|
707
|
+
start += unit.start
|
|
708
|
+
end += unit.start
|
|
818
709
|
attr = {}
|
|
819
710
|
if "attr" in ent and ent["attr"] is not None:
|
|
820
711
|
attr = ent["attr"]
|
|
@@ -831,6 +722,208 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
831
722
|
return frame_list
|
|
832
723
|
|
|
833
724
|
|
|
725
|
+
async def extract_frames_from_documents(self, text_contents:List[Union[str,Dict[str, any]]], document_key:str="text",
|
|
726
|
+
cpu_concurrency:int=4, llm_concurrency:int=32, case_sensitive:bool=False,
|
|
727
|
+
fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
728
|
+
allow_overlap_entities:bool=False, return_messages_log:bool=False) -> AsyncGenerator[Dict[str, any], None]:
|
|
729
|
+
"""
|
|
730
|
+
This method inputs a list of documents and yields the results for each document as soon as it is complete.
|
|
731
|
+
|
|
732
|
+
Parameters:
|
|
733
|
+
-----------
|
|
734
|
+
text_contents : List[Union[str,Dict[str, any]]]
|
|
735
|
+
a list of input text contents to put in prompt template.
|
|
736
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
737
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
738
|
+
document_key: str, optional
|
|
739
|
+
The key in the `text_contents` dictionaries that holds the document text.
|
|
740
|
+
cpu_concurrency: int, optional
|
|
741
|
+
The number of parallel threads to use for CPU-bound tasks like chunking.
|
|
742
|
+
llm_concurrency: int, optional
|
|
743
|
+
The number of concurrent requests to make to the LLM.
|
|
744
|
+
case_sensitive : bool, Optional
|
|
745
|
+
if True, entity text matching will be case-sensitive.
|
|
746
|
+
fuzzy_match : bool, Optional
|
|
747
|
+
if True, fuzzy matching will be applied to find entity text.
|
|
748
|
+
fuzzy_buffer_size : float, Optional
|
|
749
|
+
the buffer size for fuzzy matching. Default is 20% of entity text length.
|
|
750
|
+
fuzzy_score_cutoff : float, Optional
|
|
751
|
+
the Jaccard score cutoff for fuzzy matching.
|
|
752
|
+
Matched entity text must have a score higher than this value or a None will be returned.
|
|
753
|
+
allow_overlap_entities : bool, Optional
|
|
754
|
+
if True, entities can overlap in the text.
|
|
755
|
+
return_messages_log : bool, Optional
|
|
756
|
+
if True, a list of messages will be returned.
|
|
757
|
+
|
|
758
|
+
Yields:
|
|
759
|
+
-------
|
|
760
|
+
AsyncGenerator[Dict[str, any], None]
|
|
761
|
+
A dictionary for each completed document, containing its 'idx' and extracted 'frames'.
|
|
762
|
+
"""
|
|
763
|
+
# Validate text_contents must be a list of str or dict, and not both
|
|
764
|
+
if not isinstance(text_contents, list):
|
|
765
|
+
raise ValueError("text_contents must be a list of strings or dictionaries.")
|
|
766
|
+
if all(isinstance(doc, str) for doc in text_contents):
|
|
767
|
+
pass
|
|
768
|
+
elif all(isinstance(doc, dict) for doc in text_contents):
|
|
769
|
+
pass
|
|
770
|
+
# Set CPU executor and queues
|
|
771
|
+
cpu_executor = ThreadPoolExecutor(max_workers=cpu_concurrency)
|
|
772
|
+
tasks_queue = asyncio.Queue(maxsize=llm_concurrency * 2)
|
|
773
|
+
# Store to track units and pending counts
|
|
774
|
+
results_store = {
|
|
775
|
+
idx: {'pending': 0, 'units': [], 'text': doc if isinstance(doc, str) else doc.get(document_key, "")}
|
|
776
|
+
for idx, doc in enumerate(text_contents)
|
|
777
|
+
}
|
|
778
|
+
|
|
779
|
+
output_queue = asyncio.Queue()
|
|
780
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
781
|
+
|
|
782
|
+
async def producer():
|
|
783
|
+
try:
|
|
784
|
+
for idx, text_content in enumerate(text_contents):
|
|
785
|
+
text = text_content if isinstance(text_content, str) else text_content.get(document_key, "")
|
|
786
|
+
if not text:
|
|
787
|
+
warnings.warn(f"Document at index {idx} is empty or missing the document key '{document_key}'.")
|
|
788
|
+
# signal that this document is done
|
|
789
|
+
await output_queue.put({'idx': idx, 'frames': []})
|
|
790
|
+
continue
|
|
791
|
+
|
|
792
|
+
units = await self.unit_chunker.chunk_async(text, cpu_executor)
|
|
793
|
+
await self.context_chunker.fit_async(text, units, cpu_executor)
|
|
794
|
+
results_store[idx]['pending'] = len(units)
|
|
795
|
+
|
|
796
|
+
# Handle cases where a document yields no units
|
|
797
|
+
if not units:
|
|
798
|
+
# signal that this document is done
|
|
799
|
+
await output_queue.put({'idx': idx, 'frames': []})
|
|
800
|
+
continue
|
|
801
|
+
|
|
802
|
+
# Iterate through units
|
|
803
|
+
for unit in units:
|
|
804
|
+
context = await self.context_chunker.chunk_async(unit, cpu_executor)
|
|
805
|
+
messages = []
|
|
806
|
+
if self.system_prompt:
|
|
807
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
808
|
+
|
|
809
|
+
if not context:
|
|
810
|
+
if isinstance(text_content, str):
|
|
811
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
812
|
+
else:
|
|
813
|
+
unit_content = text_content.copy()
|
|
814
|
+
unit_content[document_key] = unit.text
|
|
815
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
816
|
+
else:
|
|
817
|
+
# insert context to user prompt
|
|
818
|
+
if isinstance(text_content, str):
|
|
819
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
820
|
+
else:
|
|
821
|
+
context_content = text_content.copy()
|
|
822
|
+
context_content[document_key] = context
|
|
823
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
824
|
+
# simulate conversation where assistant confirms
|
|
825
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
826
|
+
# place unit of interest
|
|
827
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
828
|
+
|
|
829
|
+
await tasks_queue.put({'idx': idx, 'unit': unit, 'messages': messages})
|
|
830
|
+
finally:
|
|
831
|
+
for _ in range(llm_concurrency):
|
|
832
|
+
await tasks_queue.put(None)
|
|
833
|
+
|
|
834
|
+
async def worker():
|
|
835
|
+
while True:
|
|
836
|
+
task_item = await tasks_queue.get()
|
|
837
|
+
if task_item is None:
|
|
838
|
+
tasks_queue.task_done()
|
|
839
|
+
break
|
|
840
|
+
|
|
841
|
+
idx = task_item['idx']
|
|
842
|
+
unit = task_item['unit']
|
|
843
|
+
doc_results = results_store[idx]
|
|
844
|
+
|
|
845
|
+
try:
|
|
846
|
+
gen_text = await self.inference_engine.chat_async(
|
|
847
|
+
messages=task_item['messages'], messages_logger=messages_logger
|
|
848
|
+
)
|
|
849
|
+
unit.set_generated_text(gen_text["response"])
|
|
850
|
+
unit.set_status("success")
|
|
851
|
+
doc_results['units'].append(unit)
|
|
852
|
+
except Exception as e:
|
|
853
|
+
warnings.warn(f"Error processing unit for doc idx {idx}: {e}")
|
|
854
|
+
finally:
|
|
855
|
+
doc_results['pending'] -= 1
|
|
856
|
+
if doc_results['pending'] <= 0:
|
|
857
|
+
final_frames = self._post_process_and_create_frames(doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
|
|
858
|
+
output_payload = {'idx': idx, 'frames': final_frames}
|
|
859
|
+
if return_messages_log:
|
|
860
|
+
output_payload['messages_log'] = messages_logger.get_messages_log()
|
|
861
|
+
await output_queue.put(output_payload)
|
|
862
|
+
|
|
863
|
+
tasks_queue.task_done()
|
|
864
|
+
|
|
865
|
+
# Start producer and workers
|
|
866
|
+
producer_task = asyncio.create_task(producer())
|
|
867
|
+
worker_tasks = [asyncio.create_task(worker()) for _ in range(llm_concurrency)]
|
|
868
|
+
|
|
869
|
+
# Main loop to gather results
|
|
870
|
+
docs_completed = 0
|
|
871
|
+
while docs_completed < len(text_contents):
|
|
872
|
+
result = await output_queue.get()
|
|
873
|
+
yield result
|
|
874
|
+
docs_completed += 1
|
|
875
|
+
|
|
876
|
+
# Final cleanup
|
|
877
|
+
await producer_task
|
|
878
|
+
await tasks_queue.join()
|
|
879
|
+
|
|
880
|
+
# Cancel any lingering worker tasks
|
|
881
|
+
for task in worker_tasks:
|
|
882
|
+
task.cancel()
|
|
883
|
+
await asyncio.gather(*worker_tasks, return_exceptions=True)
|
|
884
|
+
|
|
885
|
+
cpu_executor.shutdown(wait=False)
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
def _post_process_and_create_frames(self, doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
|
|
889
|
+
"""Helper function to run post-processing logic for a completed document."""
|
|
890
|
+
ENTITY_KEY = "entity_text"
|
|
891
|
+
frame_list = []
|
|
892
|
+
for res in sorted(doc_results['units'], key=lambda r: r.start):
|
|
893
|
+
entity_json = []
|
|
894
|
+
for entity in extract_json(gen_text=res.gen_text):
|
|
895
|
+
if ENTITY_KEY in entity:
|
|
896
|
+
entity_json.append(entity)
|
|
897
|
+
else:
|
|
898
|
+
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
899
|
+
|
|
900
|
+
spans = self._find_entity_spans(
|
|
901
|
+
text=res.text,
|
|
902
|
+
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
903
|
+
case_sensitive=case_sensitive,
|
|
904
|
+
fuzzy_match=fuzzy_match,
|
|
905
|
+
fuzzy_buffer_size=fuzzy_buffer_size,
|
|
906
|
+
fuzzy_score_cutoff=fuzzy_score_cutoff,
|
|
907
|
+
allow_overlap_entities=allow_overlap_entities
|
|
908
|
+
)
|
|
909
|
+
for ent, span in zip(entity_json, spans):
|
|
910
|
+
if span is not None:
|
|
911
|
+
start, end = span
|
|
912
|
+
entity_text = res.text[start:end]
|
|
913
|
+
start += res.start
|
|
914
|
+
end += res.start
|
|
915
|
+
attr = ent.get("attr", {}) or {}
|
|
916
|
+
frame = LLMInformationExtractionFrame(
|
|
917
|
+
frame_id=f"{len(frame_list)}",
|
|
918
|
+
start=start,
|
|
919
|
+
end=end,
|
|
920
|
+
entity_text=entity_text,
|
|
921
|
+
attr=attr
|
|
922
|
+
)
|
|
923
|
+
frame_list.append(frame)
|
|
924
|
+
return frame_list
|
|
925
|
+
|
|
926
|
+
|
|
834
927
|
class ReviewFrameExtractor(DirectFrameExtractor):
|
|
835
928
|
def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
|
|
836
929
|
prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
|
|
@@ -902,7 +995,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
902
995
|
raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
|
|
903
996
|
|
|
904
997
|
def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
905
|
-
verbose:bool=False, return_messages_log:bool=False) -> List[
|
|
998
|
+
verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
906
999
|
"""
|
|
907
1000
|
This method inputs a text and outputs a list of outputs per unit.
|
|
908
1001
|
|
|
@@ -923,8 +1016,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
923
1016
|
Return : List[FrameExtractionUnitResult]
|
|
924
1017
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
925
1018
|
"""
|
|
926
|
-
# define output
|
|
927
|
-
output = []
|
|
928
1019
|
# unit chunking
|
|
929
1020
|
if isinstance(text_content, str):
|
|
930
1021
|
doc_text = text_content
|
|
@@ -937,9 +1028,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
937
1028
|
units = self.unit_chunker.chunk(doc_text)
|
|
938
1029
|
# context chunker init
|
|
939
1030
|
self.context_chunker.fit(doc_text, units)
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
1031
|
+
|
|
1032
|
+
# messages logger init
|
|
1033
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
943
1034
|
|
|
944
1035
|
# generate unit by unit
|
|
945
1036
|
for i, unit in enumerate(units):
|
|
@@ -973,7 +1064,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
973
1064
|
messages.append({'role': 'user', 'content': unit.text})
|
|
974
1065
|
|
|
975
1066
|
if verbose:
|
|
976
|
-
print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
1067
|
+
print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
977
1068
|
if context != "":
|
|
978
1069
|
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
979
1070
|
|
|
@@ -983,7 +1074,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
983
1074
|
initial = self.inference_engine.chat(
|
|
984
1075
|
messages=messages,
|
|
985
1076
|
verbose=verbose,
|
|
986
|
-
stream=False
|
|
1077
|
+
stream=False,
|
|
1078
|
+
messages_logger=messages_logger
|
|
987
1079
|
)
|
|
988
1080
|
|
|
989
1081
|
# <--- Review step --->
|
|
@@ -996,7 +1088,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
996
1088
|
review = self.inference_engine.chat(
|
|
997
1089
|
messages=messages,
|
|
998
1090
|
verbose=verbose,
|
|
999
|
-
stream=False
|
|
1091
|
+
stream=False,
|
|
1092
|
+
messages_logger=messages_logger
|
|
1000
1093
|
)
|
|
1001
1094
|
|
|
1002
1095
|
# Output
|
|
@@ -1005,28 +1098,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1005
1098
|
elif self.review_mode == "addition":
|
|
1006
1099
|
gen_text = initial["response"] + '\n' + review["response"]
|
|
1007
1100
|
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
message = {"role": "assistant", "content": review["response"]}
|
|
1013
|
-
if "reasoning" in review:
|
|
1014
|
-
message["reasoning"] = review["reasoning"]
|
|
1015
|
-
messages.append(message)
|
|
1016
|
-
messages_log.append(messages)
|
|
1017
|
-
|
|
1018
|
-
# add to output
|
|
1019
|
-
result = FrameExtractionUnitResult(
|
|
1020
|
-
start=unit.start,
|
|
1021
|
-
end=unit.end,
|
|
1022
|
-
text=unit.text,
|
|
1023
|
-
gen_text=gen_text)
|
|
1024
|
-
output.append(result)
|
|
1101
|
+
# add generated text to unit
|
|
1102
|
+
unit.set_generated_text(gen_text)
|
|
1103
|
+
unit.set_status("success")
|
|
1025
1104
|
|
|
1026
1105
|
if return_messages_log:
|
|
1027
|
-
return
|
|
1106
|
+
return units, messages_logger.get_messages_log()
|
|
1028
1107
|
|
|
1029
|
-
return
|
|
1108
|
+
return units
|
|
1030
1109
|
|
|
1031
1110
|
|
|
1032
1111
|
def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
|
|
@@ -1122,7 +1201,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1122
1201
|
yield chunk
|
|
1123
1202
|
|
|
1124
1203
|
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
1125
|
-
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[
|
|
1204
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnit]:
|
|
1126
1205
|
"""
|
|
1127
1206
|
This is the asynchronous version of the extract() method with the review step.
|
|
1128
1207
|
|
|
@@ -1154,11 +1233,15 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1154
1233
|
else:
|
|
1155
1234
|
raise TypeError("text_content must be a string or a dictionary.")
|
|
1156
1235
|
|
|
1236
|
+
# unit chunking
|
|
1157
1237
|
units = self.unit_chunker.chunk(doc_text)
|
|
1158
1238
|
|
|
1159
1239
|
# context chunker init
|
|
1160
1240
|
self.context_chunker.fit(doc_text, units)
|
|
1161
1241
|
|
|
1242
|
+
# messages logger init
|
|
1243
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1244
|
+
|
|
1162
1245
|
# <--- Initial generation step --->
|
|
1163
1246
|
initial_tasks_input = []
|
|
1164
1247
|
for i, unit in enumerate(units):
|
|
@@ -1202,7 +1285,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1202
1285
|
|
|
1203
1286
|
async with semaphore:
|
|
1204
1287
|
gen_text = await self.inference_engine.chat_async(
|
|
1205
|
-
messages=messages
|
|
1288
|
+
messages=messages,
|
|
1289
|
+
messages_logger=messages_logger
|
|
1206
1290
|
)
|
|
1207
1291
|
# Return initial generation result along with the messages used and the unit
|
|
1208
1292
|
out = {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text["response"], "initial_messages": messages}
|
|
@@ -1253,16 +1337,11 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1253
1337
|
|
|
1254
1338
|
async with semaphore:
|
|
1255
1339
|
review_gen_text = await self.inference_engine.chat_async(
|
|
1256
|
-
messages=messages
|
|
1340
|
+
messages=messages,
|
|
1341
|
+
messages_logger=messages_logger
|
|
1257
1342
|
)
|
|
1258
1343
|
# Combine initial and review results
|
|
1259
1344
|
task_data["review_gen_text"] = review_gen_text["response"]
|
|
1260
|
-
if return_messages_log:
|
|
1261
|
-
# Log for the review call itself
|
|
1262
|
-
message = {'role': 'assistant', 'content': review_gen_text["response"]}
|
|
1263
|
-
if "reasoning" in review_gen_text:
|
|
1264
|
-
message["reasoning"] = review_gen_text["reasoning"]
|
|
1265
|
-
task_data["full_review_log"] = task_data["full_initial_log"] + [message]
|
|
1266
1345
|
return task_data # Return the augmented dictionary
|
|
1267
1346
|
|
|
1268
1347
|
# Create and gather review tasks
|
|
@@ -1279,9 +1358,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1279
1358
|
final_results_raw.sort(key=lambda x: x["original_index"])
|
|
1280
1359
|
|
|
1281
1360
|
# <--- Process final results --->
|
|
1282
|
-
output: List[FrameExtractionUnitResult] = []
|
|
1283
|
-
messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
|
|
1284
|
-
|
|
1285
1361
|
for result_data in final_results_raw:
|
|
1286
1362
|
unit = result_data["unit"]
|
|
1287
1363
|
initial_gen = result_data["initial_gen_text"]
|
|
@@ -1296,23 +1372,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1296
1372
|
final_gen_text = review_gen # Default to revision if mode is somehow invalid
|
|
1297
1373
|
|
|
1298
1374
|
# Create final result object
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
end=unit.end,
|
|
1302
|
-
text=unit.text,
|
|
1303
|
-
gen_text=final_gen_text # Use the combined/reviewed text
|
|
1304
|
-
)
|
|
1305
|
-
output.append(result)
|
|
1306
|
-
|
|
1307
|
-
# Append full conversation log if requested
|
|
1308
|
-
if return_messages_log:
|
|
1309
|
-
full_log_for_unit = result_data["full_review_log"]
|
|
1310
|
-
messages_log.append(full_log_for_unit)
|
|
1375
|
+
unit.set_generated_text(final_gen_text)
|
|
1376
|
+
unit.set_status("success")
|
|
1311
1377
|
|
|
1312
1378
|
if return_messages_log:
|
|
1313
|
-
return
|
|
1379
|
+
return units, messages_logger.get_messages_log()
|
|
1314
1380
|
else:
|
|
1315
|
-
return
|
|
1381
|
+
return units
|
|
1316
1382
|
|
|
1317
1383
|
|
|
1318
1384
|
class BasicFrameExtractor(DirectFrameExtractor):
|
|
@@ -1549,6 +1615,9 @@ class AttributeExtractor(Extractor):
|
|
|
1549
1615
|
a dictionary of attributes extracted from the frame.
|
|
1550
1616
|
If return_messages_log is True, a list of messages will be returned as well.
|
|
1551
1617
|
"""
|
|
1618
|
+
# messages logger init
|
|
1619
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1620
|
+
|
|
1552
1621
|
# construct chat messages
|
|
1553
1622
|
messages = []
|
|
1554
1623
|
if self.system_prompt:
|
|
@@ -1567,19 +1636,15 @@ class AttributeExtractor(Extractor):
|
|
|
1567
1636
|
gen_text = self.inference_engine.chat(
|
|
1568
1637
|
messages=messages,
|
|
1569
1638
|
verbose=verbose,
|
|
1570
|
-
stream=False
|
|
1639
|
+
stream=False,
|
|
1640
|
+
messages_logger=messages_logger
|
|
1571
1641
|
)
|
|
1572
|
-
if return_messages_log:
|
|
1573
|
-
message = {"role": "assistant", "content": gen_text["response"]}
|
|
1574
|
-
if "reasoning" in gen_text:
|
|
1575
|
-
message["reasoning"] = gen_text["reasoning"]
|
|
1576
|
-
messages.append(message)
|
|
1577
1642
|
|
|
1578
|
-
attribute_list =
|
|
1643
|
+
attribute_list = extract_json(gen_text=gen_text["response"])
|
|
1579
1644
|
if isinstance(attribute_list, list) and len(attribute_list) > 0:
|
|
1580
1645
|
attributes = attribute_list[0]
|
|
1581
1646
|
if return_messages_log:
|
|
1582
|
-
return attributes,
|
|
1647
|
+
return attributes, messages_logger.get_messages_log()
|
|
1583
1648
|
return attributes
|
|
1584
1649
|
|
|
1585
1650
|
|
|
@@ -1620,7 +1685,7 @@ class AttributeExtractor(Extractor):
|
|
|
1620
1685
|
if return_messages_log:
|
|
1621
1686
|
attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1622
1687
|
verbose=verbose, return_messages_log=return_messages_log)
|
|
1623
|
-
messages_log.
|
|
1688
|
+
messages_log.extend(messages)
|
|
1624
1689
|
else:
|
|
1625
1690
|
attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1626
1691
|
verbose=verbose, return_messages_log=return_messages_log)
|
|
@@ -1669,6 +1734,9 @@ class AttributeExtractor(Extractor):
|
|
|
1669
1734
|
if not isinstance(text, str):
|
|
1670
1735
|
raise TypeError(f"Expect text as str, received {type(text)} instead.")
|
|
1671
1736
|
|
|
1737
|
+
# messages logger init
|
|
1738
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1739
|
+
|
|
1672
1740
|
# async helper
|
|
1673
1741
|
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1674
1742
|
|
|
@@ -1681,15 +1749,8 @@ class AttributeExtractor(Extractor):
|
|
|
1681
1749
|
context = self._get_context(frame, text, context_size)
|
|
1682
1750
|
messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
|
|
1683
1751
|
|
|
1684
|
-
gen_text = await self.inference_engine.chat_async(messages=messages)
|
|
1685
|
-
|
|
1686
|
-
if return_messages_log:
|
|
1687
|
-
message = {"role": "assistant", "content": gen_text["response"]}
|
|
1688
|
-
if "reasoning" in gen_text:
|
|
1689
|
-
message["reasoning"] = gen_text["reasoning"]
|
|
1690
|
-
messages.append(message)
|
|
1691
|
-
|
|
1692
|
-
attribute_list = self._extract_json(gen_text=gen_text["response"])
|
|
1752
|
+
gen_text = await self.inference_engine.chat_async(messages=messages, messages_logger=messages_logger)
|
|
1753
|
+
attribute_list = extract_json(gen_text=gen_text["response"])
|
|
1693
1754
|
attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
|
|
1694
1755
|
return {"frame": frame, "attributes": attributes, "messages": messages}
|
|
1695
1756
|
|
|
@@ -1699,12 +1760,8 @@ class AttributeExtractor(Extractor):
|
|
|
1699
1760
|
|
|
1700
1761
|
# process results
|
|
1701
1762
|
new_frames = []
|
|
1702
|
-
messages_log = [] if return_messages_log else None
|
|
1703
1763
|
|
|
1704
1764
|
for result in results:
|
|
1705
|
-
if return_messages_log:
|
|
1706
|
-
messages_log.append(result["messages"])
|
|
1707
|
-
|
|
1708
1765
|
if inplace:
|
|
1709
1766
|
result["frame"].attr.update(result["attributes"])
|
|
1710
1767
|
else:
|
|
@@ -1714,9 +1771,9 @@ class AttributeExtractor(Extractor):
|
|
|
1714
1771
|
|
|
1715
1772
|
# output
|
|
1716
1773
|
if inplace:
|
|
1717
|
-
return
|
|
1774
|
+
return messages_logger.get_messages_log() if return_messages_log else None
|
|
1718
1775
|
else:
|
|
1719
|
-
return (new_frames,
|
|
1776
|
+
return (new_frames, messages_logger.get_messages_log()) if return_messages_log else new_frames
|
|
1720
1777
|
|
|
1721
1778
|
def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
1722
1779
|
concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
|
|
@@ -1839,7 +1896,7 @@ class RelationExtractor(Extractor):
|
|
|
1839
1896
|
return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1840
1897
|
pairs = itertools.combinations(doc.frames, 2)
|
|
1841
1898
|
relations = []
|
|
1842
|
-
|
|
1899
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1843
1900
|
|
|
1844
1901
|
for frame_1, frame_2 in pairs:
|
|
1845
1902
|
task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
|
|
@@ -1851,20 +1908,14 @@ class RelationExtractor(Extractor):
|
|
|
1851
1908
|
|
|
1852
1909
|
gen_text = self.inference_engine.chat(
|
|
1853
1910
|
messages=task_payload['messages'],
|
|
1854
|
-
verbose=verbose
|
|
1911
|
+
verbose=verbose,
|
|
1912
|
+
messages_logger=messages_logger
|
|
1855
1913
|
)
|
|
1856
1914
|
relation = self._post_process_result(gen_text["response"], task_payload)
|
|
1857
1915
|
if relation:
|
|
1858
1916
|
relations.append(relation)
|
|
1859
1917
|
|
|
1860
|
-
|
|
1861
|
-
message = {"role": "assistant", "content": gen_text["response"]}
|
|
1862
|
-
if "reasoning" in gen_text:
|
|
1863
|
-
message["reasoning"] = gen_text["reasoning"]
|
|
1864
|
-
task_payload['messages'].append(message)
|
|
1865
|
-
messages_log.append(task_payload['messages'])
|
|
1866
|
-
|
|
1867
|
-
return (relations, messages_log) if return_messages_log else relations
|
|
1918
|
+
return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
|
|
1868
1919
|
|
|
1869
1920
|
async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1870
1921
|
pairs = list(itertools.combinations(doc.frames, 2))
|
|
@@ -1873,12 +1924,12 @@ class RelationExtractor(Extractor):
|
|
|
1873
1924
|
tasks_input = [task for task in tasks_input if task is not None]
|
|
1874
1925
|
|
|
1875
1926
|
relations = []
|
|
1876
|
-
|
|
1927
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1877
1928
|
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1878
1929
|
|
|
1879
1930
|
async def semaphore_helper(task_payload: Dict):
|
|
1880
1931
|
async with semaphore:
|
|
1881
|
-
gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
|
|
1932
|
+
gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'], messages_logger=messages_logger)
|
|
1882
1933
|
return gen_text, task_payload
|
|
1883
1934
|
|
|
1884
1935
|
tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
|
|
@@ -1889,14 +1940,7 @@ class RelationExtractor(Extractor):
|
|
|
1889
1940
|
if relation:
|
|
1890
1941
|
relations.append(relation)
|
|
1891
1942
|
|
|
1892
|
-
|
|
1893
|
-
message = {"role": "assistant", "content": gen_text["response"]}
|
|
1894
|
-
if "reasoning" in gen_text:
|
|
1895
|
-
message["reasoning"] = gen_text["reasoning"]
|
|
1896
|
-
task_payload['messages'].append(message)
|
|
1897
|
-
messages_log.append(task_payload['messages'])
|
|
1898
|
-
|
|
1899
|
-
return (relations, messages_log) if return_messages_log else relations
|
|
1943
|
+
return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
|
|
1900
1944
|
|
|
1901
1945
|
def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
|
|
1902
1946
|
if not doc.has_frame():
|
|
@@ -1959,7 +2003,7 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1959
2003
|
return None
|
|
1960
2004
|
|
|
1961
2005
|
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1962
|
-
rel_json =
|
|
2006
|
+
rel_json = extract_json(gen_text)
|
|
1963
2007
|
if len(rel_json) > 0 and "Relation" in rel_json[0]:
|
|
1964
2008
|
rel = rel_json[0]["Relation"]
|
|
1965
2009
|
if (isinstance(rel, bool) and rel) or (isinstance(rel, str) and rel.lower() == 'true'):
|
|
@@ -2025,7 +2069,7 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
2025
2069
|
return None
|
|
2026
2070
|
|
|
2027
2071
|
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
2028
|
-
rel_json =
|
|
2072
|
+
rel_json = extract_json(gen_text)
|
|
2029
2073
|
pos_rel_types = pair_data['pos_rel_types']
|
|
2030
2074
|
if len(rel_json) > 0 and "RelationType" in rel_json[0]:
|
|
2031
2075
|
rel_type = rel_json[0]["RelationType"]
|