llm-ie 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +5 -4
- llm_ie/chunkers.py +78 -4
- llm_ie/data_types.py +23 -37
- llm_ie/engines.py +663 -112
- llm_ie/extractors.py +357 -206
- llm_ie/prompt_editor.py +4 -4
- {llm_ie-1.2.1.dist-info → llm_ie-1.2.3.dist-info}/METADATA +1 -1
- {llm_ie-1.2.1.dist-info → llm_ie-1.2.3.dist-info}/RECORD +9 -9
- {llm_ie-1.2.1.dist-info → llm_ie-1.2.3.dist-info}/WHEEL +0 -0
llm_ie/extractors.py
CHANGED
|
@@ -8,11 +8,12 @@ import warnings
|
|
|
8
8
|
import itertools
|
|
9
9
|
import asyncio
|
|
10
10
|
import nest_asyncio
|
|
11
|
-
from
|
|
12
|
-
from
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
+
from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
|
|
13
|
+
from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
13
14
|
from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
|
|
14
15
|
from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
|
|
15
|
-
from llm_ie.engines import InferenceEngine
|
|
16
|
+
from llm_ie.engines import InferenceEngine, MessagesLogger
|
|
16
17
|
from colorama import Fore, Style
|
|
17
18
|
|
|
18
19
|
|
|
@@ -405,7 +406,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
405
406
|
|
|
406
407
|
|
|
407
408
|
def extract(self, text_content:Union[str, Dict[str,str]],
|
|
408
|
-
document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[
|
|
409
|
+
document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
409
410
|
"""
|
|
410
411
|
This method inputs a text and outputs a list of outputs per unit.
|
|
411
412
|
|
|
@@ -423,11 +424,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
423
424
|
return_messages_log : bool, Optional
|
|
424
425
|
if True, a list of messages will be returned.
|
|
425
426
|
|
|
426
|
-
Return : List[
|
|
427
|
+
Return : List[FrameExtractionUnit]
|
|
427
428
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
428
429
|
"""
|
|
429
|
-
# define output
|
|
430
|
-
output = []
|
|
431
430
|
# unit chunking
|
|
432
431
|
if isinstance(text_content, str):
|
|
433
432
|
doc_text = text_content
|
|
@@ -440,73 +439,70 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
440
439
|
units = self.unit_chunker.chunk(doc_text)
|
|
441
440
|
# context chunker init
|
|
442
441
|
self.context_chunker.fit(doc_text, units)
|
|
442
|
+
|
|
443
443
|
# messages log
|
|
444
|
-
if return_messages_log
|
|
445
|
-
messages_log = []
|
|
444
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
446
445
|
|
|
447
446
|
# generate unit by unit
|
|
448
447
|
for i, unit in enumerate(units):
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
context = self.context_chunker.chunk(unit)
|
|
455
|
-
|
|
456
|
-
if context == "":
|
|
457
|
-
# no context, just place unit in user prompt
|
|
458
|
-
if isinstance(text_content, str):
|
|
459
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
460
|
-
else:
|
|
461
|
-
unit_content = text_content.copy()
|
|
462
|
-
unit_content[document_key] = unit.text
|
|
463
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
464
|
-
else:
|
|
465
|
-
# insert context to user prompt
|
|
466
|
-
if isinstance(text_content, str):
|
|
467
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
468
|
-
else:
|
|
469
|
-
context_content = text_content.copy()
|
|
470
|
-
context_content[document_key] = context
|
|
471
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
472
|
-
# simulate conversation where assistant confirms
|
|
473
|
-
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
474
|
-
# place unit of interest
|
|
475
|
-
messages.append({'role': 'user', 'content': unit.text})
|
|
448
|
+
try:
|
|
449
|
+
# construct chat messages
|
|
450
|
+
messages = []
|
|
451
|
+
if self.system_prompt:
|
|
452
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
476
453
|
|
|
477
|
-
|
|
478
|
-
print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
479
|
-
if context != "":
|
|
480
|
-
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
454
|
+
context = self.context_chunker.chunk(unit)
|
|
481
455
|
|
|
482
|
-
|
|
456
|
+
if context == "":
|
|
457
|
+
# no context, just place unit in user prompt
|
|
458
|
+
if isinstance(text_content, str):
|
|
459
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
460
|
+
else:
|
|
461
|
+
unit_content = text_content.copy()
|
|
462
|
+
unit_content[document_key] = unit.text
|
|
463
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
464
|
+
else:
|
|
465
|
+
# insert context to user prompt
|
|
466
|
+
if isinstance(text_content, str):
|
|
467
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
468
|
+
else:
|
|
469
|
+
context_content = text_content.copy()
|
|
470
|
+
context_content[document_key] = context
|
|
471
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
472
|
+
# simulate conversation where assistant confirms
|
|
473
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
474
|
+
# place unit of interest
|
|
475
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
483
476
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
477
|
+
if verbose:
|
|
478
|
+
print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
479
|
+
if context != "":
|
|
480
|
+
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
481
|
+
|
|
482
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
490
483
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
484
|
+
|
|
485
|
+
gen_text = self.inference_engine.chat(
|
|
486
|
+
messages=messages,
|
|
487
|
+
verbose=verbose,
|
|
488
|
+
stream=False,
|
|
489
|
+
messages_logger=messages_logger
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# add generated text to unit
|
|
493
|
+
unit.set_generated_text(gen_text["response"])
|
|
494
|
+
unit.set_status("success")
|
|
495
|
+
except Exception as e:
|
|
496
|
+
unit.set_status("fail")
|
|
497
|
+
warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
|
|
502
498
|
|
|
503
499
|
if return_messages_log:
|
|
504
|
-
return
|
|
500
|
+
return units, messages_logger.get_messages_log()
|
|
505
501
|
|
|
506
|
-
return
|
|
502
|
+
return units
|
|
507
503
|
|
|
508
504
|
def stream(self, text_content: Union[str, Dict[str, str]],
|
|
509
|
-
document_key: str = None) -> Generator[Dict[str, Any], None, List[
|
|
505
|
+
document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
|
|
510
506
|
"""
|
|
511
507
|
Streams LLM responses per unit with structured event types,
|
|
512
508
|
and returns collected data for post-processing.
|
|
@@ -522,12 +518,10 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
522
518
|
|
|
523
519
|
Returns:
|
|
524
520
|
--------
|
|
525
|
-
List[
|
|
526
|
-
A list of
|
|
521
|
+
List[FrameExtractionUnit]:
|
|
522
|
+
A list of FrameExtractionUnit objects, each containing the
|
|
527
523
|
original unit details and the fully accumulated 'gen_text' from the LLM.
|
|
528
524
|
"""
|
|
529
|
-
collected_results: List[FrameExtractionUnitResult] = []
|
|
530
|
-
|
|
531
525
|
if isinstance(text_content, str):
|
|
532
526
|
doc_text = text_content
|
|
533
527
|
elif isinstance(text_content, dict):
|
|
@@ -581,22 +575,18 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
581
575
|
)
|
|
582
576
|
for chunk in response_stream:
|
|
583
577
|
yield chunk
|
|
584
|
-
|
|
578
|
+
if chunk["type"] == "response":
|
|
579
|
+
current_gen_text += chunk["data"]
|
|
585
580
|
|
|
586
581
|
# Store the result for this unit
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
end=unit.end,
|
|
590
|
-
text=unit.text,
|
|
591
|
-
gen_text=current_gen_text
|
|
592
|
-
)
|
|
593
|
-
collected_results.append(result_for_unit)
|
|
582
|
+
unit.set_generated_text(current_gen_text)
|
|
583
|
+
unit.set_status("success")
|
|
594
584
|
|
|
595
585
|
yield {"type": "info", "data": "All units processed by LLM."}
|
|
596
|
-
return
|
|
586
|
+
return units
|
|
597
587
|
|
|
598
588
|
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
599
|
-
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[
|
|
589
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
600
590
|
"""
|
|
601
591
|
This is the asynchronous version of the extract() method.
|
|
602
592
|
|
|
@@ -614,7 +604,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
614
604
|
return_messages_log : bool, Optional
|
|
615
605
|
if True, a list of messages will be returned.
|
|
616
606
|
|
|
617
|
-
Return : List[
|
|
607
|
+
Return : List[FrameExtractionUnit]
|
|
618
608
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
619
609
|
"""
|
|
620
610
|
if isinstance(text_content, str):
|
|
@@ -633,6 +623,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
633
623
|
# context chunker init
|
|
634
624
|
self.context_chunker.fit(doc_text, units)
|
|
635
625
|
|
|
626
|
+
# messages logger init
|
|
627
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
628
|
+
|
|
636
629
|
# Prepare inputs for all units first
|
|
637
630
|
tasks_input = []
|
|
638
631
|
for i, unit in enumerate(units):
|
|
@@ -673,13 +666,15 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
673
666
|
async def semaphore_helper(task_data: Dict, **kwrs):
|
|
674
667
|
unit = task_data["unit"]
|
|
675
668
|
messages = task_data["messages"]
|
|
676
|
-
original_index = task_data["original_index"]
|
|
677
669
|
|
|
678
670
|
async with semaphore:
|
|
679
671
|
gen_text = await self.inference_engine.chat_async(
|
|
680
|
-
messages=messages
|
|
672
|
+
messages=messages,
|
|
673
|
+
messages_logger=messages_logger
|
|
681
674
|
)
|
|
682
|
-
|
|
675
|
+
|
|
676
|
+
unit.set_generated_text(gen_text["response"])
|
|
677
|
+
unit.set_status("success")
|
|
683
678
|
|
|
684
679
|
# Create and gather tasks
|
|
685
680
|
tasks = []
|
|
@@ -689,37 +684,13 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
689
684
|
))
|
|
690
685
|
tasks.append(task)
|
|
691
686
|
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
# Sort results back into original order using the index stored
|
|
695
|
-
results_raw.sort(key=lambda x: x["original_index"])
|
|
696
|
-
|
|
697
|
-
# Restructure the results
|
|
698
|
-
output: List[FrameExtractionUnitResult] = []
|
|
699
|
-
messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
|
|
700
|
-
|
|
701
|
-
for result_data in results_raw:
|
|
702
|
-
unit = result_data["unit"]
|
|
703
|
-
gen_text = result_data["gen_text"]
|
|
704
|
-
|
|
705
|
-
# Create result object
|
|
706
|
-
result = FrameExtractionUnitResult(
|
|
707
|
-
start=unit.start,
|
|
708
|
-
end=unit.end,
|
|
709
|
-
text=unit.text,
|
|
710
|
-
gen_text=gen_text
|
|
711
|
-
)
|
|
712
|
-
output.append(result)
|
|
713
|
-
|
|
714
|
-
# Append to messages log if requested
|
|
715
|
-
if return_messages_log:
|
|
716
|
-
final_messages = result_data["messages"] + [{"role": "assistant", "content": gen_text}]
|
|
717
|
-
messages_log.append(final_messages)
|
|
687
|
+
await asyncio.gather(*tasks)
|
|
718
688
|
|
|
689
|
+
# Return units
|
|
719
690
|
if return_messages_log:
|
|
720
|
-
return
|
|
691
|
+
return units, messages_logger.get_messages_log()
|
|
721
692
|
else:
|
|
722
|
-
return
|
|
693
|
+
return units
|
|
723
694
|
|
|
724
695
|
|
|
725
696
|
def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
@@ -727,7 +698,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
727
698
|
case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
728
699
|
allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
|
|
729
700
|
"""
|
|
730
|
-
This method inputs a text and outputs a list of LLMInformationExtractionFrame
|
|
701
|
+
This method inputs a document text and outputs a list of LLMInformationExtractionFrame
|
|
731
702
|
It use the extract() method and post-process outputs into frames.
|
|
732
703
|
|
|
733
704
|
Parameters:
|
|
@@ -780,18 +751,21 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
780
751
|
verbose=verbose,
|
|
781
752
|
return_messages_log=return_messages_log)
|
|
782
753
|
|
|
783
|
-
|
|
754
|
+
units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
784
755
|
|
|
785
756
|
frame_list = []
|
|
786
|
-
for
|
|
757
|
+
for unit in units:
|
|
787
758
|
entity_json = []
|
|
788
|
-
|
|
759
|
+
if unit.status != "success":
|
|
760
|
+
warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
|
|
761
|
+
continue
|
|
762
|
+
for entity in self._extract_json(gen_text=unit.gen_text):
|
|
789
763
|
if ENTITY_KEY in entity:
|
|
790
764
|
entity_json.append(entity)
|
|
791
765
|
else:
|
|
792
766
|
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
793
767
|
|
|
794
|
-
spans = self._find_entity_spans(text=
|
|
768
|
+
spans = self._find_entity_spans(text=unit.text,
|
|
795
769
|
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
796
770
|
case_sensitive=case_sensitive,
|
|
797
771
|
fuzzy_match=fuzzy_match,
|
|
@@ -801,9 +775,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
801
775
|
for ent, span in zip(entity_json, spans):
|
|
802
776
|
if span is not None:
|
|
803
777
|
start, end = span
|
|
804
|
-
entity_text =
|
|
805
|
-
start +=
|
|
806
|
-
end +=
|
|
778
|
+
entity_text = unit.text[start:end]
|
|
779
|
+
start += unit.start
|
|
780
|
+
end += unit.start
|
|
807
781
|
attr = {}
|
|
808
782
|
if "attr" in ent and ent["attr"] is not None:
|
|
809
783
|
attr = ent["attr"]
|
|
@@ -820,6 +794,208 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
820
794
|
return frame_list
|
|
821
795
|
|
|
822
796
|
|
|
797
|
+
async def extract_frames_from_documents(self, text_contents:List[Union[str,Dict[str, any]]], document_key:str="text",
|
|
798
|
+
cpu_concurrency:int=4, llm_concurrency:int=32, case_sensitive:bool=False,
|
|
799
|
+
fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
800
|
+
allow_overlap_entities:bool=False, return_messages_log:bool=False) -> AsyncGenerator[Dict[str, any], None]:
|
|
801
|
+
"""
|
|
802
|
+
This method inputs a list of documents and yields the results for each document as soon as it is complete.
|
|
803
|
+
|
|
804
|
+
Parameters:
|
|
805
|
+
-----------
|
|
806
|
+
text_contents : List[Union[str,Dict[str, any]]]
|
|
807
|
+
a list of input text contents to put in prompt template.
|
|
808
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
809
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
810
|
+
document_key: str, optional
|
|
811
|
+
The key in the `text_contents` dictionaries that holds the document text.
|
|
812
|
+
cpu_concurrency: int, optional
|
|
813
|
+
The number of parallel threads to use for CPU-bound tasks like chunking.
|
|
814
|
+
llm_concurrency: int, optional
|
|
815
|
+
The number of concurrent requests to make to the LLM.
|
|
816
|
+
case_sensitive : bool, Optional
|
|
817
|
+
if True, entity text matching will be case-sensitive.
|
|
818
|
+
fuzzy_match : bool, Optional
|
|
819
|
+
if True, fuzzy matching will be applied to find entity text.
|
|
820
|
+
fuzzy_buffer_size : float, Optional
|
|
821
|
+
the buffer size for fuzzy matching. Default is 20% of entity text length.
|
|
822
|
+
fuzzy_score_cutoff : float, Optional
|
|
823
|
+
the Jaccard score cutoff for fuzzy matching.
|
|
824
|
+
Matched entity text must have a score higher than this value or a None will be returned.
|
|
825
|
+
allow_overlap_entities : bool, Optional
|
|
826
|
+
if True, entities can overlap in the text.
|
|
827
|
+
return_messages_log : bool, Optional
|
|
828
|
+
if True, a list of messages will be returned.
|
|
829
|
+
|
|
830
|
+
Yields:
|
|
831
|
+
-------
|
|
832
|
+
AsyncGenerator[Dict[str, any], None]
|
|
833
|
+
A dictionary for each completed document, containing its 'idx' and extracted 'frames'.
|
|
834
|
+
"""
|
|
835
|
+
# Validate text_contents must be a list of str or dict, and not both
|
|
836
|
+
if not isinstance(text_contents, list):
|
|
837
|
+
raise ValueError("text_contents must be a list of strings or dictionaries.")
|
|
838
|
+
if all(isinstance(doc, str) for doc in text_contents):
|
|
839
|
+
pass
|
|
840
|
+
elif all(isinstance(doc, dict) for doc in text_contents):
|
|
841
|
+
pass
|
|
842
|
+
# Set CPU executor and queues
|
|
843
|
+
cpu_executor = ThreadPoolExecutor(max_workers=cpu_concurrency)
|
|
844
|
+
tasks_queue = asyncio.Queue(maxsize=llm_concurrency * 2)
|
|
845
|
+
# Store to track units and pending counts
|
|
846
|
+
results_store = {
|
|
847
|
+
idx: {'pending': 0, 'units': [], 'text': doc if isinstance(doc, str) else doc.get(document_key, "")}
|
|
848
|
+
for idx, doc in enumerate(text_contents)
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
output_queue = asyncio.Queue()
|
|
852
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
853
|
+
|
|
854
|
+
async def producer():
|
|
855
|
+
try:
|
|
856
|
+
for idx, text_content in enumerate(text_contents):
|
|
857
|
+
text = text_content if isinstance(text_content, str) else text_content.get(document_key, "")
|
|
858
|
+
if not text:
|
|
859
|
+
warnings.warn(f"Document at index {idx} is empty or missing the document key '{document_key}'.")
|
|
860
|
+
# signal that this document is done
|
|
861
|
+
await output_queue.put({'idx': idx, 'frames': []})
|
|
862
|
+
continue
|
|
863
|
+
|
|
864
|
+
units = await self.unit_chunker.chunk_async(text, cpu_executor)
|
|
865
|
+
await self.context_chunker.fit_async(text, units, cpu_executor)
|
|
866
|
+
results_store[idx]['pending'] = len(units)
|
|
867
|
+
|
|
868
|
+
# Handle cases where a document yields no units
|
|
869
|
+
if not units:
|
|
870
|
+
# signal that this document is done
|
|
871
|
+
await output_queue.put({'idx': idx, 'frames': []})
|
|
872
|
+
continue
|
|
873
|
+
|
|
874
|
+
# Iterate through units
|
|
875
|
+
for unit in units:
|
|
876
|
+
context = await self.context_chunker.chunk_async(unit, cpu_executor)
|
|
877
|
+
messages = []
|
|
878
|
+
if self.system_prompt:
|
|
879
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
880
|
+
|
|
881
|
+
if not context:
|
|
882
|
+
if isinstance(text_content, str):
|
|
883
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
884
|
+
else:
|
|
885
|
+
unit_content = text_content.copy()
|
|
886
|
+
unit_content[document_key] = unit.text
|
|
887
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
888
|
+
else:
|
|
889
|
+
# insert context to user prompt
|
|
890
|
+
if isinstance(text_content, str):
|
|
891
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
892
|
+
else:
|
|
893
|
+
context_content = text_content.copy()
|
|
894
|
+
context_content[document_key] = context
|
|
895
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
896
|
+
# simulate conversation where assistant confirms
|
|
897
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
898
|
+
# place unit of interest
|
|
899
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
900
|
+
|
|
901
|
+
await tasks_queue.put({'idx': idx, 'unit': unit, 'messages': messages})
|
|
902
|
+
finally:
|
|
903
|
+
for _ in range(llm_concurrency):
|
|
904
|
+
await tasks_queue.put(None)
|
|
905
|
+
|
|
906
|
+
async def worker():
|
|
907
|
+
while True:
|
|
908
|
+
task_item = await tasks_queue.get()
|
|
909
|
+
if task_item is None:
|
|
910
|
+
tasks_queue.task_done()
|
|
911
|
+
break
|
|
912
|
+
|
|
913
|
+
idx = task_item['idx']
|
|
914
|
+
unit = task_item['unit']
|
|
915
|
+
doc_results = results_store[idx]
|
|
916
|
+
|
|
917
|
+
try:
|
|
918
|
+
gen_text = await self.inference_engine.chat_async(
|
|
919
|
+
messages=task_item['messages'], messages_logger=messages_logger
|
|
920
|
+
)
|
|
921
|
+
unit.set_generated_text(gen_text["response"])
|
|
922
|
+
unit.set_status("success")
|
|
923
|
+
doc_results['units'].append(unit)
|
|
924
|
+
except Exception as e:
|
|
925
|
+
warnings.warn(f"Error processing unit for doc idx {idx}: {e}")
|
|
926
|
+
finally:
|
|
927
|
+
doc_results['pending'] -= 1
|
|
928
|
+
if doc_results['pending'] <= 0:
|
|
929
|
+
final_frames = self._post_process_and_create_frames(doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
|
|
930
|
+
output_payload = {'idx': idx, 'frames': final_frames}
|
|
931
|
+
if return_messages_log:
|
|
932
|
+
output_payload['messages_log'] = messages_logger.get_messages_log()
|
|
933
|
+
await output_queue.put(output_payload)
|
|
934
|
+
|
|
935
|
+
tasks_queue.task_done()
|
|
936
|
+
|
|
937
|
+
# Start producer and workers
|
|
938
|
+
producer_task = asyncio.create_task(producer())
|
|
939
|
+
worker_tasks = [asyncio.create_task(worker()) for _ in range(llm_concurrency)]
|
|
940
|
+
|
|
941
|
+
# Main loop to gather results
|
|
942
|
+
docs_completed = 0
|
|
943
|
+
while docs_completed < len(text_contents):
|
|
944
|
+
result = await output_queue.get()
|
|
945
|
+
yield result
|
|
946
|
+
docs_completed += 1
|
|
947
|
+
|
|
948
|
+
# Final cleanup
|
|
949
|
+
await producer_task
|
|
950
|
+
await tasks_queue.join()
|
|
951
|
+
|
|
952
|
+
# Cancel any lingering worker tasks
|
|
953
|
+
for task in worker_tasks:
|
|
954
|
+
task.cancel()
|
|
955
|
+
await asyncio.gather(*worker_tasks, return_exceptions=True)
|
|
956
|
+
|
|
957
|
+
cpu_executor.shutdown(wait=False)
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def _post_process_and_create_frames(self, doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
|
|
961
|
+
"""Helper function to run post-processing logic for a completed document."""
|
|
962
|
+
ENTITY_KEY = "entity_text"
|
|
963
|
+
frame_list = []
|
|
964
|
+
for res in sorted(doc_results['units'], key=lambda r: r.start):
|
|
965
|
+
entity_json = []
|
|
966
|
+
for entity in self._extract_json(gen_text=res.gen_text):
|
|
967
|
+
if ENTITY_KEY in entity:
|
|
968
|
+
entity_json.append(entity)
|
|
969
|
+
else:
|
|
970
|
+
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
971
|
+
|
|
972
|
+
spans = self._find_entity_spans(
|
|
973
|
+
text=res.text,
|
|
974
|
+
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
975
|
+
case_sensitive=case_sensitive,
|
|
976
|
+
fuzzy_match=fuzzy_match,
|
|
977
|
+
fuzzy_buffer_size=fuzzy_buffer_size,
|
|
978
|
+
fuzzy_score_cutoff=fuzzy_score_cutoff,
|
|
979
|
+
allow_overlap_entities=allow_overlap_entities
|
|
980
|
+
)
|
|
981
|
+
for ent, span in zip(entity_json, spans):
|
|
982
|
+
if span is not None:
|
|
983
|
+
start, end = span
|
|
984
|
+
entity_text = res.text[start:end]
|
|
985
|
+
start += res.start
|
|
986
|
+
end += res.start
|
|
987
|
+
attr = ent.get("attr", {}) or {}
|
|
988
|
+
frame = LLMInformationExtractionFrame(
|
|
989
|
+
frame_id=f"{len(frame_list)}",
|
|
990
|
+
start=start,
|
|
991
|
+
end=end,
|
|
992
|
+
entity_text=entity_text,
|
|
993
|
+
attr=attr
|
|
994
|
+
)
|
|
995
|
+
frame_list.append(frame)
|
|
996
|
+
return frame_list
|
|
997
|
+
|
|
998
|
+
|
|
823
999
|
class ReviewFrameExtractor(DirectFrameExtractor):
|
|
824
1000
|
def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
|
|
825
1001
|
prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
|
|
@@ -891,7 +1067,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
891
1067
|
raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
|
|
892
1068
|
|
|
893
1069
|
def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
894
|
-
verbose:bool=False, return_messages_log:bool=False) -> List[
|
|
1070
|
+
verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
895
1071
|
"""
|
|
896
1072
|
This method inputs a text and outputs a list of outputs per unit.
|
|
897
1073
|
|
|
@@ -912,8 +1088,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
912
1088
|
Return : List[FrameExtractionUnitResult]
|
|
913
1089
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
914
1090
|
"""
|
|
915
|
-
# define output
|
|
916
|
-
output = []
|
|
917
1091
|
# unit chunking
|
|
918
1092
|
if isinstance(text_content, str):
|
|
919
1093
|
doc_text = text_content
|
|
@@ -926,9 +1100,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
926
1100
|
units = self.unit_chunker.chunk(doc_text)
|
|
927
1101
|
# context chunker init
|
|
928
1102
|
self.context_chunker.fit(doc_text, units)
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
1103
|
+
|
|
1104
|
+
# messages logger init
|
|
1105
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
932
1106
|
|
|
933
1107
|
# generate unit by unit
|
|
934
1108
|
for i, unit in enumerate(units):
|
|
@@ -962,7 +1136,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
962
1136
|
messages.append({'role': 'user', 'content': unit.text})
|
|
963
1137
|
|
|
964
1138
|
if verbose:
|
|
965
|
-
print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
1139
|
+
print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
966
1140
|
if context != "":
|
|
967
1141
|
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
968
1142
|
|
|
@@ -972,48 +1146,38 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
972
1146
|
initial = self.inference_engine.chat(
|
|
973
1147
|
messages=messages,
|
|
974
1148
|
verbose=verbose,
|
|
975
|
-
stream=False
|
|
1149
|
+
stream=False,
|
|
1150
|
+
messages_logger=messages_logger
|
|
976
1151
|
)
|
|
977
1152
|
|
|
978
|
-
if return_messages_log:
|
|
979
|
-
messages.append({"role": "assistant", "content": initial})
|
|
980
|
-
messages_log.append(messages)
|
|
981
|
-
|
|
982
1153
|
# <--- Review step --->
|
|
983
1154
|
if verbose:
|
|
984
1155
|
print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
|
|
985
1156
|
|
|
986
|
-
messages.append({'role': 'assistant', 'content': initial})
|
|
1157
|
+
messages.append({'role': 'assistant', 'content': initial["response"]})
|
|
987
1158
|
messages.append({'role': 'user', 'content': self.review_prompt})
|
|
988
1159
|
|
|
989
1160
|
review = self.inference_engine.chat(
|
|
990
1161
|
messages=messages,
|
|
991
1162
|
verbose=verbose,
|
|
992
|
-
stream=False
|
|
1163
|
+
stream=False,
|
|
1164
|
+
messages_logger=messages_logger
|
|
993
1165
|
)
|
|
994
1166
|
|
|
995
1167
|
# Output
|
|
996
1168
|
if self.review_mode == "revision":
|
|
997
|
-
gen_text = review
|
|
1169
|
+
gen_text = review["response"]
|
|
998
1170
|
elif self.review_mode == "addition":
|
|
999
|
-
gen_text = initial + '\n' + review
|
|
1171
|
+
gen_text = initial["response"] + '\n' + review["response"]
|
|
1000
1172
|
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
# add to output
|
|
1006
|
-
result = FrameExtractionUnitResult(
|
|
1007
|
-
start=unit.start,
|
|
1008
|
-
end=unit.end,
|
|
1009
|
-
text=unit.text,
|
|
1010
|
-
gen_text=gen_text)
|
|
1011
|
-
output.append(result)
|
|
1173
|
+
# add generated text to unit
|
|
1174
|
+
unit.set_generated_text(gen_text)
|
|
1175
|
+
unit.set_status("success")
|
|
1012
1176
|
|
|
1013
1177
|
if return_messages_log:
|
|
1014
|
-
return
|
|
1178
|
+
return units, messages_logger.get_messages_log()
|
|
1015
1179
|
|
|
1016
|
-
return
|
|
1180
|
+
return units
|
|
1017
1181
|
|
|
1018
1182
|
|
|
1019
1183
|
def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
|
|
@@ -1109,7 +1273,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1109
1273
|
yield chunk
|
|
1110
1274
|
|
|
1111
1275
|
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
1112
|
-
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[
|
|
1276
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnit]:
|
|
1113
1277
|
"""
|
|
1114
1278
|
This is the asynchronous version of the extract() method with the review step.
|
|
1115
1279
|
|
|
@@ -1141,11 +1305,15 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1141
1305
|
else:
|
|
1142
1306
|
raise TypeError("text_content must be a string or a dictionary.")
|
|
1143
1307
|
|
|
1308
|
+
# unit chunking
|
|
1144
1309
|
units = self.unit_chunker.chunk(doc_text)
|
|
1145
1310
|
|
|
1146
1311
|
# context chunker init
|
|
1147
1312
|
self.context_chunker.fit(doc_text, units)
|
|
1148
1313
|
|
|
1314
|
+
# messages logger init
|
|
1315
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1316
|
+
|
|
1149
1317
|
# <--- Initial generation step --->
|
|
1150
1318
|
initial_tasks_input = []
|
|
1151
1319
|
for i, unit in enumerate(units):
|
|
@@ -1189,10 +1357,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1189
1357
|
|
|
1190
1358
|
async with semaphore:
|
|
1191
1359
|
gen_text = await self.inference_engine.chat_async(
|
|
1192
|
-
messages=messages
|
|
1360
|
+
messages=messages,
|
|
1361
|
+
messages_logger=messages_logger
|
|
1193
1362
|
)
|
|
1194
1363
|
# Return initial generation result along with the messages used and the unit
|
|
1195
|
-
|
|
1364
|
+
out = {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text["response"], "initial_messages": messages}
|
|
1365
|
+
if "reasoning" in gen_text:
|
|
1366
|
+
out["reasoning"] = gen_text["reasoning"]
|
|
1367
|
+
return out
|
|
1196
1368
|
|
|
1197
1369
|
# Create and gather initial generation tasks
|
|
1198
1370
|
initial_tasks = [
|
|
@@ -1218,28 +1390,30 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1218
1390
|
{'role': 'user', 'content': self.review_prompt}
|
|
1219
1391
|
]
|
|
1220
1392
|
# Store data needed for review task
|
|
1393
|
+
if "reasoning" in result_data:
|
|
1394
|
+
message = {'role': 'assistant', 'content': initial_gen_text, "reasoning": result_data["reasoning"]}
|
|
1395
|
+
else:
|
|
1396
|
+
message = {'role': 'assistant', 'content': initial_gen_text}
|
|
1397
|
+
|
|
1221
1398
|
review_tasks_input.append({
|
|
1222
1399
|
"unit": result_data["unit"],
|
|
1223
1400
|
"initial_gen_text": initial_gen_text,
|
|
1224
1401
|
"messages": review_messages,
|
|
1225
1402
|
"original_index": result_data["original_index"],
|
|
1226
|
-
"full_initial_log": initial_messages + [{'role': '
|
|
1403
|
+
"full_initial_log": initial_messages + [message] + [{'role': 'user', 'content': self.review_prompt}] if return_messages_log else None
|
|
1227
1404
|
})
|
|
1228
1405
|
|
|
1229
1406
|
|
|
1230
1407
|
async def review_semaphore_helper(task_data: Dict, **kwrs):
|
|
1231
1408
|
messages = task_data["messages"]
|
|
1232
|
-
original_index = task_data["original_index"]
|
|
1233
1409
|
|
|
1234
1410
|
async with semaphore:
|
|
1235
1411
|
review_gen_text = await self.inference_engine.chat_async(
|
|
1236
|
-
messages=messages
|
|
1412
|
+
messages=messages,
|
|
1413
|
+
messages_logger=messages_logger
|
|
1237
1414
|
)
|
|
1238
1415
|
# Combine initial and review results
|
|
1239
|
-
task_data["review_gen_text"] = review_gen_text
|
|
1240
|
-
if return_messages_log:
|
|
1241
|
-
# Log for the review call itself
|
|
1242
|
-
task_data["full_review_log"] = messages + [{'role': 'assistant', 'content': review_gen_text}]
|
|
1416
|
+
task_data["review_gen_text"] = review_gen_text["response"]
|
|
1243
1417
|
return task_data # Return the augmented dictionary
|
|
1244
1418
|
|
|
1245
1419
|
# Create and gather review tasks
|
|
@@ -1256,9 +1430,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1256
1430
|
final_results_raw.sort(key=lambda x: x["original_index"])
|
|
1257
1431
|
|
|
1258
1432
|
# <--- Process final results --->
|
|
1259
|
-
output: List[FrameExtractionUnitResult] = []
|
|
1260
|
-
messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
|
|
1261
|
-
|
|
1262
1433
|
for result_data in final_results_raw:
|
|
1263
1434
|
unit = result_data["unit"]
|
|
1264
1435
|
initial_gen = result_data["initial_gen_text"]
|
|
@@ -1273,23 +1444,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1273
1444
|
final_gen_text = review_gen # Default to revision if mode is somehow invalid
|
|
1274
1445
|
|
|
1275
1446
|
# Create final result object
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
end=unit.end,
|
|
1279
|
-
text=unit.text,
|
|
1280
|
-
gen_text=final_gen_text # Use the combined/reviewed text
|
|
1281
|
-
)
|
|
1282
|
-
output.append(result)
|
|
1283
|
-
|
|
1284
|
-
# Append full conversation log if requested
|
|
1285
|
-
if return_messages_log:
|
|
1286
|
-
full_log_for_unit = result_data.get("full_initial_log", []) + [{'role': 'user', 'content': self.review_prompt}] + [{'role': 'assistant', 'content': review_gen}]
|
|
1287
|
-
messages_log.append(full_log_for_unit)
|
|
1447
|
+
unit.set_generated_text(final_gen_text)
|
|
1448
|
+
unit.set_status("success")
|
|
1288
1449
|
|
|
1289
1450
|
if return_messages_log:
|
|
1290
|
-
return
|
|
1451
|
+
return units, messages_logger.get_messages_log()
|
|
1291
1452
|
else:
|
|
1292
|
-
return
|
|
1453
|
+
return units
|
|
1293
1454
|
|
|
1294
1455
|
|
|
1295
1456
|
class BasicFrameExtractor(DirectFrameExtractor):
|
|
@@ -1526,6 +1687,9 @@ class AttributeExtractor(Extractor):
|
|
|
1526
1687
|
a dictionary of attributes extracted from the frame.
|
|
1527
1688
|
If return_messages_log is True, a list of messages will be returned as well.
|
|
1528
1689
|
"""
|
|
1690
|
+
# messages logger init
|
|
1691
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1692
|
+
|
|
1529
1693
|
# construct chat messages
|
|
1530
1694
|
messages = []
|
|
1531
1695
|
if self.system_prompt:
|
|
@@ -1541,19 +1705,18 @@ class AttributeExtractor(Extractor):
|
|
|
1541
1705
|
|
|
1542
1706
|
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1543
1707
|
|
|
1544
|
-
|
|
1708
|
+
gen_text = self.inference_engine.chat(
|
|
1545
1709
|
messages=messages,
|
|
1546
1710
|
verbose=verbose,
|
|
1547
|
-
stream=False
|
|
1711
|
+
stream=False,
|
|
1712
|
+
messages_logger=messages_logger
|
|
1548
1713
|
)
|
|
1549
|
-
if return_messages_log:
|
|
1550
|
-
messages.append({"role": "assistant", "content": get_text})
|
|
1551
1714
|
|
|
1552
|
-
attribute_list = self._extract_json(gen_text=
|
|
1715
|
+
attribute_list = self._extract_json(gen_text=gen_text["response"])
|
|
1553
1716
|
if isinstance(attribute_list, list) and len(attribute_list) > 0:
|
|
1554
1717
|
attributes = attribute_list[0]
|
|
1555
1718
|
if return_messages_log:
|
|
1556
|
-
return attributes,
|
|
1719
|
+
return attributes, messages_logger.get_messages_log()
|
|
1557
1720
|
return attributes
|
|
1558
1721
|
|
|
1559
1722
|
|
|
@@ -1594,7 +1757,7 @@ class AttributeExtractor(Extractor):
|
|
|
1594
1757
|
if return_messages_log:
|
|
1595
1758
|
attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1596
1759
|
verbose=verbose, return_messages_log=return_messages_log)
|
|
1597
|
-
messages_log.
|
|
1760
|
+
messages_log.extend(messages)
|
|
1598
1761
|
else:
|
|
1599
1762
|
attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1600
1763
|
verbose=verbose, return_messages_log=return_messages_log)
|
|
@@ -1643,6 +1806,9 @@ class AttributeExtractor(Extractor):
|
|
|
1643
1806
|
if not isinstance(text, str):
|
|
1644
1807
|
raise TypeError(f"Expect text as str, received {type(text)} instead.")
|
|
1645
1808
|
|
|
1809
|
+
# messages logger init
|
|
1810
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1811
|
+
|
|
1646
1812
|
# async helper
|
|
1647
1813
|
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1648
1814
|
|
|
@@ -1655,12 +1821,8 @@ class AttributeExtractor(Extractor):
|
|
|
1655
1821
|
context = self._get_context(frame, text, context_size)
|
|
1656
1822
|
messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
|
|
1657
1823
|
|
|
1658
|
-
gen_text = await self.inference_engine.chat_async(messages=messages)
|
|
1659
|
-
|
|
1660
|
-
if return_messages_log:
|
|
1661
|
-
messages.append({"role": "assistant", "content": gen_text})
|
|
1662
|
-
|
|
1663
|
-
attribute_list = self._extract_json(gen_text=gen_text)
|
|
1824
|
+
gen_text = await self.inference_engine.chat_async(messages=messages, messages_logger=messages_logger)
|
|
1825
|
+
attribute_list = self._extract_json(gen_text=gen_text["response"])
|
|
1664
1826
|
attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
|
|
1665
1827
|
return {"frame": frame, "attributes": attributes, "messages": messages}
|
|
1666
1828
|
|
|
@@ -1670,12 +1832,8 @@ class AttributeExtractor(Extractor):
|
|
|
1670
1832
|
|
|
1671
1833
|
# process results
|
|
1672
1834
|
new_frames = []
|
|
1673
|
-
messages_log = [] if return_messages_log else None
|
|
1674
1835
|
|
|
1675
1836
|
for result in results:
|
|
1676
|
-
if return_messages_log:
|
|
1677
|
-
messages_log.append(result["messages"])
|
|
1678
|
-
|
|
1679
1837
|
if inplace:
|
|
1680
1838
|
result["frame"].attr.update(result["attributes"])
|
|
1681
1839
|
else:
|
|
@@ -1685,9 +1843,9 @@ class AttributeExtractor(Extractor):
|
|
|
1685
1843
|
|
|
1686
1844
|
# output
|
|
1687
1845
|
if inplace:
|
|
1688
|
-
return
|
|
1846
|
+
return messages_logger.get_messages_log() if return_messages_log else None
|
|
1689
1847
|
else:
|
|
1690
|
-
return (new_frames,
|
|
1848
|
+
return (new_frames, messages_logger.get_messages_log()) if return_messages_log else new_frames
|
|
1691
1849
|
|
|
1692
1850
|
def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
1693
1851
|
concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
|
|
@@ -1810,7 +1968,7 @@ class RelationExtractor(Extractor):
|
|
|
1810
1968
|
return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1811
1969
|
pairs = itertools.combinations(doc.frames, 2)
|
|
1812
1970
|
relations = []
|
|
1813
|
-
|
|
1971
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1814
1972
|
|
|
1815
1973
|
for frame_1, frame_2 in pairs:
|
|
1816
1974
|
task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
|
|
@@ -1822,17 +1980,14 @@ class RelationExtractor(Extractor):
|
|
|
1822
1980
|
|
|
1823
1981
|
gen_text = self.inference_engine.chat(
|
|
1824
1982
|
messages=task_payload['messages'],
|
|
1825
|
-
verbose=verbose
|
|
1983
|
+
verbose=verbose,
|
|
1984
|
+
messages_logger=messages_logger
|
|
1826
1985
|
)
|
|
1827
|
-
relation = self._post_process_result(gen_text, task_payload)
|
|
1986
|
+
relation = self._post_process_result(gen_text["response"], task_payload)
|
|
1828
1987
|
if relation:
|
|
1829
1988
|
relations.append(relation)
|
|
1830
1989
|
|
|
1831
|
-
|
|
1832
|
-
task_payload['messages'].append({"role": "assistant", "content": gen_text})
|
|
1833
|
-
messages_log.append(task_payload['messages'])
|
|
1834
|
-
|
|
1835
|
-
return (relations, messages_log) if return_messages_log else relations
|
|
1990
|
+
return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
|
|
1836
1991
|
|
|
1837
1992
|
async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1838
1993
|
pairs = list(itertools.combinations(doc.frames, 2))
|
|
@@ -1841,27 +1996,23 @@ class RelationExtractor(Extractor):
|
|
|
1841
1996
|
tasks_input = [task for task in tasks_input if task is not None]
|
|
1842
1997
|
|
|
1843
1998
|
relations = []
|
|
1844
|
-
|
|
1999
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1845
2000
|
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1846
2001
|
|
|
1847
2002
|
async def semaphore_helper(task_payload: Dict):
|
|
1848
2003
|
async with semaphore:
|
|
1849
|
-
gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
|
|
2004
|
+
gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'], messages_logger=messages_logger)
|
|
1850
2005
|
return gen_text, task_payload
|
|
1851
2006
|
|
|
1852
2007
|
tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
|
|
1853
2008
|
results = await asyncio.gather(*tasks)
|
|
1854
2009
|
|
|
1855
2010
|
for gen_text, task_payload in results:
|
|
1856
|
-
relation = self._post_process_result(gen_text, task_payload)
|
|
2011
|
+
relation = self._post_process_result(gen_text["response"], task_payload)
|
|
1857
2012
|
if relation:
|
|
1858
2013
|
relations.append(relation)
|
|
1859
2014
|
|
|
1860
|
-
|
|
1861
|
-
task_payload['messages'].append({"role": "assistant", "content": gen_text})
|
|
1862
|
-
messages_log.append(task_payload['messages'])
|
|
1863
|
-
|
|
1864
|
-
return (relations, messages_log) if return_messages_log else relations
|
|
2015
|
+
return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
|
|
1865
2016
|
|
|
1866
2017
|
def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
|
|
1867
2018
|
if not doc.has_frame():
|