llm-ie 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +4 -4
- llm_ie/asset/prompt_guide/AttributeExtractor_prompt_guide.txt +52 -0
- llm_ie/engines.py +497 -250
- llm_ie/extractors.py +479 -681
- llm_ie/prompt_editor.py +13 -13
- {llm_ie-1.0.0.dist-info → llm_ie-1.2.0.dist-info}/METADATA +2 -2
- {llm_ie-1.0.0.dist-info → llm_ie-1.2.0.dist-info}/RECORD +8 -7
- {llm_ie-1.0.0.dist-info → llm_ie-1.2.0.dist-info}/WHEEL +0 -0
llm_ie/extractors.py
CHANGED
|
@@ -17,7 +17,7 @@ from colorama import Fore, Style
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class Extractor:
|
|
20
|
-
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None
|
|
20
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
|
|
21
21
|
"""
|
|
22
22
|
This is the abstract class for (frame and relation) extractors.
|
|
23
23
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
@@ -172,7 +172,7 @@ class Extractor:
|
|
|
172
172
|
class FrameExtractor(Extractor):
|
|
173
173
|
from nltk.tokenize import RegexpTokenizer
|
|
174
174
|
def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
|
|
175
|
-
prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None
|
|
175
|
+
prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None):
|
|
176
176
|
"""
|
|
177
177
|
This is the abstract class for frame extraction.
|
|
178
178
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
@@ -192,8 +192,7 @@ class FrameExtractor(Extractor):
|
|
|
192
192
|
"""
|
|
193
193
|
super().__init__(inference_engine=inference_engine,
|
|
194
194
|
prompt_template=prompt_template,
|
|
195
|
-
system_prompt=system_prompt
|
|
196
|
-
**kwrs)
|
|
195
|
+
system_prompt=system_prompt)
|
|
197
196
|
|
|
198
197
|
self.unit_chunker = unit_chunker
|
|
199
198
|
if context_chunker is None:
|
|
@@ -332,7 +331,7 @@ class FrameExtractor(Extractor):
|
|
|
332
331
|
return entity_spans
|
|
333
332
|
|
|
334
333
|
@abc.abstractmethod
|
|
335
|
-
def extract(self, text_content:Union[str, Dict[str,str]],
|
|
334
|
+
def extract(self, text_content:Union[str, Dict[str,str]], return_messages_log:bool=False, **kwrs) -> str:
|
|
336
335
|
"""
|
|
337
336
|
This method inputs text content and outputs a string generated by LLM
|
|
338
337
|
|
|
@@ -342,8 +341,6 @@ class FrameExtractor(Extractor):
|
|
|
342
341
|
the input text content to put in prompt template.
|
|
343
342
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
344
343
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
345
|
-
max_new_tokens : str, Optional
|
|
346
|
-
the max number of new tokens LLM can generate.
|
|
347
344
|
return_messages_log : bool, Optional
|
|
348
345
|
if True, a list of messages will be returned.
|
|
349
346
|
|
|
@@ -354,7 +351,7 @@ class FrameExtractor(Extractor):
|
|
|
354
351
|
|
|
355
352
|
|
|
356
353
|
@abc.abstractmethod
|
|
357
|
-
def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str,
|
|
354
|
+
def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str,
|
|
358
355
|
document_key:str=None, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
|
|
359
356
|
"""
|
|
360
357
|
This method inputs text content and outputs a list of LLMInformationExtractionFrame
|
|
@@ -368,8 +365,6 @@ class FrameExtractor(Extractor):
|
|
|
368
365
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
369
366
|
entity_key : str
|
|
370
367
|
the key (in ouptut JSON) for entity text. Any extraction that does not include entity key will be dropped.
|
|
371
|
-
max_new_tokens : str, Optional
|
|
372
|
-
the max number of new tokens LLM should generate.
|
|
373
368
|
document_key : str, Optional
|
|
374
369
|
specify the key in text_content where document text is.
|
|
375
370
|
If text_content is str, this parameter will be ignored.
|
|
@@ -384,7 +379,7 @@ class FrameExtractor(Extractor):
|
|
|
384
379
|
|
|
385
380
|
class DirectFrameExtractor(FrameExtractor):
|
|
386
381
|
def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
|
|
387
|
-
prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None
|
|
382
|
+
prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None):
|
|
388
383
|
"""
|
|
389
384
|
This class is for general unit-context frame extraction.
|
|
390
385
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
@@ -406,12 +401,11 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
406
401
|
unit_chunker=unit_chunker,
|
|
407
402
|
prompt_template=prompt_template,
|
|
408
403
|
system_prompt=system_prompt,
|
|
409
|
-
context_chunker=context_chunker
|
|
410
|
-
**kwrs)
|
|
404
|
+
context_chunker=context_chunker)
|
|
411
405
|
|
|
412
406
|
|
|
413
|
-
def extract(self, text_content:Union[str, Dict[str,str]],
|
|
414
|
-
document_key:str=None,
|
|
407
|
+
def extract(self, text_content:Union[str, Dict[str,str]],
|
|
408
|
+
document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
|
|
415
409
|
"""
|
|
416
410
|
This method inputs a text and outputs a list of outputs per unit.
|
|
417
411
|
|
|
@@ -421,13 +415,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
421
415
|
the input text content to put in prompt template.
|
|
422
416
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
423
417
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
424
|
-
max_new_tokens : int, Optional
|
|
425
|
-
the max number of new tokens LLM should generate.
|
|
426
418
|
document_key : str, Optional
|
|
427
419
|
specify the key in text_content where document text is.
|
|
428
420
|
If text_content is str, this parameter will be ignored.
|
|
429
|
-
temperature : float, Optional
|
|
430
|
-
the temperature for token sampling.
|
|
431
421
|
verbose : bool, Optional
|
|
432
422
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
433
423
|
return_messages_log : bool, Optional
|
|
@@ -491,27 +481,12 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
491
481
|
|
|
492
482
|
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
493
483
|
|
|
494
|
-
response_stream = self.inference_engine.chat(
|
|
495
|
-
messages=messages,
|
|
496
|
-
max_new_tokens=max_new_tokens,
|
|
497
|
-
temperature=temperature,
|
|
498
|
-
stream=True,
|
|
499
|
-
**kwrs
|
|
500
|
-
)
|
|
501
|
-
|
|
502
|
-
gen_text = ""
|
|
503
|
-
for chunk in response_stream:
|
|
504
|
-
gen_text += chunk
|
|
505
|
-
print(chunk, end='', flush=True)
|
|
506
484
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
stream=False,
|
|
513
|
-
**kwrs
|
|
514
|
-
)
|
|
485
|
+
gen_text = self.inference_engine.chat(
|
|
486
|
+
messages=messages,
|
|
487
|
+
verbose=verbose,
|
|
488
|
+
stream=False
|
|
489
|
+
)
|
|
515
490
|
|
|
516
491
|
if return_messages_log:
|
|
517
492
|
messages.append({"role": "assistant", "content": gen_text})
|
|
@@ -530,8 +505,8 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
530
505
|
|
|
531
506
|
return output
|
|
532
507
|
|
|
533
|
-
def stream(self, text_content: Union[str, Dict[str, str]],
|
|
534
|
-
|
|
508
|
+
def stream(self, text_content: Union[str, Dict[str, str]],
|
|
509
|
+
document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
|
|
535
510
|
"""
|
|
536
511
|
Streams LLM responses per unit with structured event types,
|
|
537
512
|
and returns collected data for post-processing.
|
|
@@ -542,7 +517,8 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
542
517
|
- {"type": "info", "data": str_message}: General informational messages.
|
|
543
518
|
- {"type": "unit", "data": dict_unit_info}: Signals start of a new unit. dict_unit_info contains {'id', 'text', 'start', 'end'}
|
|
544
519
|
- {"type": "context", "data": str_context}: Context string for the current unit.
|
|
545
|
-
- {"type": "
|
|
520
|
+
- {"type": "reasoning", "data": str_chunk}: A reasoning model thinking chunk from the LLM.
|
|
521
|
+
- {"type": "response", "data": str_chunk}: A response/answer chunk from the LLM.
|
|
546
522
|
|
|
547
523
|
Returns:
|
|
548
524
|
--------
|
|
@@ -601,13 +577,10 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
601
577
|
|
|
602
578
|
response_stream = self.inference_engine.chat(
|
|
603
579
|
messages=messages,
|
|
604
|
-
|
|
605
|
-
temperature=temperature,
|
|
606
|
-
stream=True,
|
|
607
|
-
**kwrs
|
|
580
|
+
stream=True
|
|
608
581
|
)
|
|
609
582
|
for chunk in response_stream:
|
|
610
|
-
yield
|
|
583
|
+
yield chunk
|
|
611
584
|
current_gen_text += chunk
|
|
612
585
|
|
|
613
586
|
# Store the result for this unit
|
|
@@ -622,8 +595,8 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
622
595
|
yield {"type": "info", "data": "All units processed by LLM."}
|
|
623
596
|
return collected_results
|
|
624
597
|
|
|
625
|
-
async def extract_async(self, text_content:Union[str, Dict[str,str]],
|
|
626
|
-
concurrent_batch_size:int=32, return_messages_log:bool=False
|
|
598
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
599
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
|
|
627
600
|
"""
|
|
628
601
|
This is the asynchronous version of the extract() method.
|
|
629
602
|
|
|
@@ -633,13 +606,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
633
606
|
the input text content to put in prompt template.
|
|
634
607
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
635
608
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
636
|
-
max_new_tokens : int, Optional
|
|
637
|
-
the max number of new tokens LLM should generate.
|
|
638
609
|
document_key : str, Optional
|
|
639
610
|
specify the key in text_content where document text is.
|
|
640
611
|
If text_content is str, this parameter will be ignored.
|
|
641
|
-
temperature : float, Optional
|
|
642
|
-
the temperature for token sampling.
|
|
643
612
|
concurrent_batch_size : int, Optional
|
|
644
613
|
the batch size for concurrent processing.
|
|
645
614
|
return_messages_log : bool, Optional
|
|
@@ -701,17 +670,14 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
701
670
|
# Process units concurrently with asyncio.Semaphore
|
|
702
671
|
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
703
672
|
|
|
704
|
-
async def semaphore_helper(task_data: Dict,
|
|
673
|
+
async def semaphore_helper(task_data: Dict, **kwrs):
|
|
705
674
|
unit = task_data["unit"]
|
|
706
675
|
messages = task_data["messages"]
|
|
707
676
|
original_index = task_data["original_index"]
|
|
708
677
|
|
|
709
678
|
async with semaphore:
|
|
710
679
|
gen_text = await self.inference_engine.chat_async(
|
|
711
|
-
messages=messages
|
|
712
|
-
max_new_tokens=max_new_tokens,
|
|
713
|
-
temperature=temperature,
|
|
714
|
-
**kwrs
|
|
680
|
+
messages=messages
|
|
715
681
|
)
|
|
716
682
|
return {"original_index": original_index, "unit": unit, "gen_text": gen_text, "messages": messages}
|
|
717
683
|
|
|
@@ -719,10 +685,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
719
685
|
tasks = []
|
|
720
686
|
for task_inp in tasks_input:
|
|
721
687
|
task = asyncio.create_task(semaphore_helper(
|
|
722
|
-
task_inp
|
|
723
|
-
max_new_tokens=max_new_tokens,
|
|
724
|
-
temperature=temperature,
|
|
725
|
-
**kwrs
|
|
688
|
+
task_inp
|
|
726
689
|
))
|
|
727
690
|
tasks.append(task)
|
|
728
691
|
|
|
@@ -759,11 +722,10 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
759
722
|
return output
|
|
760
723
|
|
|
761
724
|
|
|
762
|
-
def extract_frames(self, text_content:Union[str, Dict[str,str]],
|
|
763
|
-
|
|
764
|
-
concurrent:bool=False, concurrent_batch_size:int=32,
|
|
725
|
+
def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
726
|
+
verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32,
|
|
765
727
|
case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
766
|
-
allow_overlap_entities:bool=False, return_messages_log:bool=False
|
|
728
|
+
allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
|
|
767
729
|
"""
|
|
768
730
|
This method inputs a text and outputs a list of LLMInformationExtractionFrame
|
|
769
731
|
It use the extract() method and post-process outputs into frames.
|
|
@@ -774,13 +736,9 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
774
736
|
the input text content to put in prompt template.
|
|
775
737
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
776
738
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
777
|
-
max_new_tokens : str, Optional
|
|
778
|
-
the max number of new tokens LLM should generate.
|
|
779
739
|
document_key : str, Optional
|
|
780
740
|
specify the key in text_content where document text is.
|
|
781
741
|
If text_content is str, this parameter will be ignored.
|
|
782
|
-
temperature : float, Optional
|
|
783
|
-
the temperature for token sampling.
|
|
784
742
|
verbose : bool, Optional
|
|
785
743
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
786
744
|
concurrent : bool, Optional
|
|
@@ -812,21 +770,15 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
812
770
|
|
|
813
771
|
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
814
772
|
extraction_results = asyncio.run(self.extract_async(text_content=text_content,
|
|
815
|
-
max_new_tokens=max_new_tokens,
|
|
816
773
|
document_key=document_key,
|
|
817
|
-
temperature=temperature,
|
|
818
774
|
concurrent_batch_size=concurrent_batch_size,
|
|
819
|
-
return_messages_log=return_messages_log
|
|
820
|
-
**kwrs)
|
|
775
|
+
return_messages_log=return_messages_log)
|
|
821
776
|
)
|
|
822
777
|
else:
|
|
823
778
|
extraction_results = self.extract(text_content=text_content,
|
|
824
|
-
max_new_tokens=max_new_tokens,
|
|
825
779
|
document_key=document_key,
|
|
826
|
-
temperature=temperature,
|
|
827
780
|
verbose=verbose,
|
|
828
|
-
return_messages_log=return_messages_log
|
|
829
|
-
**kwrs)
|
|
781
|
+
return_messages_log=return_messages_log)
|
|
830
782
|
|
|
831
783
|
llm_output_results, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
832
784
|
|
|
@@ -869,8 +821,8 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
869
821
|
|
|
870
822
|
|
|
871
823
|
class ReviewFrameExtractor(DirectFrameExtractor):
|
|
872
|
-
def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker,
|
|
873
|
-
|
|
824
|
+
def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
|
|
825
|
+
prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
|
|
874
826
|
"""
|
|
875
827
|
This class add a review step after the DirectFrameExtractor.
|
|
876
828
|
The Review process asks LLM to review its output and:
|
|
@@ -901,8 +853,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
901
853
|
unit_chunker=unit_chunker,
|
|
902
854
|
prompt_template=prompt_template,
|
|
903
855
|
system_prompt=system_prompt,
|
|
904
|
-
context_chunker=context_chunker
|
|
905
|
-
**kwrs)
|
|
856
|
+
context_chunker=context_chunker)
|
|
906
857
|
# check review mode
|
|
907
858
|
if review_mode not in {"addition", "revision"}:
|
|
908
859
|
raise ValueError('review_mode must be one of {"addition", "revision"}.')
|
|
@@ -939,8 +890,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
939
890
|
if self.review_prompt is None:
|
|
940
891
|
raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
|
|
941
892
|
|
|
942
|
-
def extract(self, text_content:Union[str, Dict[str,str]],
|
|
943
|
-
|
|
893
|
+
def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
894
|
+
verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
|
|
944
895
|
"""
|
|
945
896
|
This method inputs a text and outputs a list of outputs per unit.
|
|
946
897
|
|
|
@@ -950,13 +901,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
950
901
|
the input text content to put in prompt template.
|
|
951
902
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
952
903
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
953
|
-
max_new_tokens : int, Optional
|
|
954
|
-
the max number of new tokens LLM should generate.
|
|
955
904
|
document_key : str, Optional
|
|
956
905
|
specify the key in text_content where document text is.
|
|
957
906
|
If text_content is str, this parameter will be ignored.
|
|
958
|
-
temperature : float, Optional
|
|
959
|
-
the temperature for token sampling.
|
|
960
907
|
verbose : bool, Optional
|
|
961
908
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
962
909
|
return_messages_log : bool, Optional
|
|
@@ -1020,28 +967,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1020
967
|
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
1021
968
|
|
|
1022
969
|
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1023
|
-
|
|
1024
|
-
response_stream = self.inference_engine.chat(
|
|
1025
|
-
messages=messages,
|
|
1026
|
-
max_new_tokens=max_new_tokens,
|
|
1027
|
-
temperature=temperature,
|
|
1028
|
-
stream=True,
|
|
1029
|
-
**kwrs
|
|
1030
|
-
)
|
|
1031
|
-
|
|
1032
|
-
initial = ""
|
|
1033
|
-
for chunk in response_stream:
|
|
1034
|
-
initial += chunk
|
|
1035
|
-
print(chunk, end='', flush=True)
|
|
1036
970
|
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
**kwrs
|
|
1044
|
-
)
|
|
971
|
+
|
|
972
|
+
initial = self.inference_engine.chat(
|
|
973
|
+
messages=messages,
|
|
974
|
+
verbose=verbose,
|
|
975
|
+
stream=False
|
|
976
|
+
)
|
|
1045
977
|
|
|
1046
978
|
if return_messages_log:
|
|
1047
979
|
messages.append({"role": "assistant", "content": initial})
|
|
@@ -1053,29 +985,12 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1053
985
|
|
|
1054
986
|
messages.append({'role': 'assistant', 'content': initial})
|
|
1055
987
|
messages.append({'role': 'user', 'content': self.review_prompt})
|
|
1056
|
-
|
|
1057
|
-
if verbose:
|
|
1058
|
-
response_stream = self.inference_engine.chat(
|
|
1059
|
-
messages=messages,
|
|
1060
|
-
max_new_tokens=max_new_tokens,
|
|
1061
|
-
temperature=temperature,
|
|
1062
|
-
stream=True,
|
|
1063
|
-
**kwrs
|
|
1064
|
-
)
|
|
1065
|
-
|
|
1066
|
-
review = ""
|
|
1067
|
-
for chunk in response_stream:
|
|
1068
|
-
review += chunk
|
|
1069
|
-
print(chunk, end='', flush=True)
|
|
1070
988
|
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
stream=False,
|
|
1077
|
-
**kwrs
|
|
1078
|
-
)
|
|
989
|
+
review = self.inference_engine.chat(
|
|
990
|
+
messages=messages,
|
|
991
|
+
verbose=verbose,
|
|
992
|
+
stream=False
|
|
993
|
+
)
|
|
1079
994
|
|
|
1080
995
|
# Output
|
|
1081
996
|
if self.review_mode == "revision":
|
|
@@ -1101,8 +1016,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1101
1016
|
return output
|
|
1102
1017
|
|
|
1103
1018
|
|
|
1104
|
-
def stream(self, text_content:Union[str, Dict[str,str]],
|
|
1105
|
-
document_key:str=None, temperature:float=0.0, **kwrs) -> Generator[str, None, None]:
|
|
1019
|
+
def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
|
|
1106
1020
|
"""
|
|
1107
1021
|
This method inputs a text and outputs a list of outputs per unit.
|
|
1108
1022
|
|
|
@@ -1112,13 +1026,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1112
1026
|
the input text content to put in prompt template.
|
|
1113
1027
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
1114
1028
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
1115
|
-
max_new_tokens : int, Optional
|
|
1116
|
-
the max number of new tokens LLM should generate.
|
|
1117
1029
|
document_key : str, Optional
|
|
1118
1030
|
specify the key in text_content where document text is.
|
|
1119
1031
|
If text_content is str, this parameter will be ignored.
|
|
1120
|
-
temperature : float, Optional
|
|
1121
|
-
the temperature for token sampling.
|
|
1122
1032
|
|
|
1123
1033
|
Return : List[FrameExtractionUnitResult]
|
|
1124
1034
|
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
@@ -1176,10 +1086,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1176
1086
|
|
|
1177
1087
|
response_stream = self.inference_engine.chat(
|
|
1178
1088
|
messages=messages,
|
|
1179
|
-
|
|
1180
|
-
temperature=temperature,
|
|
1181
|
-
stream=True,
|
|
1182
|
-
**kwrs
|
|
1089
|
+
stream=True
|
|
1183
1090
|
)
|
|
1184
1091
|
|
|
1185
1092
|
initial = ""
|
|
@@ -1195,16 +1102,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1195
1102
|
|
|
1196
1103
|
response_stream = self.inference_engine.chat(
|
|
1197
1104
|
messages=messages,
|
|
1198
|
-
|
|
1199
|
-
temperature=temperature,
|
|
1200
|
-
stream=True,
|
|
1201
|
-
**kwrs
|
|
1105
|
+
stream=True
|
|
1202
1106
|
)
|
|
1203
1107
|
|
|
1204
1108
|
for chunk in response_stream:
|
|
1205
1109
|
yield chunk
|
|
1206
1110
|
|
|
1207
|
-
async def extract_async(self, text_content:Union[str, Dict[str,str]],
|
|
1111
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
1208
1112
|
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
|
|
1209
1113
|
"""
|
|
1210
1114
|
This is the asynchronous version of the extract() method with the review step.
|
|
@@ -1215,13 +1119,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1215
1119
|
the input text content to put in prompt template.
|
|
1216
1120
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
1217
1121
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
1218
|
-
max_new_tokens : int, Optional
|
|
1219
|
-
the max number of new tokens LLM should generate.
|
|
1220
1122
|
document_key : str, Optional
|
|
1221
1123
|
specify the key in text_content where document text is.
|
|
1222
1124
|
If text_content is str, this parameter will be ignored.
|
|
1223
|
-
temperature : float, Optional
|
|
1224
|
-
the temperature for token sampling.
|
|
1225
1125
|
concurrent_batch_size : int, Optional
|
|
1226
1126
|
the batch size for concurrent processing.
|
|
1227
1127
|
return_messages_log : bool, Optional
|
|
@@ -1282,17 +1182,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1282
1182
|
|
|
1283
1183
|
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1284
1184
|
|
|
1285
|
-
async def initial_semaphore_helper(task_data: Dict
|
|
1185
|
+
async def initial_semaphore_helper(task_data: Dict):
|
|
1286
1186
|
unit = task_data["unit"]
|
|
1287
1187
|
messages = task_data["messages"]
|
|
1288
1188
|
original_index = task_data["original_index"]
|
|
1289
1189
|
|
|
1290
1190
|
async with semaphore:
|
|
1291
1191
|
gen_text = await self.inference_engine.chat_async(
|
|
1292
|
-
messages=messages
|
|
1293
|
-
max_new_tokens=max_new_tokens,
|
|
1294
|
-
temperature=temperature,
|
|
1295
|
-
**kwrs
|
|
1192
|
+
messages=messages
|
|
1296
1193
|
)
|
|
1297
1194
|
# Return initial generation result along with the messages used and the unit
|
|
1298
1195
|
return {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text, "initial_messages": messages}
|
|
@@ -1300,10 +1197,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1300
1197
|
# Create and gather initial generation tasks
|
|
1301
1198
|
initial_tasks = [
|
|
1302
1199
|
asyncio.create_task(initial_semaphore_helper(
|
|
1303
|
-
task_inp
|
|
1304
|
-
max_new_tokens=max_new_tokens,
|
|
1305
|
-
temperature=temperature,
|
|
1306
|
-
**kwrs
|
|
1200
|
+
task_inp
|
|
1307
1201
|
))
|
|
1308
1202
|
for task_inp in initial_tasks_input
|
|
1309
1203
|
]
|
|
@@ -1333,16 +1227,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1333
1227
|
})
|
|
1334
1228
|
|
|
1335
1229
|
|
|
1336
|
-
async def review_semaphore_helper(task_data: Dict,
|
|
1230
|
+
async def review_semaphore_helper(task_data: Dict, **kwrs):
|
|
1337
1231
|
messages = task_data["messages"]
|
|
1338
1232
|
original_index = task_data["original_index"]
|
|
1339
1233
|
|
|
1340
1234
|
async with semaphore:
|
|
1341
1235
|
review_gen_text = await self.inference_engine.chat_async(
|
|
1342
|
-
messages=messages
|
|
1343
|
-
max_new_tokens=max_new_tokens,
|
|
1344
|
-
temperature=temperature,
|
|
1345
|
-
**kwrs
|
|
1236
|
+
messages=messages
|
|
1346
1237
|
)
|
|
1347
1238
|
# Combine initial and review results
|
|
1348
1239
|
task_data["review_gen_text"] = review_gen_text
|
|
@@ -1354,10 +1245,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1354
1245
|
# Create and gather review tasks
|
|
1355
1246
|
review_tasks = [
|
|
1356
1247
|
asyncio.create_task(review_semaphore_helper(
|
|
1357
|
-
task_inp
|
|
1358
|
-
max_new_tokens=max_new_tokens,
|
|
1359
|
-
temperature=temperature,
|
|
1360
|
-
**kwrs
|
|
1248
|
+
task_inp
|
|
1361
1249
|
))
|
|
1362
1250
|
for task_inp in review_tasks_input
|
|
1363
1251
|
]
|
|
@@ -1405,7 +1293,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
|
|
|
1405
1293
|
|
|
1406
1294
|
|
|
1407
1295
|
class BasicFrameExtractor(DirectFrameExtractor):
|
|
1408
|
-
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None
|
|
1296
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
|
|
1409
1297
|
"""
|
|
1410
1298
|
This class diretly prompt LLM for frame extraction.
|
|
1411
1299
|
Input system prompt (optional), prompt template (with instruction, few-shot examples),
|
|
@@ -1424,11 +1312,10 @@ class BasicFrameExtractor(DirectFrameExtractor):
|
|
|
1424
1312
|
unit_chunker=WholeDocumentUnitChunker(),
|
|
1425
1313
|
prompt_template=prompt_template,
|
|
1426
1314
|
system_prompt=system_prompt,
|
|
1427
|
-
context_chunker=NoContextChunker()
|
|
1428
|
-
**kwrs)
|
|
1315
|
+
context_chunker=NoContextChunker())
|
|
1429
1316
|
|
|
1430
1317
|
class BasicReviewFrameExtractor(ReviewFrameExtractor):
|
|
1431
|
-
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None
|
|
1318
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
|
|
1432
1319
|
"""
|
|
1433
1320
|
This class add a review step after the BasicFrameExtractor.
|
|
1434
1321
|
The Review process asks LLM to review its output and:
|
|
@@ -1457,13 +1344,12 @@ class BasicReviewFrameExtractor(ReviewFrameExtractor):
|
|
|
1457
1344
|
review_mode=review_mode,
|
|
1458
1345
|
review_prompt=review_prompt,
|
|
1459
1346
|
system_prompt=system_prompt,
|
|
1460
|
-
context_chunker=NoContextChunker()
|
|
1461
|
-
**kwrs)
|
|
1347
|
+
context_chunker=NoContextChunker())
|
|
1462
1348
|
|
|
1463
1349
|
|
|
1464
1350
|
class SentenceFrameExtractor(DirectFrameExtractor):
|
|
1465
1351
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
|
|
1466
|
-
context_sentences:Union[str, int]="all"
|
|
1352
|
+
context_sentences:Union[str, int]="all"):
|
|
1467
1353
|
"""
|
|
1468
1354
|
This class performs sentence-by-sentence information extraction.
|
|
1469
1355
|
The process is as follows:
|
|
@@ -1507,14 +1393,13 @@ class SentenceFrameExtractor(DirectFrameExtractor):
|
|
|
1507
1393
|
unit_chunker=SentenceUnitChunker(),
|
|
1508
1394
|
prompt_template=prompt_template,
|
|
1509
1395
|
system_prompt=system_prompt,
|
|
1510
|
-
context_chunker=context_chunker
|
|
1511
|
-
**kwrs)
|
|
1396
|
+
context_chunker=context_chunker)
|
|
1512
1397
|
|
|
1513
1398
|
|
|
1514
1399
|
class SentenceReviewFrameExtractor(ReviewFrameExtractor):
|
|
1515
1400
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
|
|
1516
1401
|
review_mode:str, review_prompt:str=None, system_prompt:str=None,
|
|
1517
|
-
context_sentences:Union[str, int]="all"
|
|
1402
|
+
context_sentences:Union[str, int]="all"):
|
|
1518
1403
|
"""
|
|
1519
1404
|
This class adds a review step after the SentenceFrameExtractor.
|
|
1520
1405
|
For each sentence, the review process asks LLM to review its output and:
|
|
@@ -1561,15 +1446,14 @@ class SentenceReviewFrameExtractor(ReviewFrameExtractor):
|
|
|
1561
1446
|
review_mode=review_mode,
|
|
1562
1447
|
review_prompt=review_prompt,
|
|
1563
1448
|
system_prompt=system_prompt,
|
|
1564
|
-
context_chunker=context_chunker
|
|
1565
|
-
**kwrs)
|
|
1449
|
+
context_chunker=context_chunker)
|
|
1566
1450
|
|
|
1567
1451
|
|
|
1568
|
-
class
|
|
1569
|
-
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None
|
|
1452
|
+
class AttributeExtractor(Extractor):
|
|
1453
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
|
|
1570
1454
|
"""
|
|
1571
|
-
This is
|
|
1572
|
-
|
|
1455
|
+
This class is for attribute extraction for frames. Though FrameExtractors can also extract attributes, when
|
|
1456
|
+
the number of attribute increases, it is more efficient to use a dedicated AttributeExtractor.
|
|
1573
1457
|
|
|
1574
1458
|
Parameters
|
|
1575
1459
|
----------
|
|
@@ -1582,350 +1466,475 @@ class RelationExtractor(Extractor):
|
|
|
1582
1466
|
"""
|
|
1583
1467
|
super().__init__(inference_engine=inference_engine,
|
|
1584
1468
|
prompt_template=prompt_template,
|
|
1585
|
-
system_prompt=system_prompt
|
|
1586
|
-
|
|
1469
|
+
system_prompt=system_prompt)
|
|
1470
|
+
# validate prompt template
|
|
1471
|
+
if "{{context}}" not in self.prompt_template or "{{frame}}" not in self.prompt_template:
|
|
1472
|
+
raise ValueError("prompt_template must contain both {{context}} and {{frame}} placeholders.")
|
|
1587
1473
|
|
|
1588
|
-
def
|
|
1589
|
-
text:str, buffer_size:int=100) -> str:
|
|
1474
|
+
def _get_context(self, frame:LLMInformationExtractionFrame, text:str, context_size:int=256) -> str:
|
|
1590
1475
|
"""
|
|
1591
|
-
This method returns the
|
|
1592
|
-
The returned text has the
|
|
1476
|
+
This method returns the context that covers the frame. Leaves a context_size of characters before and after.
|
|
1477
|
+
The returned text has the frame inline annotated with <entity>.
|
|
1593
1478
|
|
|
1594
1479
|
Parameters:
|
|
1595
1480
|
-----------
|
|
1596
|
-
|
|
1481
|
+
frame : LLMInformationExtractionFrame
|
|
1597
1482
|
a frame
|
|
1598
|
-
frame_2 : LLMInformationExtractionFrame
|
|
1599
|
-
the other frame
|
|
1600
1483
|
text : str
|
|
1601
1484
|
the entire document text
|
|
1602
|
-
|
|
1603
|
-
the number of characters before and after the
|
|
1485
|
+
context_size : int, Optional
|
|
1486
|
+
the number of characters before and after the frame in the context text.
|
|
1604
1487
|
|
|
1605
1488
|
Return : str
|
|
1606
|
-
the
|
|
1489
|
+
the context text with the frame inline annotated with <entity>.
|
|
1607
1490
|
"""
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
start = max(left_frame.start - buffer_size, 0)
|
|
1613
|
-
end = min(right_frame.end + buffer_size, len(text))
|
|
1614
|
-
roi = text[start:end]
|
|
1491
|
+
start = max(frame.start - context_size, 0)
|
|
1492
|
+
end = min(frame.end + context_size, len(text))
|
|
1493
|
+
context = text[start:end]
|
|
1615
1494
|
|
|
1616
|
-
|
|
1617
|
-
f
|
|
1618
|
-
|
|
1619
|
-
f"</
|
|
1620
|
-
|
|
1621
|
-
f'<{right_frame_name}>' + \
|
|
1622
|
-
roi[right_frame.start - start:right_frame.end - start] + \
|
|
1623
|
-
f"</{right_frame_name}>" + \
|
|
1624
|
-
roi[right_frame.end - start:end - start]
|
|
1495
|
+
context_annotated = context[0:frame.start - start] + \
|
|
1496
|
+
f"<entity> " + \
|
|
1497
|
+
context[frame.start - start:frame.end - start] + \
|
|
1498
|
+
f" </entity>" + \
|
|
1499
|
+
context[frame.end - start:end - start]
|
|
1625
1500
|
|
|
1626
1501
|
if start > 0:
|
|
1627
|
-
|
|
1502
|
+
context_annotated = "..." + context_annotated
|
|
1628
1503
|
if end < len(text):
|
|
1629
|
-
|
|
1630
|
-
return
|
|
1504
|
+
context_annotated = context_annotated + "..."
|
|
1505
|
+
return context_annotated
|
|
1631
1506
|
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1635
|
-
temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1507
|
+
def _extract_from_frame(self, frame:LLMInformationExtractionFrame, text:str,
|
|
1508
|
+
context_size:int=256, verbose:bool=False, return_messages_log:bool=False) -> Dict[str, Any]:
|
|
1636
1509
|
"""
|
|
1637
|
-
This method
|
|
1510
|
+
This method extracts attributes from a single frame.
|
|
1638
1511
|
|
|
1639
1512
|
Parameters:
|
|
1640
1513
|
-----------
|
|
1641
|
-
|
|
1642
|
-
a
|
|
1643
|
-
|
|
1644
|
-
the
|
|
1645
|
-
|
|
1646
|
-
the
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
stream : bool, Optional
|
|
1650
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
1514
|
+
frame : LLMInformationExtractionFrame
|
|
1515
|
+
a frame to extract attributes from.
|
|
1516
|
+
text : str
|
|
1517
|
+
the entire document text.
|
|
1518
|
+
context_size : int, Optional
|
|
1519
|
+
the number of characters before and after the frame in the context text.
|
|
1520
|
+
verbose : bool, Optional
|
|
1521
|
+
if True, LLM generated text will be printed in terminal in real-time.
|
|
1651
1522
|
return_messages_log : bool, Optional
|
|
1652
1523
|
if True, a list of messages will be returned.
|
|
1653
1524
|
|
|
1654
|
-
Return :
|
|
1655
|
-
a
|
|
1525
|
+
Return : Dict[str, Any]
|
|
1526
|
+
a dictionary of attributes extracted from the frame.
|
|
1527
|
+
If return_messages_log is True, a list of messages will be returned as well.
|
|
1656
1528
|
"""
|
|
1657
|
-
|
|
1658
|
-
|
|
1529
|
+
# construct chat messages
|
|
1530
|
+
messages = []
|
|
1531
|
+
if self.system_prompt:
|
|
1532
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1659
1533
|
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
system_prompt:str=None, **kwrs):
|
|
1663
|
-
"""
|
|
1664
|
-
This class extracts binary (yes/no) relations between two entities.
|
|
1665
|
-
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
1534
|
+
context = self._get_context(frame, text, context_size)
|
|
1535
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
|
|
1666
1536
|
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
prompt_template : str
|
|
1672
|
-
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1673
|
-
possible_relation_func : Callable, Optional
|
|
1674
|
-
a function that inputs 2 frames and returns a bool indicating possible relations between them.
|
|
1675
|
-
system_prompt : str, Optional
|
|
1676
|
-
system prompt.
|
|
1677
|
-
"""
|
|
1678
|
-
super().__init__(inference_engine=inference_engine,
|
|
1679
|
-
prompt_template=prompt_template,
|
|
1680
|
-
system_prompt=system_prompt,
|
|
1681
|
-
**kwrs)
|
|
1682
|
-
|
|
1683
|
-
if possible_relation_func:
|
|
1684
|
-
# Check if possible_relation_func is a function
|
|
1685
|
-
if not callable(possible_relation_func):
|
|
1686
|
-
raise TypeError(f"Expect possible_relation_func as a function, received {type(possible_relation_func)} instead.")
|
|
1537
|
+
if verbose:
|
|
1538
|
+
print(f"\n\n{Fore.GREEN}Frame: {frame.frame_id}{Style.RESET_ALL}\n{frame.to_dict()}\n")
|
|
1539
|
+
if context != "":
|
|
1540
|
+
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
1687
1541
|
|
|
1688
|
-
|
|
1689
|
-
# Check if frame_1, frame_2 are in input parameters
|
|
1690
|
-
if len(sig.parameters) != 2:
|
|
1691
|
-
raise ValueError("The possible_relation_func must have exactly frame_1 and frame_2 as parameters.")
|
|
1692
|
-
if "frame_1" not in sig.parameters.keys():
|
|
1693
|
-
raise ValueError("The possible_relation_func is missing frame_1 as a parameter.")
|
|
1694
|
-
if "frame_2" not in sig.parameters.keys():
|
|
1695
|
-
raise ValueError("The possible_relation_func is missing frame_2 as a parameter.")
|
|
1696
|
-
# Check if output is a bool
|
|
1697
|
-
if sig.return_annotation != bool:
|
|
1698
|
-
raise ValueError(f"Expect possible_relation_func to output a bool, current type hint suggests {sig.return_annotation} instead.")
|
|
1699
|
-
|
|
1700
|
-
self.possible_relation_func = possible_relation_func
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
def _post_process(self, rel_json:str) -> bool:
|
|
1704
|
-
if len(rel_json) > 0:
|
|
1705
|
-
if "Relation" in rel_json[0]:
|
|
1706
|
-
rel = rel_json[0]["Relation"]
|
|
1707
|
-
if isinstance(rel, bool):
|
|
1708
|
-
return rel
|
|
1709
|
-
elif isinstance(rel, str) and rel in {"True", "False"}:
|
|
1710
|
-
return eval(rel)
|
|
1711
|
-
else:
|
|
1712
|
-
warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
|
|
1713
|
-
'Following default, relation = False.', RuntimeWarning)
|
|
1714
|
-
else:
|
|
1715
|
-
warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
|
|
1716
|
-
else:
|
|
1717
|
-
warnings.warn('Extractor did not output a JSON list. Following default, relation = False.', RuntimeWarning)
|
|
1718
|
-
return False
|
|
1719
|
-
|
|
1542
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1720
1543
|
|
|
1721
|
-
|
|
1722
|
-
|
|
1544
|
+
get_text = self.inference_engine.chat(
|
|
1545
|
+
messages=messages,
|
|
1546
|
+
verbose=verbose,
|
|
1547
|
+
stream=False
|
|
1548
|
+
)
|
|
1549
|
+
if return_messages_log:
|
|
1550
|
+
messages.append({"role": "assistant", "content": get_text})
|
|
1551
|
+
|
|
1552
|
+
attribute_list = self._extract_json(gen_text=get_text)
|
|
1553
|
+
if isinstance(attribute_list, list) and len(attribute_list) > 0:
|
|
1554
|
+
attributes = attribute_list[0]
|
|
1555
|
+
if return_messages_log:
|
|
1556
|
+
return attributes, messages
|
|
1557
|
+
return attributes
|
|
1558
|
+
|
|
1559
|
+
|
|
1560
|
+
def extract(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256, verbose:bool=False,
|
|
1561
|
+
return_messages_log:bool=False, inplace:bool=True) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
1723
1562
|
"""
|
|
1724
|
-
This method
|
|
1725
|
-
Outputs pairs that are related.
|
|
1563
|
+
This method extracts attributes from the document.
|
|
1726
1564
|
|
|
1727
1565
|
Parameters:
|
|
1728
1566
|
-----------
|
|
1729
|
-
|
|
1730
|
-
a
|
|
1731
|
-
|
|
1732
|
-
the
|
|
1733
|
-
|
|
1734
|
-
the
|
|
1735
|
-
|
|
1736
|
-
the temperature for token sampling.
|
|
1737
|
-
stream : bool, Optional
|
|
1567
|
+
frames : List[LLMInformationExtractionFrame]
|
|
1568
|
+
a list of frames to extract attributes from.
|
|
1569
|
+
text : str
|
|
1570
|
+
the entire document text.
|
|
1571
|
+
context_size : int, Optional
|
|
1572
|
+
the number of characters before and after the frame in the context text.
|
|
1573
|
+
verbose : bool, Optional
|
|
1738
1574
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1739
1575
|
return_messages_log : bool, Optional
|
|
1740
1576
|
if True, a list of messages will be returned.
|
|
1577
|
+
inplace : bool, Optional
|
|
1578
|
+
if True, the method will modify the frames in-place.
|
|
1579
|
+
|
|
1580
|
+
Return : Union[None, List[LLMInformationExtractionFrame]]
|
|
1581
|
+
if inplace is True, the method will modify the frames in-place.
|
|
1582
|
+
if inplace is False, the method will return a list of frames with attributes extracted.
|
|
1583
|
+
"""
|
|
1584
|
+
for frame in frames:
|
|
1585
|
+
if not isinstance(frame, LLMInformationExtractionFrame):
|
|
1586
|
+
raise TypeError(f"Expect frame as LLMInformationExtractionFrame, received {type(frame)} instead.")
|
|
1587
|
+
if not isinstance(text, str):
|
|
1588
|
+
raise TypeError(f"Expect text as str, received {type(text)} instead.")
|
|
1589
|
+
|
|
1590
|
+
new_frames = []
|
|
1591
|
+
messages_log = [] if return_messages_log else None
|
|
1741
1592
|
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1593
|
+
for frame in frames:
|
|
1594
|
+
if return_messages_log:
|
|
1595
|
+
attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1596
|
+
verbose=verbose, return_messages_log=return_messages_log)
|
|
1597
|
+
messages_log.append(messages)
|
|
1598
|
+
else:
|
|
1599
|
+
attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1600
|
+
verbose=verbose, return_messages_log=return_messages_log)
|
|
1601
|
+
|
|
1602
|
+
if inplace:
|
|
1603
|
+
frame.attr.update(attr)
|
|
1604
|
+
else:
|
|
1605
|
+
new_frame = frame.copy()
|
|
1606
|
+
new_frame.attr.update(attr)
|
|
1607
|
+
new_frames.append(new_frame)
|
|
1746
1608
|
|
|
1747
|
-
if
|
|
1748
|
-
messages_log
|
|
1609
|
+
if inplace:
|
|
1610
|
+
return messages_log if return_messages_log else None
|
|
1611
|
+
else:
|
|
1612
|
+
return (new_frames, messages_log) if return_messages_log else new_frames
|
|
1749
1613
|
|
|
1750
|
-
output = []
|
|
1751
|
-
for frame_1, frame_2 in pairs:
|
|
1752
|
-
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1753
1614
|
|
|
1754
|
-
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1615
|
+
async def extract_async(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
1616
|
+
concurrent_batch_size:int=32, inplace:bool=True, return_messages_log:bool=False) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
1617
|
+
"""
|
|
1618
|
+
This method extracts attributes from the document asynchronously.
|
|
1619
|
+
|
|
1620
|
+
Parameters:
|
|
1621
|
+
-----------
|
|
1622
|
+
frames : List[LLMInformationExtractionFrame]
|
|
1623
|
+
a list of frames to extract attributes from.
|
|
1624
|
+
text : str
|
|
1625
|
+
the entire document text.
|
|
1626
|
+
context_size : int, Optional
|
|
1627
|
+
the number of characters before and after the frame in the context text.
|
|
1628
|
+
concurrent_batch_size : int, Optional
|
|
1629
|
+
the batch size for concurrent processing.
|
|
1630
|
+
inplace : bool, Optional
|
|
1631
|
+
if True, the method will modify the frames in-place.
|
|
1632
|
+
return_messages_log : bool, Optional
|
|
1633
|
+
if True, a list of messages will be returned.
|
|
1634
|
+
|
|
1635
|
+
Return : Union[None, List[LLMInformationExtractionFrame]]
|
|
1636
|
+
if inplace is True, the method will modify the frames in-place.
|
|
1637
|
+
if inplace is False, the method will return a list of frames with attributes extracted.
|
|
1638
|
+
"""
|
|
1639
|
+
# validation
|
|
1640
|
+
for frame in frames:
|
|
1641
|
+
if not isinstance(frame, LLMInformationExtractionFrame):
|
|
1642
|
+
raise TypeError(f"Expect frame as LLMInformationExtractionFrame, received {type(frame)} instead.")
|
|
1643
|
+
if not isinstance(text, str):
|
|
1644
|
+
raise TypeError(f"Expect text as str, received {type(text)} instead.")
|
|
1645
|
+
|
|
1646
|
+
# async helper
|
|
1647
|
+
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1648
|
+
|
|
1649
|
+
async def semaphore_helper(frame:LLMInformationExtractionFrame, text:str, context_size:int) -> dict:
|
|
1650
|
+
async with semaphore:
|
|
1759
1651
|
messages = []
|
|
1760
1652
|
if self.system_prompt:
|
|
1761
1653
|
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1762
1654
|
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
"frame_2": str(frame_2.to_dict())}
|
|
1766
|
-
)})
|
|
1767
|
-
|
|
1768
|
-
gen_text = self.inference_engine.chat(
|
|
1769
|
-
messages=messages,
|
|
1770
|
-
max_new_tokens=max_new_tokens,
|
|
1771
|
-
temperature=temperature,
|
|
1772
|
-
stream=stream,
|
|
1773
|
-
**kwrs
|
|
1774
|
-
)
|
|
1775
|
-
rel_json = self._extract_json(gen_text)
|
|
1776
|
-
if self._post_process(rel_json):
|
|
1777
|
-
output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
|
|
1655
|
+
context = self._get_context(frame, text, context_size)
|
|
1656
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
|
|
1778
1657
|
|
|
1658
|
+
gen_text = await self.inference_engine.chat_async(messages=messages)
|
|
1659
|
+
|
|
1779
1660
|
if return_messages_log:
|
|
1780
1661
|
messages.append({"role": "assistant", "content": gen_text})
|
|
1781
|
-
messages_log.append(messages)
|
|
1782
1662
|
|
|
1783
|
-
|
|
1784
|
-
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
|
|
1663
|
+
attribute_list = self._extract_json(gen_text=gen_text)
|
|
1664
|
+
attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
|
|
1665
|
+
return {"frame": frame, "attributes": attributes, "messages": messages}
|
|
1666
|
+
|
|
1667
|
+
# create tasks
|
|
1668
|
+
tasks = [asyncio.create_task(semaphore_helper(frame, text, context_size)) for frame in frames]
|
|
1669
|
+
results = await asyncio.gather(*tasks)
|
|
1670
|
+
|
|
1671
|
+
# process results
|
|
1672
|
+
new_frames = []
|
|
1673
|
+
messages_log = [] if return_messages_log else None
|
|
1674
|
+
|
|
1675
|
+
for result in results:
|
|
1676
|
+
if return_messages_log:
|
|
1677
|
+
messages_log.append(result["messages"])
|
|
1678
|
+
|
|
1679
|
+
if inplace:
|
|
1680
|
+
result["frame"].attr.update(result["attributes"])
|
|
1681
|
+
else:
|
|
1682
|
+
new_frame = result["frame"].copy()
|
|
1683
|
+
new_frame.attr.update(result["attributes"])
|
|
1684
|
+
new_frames.append(new_frame)
|
|
1685
|
+
|
|
1686
|
+
# output
|
|
1687
|
+
if inplace:
|
|
1688
|
+
return messages_log if return_messages_log else None
|
|
1689
|
+
else:
|
|
1690
|
+
return (new_frames, messages_log) if return_messages_log else new_frames
|
|
1691
|
+
|
|
1692
|
+
def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
1693
|
+
concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
|
|
1694
|
+
return_messages_log:bool=False, inplace:bool=True) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
1790
1695
|
"""
|
|
1791
|
-
This
|
|
1696
|
+
This method extracts attributes from the document.
|
|
1792
1697
|
|
|
1793
1698
|
Parameters:
|
|
1794
1699
|
-----------
|
|
1795
|
-
|
|
1796
|
-
a
|
|
1797
|
-
|
|
1798
|
-
the
|
|
1799
|
-
|
|
1800
|
-
the
|
|
1801
|
-
|
|
1802
|
-
the
|
|
1700
|
+
frames : List[LLMInformationExtractionFrame]
|
|
1701
|
+
a list of frames to extract attributes from.
|
|
1702
|
+
text : str
|
|
1703
|
+
the entire document text.
|
|
1704
|
+
context_size : int, Optional
|
|
1705
|
+
the number of characters before and after the frame in the context text.
|
|
1706
|
+
concurrent : bool, Optional
|
|
1707
|
+
if True, the method will run in concurrent mode with batch size concurrent_batch_size.
|
|
1803
1708
|
concurrent_batch_size : int, Optional
|
|
1804
|
-
the
|
|
1709
|
+
the batch size for concurrent processing.
|
|
1710
|
+
verbose : bool, Optional
|
|
1711
|
+
if True, LLM generated text will be printed in terminal in real-time.
|
|
1805
1712
|
return_messages_log : bool, Optional
|
|
1806
1713
|
if True, a list of messages will be returned.
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
a list of dict with {"frame_1", "frame_2"}.
|
|
1810
|
-
"""
|
|
1811
|
-
# Check if self.inference_engine.chat_async() is implemented
|
|
1812
|
-
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1813
|
-
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1714
|
+
inplace : bool, Optional
|
|
1715
|
+
if True, the method will modify the frames in-place.
|
|
1814
1716
|
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
for i in range(0, num_pairs, concurrent_batch_size):
|
|
1823
|
-
rel_pair_list = []
|
|
1824
|
-
tasks = []
|
|
1825
|
-
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
1826
|
-
batch_messages = []
|
|
1827
|
-
for frame_1, frame_2 in batch:
|
|
1828
|
-
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1829
|
-
|
|
1830
|
-
if pos_rel:
|
|
1831
|
-
rel_pair_list.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
|
|
1832
|
-
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1833
|
-
messages = []
|
|
1834
|
-
if self.system_prompt:
|
|
1835
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1836
|
-
|
|
1837
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1838
|
-
"frame_1": str(frame_1.to_dict()),
|
|
1839
|
-
"frame_2": str(frame_2.to_dict())}
|
|
1840
|
-
)})
|
|
1841
|
-
|
|
1842
|
-
task = asyncio.create_task(
|
|
1843
|
-
self.inference_engine.chat_async(
|
|
1844
|
-
messages=messages,
|
|
1845
|
-
max_new_tokens=max_new_tokens,
|
|
1846
|
-
temperature=temperature,
|
|
1847
|
-
**kwrs
|
|
1848
|
-
)
|
|
1849
|
-
)
|
|
1850
|
-
tasks.append(task)
|
|
1851
|
-
batch_messages.append(messages)
|
|
1717
|
+
Return : Union[None, List[LLMInformationExtractionFrame]]
|
|
1718
|
+
if inplace is True, the method will modify the frames in-place.
|
|
1719
|
+
if inplace is False, the method will return a list of frames with attributes extracted.
|
|
1720
|
+
"""
|
|
1721
|
+
if concurrent:
|
|
1722
|
+
if verbose:
|
|
1723
|
+
warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1852
1724
|
|
|
1853
|
-
|
|
1725
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
1854
1726
|
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1727
|
+
return asyncio.run(self.extract_async(frames=frames, text=text, context_size=context_size,
|
|
1728
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
1729
|
+
inplace=inplace, return_messages_log=return_messages_log))
|
|
1730
|
+
else:
|
|
1731
|
+
return self.extract(frames=frames, text=text, context_size=context_size,
|
|
1732
|
+
verbose=verbose, return_messages_log=return_messages_log, inplace=inplace)
|
|
1859
1733
|
|
|
1860
|
-
rel_json = self._extract_json(response)
|
|
1861
|
-
if self._post_process(rel_json):
|
|
1862
|
-
output.append(d)
|
|
1863
|
-
|
|
1864
|
-
if return_messages_log:
|
|
1865
|
-
return output, messages_log
|
|
1866
|
-
return output
|
|
1867
1734
|
|
|
1735
|
+
class RelationExtractor(Extractor):
|
|
1736
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
|
|
1737
|
+
"""
|
|
1738
|
+
This is the abstract class for relation extraction.
|
|
1739
|
+
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
1868
1740
|
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1741
|
+
Parameters
|
|
1742
|
+
----------
|
|
1743
|
+
inference_engine : InferenceEngine
|
|
1744
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
1745
|
+
prompt_template : str
|
|
1746
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1747
|
+
system_prompt : str, Optional
|
|
1748
|
+
system prompt.
|
|
1872
1749
|
"""
|
|
1873
|
-
|
|
1750
|
+
super().__init__(inference_engine=inference_engine,
|
|
1751
|
+
prompt_template=prompt_template,
|
|
1752
|
+
system_prompt=system_prompt)
|
|
1753
|
+
|
|
1754
|
+
def _get_ROI(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
|
|
1755
|
+
text:str, buffer_size:int=128) -> str:
|
|
1756
|
+
"""
|
|
1757
|
+
This method returns the Region of Interest (ROI) that covers the two frames. Leaves a buffer_size of characters before and after.
|
|
1758
|
+
The returned text has the two frames inline annotated with <entity_1>, <entity_2>.
|
|
1874
1759
|
|
|
1875
1760
|
Parameters:
|
|
1876
1761
|
-----------
|
|
1877
|
-
|
|
1878
|
-
a
|
|
1762
|
+
frame_1 : LLMInformationExtractionFrame
|
|
1763
|
+
a frame
|
|
1764
|
+
frame_2 : LLMInformationExtractionFrame
|
|
1765
|
+
the other frame
|
|
1766
|
+
text : str
|
|
1767
|
+
the entire document text
|
|
1879
1768
|
buffer_size : int, Optional
|
|
1880
1769
|
the number of characters before and after the two frames in the ROI text.
|
|
1881
|
-
max_new_tokens : str, Optional
|
|
1882
|
-
the max number of new tokens LLM should generate.
|
|
1883
|
-
temperature : float, Optional
|
|
1884
|
-
the temperature for token sampling.
|
|
1885
|
-
concurrent: bool, Optional
|
|
1886
|
-
if True, the extraction will be done in concurrent.
|
|
1887
|
-
concurrent_batch_size : int, Optional
|
|
1888
|
-
the number of frame pairs to process in concurrent.
|
|
1889
|
-
stream : bool, Optional
|
|
1890
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
1891
|
-
return_messages_log : bool, Optional
|
|
1892
|
-
if True, a list of messages will be returned.
|
|
1893
1770
|
|
|
1894
|
-
Return :
|
|
1895
|
-
|
|
1771
|
+
Return : str
|
|
1772
|
+
the ROI text with the two frames inline annotated with <entity_1>, <entity_2>.
|
|
1896
1773
|
"""
|
|
1774
|
+
left_frame, right_frame = sorted([frame_1, frame_2], key=lambda f: f.start)
|
|
1775
|
+
left_frame_name = "entity_1" if left_frame.frame_id == frame_1.frame_id else "entity_2"
|
|
1776
|
+
right_frame_name = "entity_1" if right_frame.frame_id == frame_1.frame_id else "entity_2"
|
|
1777
|
+
|
|
1778
|
+
start = max(left_frame.start - buffer_size, 0)
|
|
1779
|
+
end = min(right_frame.end + buffer_size, len(text))
|
|
1780
|
+
roi = text[start:end]
|
|
1781
|
+
|
|
1782
|
+
roi_annotated = roi[0:left_frame.start - start] + \
|
|
1783
|
+
f"<{left_frame_name}> " + \
|
|
1784
|
+
roi[left_frame.start - start:left_frame.end - start] + \
|
|
1785
|
+
f" </{left_frame_name}>" + \
|
|
1786
|
+
roi[left_frame.end - start:right_frame.start - start] + \
|
|
1787
|
+
f"<{right_frame_name}> " + \
|
|
1788
|
+
roi[right_frame.start - start:right_frame.end - start] + \
|
|
1789
|
+
f" </{right_frame_name}>" + \
|
|
1790
|
+
roi[right_frame.end - start:end - start]
|
|
1791
|
+
|
|
1792
|
+
if start > 0:
|
|
1793
|
+
roi_annotated = "..." + roi_annotated
|
|
1794
|
+
if end < len(text):
|
|
1795
|
+
roi_annotated = roi_annotated + "..."
|
|
1796
|
+
return roi_annotated
|
|
1797
|
+
|
|
1798
|
+
@abc.abstractmethod
|
|
1799
|
+
def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
|
|
1800
|
+
text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
|
|
1801
|
+
"""Checks if a relation is possible and constructs the task payload."""
|
|
1802
|
+
raise NotImplementedError
|
|
1803
|
+
|
|
1804
|
+
@abc.abstractmethod
|
|
1805
|
+
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1806
|
+
"""Processes the LLM output for a single pair and returns the final relation dictionary."""
|
|
1807
|
+
raise NotImplementedError
|
|
1808
|
+
|
|
1809
|
+
def _extract(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, verbose: bool = False,
|
|
1810
|
+
return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1811
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1812
|
+
relations = []
|
|
1813
|
+
messages_log = [] if return_messages_log else None
|
|
1814
|
+
|
|
1815
|
+
for frame_1, frame_2 in pairs:
|
|
1816
|
+
task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
|
|
1817
|
+
if task_payload:
|
|
1818
|
+
if verbose:
|
|
1819
|
+
print(f"\n\n{Fore.GREEN}Evaluating pair:{Style.RESET_ALL} ({frame_1.frame_id}, {frame_2.frame_id})")
|
|
1820
|
+
print(f"{Fore.YELLOW}ROI Text:{Style.RESET_ALL}\n{task_payload['roi_text']}\n")
|
|
1821
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1822
|
+
|
|
1823
|
+
gen_text = self.inference_engine.chat(
|
|
1824
|
+
messages=task_payload['messages'],
|
|
1825
|
+
verbose=verbose
|
|
1826
|
+
)
|
|
1827
|
+
relation = self._post_process_result(gen_text, task_payload)
|
|
1828
|
+
if relation:
|
|
1829
|
+
relations.append(relation)
|
|
1830
|
+
|
|
1831
|
+
if return_messages_log:
|
|
1832
|
+
task_payload['messages'].append({"role": "assistant", "content": gen_text})
|
|
1833
|
+
messages_log.append(task_payload['messages'])
|
|
1834
|
+
|
|
1835
|
+
return (relations, messages_log) if return_messages_log else relations
|
|
1836
|
+
|
|
1837
|
+
async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1838
|
+
pairs = list(itertools.combinations(doc.frames, 2))
|
|
1839
|
+
tasks_input = [self._get_task_if_possible(f1, f2, doc.text, buffer_size) for f1, f2 in pairs]
|
|
1840
|
+
# Filter out impossible pairs
|
|
1841
|
+
tasks_input = [task for task in tasks_input if task is not None]
|
|
1842
|
+
|
|
1843
|
+
relations = []
|
|
1844
|
+
messages_log = [] if return_messages_log else None
|
|
1845
|
+
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1846
|
+
|
|
1847
|
+
async def semaphore_helper(task_payload: Dict):
|
|
1848
|
+
async with semaphore:
|
|
1849
|
+
gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
|
|
1850
|
+
return gen_text, task_payload
|
|
1851
|
+
|
|
1852
|
+
tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
|
|
1853
|
+
results = await asyncio.gather(*tasks)
|
|
1854
|
+
|
|
1855
|
+
for gen_text, task_payload in results:
|
|
1856
|
+
relation = self._post_process_result(gen_text, task_payload)
|
|
1857
|
+
if relation:
|
|
1858
|
+
relations.append(relation)
|
|
1859
|
+
|
|
1860
|
+
if return_messages_log:
|
|
1861
|
+
task_payload['messages'].append({"role": "assistant", "content": gen_text})
|
|
1862
|
+
messages_log.append(task_payload['messages'])
|
|
1863
|
+
|
|
1864
|
+
return (relations, messages_log) if return_messages_log else relations
|
|
1865
|
+
|
|
1866
|
+
def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
|
|
1897
1867
|
if not doc.has_frame():
|
|
1898
1868
|
raise ValueError("Input document must have frames.")
|
|
1899
|
-
|
|
1900
1869
|
if doc.has_duplicate_frame_ids():
|
|
1901
1870
|
raise ValueError("All frame_ids in the input document must be unique.")
|
|
1902
1871
|
|
|
1903
1872
|
if concurrent:
|
|
1904
|
-
if
|
|
1905
|
-
warnings.warn("
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
return asyncio.run(self.extract_async(doc=doc,
|
|
1909
|
-
buffer_size=buffer_size,
|
|
1910
|
-
max_new_tokens=max_new_tokens,
|
|
1911
|
-
temperature=temperature,
|
|
1912
|
-
concurrent_batch_size=concurrent_batch_size,
|
|
1913
|
-
return_messages_log=return_messages_log,
|
|
1914
|
-
**kwrs)
|
|
1915
|
-
)
|
|
1873
|
+
if verbose:
|
|
1874
|
+
warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1875
|
+
nest_asyncio.apply()
|
|
1876
|
+
return asyncio.run(self._extract_async(doc, buffer_size, concurrent_batch_size, return_messages_log))
|
|
1916
1877
|
else:
|
|
1917
|
-
return self.
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1878
|
+
return self._extract(doc, buffer_size, verbose, return_messages_log)
|
|
1879
|
+
|
|
1880
|
+
|
|
1881
|
+
class BinaryRelationExtractor(RelationExtractor):
|
|
1882
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_func: Callable,
|
|
1883
|
+
system_prompt:str=None):
|
|
1884
|
+
"""
|
|
1885
|
+
This class extracts binary (yes/no) relations between two entities.
|
|
1886
|
+
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
1887
|
+
|
|
1888
|
+
Parameters
|
|
1889
|
+
----------
|
|
1890
|
+
inference_engine : InferenceEngine
|
|
1891
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
1892
|
+
prompt_template : str
|
|
1893
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1894
|
+
possible_relation_func : Callable, Optional
|
|
1895
|
+
a function that inputs 2 frames and returns a bool indicating possible relations between them.
|
|
1896
|
+
system_prompt : str, Optional
|
|
1897
|
+
system prompt.
|
|
1898
|
+
"""
|
|
1899
|
+
super().__init__(inference_engine, prompt_template, system_prompt)
|
|
1900
|
+
if not callable(possible_relation_func):
|
|
1901
|
+
raise TypeError(f"Expect possible_relation_func as a function, received {type(possible_relation_func)} instead.")
|
|
1902
|
+
|
|
1903
|
+
sig = inspect.signature(possible_relation_func)
|
|
1904
|
+
if len(sig.parameters) != 2:
|
|
1905
|
+
raise ValueError("The possible_relation_func must have exactly two parameters.")
|
|
1906
|
+
|
|
1907
|
+
if sig.return_annotation not in {bool, inspect.Signature.empty}:
|
|
1908
|
+
warnings.warn(f"Expected possible_relation_func return annotation to be bool, but got {sig.return_annotation}.")
|
|
1909
|
+
|
|
1910
|
+
self.possible_relation_func = possible_relation_func
|
|
1911
|
+
|
|
1912
|
+
def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
|
|
1913
|
+
text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
|
|
1914
|
+
if self.possible_relation_func(frame_1, frame_2):
|
|
1915
|
+
roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size)
|
|
1916
|
+
messages = []
|
|
1917
|
+
if self.system_prompt:
|
|
1918
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1919
|
+
|
|
1920
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(
|
|
1921
|
+
text_content={"roi_text": roi_text, "frame_1": str(frame_1.to_dict()), "frame_2": str(frame_2.to_dict())}
|
|
1922
|
+
)})
|
|
1923
|
+
return {"frame_1": frame_1, "frame_2": frame_2, "messages": messages, "roi_text": roi_text}
|
|
1924
|
+
return None
|
|
1925
|
+
|
|
1926
|
+
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1927
|
+
rel_json = self._extract_json(gen_text)
|
|
1928
|
+
if len(rel_json) > 0 and "Relation" in rel_json[0]:
|
|
1929
|
+
rel = rel_json[0]["Relation"]
|
|
1930
|
+
if (isinstance(rel, bool) and rel) or (isinstance(rel, str) and rel.lower() == 'true'):
|
|
1931
|
+
return {'frame_1_id': pair_data['frame_1'].frame_id, 'frame_2_id': pair_data['frame_2'].frame_id}
|
|
1932
|
+
return None
|
|
1924
1933
|
|
|
1925
1934
|
|
|
1926
1935
|
class MultiClassRelationExtractor(RelationExtractor):
|
|
1927
1936
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_types_func: Callable,
|
|
1928
|
-
system_prompt:str=None
|
|
1937
|
+
system_prompt:str=None):
|
|
1929
1938
|
"""
|
|
1930
1939
|
This class extracts relations with relation types.
|
|
1931
1940
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
@@ -1944,8 +1953,7 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1944
1953
|
"""
|
|
1945
1954
|
super().__init__(inference_engine=inference_engine,
|
|
1946
1955
|
prompt_template=prompt_template,
|
|
1947
|
-
system_prompt=system_prompt
|
|
1948
|
-
**kwrs)
|
|
1956
|
+
system_prompt=system_prompt)
|
|
1949
1957
|
|
|
1950
1958
|
if possible_relation_types_func:
|
|
1951
1959
|
# Check if possible_relation_types_func is a function
|
|
@@ -1967,235 +1975,25 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1967
1975
|
self.possible_relation_types_func = possible_relation_types_func
|
|
1968
1976
|
|
|
1969
1977
|
|
|
1970
|
-
def
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
the relation type (str) or None for no relation.
|
|
1983
|
-
"""
|
|
1984
|
-
if len(rel_json) > 0:
|
|
1985
|
-
if "RelationType" in rel_json[0]:
|
|
1986
|
-
if rel_json[0]["RelationType"] in pos_rel_types:
|
|
1987
|
-
return rel_json[0]["RelationType"]
|
|
1988
|
-
else:
|
|
1989
|
-
warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
|
|
1990
|
-
else:
|
|
1991
|
-
warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
|
|
1978
|
+
def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
|
|
1979
|
+
text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
|
|
1980
|
+
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1981
|
+
if pos_rel_types:
|
|
1982
|
+
roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size)
|
|
1983
|
+
messages = []
|
|
1984
|
+
if self.system_prompt:
|
|
1985
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1986
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(
|
|
1987
|
+
text_content={"roi_text": roi_text, "frame_1": str(frame_1.to_dict()), "frame_2": str(frame_2.to_dict()), "pos_rel_types": str(pos_rel_types)}
|
|
1988
|
+
)})
|
|
1989
|
+
return {"frame_1": frame_1, "frame_2": frame_2, "messages": messages, "pos_rel_types": pos_rel_types, "roi_text": roi_text}
|
|
1992
1990
|
return None
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1996
|
-
temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1997
|
-
"""
|
|
1998
|
-
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
|
|
1999
|
-
|
|
2000
|
-
Parameters:
|
|
2001
|
-
-----------
|
|
2002
|
-
doc : LLMInformationExtractionDocument
|
|
2003
|
-
a document with frames.
|
|
2004
|
-
buffer_size : int, Optional
|
|
2005
|
-
the number of characters before and after the two frames in the ROI text.
|
|
2006
|
-
max_new_tokens : str, Optional
|
|
2007
|
-
the max number of new tokens LLM should generate.
|
|
2008
|
-
temperature : float, Optional
|
|
2009
|
-
the temperature for token sampling.
|
|
2010
|
-
stream : bool, Optional
|
|
2011
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
2012
|
-
return_messages_log : bool, Optional
|
|
2013
|
-
if True, a list of messages will be returned.
|
|
2014
|
-
|
|
2015
|
-
Return : List[Dict]
|
|
2016
|
-
a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
|
|
2017
|
-
"""
|
|
2018
|
-
pairs = itertools.combinations(doc.frames, 2)
|
|
2019
|
-
|
|
2020
|
-
if return_messages_log:
|
|
2021
|
-
messages_log = []
|
|
2022
|
-
|
|
2023
|
-
output = []
|
|
2024
|
-
for frame_1, frame_2 in pairs:
|
|
2025
|
-
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
2026
|
-
|
|
2027
|
-
if pos_rel_types:
|
|
2028
|
-
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
2029
|
-
if stream:
|
|
2030
|
-
print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
|
|
2031
|
-
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
2032
|
-
messages = []
|
|
2033
|
-
if self.system_prompt:
|
|
2034
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
2035
|
-
|
|
2036
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
2037
|
-
"frame_1": str(frame_1.to_dict()),
|
|
2038
|
-
"frame_2": str(frame_2.to_dict()),
|
|
2039
|
-
"pos_rel_types":str(pos_rel_types)}
|
|
2040
|
-
)})
|
|
2041
|
-
|
|
2042
|
-
gen_text = self.inference_engine.chat(
|
|
2043
|
-
messages=messages,
|
|
2044
|
-
max_new_tokens=max_new_tokens,
|
|
2045
|
-
temperature=temperature,
|
|
2046
|
-
stream=stream,
|
|
2047
|
-
**kwrs
|
|
2048
|
-
)
|
|
2049
|
-
|
|
2050
|
-
if return_messages_log:
|
|
2051
|
-
messages.append({"role": "assistant", "content": gen_text})
|
|
2052
|
-
messages_log.append(messages)
|
|
2053
|
-
|
|
2054
|
-
rel_json = self._extract_json(gen_text)
|
|
2055
|
-
rel = self._post_process(rel_json, pos_rel_types)
|
|
2056
|
-
if rel:
|
|
2057
|
-
output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id, 'relation':rel})
|
|
2058
|
-
|
|
2059
|
-
if return_messages_log:
|
|
2060
|
-
return output, messages_log
|
|
2061
|
-
return output
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
2065
|
-
temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
2066
|
-
"""
|
|
2067
|
-
This is the asynchronous version of the extract() method.
|
|
2068
|
-
|
|
2069
|
-
Parameters:
|
|
2070
|
-
-----------
|
|
2071
|
-
doc : LLMInformationExtractionDocument
|
|
2072
|
-
a document with frames.
|
|
2073
|
-
buffer_size : int, Optional
|
|
2074
|
-
the number of characters before and after the two frames in the ROI text.
|
|
2075
|
-
max_new_tokens : str, Optional
|
|
2076
|
-
the max number of new tokens LLM should generate.
|
|
2077
|
-
temperature : float, Optional
|
|
2078
|
-
the temperature for token sampling.
|
|
2079
|
-
concurrent_batch_size : int, Optional
|
|
2080
|
-
the number of frame pairs to process in concurrent.
|
|
2081
|
-
return_messages_log : bool, Optional
|
|
2082
|
-
if True, a list of messages will be returned.
|
|
2083
|
-
|
|
2084
|
-
Return : List[Dict]
|
|
2085
|
-
a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
|
|
2086
|
-
"""
|
|
2087
|
-
# Check if self.inference_engine.chat_async() is implemented
|
|
2088
|
-
if not hasattr(self.inference_engine, 'chat_async'):
|
|
2089
|
-
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
2090
|
-
|
|
2091
|
-
pairs = itertools.combinations(doc.frames, 2)
|
|
2092
|
-
if return_messages_log:
|
|
2093
|
-
messages_log = []
|
|
2094
|
-
|
|
2095
|
-
n_frames = len(doc.frames)
|
|
2096
|
-
num_pairs = (n_frames * (n_frames-1)) // 2
|
|
2097
|
-
output = []
|
|
2098
|
-
for i in range(0, num_pairs, concurrent_batch_size):
|
|
2099
|
-
rel_pair_list = []
|
|
2100
|
-
tasks = []
|
|
2101
|
-
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
2102
|
-
batch_messages = []
|
|
2103
|
-
for frame_1, frame_2 in batch:
|
|
2104
|
-
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
2105
|
-
|
|
2106
|
-
if pos_rel_types:
|
|
2107
|
-
rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'pos_rel_types':pos_rel_types})
|
|
2108
|
-
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
2109
|
-
messages = []
|
|
2110
|
-
if self.system_prompt:
|
|
2111
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
2112
|
-
|
|
2113
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
2114
|
-
"frame_1": str(frame_1.to_dict()),
|
|
2115
|
-
"frame_2": str(frame_2.to_dict()),
|
|
2116
|
-
"pos_rel_types":str(pos_rel_types)}
|
|
2117
|
-
)})
|
|
2118
|
-
task = asyncio.create_task(
|
|
2119
|
-
self.inference_engine.chat_async(
|
|
2120
|
-
messages=messages,
|
|
2121
|
-
max_new_tokens=max_new_tokens,
|
|
2122
|
-
temperature=temperature,
|
|
2123
|
-
**kwrs
|
|
2124
|
-
)
|
|
2125
|
-
)
|
|
2126
|
-
tasks.append(task)
|
|
2127
|
-
batch_messages.append(messages)
|
|
2128
|
-
|
|
2129
|
-
responses = await asyncio.gather(*tasks)
|
|
2130
|
-
|
|
2131
|
-
for d, response, messages in zip(rel_pair_list, responses, batch_messages):
|
|
2132
|
-
if return_messages_log:
|
|
2133
|
-
messages.append({"role": "assistant", "content": response})
|
|
2134
|
-
messages_log.append(messages)
|
|
2135
|
-
|
|
2136
|
-
rel_json = self._extract_json(response)
|
|
2137
|
-
rel = self._post_process(rel_json, d['pos_rel_types'])
|
|
2138
|
-
if rel:
|
|
2139
|
-
output.append({'frame_1_id':d['frame_1'], 'frame_2_id':d['frame_2'], 'relation':rel})
|
|
2140
|
-
|
|
2141
|
-
if return_messages_log:
|
|
2142
|
-
return output, messages_log
|
|
2143
|
-
return output
|
|
2144
|
-
|
|
2145
1991
|
|
|
2146
|
-
def
|
|
2147
|
-
|
|
2148
|
-
|
|
2149
|
-
""
|
|
2150
|
-
|
|
2151
|
-
|
|
2152
|
-
|
|
2153
|
-
|
|
2154
|
-
doc : LLMInformationExtractionDocument
|
|
2155
|
-
a document with frames.
|
|
2156
|
-
buffer_size : int, Optional
|
|
2157
|
-
the number of characters before and after the two frames in the ROI text.
|
|
2158
|
-
max_new_tokens : str, Optional
|
|
2159
|
-
the max number of new tokens LLM should generate.
|
|
2160
|
-
temperature : float, Optional
|
|
2161
|
-
the temperature for token sampling.
|
|
2162
|
-
concurrent: bool, Optional
|
|
2163
|
-
if True, the extraction will be done in concurrent.
|
|
2164
|
-
concurrent_batch_size : int, Optional
|
|
2165
|
-
the number of frame pairs to process in concurrent.
|
|
2166
|
-
stream : bool, Optional
|
|
2167
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
2168
|
-
return_messages_log : bool, Optional
|
|
2169
|
-
if True, a list of messages will be returned.
|
|
2170
|
-
|
|
2171
|
-
Return : List[Dict]
|
|
2172
|
-
a list of dict with {"frame_1", "frame_2", "relation"} for all relations.
|
|
2173
|
-
"""
|
|
2174
|
-
if not doc.has_frame():
|
|
2175
|
-
raise ValueError("Input document must have frames.")
|
|
2176
|
-
|
|
2177
|
-
if doc.has_duplicate_frame_ids():
|
|
2178
|
-
raise ValueError("All frame_ids in the input document must be unique.")
|
|
2179
|
-
|
|
2180
|
-
if concurrent:
|
|
2181
|
-
if stream:
|
|
2182
|
-
warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
|
|
2183
|
-
|
|
2184
|
-
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
2185
|
-
return asyncio.run(self.extract_async(doc=doc,
|
|
2186
|
-
buffer_size=buffer_size,
|
|
2187
|
-
max_new_tokens=max_new_tokens,
|
|
2188
|
-
temperature=temperature,
|
|
2189
|
-
concurrent_batch_size=concurrent_batch_size,
|
|
2190
|
-
return_messages_log=return_messages_log,
|
|
2191
|
-
**kwrs)
|
|
2192
|
-
)
|
|
2193
|
-
else:
|
|
2194
|
-
return self.extract(doc=doc,
|
|
2195
|
-
buffer_size=buffer_size,
|
|
2196
|
-
max_new_tokens=max_new_tokens,
|
|
2197
|
-
temperature=temperature,
|
|
2198
|
-
stream=stream,
|
|
2199
|
-
return_messages_log=return_messages_log,
|
|
2200
|
-
**kwrs)
|
|
2201
|
-
|
|
1992
|
+
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1993
|
+
rel_json = self._extract_json(gen_text)
|
|
1994
|
+
pos_rel_types = pair_data['pos_rel_types']
|
|
1995
|
+
if len(rel_json) > 0 and "RelationType" in rel_json[0]:
|
|
1996
|
+
rel_type = rel_json[0]["RelationType"]
|
|
1997
|
+
if rel_type in pos_rel_types:
|
|
1998
|
+
return {'frame_1_id': pair_data['frame_1'].frame_id, 'frame_2_id': pair_data['frame_2'].frame_id, 'relation': rel_type}
|
|
1999
|
+
return None
|