llm-ie 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +5 -4
- llm_ie/chunkers.py +44 -5
- llm_ie/data_types.py +23 -37
- llm_ie/engines.py +577 -61
- llm_ie/extractors.py +335 -219
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.3.dist-info}/METADATA +1 -1
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.3.dist-info}/RECORD +8 -8
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.3.dist-info}/WHEEL +0 -0
llm_ie/extractors.py
CHANGED
|
@@ -8,11 +8,12 @@ import warnings
|
|
|
8
8
|
import itertools
|
|
9
9
|
import asyncio
|
|
10
10
|
import nest_asyncio
|
|
11
|
-
from
|
|
12
|
-
from
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
+
from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
|
|
13
|
+
from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
13
14
|
from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
|
|
14
15
|
from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
|
|
15
|
-
from llm_ie.engines import InferenceEngine
|
|
16
|
+
from llm_ie.engines import InferenceEngine, MessagesLogger
|
|
16
17
|
from colorama import Fore, Style
|
|
17
18
|
|
|
18
19
|
|
|
@@ -405,7 +406,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
405
406
|
|
|
406
407
|
|
|
407
408
|
def extract(self, text_content:Union[str, Dict[str,str]],
|
|
408
|
-
document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[
|
|
409
|
+
document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
409
410
|
"""
|
|
410
411
|
This method inputs a text and outputs a list of outputs per unit.
|
|
411
412
|
|
|
@@ -423,11 +424,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
423
424
|
return_messages_log : bool, Optional
|
|
424
425
|
if True, a list of messages will be returned.
|
|
425
426
|
|
|
426
|
-
Return : List[
|
|
427
|
+
Return : List[FrameExtractionUnit]
|
|
427
428
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
428
429
|
"""
|
|
429
|
-
# define output
|
|
430
|
-
output = []
|
|
431
430
|
# unit chunking
|
|
432
431
|
if isinstance(text_content, str):
|
|
433
432
|
doc_text = text_content
|
|
@@ -440,76 +439,70 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
440
439
|
units = self.unit_chunker.chunk(doc_text)
|
|
441
440
|
# context chunker init
|
|
442
441
|
self.context_chunker.fit(doc_text, units)
|
|
442
|
+
|
|
443
443
|
# messages log
|
|
444
|
-
if return_messages_log
|
|
445
|
-
messages_log = []
|
|
444
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
446
445
|
|
|
447
446
|
# generate unit by unit
|
|
448
447
|
for i, unit in enumerate(units):
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
context = self.context_chunker.chunk(unit)
|
|
455
|
-
|
|
456
|
-
if context == "":
|
|
457
|
-
# no context, just place unit in user prompt
|
|
458
|
-
if isinstance(text_content, str):
|
|
459
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
460
|
-
else:
|
|
461
|
-
unit_content = text_content.copy()
|
|
462
|
-
unit_content[document_key] = unit.text
|
|
463
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
464
|
-
else:
|
|
465
|
-
# insert context to user prompt
|
|
466
|
-
if isinstance(text_content, str):
|
|
467
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
468
|
-
else:
|
|
469
|
-
context_content = text_content.copy()
|
|
470
|
-
context_content[document_key] = context
|
|
471
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
472
|
-
# simulate conversation where assistant confirms
|
|
473
|
-
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
474
|
-
# place unit of interest
|
|
475
|
-
messages.append({'role': 'user', 'content': unit.text})
|
|
448
|
+
try:
|
|
449
|
+
# construct chat messages
|
|
450
|
+
messages = []
|
|
451
|
+
if self.system_prompt:
|
|
452
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
476
453
|
|
|
477
|
-
|
|
478
|
-
print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
479
|
-
if context != "":
|
|
480
|
-
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
454
|
+
context = self.context_chunker.chunk(unit)
|
|
481
455
|
|
|
482
|
-
|
|
456
|
+
if context == "":
|
|
457
|
+
# no context, just place unit in user prompt
|
|
458
|
+
if isinstance(text_content, str):
|
|
459
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
460
|
+
else:
|
|
461
|
+
unit_content = text_content.copy()
|
|
462
|
+
unit_content[document_key] = unit.text
|
|
463
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
464
|
+
else:
|
|
465
|
+
# insert context to user prompt
|
|
466
|
+
if isinstance(text_content, str):
|
|
467
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
468
|
+
else:
|
|
469
|
+
context_content = text_content.copy()
|
|
470
|
+
context_content[document_key] = context
|
|
471
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
472
|
+
# simulate conversation where assistant confirms
|
|
473
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
474
|
+
# place unit of interest
|
|
475
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
483
476
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
477
|
+
if verbose:
|
|
478
|
+
print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
479
|
+
if context != "":
|
|
480
|
+
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
481
|
+
|
|
482
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
490
483
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
484
|
+
|
|
485
|
+
gen_text = self.inference_engine.chat(
|
|
486
|
+
messages=messages,
|
|
487
|
+
verbose=verbose,
|
|
488
|
+
stream=False,
|
|
489
|
+
messages_logger=messages_logger
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# add generated text to unit
|
|
493
|
+
unit.set_generated_text(gen_text["response"])
|
|
494
|
+
unit.set_status("success")
|
|
495
|
+
except Exception as e:
|
|
496
|
+
unit.set_status("fail")
|
|
497
|
+
warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
|
|
505
498
|
|
|
506
499
|
if return_messages_log:
|
|
507
|
-
return
|
|
500
|
+
return units, messages_logger.get_messages_log()
|
|
508
501
|
|
|
509
|
-
return
|
|
502
|
+
return units
|
|
510
503
|
|
|
511
504
|
def stream(self, text_content: Union[str, Dict[str, str]],
|
|
512
|
-
document_key: str = None) -> Generator[Dict[str, Any], None, List[
|
|
505
|
+
document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
|
|
513
506
|
"""
|
|
514
507
|
Streams LLM responses per unit with structured event types,
|
|
515
508
|
and returns collected data for post-processing.
|
|
@@ -525,12 +518,10 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
525
518
|
|
|
526
519
|
Returns:
|
|
527
520
|
--------
|
|
528
|
-
List[
|
|
529
|
-
A list of
|
|
521
|
+
List[FrameExtractionUnit]:
|
|
522
|
+
A list of FrameExtractionUnit objects, each containing the
|
|
530
523
|
original unit details and the fully accumulated 'gen_text' from the LLM.
|
|
531
524
|
"""
|
|
532
|
-
collected_results: List[FrameExtractionUnitResult] = []
|
|
533
|
-
|
|
534
525
|
if isinstance(text_content, str):
|
|
535
526
|
doc_text = text_content
|
|
536
527
|
elif isinstance(text_content, dict):
|
|
@@ -588,19 +579,14 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
588
579
|
current_gen_text += chunk["data"]
|
|
589
580
|
|
|
590
581
|
# Store the result for this unit
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
end=unit.end,
|
|
594
|
-
text=unit.text,
|
|
595
|
-
gen_text=current_gen_text
|
|
596
|
-
)
|
|
597
|
-
collected_results.append(result_for_unit)
|
|
582
|
+
unit.set_generated_text(current_gen_text)
|
|
583
|
+
unit.set_status("success")
|
|
598
584
|
|
|
599
585
|
yield {"type": "info", "data": "All units processed by LLM."}
|
|
600
|
-
return
|
|
586
|
+
return units
|
|
601
587
|
|
|
602
588
|
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
603
|
-
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[
|
|
589
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
604
590
|
"""
|
|
605
591
|
This is the asynchronous version of the extract() method.
|
|
606
592
|
|
|
@@ -618,7 +604,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
618
604
|
return_messages_log : bool, Optional
|
|
619
605
|
if True, a list of messages will be returned.
|
|
620
606
|
|
|
621
|
-
Return : List[
|
|
607
|
+
Return : List[FrameExtractionUnit]
|
|
622
608
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
623
609
|
"""
|
|
624
610
|
if isinstance(text_content, str):
|
|
@@ -637,6 +623,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
637
623
|
# context chunker init
|
|
638
624
|
self.context_chunker.fit(doc_text, units)
|
|
639
625
|
|
|
626
|
+
# messages logger init
|
|
627
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
628
|
+
|
|
640
629
|
# Prepare inputs for all units first
|
|
641
630
|
tasks_input = []
|
|
642
631
|
for i, unit in enumerate(units):
|
|
@@ -677,17 +666,15 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
677
666
|
async def semaphore_helper(task_data: Dict, **kwrs):
|
|
678
667
|
unit = task_data["unit"]
|
|
679
668
|
messages = task_data["messages"]
|
|
680
|
-
original_index = task_data["original_index"]
|
|
681
669
|
|
|
682
670
|
async with semaphore:
|
|
683
671
|
gen_text = await self.inference_engine.chat_async(
|
|
684
|
-
messages=messages
|
|
672
|
+
messages=messages,
|
|
673
|
+
messages_logger=messages_logger
|
|
685
674
|
)
|
|
686
675
|
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
out["reasoning"] = gen_text["reasoning"]
|
|
690
|
-
return out
|
|
676
|
+
unit.set_generated_text(gen_text["response"])
|
|
677
|
+
unit.set_status("success")
|
|
691
678
|
|
|
692
679
|
# Create and gather tasks
|
|
693
680
|
tasks = []
|
|
@@ -697,40 +684,13 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
697
684
|
))
|
|
698
685
|
tasks.append(task)
|
|
699
686
|
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
# Sort results back into original order using the index stored
|
|
703
|
-
results_raw.sort(key=lambda x: x["original_index"])
|
|
704
|
-
|
|
705
|
-
# Restructure the results
|
|
706
|
-
output: List[FrameExtractionUnitResult] = []
|
|
707
|
-
messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
|
|
708
|
-
|
|
709
|
-
for result_data in results_raw:
|
|
710
|
-
unit = result_data["unit"]
|
|
711
|
-
gen_text = result_data["gen_text"]
|
|
712
|
-
|
|
713
|
-
# Create result object
|
|
714
|
-
result = FrameExtractionUnitResult(
|
|
715
|
-
start=unit.start,
|
|
716
|
-
end=unit.end,
|
|
717
|
-
text=unit.text,
|
|
718
|
-
gen_text=gen_text
|
|
719
|
-
)
|
|
720
|
-
output.append(result)
|
|
721
|
-
|
|
722
|
-
# Append to messages log if requested
|
|
723
|
-
if return_messages_log:
|
|
724
|
-
message = {"role": "assistant", "content": gen_text}
|
|
725
|
-
if "reasoning" in result_data:
|
|
726
|
-
message["reasoning"] = result_data["reasoning"]
|
|
727
|
-
final_messages = result_data["messages"] + [message]
|
|
728
|
-
messages_log.append(final_messages)
|
|
687
|
+
await asyncio.gather(*tasks)
|
|
729
688
|
|
|
689
|
+
# Return units
|
|
730
690
|
if return_messages_log:
|
|
731
|
-
return
|
|
691
|
+
return units, messages_logger.get_messages_log()
|
|
732
692
|
else:
|
|
733
|
-
return
|
|
693
|
+
return units
|
|
734
694
|
|
|
735
695
|
|
|
736
696
|
def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
@@ -738,7 +698,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
738
698
|
case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
739
699
|
allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
|
|
740
700
|
"""
|
|
741
|
-
This method inputs a text and outputs a list of LLMInformationExtractionFrame
|
|
701
|
+
This method inputs a document text and outputs a list of LLMInformationExtractionFrame
|
|
742
702
|
It use the extract() method and post-process outputs into frames.
|
|
743
703
|
|
|
744
704
|
Parameters:
|
|
@@ -791,18 +751,21 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
791
751
|
verbose=verbose,
|
|
792
752
|
return_messages_log=return_messages_log)
|
|
793
753
|
|
|
794
|
-
|
|
754
|
+
units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
795
755
|
|
|
796
756
|
frame_list = []
|
|
797
|
-
for
|
|
757
|
+
for unit in units:
|
|
798
758
|
entity_json = []
|
|
799
|
-
|
|
759
|
+
if unit.status != "success":
|
|
760
|
+
warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
|
|
761
|
+
continue
|
|
762
|
+
for entity in self._extract_json(gen_text=unit.gen_text):
|
|
800
763
|
if ENTITY_KEY in entity:
|
|
801
764
|
entity_json.append(entity)
|
|
802
765
|
else:
|
|
803
766
|
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
804
767
|
|
|
805
|
-
spans = self._find_entity_spans(text=
|
|
768
|
+
spans = self._find_entity_spans(text=unit.text,
|
|
806
769
|
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
807
770
|
case_sensitive=case_sensitive,
|
|
808
771
|
fuzzy_match=fuzzy_match,
|
|
@@ -812,9 +775,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
812
775
|
for ent, span in zip(entity_json, spans):
|
|
813
776
|
if span is not None:
|
|
814
777
|
start, end = span
|
|
815
|
-
entity_text =
|
|
816
|
-
start +=
|
|
817
|
-
end +=
|
|
778
|
+
entity_text = unit.text[start:end]
|
|
779
|
+
start += unit.start
|
|
780
|
+
end += unit.start
|
|
818
781
|
attr = {}
|
|
819
782
|
if "attr" in ent and ent["attr"] is not None:
|
|
820
783
|
attr = ent["attr"]
|
|
@@ -831,6 +794,208 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
831
794
|
return frame_list
|
|
832
795
|
|
|
833
796
|
|
|
797
|
+
async def extract_frames_from_documents(self, text_contents:List[Union[str,Dict[str, any]]], document_key:str="text",
|
|
798
|
+
cpu_concurrency:int=4, llm_concurrency:int=32, case_sensitive:bool=False,
|
|
799
|
+
fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
800
|
+
allow_overlap_entities:bool=False, return_messages_log:bool=False) -> AsyncGenerator[Dict[str, any], None]:
|
|
801
|
+
"""
|
|
802
|
+
This method inputs a list of documents and yields the results for each document as soon as it is complete.
|
|
803
|
+
|
|
804
|
+
Parameters:
|
|
805
|
+
-----------
|
|
806
|
+
text_contents : List[Union[str,Dict[str, any]]]
|
|
807
|
+
a list of input text contents to put in prompt template.
|
|
808
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
809
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
810
|
+
document_key: str, optional
|
|
811
|
+
The key in the `text_contents` dictionaries that holds the document text.
|
|
812
|
+
cpu_concurrency: int, optional
|
|
813
|
+
The number of parallel threads to use for CPU-bound tasks like chunking.
|
|
814
|
+
llm_concurrency: int, optional
|
|
815
|
+
The number of concurrent requests to make to the LLM.
|
|
816
|
+
case_sensitive : bool, Optional
|
|
817
|
+
if True, entity text matching will be case-sensitive.
|
|
818
|
+
fuzzy_match : bool, Optional
|
|
819
|
+
if True, fuzzy matching will be applied to find entity text.
|
|
820
|
+
fuzzy_buffer_size : float, Optional
|
|
821
|
+
the buffer size for fuzzy matching. Default is 20% of entity text length.
|
|
822
|
+
fuzzy_score_cutoff : float, Optional
|
|
823
|
+
the Jaccard score cutoff for fuzzy matching.
|
|
824
|
+
Matched entity text must have a score higher than this value or a None will be returned.
|
|
825
|
+
allow_overlap_entities : bool, Optional
|
|
826
|
+
if True, entities can overlap in the text.
|
|
827
|
+
return_messages_log : bool, Optional
|
|
828
|
+
if True, a list of messages will be returned.
|
|
829
|
+
|
|
830
|
+
Yields:
|
|
831
|
+
-------
|
|
832
|
+
AsyncGenerator[Dict[str, any], None]
|
|
833
|
+
A dictionary for each completed document, containing its 'idx' and extracted 'frames'.
|
|
834
|
+
"""
|
|
835
|
+
# Validate text_contents must be a list of str or dict, and not both
|
|
836
|
+
if not isinstance(text_contents, list):
|
|
837
|
+
raise ValueError("text_contents must be a list of strings or dictionaries.")
|
|
838
|
+
if all(isinstance(doc, str) for doc in text_contents):
|
|
839
|
+
pass
|
|
840
|
+
elif all(isinstance(doc, dict) for doc in text_contents):
|
|
841
|
+
pass
|
|
842
|
+
# Set CPU executor and queues
|
|
843
|
+
cpu_executor = ThreadPoolExecutor(max_workers=cpu_concurrency)
|
|
844
|
+
tasks_queue = asyncio.Queue(maxsize=llm_concurrency * 2)
|
|
845
|
+
# Store to track units and pending counts
|
|
846
|
+
results_store = {
|
|
847
|
+
idx: {'pending': 0, 'units': [], 'text': doc if isinstance(doc, str) else doc.get(document_key, "")}
|
|
848
|
+
for idx, doc in enumerate(text_contents)
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
output_queue = asyncio.Queue()
|
|
852
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
853
|
+
|
|
854
|
+
async def producer():
|
|
855
|
+
try:
|
|
856
|
+
for idx, text_content in enumerate(text_contents):
|
|
857
|
+
text = text_content if isinstance(text_content, str) else text_content.get(document_key, "")
|
|
858
|
+
if not text:
|
|
859
|
+
warnings.warn(f"Document at index {idx} is empty or missing the document key '{document_key}'.")
|
|
860
|
+
# signal that this document is done
|
|
861
|
+
await output_queue.put({'idx': idx, 'frames': []})
|
|
862
|
+
continue
|
|
863
|
+
|
|
864
|
+
units = await self.unit_chunker.chunk_async(text, cpu_executor)
|
|
865
|
+
await self.context_chunker.fit_async(text, units, cpu_executor)
|
|
866
|
+
results_store[idx]['pending'] = len(units)
|
|
867
|
+
|
|
868
|
+
# Handle cases where a document yields no units
|
|
869
|
+
if not units:
|
|
870
|
+
# signal that this document is done
|
|
871
|
+
await output_queue.put({'idx': idx, 'frames': []})
|
|
872
|
+
continue
|
|
873
|
+
|
|
874
|
+
# Iterate through units
|
|
875
|
+
for unit in units:
|
|
876
|
+
context = await self.context_chunker.chunk_async(unit, cpu_executor)
|
|
877
|
+
messages = []
|
|
878
|
+
if self.system_prompt:
|
|
879
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
880
|
+
|
|
881
|
+
if not context:
|
|
882
|
+
if isinstance(text_content, str):
|
|
883
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
884
|
+
else:
|
|
885
|
+
unit_content = text_content.copy()
|
|
886
|
+
unit_content[document_key] = unit.text
|
|
887
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
888
|
+
else:
|
|
889
|
+
# insert context to user prompt
|
|
890
|
+
if isinstance(text_content, str):
|
|
891
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
892
|
+
else:
|
|
893
|
+
context_content = text_content.copy()
|
|
894
|
+
context_content[document_key] = context
|
|
895
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
896
|
+
# simulate conversation where assistant confirms
|
|
897
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
898
|
+
# place unit of interest
|
|
899
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
900
|
+
|
|
901
|
+
await tasks_queue.put({'idx': idx, 'unit': unit, 'messages': messages})
|
|
902
|
+
finally:
|
|
903
|
+
for _ in range(llm_concurrency):
|
|
904
|
+
await tasks_queue.put(None)
|
|
905
|
+
|
|
906
|
+
async def worker():
|
|
907
|
+
while True:
|
|
908
|
+
task_item = await tasks_queue.get()
|
|
909
|
+
if task_item is None:
|
|
910
|
+
tasks_queue.task_done()
|
|
911
|
+
break
|
|
912
|
+
|
|
913
|
+
idx = task_item['idx']
|
|
914
|
+
unit = task_item['unit']
|
|
915
|
+
doc_results = results_store[idx]
|
|
916
|
+
|
|
917
|
+
try:
|
|
918
|
+
gen_text = await self.inference_engine.chat_async(
|
|
919
|
+
messages=task_item['messages'], messages_logger=messages_logger
|
|
920
|
+
)
|
|
921
|
+
unit.set_generated_text(gen_text["response"])
|
|
922
|
+
unit.set_status("success")
|
|
923
|
+
doc_results['units'].append(unit)
|
|
924
|
+
except Exception as e:
|
|
925
|
+
warnings.warn(f"Error processing unit for doc idx {idx}: {e}")
|
|
926
|
+
finally:
|
|
927
|
+
doc_results['pending'] -= 1
|
|
928
|
+
if doc_results['pending'] <= 0:
|
|
929
|
+
final_frames = self._post_process_and_create_frames(doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
|
|
930
|
+
output_payload = {'idx': idx, 'frames': final_frames}
|
|
931
|
+
if return_messages_log:
|
|
932
|
+
output_payload['messages_log'] = messages_logger.get_messages_log()
|
|
933
|
+
await output_queue.put(output_payload)
|
|
934
|
+
|
|
935
|
+
tasks_queue.task_done()
|
|
936
|
+
|
|
937
|
+
# Start producer and workers
|
|
938
|
+
producer_task = asyncio.create_task(producer())
|
|
939
|
+
worker_tasks = [asyncio.create_task(worker()) for _ in range(llm_concurrency)]
|
|
940
|
+
|
|
941
|
+
# Main loop to gather results
|
|
942
|
+
docs_completed = 0
|
|
943
|
+
while docs_completed < len(text_contents):
|
|
944
|
+
result = await output_queue.get()
|
|
945
|
+
yield result
|
|
946
|
+
docs_completed += 1
|
|
947
|
+
|
|
948
|
+
# Final cleanup
|
|
949
|
+
await producer_task
|
|
950
|
+
await tasks_queue.join()
|
|
951
|
+
|
|
952
|
+
# Cancel any lingering worker tasks
|
|
953
|
+
for task in worker_tasks:
|
|
954
|
+
task.cancel()
|
|
955
|
+
await asyncio.gather(*worker_tasks, return_exceptions=True)
|
|
956
|
+
|
|
957
|
+
cpu_executor.shutdown(wait=False)
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def _post_process_and_create_frames(self, doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
|
|
961
|
+
"""Helper function to run post-processing logic for a completed document."""
|
|
962
|
+
ENTITY_KEY = "entity_text"
|
|
963
|
+
frame_list = []
|
|
964
|
+
for res in sorted(doc_results['units'], key=lambda r: r.start):
|
|
965
|
+
entity_json = []
|
|
966
|
+
for entity in self._extract_json(gen_text=res.gen_text):
|
|
967
|
+
if ENTITY_KEY in entity:
|
|
968
|
+
entity_json.append(entity)
|
|
969
|
+
else:
|
|
970
|
+
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
971
|
+
|
|
972
|
+
spans = self._find_entity_spans(
|
|
973
|
+
text=res.text,
|
|
974
|
+
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
975
|
+
case_sensitive=case_sensitive,
|
|
976
|
+
fuzzy_match=fuzzy_match,
|
|
977
|
+
fuzzy_buffer_size=fuzzy_buffer_size,
|
|
978
|
+
fuzzy_score_cutoff=fuzzy_score_cutoff,
|
|
979
|
+
allow_overlap_entities=allow_overlap_entities
|
|
980
|
+
)
|
|
981
|
+
for ent, span in zip(entity_json, spans):
|
|
982
|
+
if span is not None:
|
|
983
|
+
start, end = span
|
|
984
|
+
entity_text = res.text[start:end]
|
|
985
|
+
start += res.start
|
|
986
|
+
end += res.start
|
|
987
|
+
attr = ent.get("attr", {}) or {}
|
|
988
|
+
frame = LLMInformationExtractionFrame(
|
|
989
|
+
frame_id=f"{len(frame_list)}",
|
|
990
|
+
start=start,
|
|
991
|
+
end=end,
|
|
992
|
+
entity_text=entity_text,
|
|
993
|
+
attr=attr
|
|
994
|
+
)
|
|
995
|
+
frame_list.append(frame)
|
|
996
|
+
return frame_list
|
|
997
|
+
|
|
998
|
+
|
|
834
999
|
class ReviewFrameExtractor(DirectFrameExtractor):
|
|
835
1000
|
def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
|
|
836
1001
|
prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
|
|
@@ -902,7 +1067,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
902
1067
|
raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
|
|
903
1068
|
|
|
904
1069
|
def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
905
|
-
verbose:bool=False, return_messages_log:bool=False) -> List[
|
|
1070
|
+
verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
906
1071
|
"""
|
|
907
1072
|
This method inputs a text and outputs a list of outputs per unit.
|
|
908
1073
|
|
|
@@ -923,8 +1088,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
923
1088
|
Return : List[FrameExtractionUnitResult]
|
|
924
1089
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
925
1090
|
"""
|
|
926
|
-
# define output
|
|
927
|
-
output = []
|
|
928
1091
|
# unit chunking
|
|
929
1092
|
if isinstance(text_content, str):
|
|
930
1093
|
doc_text = text_content
|
|
@@ -937,9 +1100,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
937
1100
|
units = self.unit_chunker.chunk(doc_text)
|
|
938
1101
|
# context chunker init
|
|
939
1102
|
self.context_chunker.fit(doc_text, units)
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
1103
|
+
|
|
1104
|
+
# messages logger init
|
|
1105
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
943
1106
|
|
|
944
1107
|
# generate unit by unit
|
|
945
1108
|
for i, unit in enumerate(units):
|
|
@@ -973,7 +1136,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
973
1136
|
messages.append({'role': 'user', 'content': unit.text})
|
|
974
1137
|
|
|
975
1138
|
if verbose:
|
|
976
|
-
print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
1139
|
+
print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
977
1140
|
if context != "":
|
|
978
1141
|
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
979
1142
|
|
|
@@ -983,7 +1146,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
983
1146
|
initial = self.inference_engine.chat(
|
|
984
1147
|
messages=messages,
|
|
985
1148
|
verbose=verbose,
|
|
986
|
-
stream=False
|
|
1149
|
+
stream=False,
|
|
1150
|
+
messages_logger=messages_logger
|
|
987
1151
|
)
|
|
988
1152
|
|
|
989
1153
|
# <--- Review step --->
|
|
@@ -996,7 +1160,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
996
1160
|
review = self.inference_engine.chat(
|
|
997
1161
|
messages=messages,
|
|
998
1162
|
verbose=verbose,
|
|
999
|
-
stream=False
|
|
1163
|
+
stream=False,
|
|
1164
|
+
messages_logger=messages_logger
|
|
1000
1165
|
)
|
|
1001
1166
|
|
|
1002
1167
|
# Output
|
|
@@ -1005,28 +1170,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1005
1170
|
elif self.review_mode == "addition":
|
|
1006
1171
|
gen_text = initial["response"] + '\n' + review["response"]
|
|
1007
1172
|
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
message = {"role": "assistant", "content": review["response"]}
|
|
1013
|
-
if "reasoning" in review:
|
|
1014
|
-
message["reasoning"] = review["reasoning"]
|
|
1015
|
-
messages.append(message)
|
|
1016
|
-
messages_log.append(messages)
|
|
1017
|
-
|
|
1018
|
-
# add to output
|
|
1019
|
-
result = FrameExtractionUnitResult(
|
|
1020
|
-
start=unit.start,
|
|
1021
|
-
end=unit.end,
|
|
1022
|
-
text=unit.text,
|
|
1023
|
-
gen_text=gen_text)
|
|
1024
|
-
output.append(result)
|
|
1173
|
+
# add generated text to unit
|
|
1174
|
+
unit.set_generated_text(gen_text)
|
|
1175
|
+
unit.set_status("success")
|
|
1025
1176
|
|
|
1026
1177
|
if return_messages_log:
|
|
1027
|
-
return
|
|
1178
|
+
return units, messages_logger.get_messages_log()
|
|
1028
1179
|
|
|
1029
|
-
return
|
|
1180
|
+
return units
|
|
1030
1181
|
|
|
1031
1182
|
|
|
1032
1183
|
def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
|
|
@@ -1122,7 +1273,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1122
1273
|
yield chunk
|
|
1123
1274
|
|
|
1124
1275
|
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
1125
|
-
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[
|
|
1276
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnit]:
|
|
1126
1277
|
"""
|
|
1127
1278
|
This is the asynchronous version of the extract() method with the review step.
|
|
1128
1279
|
|
|
@@ -1154,11 +1305,15 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1154
1305
|
else:
|
|
1155
1306
|
raise TypeError("text_content must be a string or a dictionary.")
|
|
1156
1307
|
|
|
1308
|
+
# unit chunking
|
|
1157
1309
|
units = self.unit_chunker.chunk(doc_text)
|
|
1158
1310
|
|
|
1159
1311
|
# context chunker init
|
|
1160
1312
|
self.context_chunker.fit(doc_text, units)
|
|
1161
1313
|
|
|
1314
|
+
# messages logger init
|
|
1315
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1316
|
+
|
|
1162
1317
|
# <--- Initial generation step --->
|
|
1163
1318
|
initial_tasks_input = []
|
|
1164
1319
|
for i, unit in enumerate(units):
|
|
@@ -1202,7 +1357,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1202
1357
|
|
|
1203
1358
|
async with semaphore:
|
|
1204
1359
|
gen_text = await self.inference_engine.chat_async(
|
|
1205
|
-
messages=messages
|
|
1360
|
+
messages=messages,
|
|
1361
|
+
messages_logger=messages_logger
|
|
1206
1362
|
)
|
|
1207
1363
|
# Return initial generation result along with the messages used and the unit
|
|
1208
1364
|
out = {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text["response"], "initial_messages": messages}
|
|
@@ -1253,16 +1409,11 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1253
1409
|
|
|
1254
1410
|
async with semaphore:
|
|
1255
1411
|
review_gen_text = await self.inference_engine.chat_async(
|
|
1256
|
-
messages=messages
|
|
1412
|
+
messages=messages,
|
|
1413
|
+
messages_logger=messages_logger
|
|
1257
1414
|
)
|
|
1258
1415
|
# Combine initial and review results
|
|
1259
1416
|
task_data["review_gen_text"] = review_gen_text["response"]
|
|
1260
|
-
if return_messages_log:
|
|
1261
|
-
# Log for the review call itself
|
|
1262
|
-
message = {'role': 'assistant', 'content': review_gen_text["response"]}
|
|
1263
|
-
if "reasoning" in review_gen_text:
|
|
1264
|
-
message["reasoning"] = review_gen_text["reasoning"]
|
|
1265
|
-
task_data["full_review_log"] = task_data["full_initial_log"] + [message]
|
|
1266
1417
|
return task_data # Return the augmented dictionary
|
|
1267
1418
|
|
|
1268
1419
|
# Create and gather review tasks
|
|
@@ -1279,9 +1430,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1279
1430
|
final_results_raw.sort(key=lambda x: x["original_index"])
|
|
1280
1431
|
|
|
1281
1432
|
# <--- Process final results --->
|
|
1282
|
-
output: List[FrameExtractionUnitResult] = []
|
|
1283
|
-
messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
|
|
1284
|
-
|
|
1285
1433
|
for result_data in final_results_raw:
|
|
1286
1434
|
unit = result_data["unit"]
|
|
1287
1435
|
initial_gen = result_data["initial_gen_text"]
|
|
@@ -1296,23 +1444,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1296
1444
|
final_gen_text = review_gen # Default to revision if mode is somehow invalid
|
|
1297
1445
|
|
|
1298
1446
|
# Create final result object
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
end=unit.end,
|
|
1302
|
-
text=unit.text,
|
|
1303
|
-
gen_text=final_gen_text # Use the combined/reviewed text
|
|
1304
|
-
)
|
|
1305
|
-
output.append(result)
|
|
1306
|
-
|
|
1307
|
-
# Append full conversation log if requested
|
|
1308
|
-
if return_messages_log:
|
|
1309
|
-
full_log_for_unit = result_data["full_review_log"]
|
|
1310
|
-
messages_log.append(full_log_for_unit)
|
|
1447
|
+
unit.set_generated_text(final_gen_text)
|
|
1448
|
+
unit.set_status("success")
|
|
1311
1449
|
|
|
1312
1450
|
if return_messages_log:
|
|
1313
|
-
return
|
|
1451
|
+
return units, messages_logger.get_messages_log()
|
|
1314
1452
|
else:
|
|
1315
|
-
return
|
|
1453
|
+
return units
|
|
1316
1454
|
|
|
1317
1455
|
|
|
1318
1456
|
class BasicFrameExtractor(DirectFrameExtractor):
|
|
@@ -1549,6 +1687,9 @@ class AttributeExtractor(Extractor):
|
|
|
1549
1687
|
a dictionary of attributes extracted from the frame.
|
|
1550
1688
|
If return_messages_log is True, a list of messages will be returned as well.
|
|
1551
1689
|
"""
|
|
1690
|
+
# messages logger init
|
|
1691
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1692
|
+
|
|
1552
1693
|
# construct chat messages
|
|
1553
1694
|
messages = []
|
|
1554
1695
|
if self.system_prompt:
|
|
@@ -1567,19 +1708,15 @@ class AttributeExtractor(Extractor):
|
|
|
1567
1708
|
gen_text = self.inference_engine.chat(
|
|
1568
1709
|
messages=messages,
|
|
1569
1710
|
verbose=verbose,
|
|
1570
|
-
stream=False
|
|
1711
|
+
stream=False,
|
|
1712
|
+
messages_logger=messages_logger
|
|
1571
1713
|
)
|
|
1572
|
-
if return_messages_log:
|
|
1573
|
-
message = {"role": "assistant", "content": gen_text["response"]}
|
|
1574
|
-
if "reasoning" in gen_text:
|
|
1575
|
-
message["reasoning"] = gen_text["reasoning"]
|
|
1576
|
-
messages.append(message)
|
|
1577
1714
|
|
|
1578
1715
|
attribute_list = self._extract_json(gen_text=gen_text["response"])
|
|
1579
1716
|
if isinstance(attribute_list, list) and len(attribute_list) > 0:
|
|
1580
1717
|
attributes = attribute_list[0]
|
|
1581
1718
|
if return_messages_log:
|
|
1582
|
-
return attributes,
|
|
1719
|
+
return attributes, messages_logger.get_messages_log()
|
|
1583
1720
|
return attributes
|
|
1584
1721
|
|
|
1585
1722
|
|
|
@@ -1620,7 +1757,7 @@ class AttributeExtractor(Extractor):
|
|
|
1620
1757
|
if return_messages_log:
|
|
1621
1758
|
attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1622
1759
|
verbose=verbose, return_messages_log=return_messages_log)
|
|
1623
|
-
messages_log.
|
|
1760
|
+
messages_log.extend(messages)
|
|
1624
1761
|
else:
|
|
1625
1762
|
attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1626
1763
|
verbose=verbose, return_messages_log=return_messages_log)
|
|
@@ -1669,6 +1806,9 @@ class AttributeExtractor(Extractor):
|
|
|
1669
1806
|
if not isinstance(text, str):
|
|
1670
1807
|
raise TypeError(f"Expect text as str, received {type(text)} instead.")
|
|
1671
1808
|
|
|
1809
|
+
# messages logger init
|
|
1810
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1811
|
+
|
|
1672
1812
|
# async helper
|
|
1673
1813
|
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1674
1814
|
|
|
@@ -1681,14 +1821,7 @@ class AttributeExtractor(Extractor):
|
|
|
1681
1821
|
context = self._get_context(frame, text, context_size)
|
|
1682
1822
|
messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
|
|
1683
1823
|
|
|
1684
|
-
gen_text = await self.inference_engine.chat_async(messages=messages)
|
|
1685
|
-
|
|
1686
|
-
if return_messages_log:
|
|
1687
|
-
message = {"role": "assistant", "content": gen_text["response"]}
|
|
1688
|
-
if "reasoning" in gen_text:
|
|
1689
|
-
message["reasoning"] = gen_text["reasoning"]
|
|
1690
|
-
messages.append(message)
|
|
1691
|
-
|
|
1824
|
+
gen_text = await self.inference_engine.chat_async(messages=messages, messages_logger=messages_logger)
|
|
1692
1825
|
attribute_list = self._extract_json(gen_text=gen_text["response"])
|
|
1693
1826
|
attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
|
|
1694
1827
|
return {"frame": frame, "attributes": attributes, "messages": messages}
|
|
@@ -1699,12 +1832,8 @@ class AttributeExtractor(Extractor):
|
|
|
1699
1832
|
|
|
1700
1833
|
# process results
|
|
1701
1834
|
new_frames = []
|
|
1702
|
-
messages_log = [] if return_messages_log else None
|
|
1703
1835
|
|
|
1704
1836
|
for result in results:
|
|
1705
|
-
if return_messages_log:
|
|
1706
|
-
messages_log.append(result["messages"])
|
|
1707
|
-
|
|
1708
1837
|
if inplace:
|
|
1709
1838
|
result["frame"].attr.update(result["attributes"])
|
|
1710
1839
|
else:
|
|
@@ -1714,9 +1843,9 @@ class AttributeExtractor(Extractor):
|
|
|
1714
1843
|
|
|
1715
1844
|
# output
|
|
1716
1845
|
if inplace:
|
|
1717
|
-
return
|
|
1846
|
+
return messages_logger.get_messages_log() if return_messages_log else None
|
|
1718
1847
|
else:
|
|
1719
|
-
return (new_frames,
|
|
1848
|
+
return (new_frames, messages_logger.get_messages_log()) if return_messages_log else new_frames
|
|
1720
1849
|
|
|
1721
1850
|
def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
1722
1851
|
concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
|
|
@@ -1839,7 +1968,7 @@ class RelationExtractor(Extractor):
|
|
|
1839
1968
|
return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1840
1969
|
pairs = itertools.combinations(doc.frames, 2)
|
|
1841
1970
|
relations = []
|
|
1842
|
-
|
|
1971
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1843
1972
|
|
|
1844
1973
|
for frame_1, frame_2 in pairs:
|
|
1845
1974
|
task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
|
|
@@ -1851,20 +1980,14 @@ class RelationExtractor(Extractor):
|
|
|
1851
1980
|
|
|
1852
1981
|
gen_text = self.inference_engine.chat(
|
|
1853
1982
|
messages=task_payload['messages'],
|
|
1854
|
-
verbose=verbose
|
|
1983
|
+
verbose=verbose,
|
|
1984
|
+
messages_logger=messages_logger
|
|
1855
1985
|
)
|
|
1856
1986
|
relation = self._post_process_result(gen_text["response"], task_payload)
|
|
1857
1987
|
if relation:
|
|
1858
1988
|
relations.append(relation)
|
|
1859
1989
|
|
|
1860
|
-
|
|
1861
|
-
message = {"role": "assistant", "content": gen_text["response"]}
|
|
1862
|
-
if "reasoning" in gen_text:
|
|
1863
|
-
message["reasoning"] = gen_text["reasoning"]
|
|
1864
|
-
task_payload['messages'].append(message)
|
|
1865
|
-
messages_log.append(task_payload['messages'])
|
|
1866
|
-
|
|
1867
|
-
return (relations, messages_log) if return_messages_log else relations
|
|
1990
|
+
return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
|
|
1868
1991
|
|
|
1869
1992
|
async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1870
1993
|
pairs = list(itertools.combinations(doc.frames, 2))
|
|
@@ -1873,12 +1996,12 @@ class RelationExtractor(Extractor):
|
|
|
1873
1996
|
tasks_input = [task for task in tasks_input if task is not None]
|
|
1874
1997
|
|
|
1875
1998
|
relations = []
|
|
1876
|
-
|
|
1999
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
1877
2000
|
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1878
2001
|
|
|
1879
2002
|
async def semaphore_helper(task_payload: Dict):
|
|
1880
2003
|
async with semaphore:
|
|
1881
|
-
gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
|
|
2004
|
+
gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'], messages_logger=messages_logger)
|
|
1882
2005
|
return gen_text, task_payload
|
|
1883
2006
|
|
|
1884
2007
|
tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
|
|
@@ -1889,14 +2012,7 @@ class RelationExtractor(Extractor):
|
|
|
1889
2012
|
if relation:
|
|
1890
2013
|
relations.append(relation)
|
|
1891
2014
|
|
|
1892
|
-
|
|
1893
|
-
message = {"role": "assistant", "content": gen_text["response"]}
|
|
1894
|
-
if "reasoning" in gen_text:
|
|
1895
|
-
message["reasoning"] = gen_text["reasoning"]
|
|
1896
|
-
task_payload['messages'].append(message)
|
|
1897
|
-
messages_log.append(task_payload['messages'])
|
|
1898
|
-
|
|
1899
|
-
return (relations, messages_log) if return_messages_log else relations
|
|
2015
|
+
return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
|
|
1900
2016
|
|
|
1901
2017
|
def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
|
|
1902
2018
|
if not doc.has_frame():
|