llm-ie 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -17,7 +17,7 @@ from colorama import Fore, Style
17
17
 
18
18
 
19
19
  class Extractor:
20
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
20
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
21
21
  """
22
22
  This is the abstract class for (frame and relation) extractors.
23
23
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -172,7 +172,7 @@ class Extractor:
172
172
  class FrameExtractor(Extractor):
173
173
  from nltk.tokenize import RegexpTokenizer
174
174
  def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
175
- prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None, **kwrs):
175
+ prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None):
176
176
  """
177
177
  This is the abstract class for frame extraction.
178
178
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -192,8 +192,7 @@ class FrameExtractor(Extractor):
192
192
  """
193
193
  super().__init__(inference_engine=inference_engine,
194
194
  prompt_template=prompt_template,
195
- system_prompt=system_prompt,
196
- **kwrs)
195
+ system_prompt=system_prompt)
197
196
 
198
197
  self.unit_chunker = unit_chunker
199
198
  if context_chunker is None:
@@ -332,7 +331,7 @@ class FrameExtractor(Extractor):
332
331
  return entity_spans
333
332
 
334
333
  @abc.abstractmethod
335
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, return_messages_log:bool=False, **kwrs) -> str:
334
+ def extract(self, text_content:Union[str, Dict[str,str]], return_messages_log:bool=False, **kwrs) -> str:
336
335
  """
337
336
  This method inputs text content and outputs a string generated by LLM
338
337
 
@@ -342,8 +341,6 @@ class FrameExtractor(Extractor):
342
341
  the input text content to put in prompt template.
343
342
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
344
343
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
345
- max_new_tokens : str, Optional
346
- the max number of new tokens LLM can generate.
347
344
  return_messages_log : bool, Optional
348
345
  if True, a list of messages will be returned.
349
346
 
@@ -354,7 +351,7 @@ class FrameExtractor(Extractor):
354
351
 
355
352
 
356
353
  @abc.abstractmethod
357
- def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
354
+ def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str,
358
355
  document_key:str=None, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
359
356
  """
360
357
  This method inputs text content and outputs a list of LLMInformationExtractionFrame
@@ -368,8 +365,6 @@ class FrameExtractor(Extractor):
368
365
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
369
366
  entity_key : str
370
367
  the key (in ouptut JSON) for entity text. Any extraction that does not include entity key will be dropped.
371
- max_new_tokens : str, Optional
372
- the max number of new tokens LLM should generate.
373
368
  document_key : str, Optional
374
369
  specify the key in text_content where document text is.
375
370
  If text_content is str, this parameter will be ignored.
@@ -384,7 +379,7 @@ class FrameExtractor(Extractor):
384
379
 
385
380
  class DirectFrameExtractor(FrameExtractor):
386
381
  def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
387
- prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None, **kwrs):
382
+ prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None):
388
383
  """
389
384
  This class is for general unit-context frame extraction.
390
385
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -406,12 +401,11 @@ class DirectFrameExtractor(FrameExtractor):
406
401
  unit_chunker=unit_chunker,
407
402
  prompt_template=prompt_template,
408
403
  system_prompt=system_prompt,
409
- context_chunker=context_chunker,
410
- **kwrs)
404
+ context_chunker=context_chunker)
411
405
 
412
406
 
413
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
414
- document_key:str=None, temperature:float=0.0, verbose:bool=False, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
407
+ def extract(self, text_content:Union[str, Dict[str,str]],
408
+ document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
415
409
  """
416
410
  This method inputs a text and outputs a list of outputs per unit.
417
411
 
@@ -421,13 +415,9 @@ class DirectFrameExtractor(FrameExtractor):
421
415
  the input text content to put in prompt template.
422
416
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
423
417
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
424
- max_new_tokens : int, Optional
425
- the max number of new tokens LLM should generate.
426
418
  document_key : str, Optional
427
419
  specify the key in text_content where document text is.
428
420
  If text_content is str, this parameter will be ignored.
429
- temperature : float, Optional
430
- the temperature for token sampling.
431
421
  verbose : bool, Optional
432
422
  if True, LLM generated text will be printed in terminal in real-time.
433
423
  return_messages_log : bool, Optional
@@ -491,27 +481,12 @@ class DirectFrameExtractor(FrameExtractor):
491
481
 
492
482
  print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
493
483
 
494
- response_stream = self.inference_engine.chat(
495
- messages=messages,
496
- max_new_tokens=max_new_tokens,
497
- temperature=temperature,
498
- stream=True,
499
- **kwrs
500
- )
501
-
502
- gen_text = ""
503
- for chunk in response_stream:
504
- gen_text += chunk
505
- print(chunk, end='', flush=True)
506
484
 
507
- else:
508
- gen_text = self.inference_engine.chat(
509
- messages=messages,
510
- max_new_tokens=max_new_tokens,
511
- temperature=temperature,
512
- stream=False,
513
- **kwrs
514
- )
485
+ gen_text = self.inference_engine.chat(
486
+ messages=messages,
487
+ verbose=verbose,
488
+ stream=False
489
+ )
515
490
 
516
491
  if return_messages_log:
517
492
  messages.append({"role": "assistant", "content": gen_text})
@@ -530,8 +505,8 @@ class DirectFrameExtractor(FrameExtractor):
530
505
 
531
506
  return output
532
507
 
533
- def stream(self, text_content: Union[str, Dict[str, str]], max_new_tokens: int = 2048, document_key: str = None,
534
- temperature: float = 0.0, **kwrs) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
508
+ def stream(self, text_content: Union[str, Dict[str, str]],
509
+ document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
535
510
  """
536
511
  Streams LLM responses per unit with structured event types,
537
512
  and returns collected data for post-processing.
@@ -542,7 +517,8 @@ class DirectFrameExtractor(FrameExtractor):
542
517
  - {"type": "info", "data": str_message}: General informational messages.
543
518
  - {"type": "unit", "data": dict_unit_info}: Signals start of a new unit. dict_unit_info contains {'id', 'text', 'start', 'end'}
544
519
  - {"type": "context", "data": str_context}: Context string for the current unit.
545
- - {"type": "llm_chunk", "data": str_chunk}: A raw chunk from the LLM.
520
+ - {"type": "reasoning", "data": str_chunk}: A reasoning model thinking chunk from the LLM.
521
+ - {"type": "response", "data": str_chunk}: A response/answer chunk from the LLM.
546
522
 
547
523
  Returns:
548
524
  --------
@@ -601,13 +577,10 @@ class DirectFrameExtractor(FrameExtractor):
601
577
 
602
578
  response_stream = self.inference_engine.chat(
603
579
  messages=messages,
604
- max_new_tokens=max_new_tokens,
605
- temperature=temperature,
606
- stream=True,
607
- **kwrs
580
+ stream=True
608
581
  )
609
582
  for chunk in response_stream:
610
- yield {"type": "llm_chunk", "data": chunk}
583
+ yield chunk
611
584
  current_gen_text += chunk
612
585
 
613
586
  # Store the result for this unit
@@ -622,8 +595,8 @@ class DirectFrameExtractor(FrameExtractor):
622
595
  yield {"type": "info", "data": "All units processed by LLM."}
623
596
  return collected_results
624
597
 
625
- async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None, temperature:float=0.0,
626
- concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
598
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
599
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
627
600
  """
628
601
  This is the asynchronous version of the extract() method.
629
602
 
@@ -633,13 +606,9 @@ class DirectFrameExtractor(FrameExtractor):
633
606
  the input text content to put in prompt template.
634
607
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
635
608
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
636
- max_new_tokens : int, Optional
637
- the max number of new tokens LLM should generate.
638
609
  document_key : str, Optional
639
610
  specify the key in text_content where document text is.
640
611
  If text_content is str, this parameter will be ignored.
641
- temperature : float, Optional
642
- the temperature for token sampling.
643
612
  concurrent_batch_size : int, Optional
644
613
  the batch size for concurrent processing.
645
614
  return_messages_log : bool, Optional
@@ -701,17 +670,14 @@ class DirectFrameExtractor(FrameExtractor):
701
670
  # Process units concurrently with asyncio.Semaphore
702
671
  semaphore = asyncio.Semaphore(concurrent_batch_size)
703
672
 
704
- async def semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
673
+ async def semaphore_helper(task_data: Dict, **kwrs):
705
674
  unit = task_data["unit"]
706
675
  messages = task_data["messages"]
707
676
  original_index = task_data["original_index"]
708
677
 
709
678
  async with semaphore:
710
679
  gen_text = await self.inference_engine.chat_async(
711
- messages=messages,
712
- max_new_tokens=max_new_tokens,
713
- temperature=temperature,
714
- **kwrs
680
+ messages=messages
715
681
  )
716
682
  return {"original_index": original_index, "unit": unit, "gen_text": gen_text, "messages": messages}
717
683
 
@@ -719,10 +685,7 @@ class DirectFrameExtractor(FrameExtractor):
719
685
  tasks = []
720
686
  for task_inp in tasks_input:
721
687
  task = asyncio.create_task(semaphore_helper(
722
- task_inp,
723
- max_new_tokens=max_new_tokens,
724
- temperature=temperature,
725
- **kwrs
688
+ task_inp
726
689
  ))
727
690
  tasks.append(task)
728
691
 
@@ -759,11 +722,10 @@ class DirectFrameExtractor(FrameExtractor):
759
722
  return output
760
723
 
761
724
 
762
- def extract_frames(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
763
- document_key:str=None, temperature:float=0.0, verbose:bool=False,
764
- concurrent:bool=False, concurrent_batch_size:int=32,
725
+ def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
726
+ verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32,
765
727
  case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
766
- allow_overlap_entities:bool=False, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
728
+ allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
767
729
  """
768
730
  This method inputs a text and outputs a list of LLMInformationExtractionFrame
769
731
  It use the extract() method and post-process outputs into frames.
@@ -774,13 +736,9 @@ class DirectFrameExtractor(FrameExtractor):
774
736
  the input text content to put in prompt template.
775
737
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
776
738
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
777
- max_new_tokens : str, Optional
778
- the max number of new tokens LLM should generate.
779
739
  document_key : str, Optional
780
740
  specify the key in text_content where document text is.
781
741
  If text_content is str, this parameter will be ignored.
782
- temperature : float, Optional
783
- the temperature for token sampling.
784
742
  verbose : bool, Optional
785
743
  if True, LLM generated text will be printed in terminal in real-time.
786
744
  concurrent : bool, Optional
@@ -812,21 +770,15 @@ class DirectFrameExtractor(FrameExtractor):
812
770
 
813
771
  nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
814
772
  extraction_results = asyncio.run(self.extract_async(text_content=text_content,
815
- max_new_tokens=max_new_tokens,
816
773
  document_key=document_key,
817
- temperature=temperature,
818
774
  concurrent_batch_size=concurrent_batch_size,
819
- return_messages_log=return_messages_log,
820
- **kwrs)
775
+ return_messages_log=return_messages_log)
821
776
  )
822
777
  else:
823
778
  extraction_results = self.extract(text_content=text_content,
824
- max_new_tokens=max_new_tokens,
825
779
  document_key=document_key,
826
- temperature=temperature,
827
780
  verbose=verbose,
828
- return_messages_log=return_messages_log,
829
- **kwrs)
781
+ return_messages_log=return_messages_log)
830
782
 
831
783
  llm_output_results, messages_log = extraction_results if return_messages_log else (extraction_results, None)
832
784
 
@@ -869,8 +821,8 @@ class DirectFrameExtractor(FrameExtractor):
869
821
 
870
822
 
871
823
  class ReviewFrameExtractor(DirectFrameExtractor):
872
- def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker,
873
- inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None, **kwrs):
824
+ def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
825
+ prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
874
826
  """
875
827
  This class add a review step after the DirectFrameExtractor.
876
828
  The Review process asks LLM to review its output and:
@@ -901,8 +853,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
901
853
  unit_chunker=unit_chunker,
902
854
  prompt_template=prompt_template,
903
855
  system_prompt=system_prompt,
904
- context_chunker=context_chunker,
905
- **kwrs)
856
+ context_chunker=context_chunker)
906
857
  # check review mode
907
858
  if review_mode not in {"addition", "revision"}:
908
859
  raise ValueError('review_mode must be one of {"addition", "revision"}.')
@@ -939,8 +890,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
939
890
  if self.review_prompt is None:
940
891
  raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
941
892
 
942
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None,
943
- temperature:float=0.0, verbose:bool=False, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
893
+ def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
894
+ verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
944
895
  """
945
896
  This method inputs a text and outputs a list of outputs per unit.
946
897
 
@@ -950,13 +901,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
950
901
  the input text content to put in prompt template.
951
902
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
952
903
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
953
- max_new_tokens : int, Optional
954
- the max number of new tokens LLM should generate.
955
904
  document_key : str, Optional
956
905
  specify the key in text_content where document text is.
957
906
  If text_content is str, this parameter will be ignored.
958
- temperature : float, Optional
959
- the temperature for token sampling.
960
907
  verbose : bool, Optional
961
908
  if True, LLM generated text will be printed in terminal in real-time.
962
909
  return_messages_log : bool, Optional
@@ -1020,28 +967,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1020
967
  print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
1021
968
 
1022
969
  print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1023
-
1024
- response_stream = self.inference_engine.chat(
1025
- messages=messages,
1026
- max_new_tokens=max_new_tokens,
1027
- temperature=temperature,
1028
- stream=True,
1029
- **kwrs
1030
- )
1031
-
1032
- initial = ""
1033
- for chunk in response_stream:
1034
- initial += chunk
1035
- print(chunk, end='', flush=True)
1036
970
 
1037
- else:
1038
- initial = self.inference_engine.chat(
1039
- messages=messages,
1040
- max_new_tokens=max_new_tokens,
1041
- temperature=temperature,
1042
- stream=False,
1043
- **kwrs
1044
- )
971
+
972
+ initial = self.inference_engine.chat(
973
+ messages=messages,
974
+ verbose=verbose,
975
+ stream=False
976
+ )
1045
977
 
1046
978
  if return_messages_log:
1047
979
  messages.append({"role": "assistant", "content": initial})
@@ -1053,29 +985,12 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1053
985
 
1054
986
  messages.append({'role': 'assistant', 'content': initial})
1055
987
  messages.append({'role': 'user', 'content': self.review_prompt})
1056
-
1057
- if verbose:
1058
- response_stream = self.inference_engine.chat(
1059
- messages=messages,
1060
- max_new_tokens=max_new_tokens,
1061
- temperature=temperature,
1062
- stream=True,
1063
- **kwrs
1064
- )
1065
-
1066
- review = ""
1067
- for chunk in response_stream:
1068
- review += chunk
1069
- print(chunk, end='', flush=True)
1070
988
 
1071
- else:
1072
- review = self.inference_engine.chat(
1073
- messages=messages,
1074
- max_new_tokens=max_new_tokens,
1075
- temperature=temperature,
1076
- stream=False,
1077
- **kwrs
1078
- )
989
+ review = self.inference_engine.chat(
990
+ messages=messages,
991
+ verbose=verbose,
992
+ stream=False
993
+ )
1079
994
 
1080
995
  # Output
1081
996
  if self.review_mode == "revision":
@@ -1101,8 +1016,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1101
1016
  return output
1102
1017
 
1103
1018
 
1104
- def stream(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
1105
- document_key:str=None, temperature:float=0.0, **kwrs) -> Generator[str, None, None]:
1019
+ def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
1106
1020
  """
1107
1021
  This method inputs a text and outputs a list of outputs per unit.
1108
1022
 
@@ -1112,13 +1026,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1112
1026
  the input text content to put in prompt template.
1113
1027
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1114
1028
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1115
- max_new_tokens : int, Optional
1116
- the max number of new tokens LLM should generate.
1117
1029
  document_key : str, Optional
1118
1030
  specify the key in text_content where document text is.
1119
1031
  If text_content is str, this parameter will be ignored.
1120
- temperature : float, Optional
1121
- the temperature for token sampling.
1122
1032
 
1123
1033
  Return : List[FrameExtractionUnitResult]
1124
1034
  the output from LLM for each unit. Contains the start, end, text, and generated text.
@@ -1176,10 +1086,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1176
1086
 
1177
1087
  response_stream = self.inference_engine.chat(
1178
1088
  messages=messages,
1179
- max_new_tokens=max_new_tokens,
1180
- temperature=temperature,
1181
- stream=True,
1182
- **kwrs
1089
+ stream=True
1183
1090
  )
1184
1091
 
1185
1092
  initial = ""
@@ -1195,16 +1102,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1195
1102
 
1196
1103
  response_stream = self.inference_engine.chat(
1197
1104
  messages=messages,
1198
- max_new_tokens=max_new_tokens,
1199
- temperature=temperature,
1200
- stream=True,
1201
- **kwrs
1105
+ stream=True
1202
1106
  )
1203
1107
 
1204
1108
  for chunk in response_stream:
1205
1109
  yield chunk
1206
1110
 
1207
- async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None, temperature:float=0.0,
1111
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
1208
1112
  concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
1209
1113
  """
1210
1114
  This is the asynchronous version of the extract() method with the review step.
@@ -1215,13 +1119,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1215
1119
  the input text content to put in prompt template.
1216
1120
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1217
1121
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1218
- max_new_tokens : int, Optional
1219
- the max number of new tokens LLM should generate.
1220
1122
  document_key : str, Optional
1221
1123
  specify the key in text_content where document text is.
1222
1124
  If text_content is str, this parameter will be ignored.
1223
- temperature : float, Optional
1224
- the temperature for token sampling.
1225
1125
  concurrent_batch_size : int, Optional
1226
1126
  the batch size for concurrent processing.
1227
1127
  return_messages_log : bool, Optional
@@ -1282,17 +1182,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1282
1182
 
1283
1183
  semaphore = asyncio.Semaphore(concurrent_batch_size)
1284
1184
 
1285
- async def initial_semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
1185
+ async def initial_semaphore_helper(task_data: Dict):
1286
1186
  unit = task_data["unit"]
1287
1187
  messages = task_data["messages"]
1288
1188
  original_index = task_data["original_index"]
1289
1189
 
1290
1190
  async with semaphore:
1291
1191
  gen_text = await self.inference_engine.chat_async(
1292
- messages=messages,
1293
- max_new_tokens=max_new_tokens,
1294
- temperature=temperature,
1295
- **kwrs
1192
+ messages=messages
1296
1193
  )
1297
1194
  # Return initial generation result along with the messages used and the unit
1298
1195
  return {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text, "initial_messages": messages}
@@ -1300,10 +1197,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1300
1197
  # Create and gather initial generation tasks
1301
1198
  initial_tasks = [
1302
1199
  asyncio.create_task(initial_semaphore_helper(
1303
- task_inp,
1304
- max_new_tokens=max_new_tokens,
1305
- temperature=temperature,
1306
- **kwrs
1200
+ task_inp
1307
1201
  ))
1308
1202
  for task_inp in initial_tasks_input
1309
1203
  ]
@@ -1333,16 +1227,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1333
1227
  })
1334
1228
 
1335
1229
 
1336
- async def review_semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
1230
+ async def review_semaphore_helper(task_data: Dict, **kwrs):
1337
1231
  messages = task_data["messages"]
1338
1232
  original_index = task_data["original_index"]
1339
1233
 
1340
1234
  async with semaphore:
1341
1235
  review_gen_text = await self.inference_engine.chat_async(
1342
- messages=messages,
1343
- max_new_tokens=max_new_tokens,
1344
- temperature=temperature,
1345
- **kwrs
1236
+ messages=messages
1346
1237
  )
1347
1238
  # Combine initial and review results
1348
1239
  task_data["review_gen_text"] = review_gen_text
@@ -1354,10 +1245,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1354
1245
  # Create and gather review tasks
1355
1246
  review_tasks = [
1356
1247
  asyncio.create_task(review_semaphore_helper(
1357
- task_inp,
1358
- max_new_tokens=max_new_tokens,
1359
- temperature=temperature,
1360
- **kwrs
1248
+ task_inp
1361
1249
  ))
1362
1250
  for task_inp in review_tasks_input
1363
1251
  ]
@@ -1405,7 +1293,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1405
1293
 
1406
1294
 
1407
1295
  class BasicFrameExtractor(DirectFrameExtractor):
1408
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
1296
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
1409
1297
  """
1410
1298
  This class diretly prompt LLM for frame extraction.
1411
1299
  Input system prompt (optional), prompt template (with instruction, few-shot examples),
@@ -1424,11 +1312,10 @@ class BasicFrameExtractor(DirectFrameExtractor):
1424
1312
  unit_chunker=WholeDocumentUnitChunker(),
1425
1313
  prompt_template=prompt_template,
1426
1314
  system_prompt=system_prompt,
1427
- context_chunker=NoContextChunker(),
1428
- **kwrs)
1315
+ context_chunker=NoContextChunker())
1429
1316
 
1430
1317
  class BasicReviewFrameExtractor(ReviewFrameExtractor):
1431
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None, **kwrs):
1318
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
1432
1319
  """
1433
1320
  This class add a review step after the BasicFrameExtractor.
1434
1321
  The Review process asks LLM to review its output and:
@@ -1457,13 +1344,12 @@ class BasicReviewFrameExtractor(ReviewFrameExtractor):
1457
1344
  review_mode=review_mode,
1458
1345
  review_prompt=review_prompt,
1459
1346
  system_prompt=system_prompt,
1460
- context_chunker=NoContextChunker(),
1461
- **kwrs)
1347
+ context_chunker=NoContextChunker())
1462
1348
 
1463
1349
 
1464
1350
  class SentenceFrameExtractor(DirectFrameExtractor):
1465
1351
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
1466
- context_sentences:Union[str, int]="all", **kwrs):
1352
+ context_sentences:Union[str, int]="all"):
1467
1353
  """
1468
1354
  This class performs sentence-by-sentence information extraction.
1469
1355
  The process is as follows:
@@ -1507,14 +1393,13 @@ class SentenceFrameExtractor(DirectFrameExtractor):
1507
1393
  unit_chunker=SentenceUnitChunker(),
1508
1394
  prompt_template=prompt_template,
1509
1395
  system_prompt=system_prompt,
1510
- context_chunker=context_chunker,
1511
- **kwrs)
1396
+ context_chunker=context_chunker)
1512
1397
 
1513
1398
 
1514
1399
  class SentenceReviewFrameExtractor(ReviewFrameExtractor):
1515
1400
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
1516
1401
  review_mode:str, review_prompt:str=None, system_prompt:str=None,
1517
- context_sentences:Union[str, int]="all", **kwrs):
1402
+ context_sentences:Union[str, int]="all"):
1518
1403
  """
1519
1404
  This class adds a review step after the SentenceFrameExtractor.
1520
1405
  For each sentence, the review process asks LLM to review its output and:
@@ -1561,15 +1446,14 @@ class SentenceReviewFrameExtractor(ReviewFrameExtractor):
1561
1446
  review_mode=review_mode,
1562
1447
  review_prompt=review_prompt,
1563
1448
  system_prompt=system_prompt,
1564
- context_chunker=context_chunker,
1565
- **kwrs)
1449
+ context_chunker=context_chunker)
1566
1450
 
1567
1451
 
1568
- class RelationExtractor(Extractor):
1569
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
1452
+ class AttributeExtractor(Extractor):
1453
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
1570
1454
  """
1571
- This is the abstract class for relation extraction.
1572
- Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
1455
+ This class is for attribute extraction for frames. Though FrameExtractors can also extract attributes, when
1456
+ the number of attribute increases, it is more efficient to use a dedicated AttributeExtractor.
1573
1457
 
1574
1458
  Parameters
1575
1459
  ----------
@@ -1582,350 +1466,475 @@ class RelationExtractor(Extractor):
1582
1466
  """
1583
1467
  super().__init__(inference_engine=inference_engine,
1584
1468
  prompt_template=prompt_template,
1585
- system_prompt=system_prompt,
1586
- **kwrs)
1469
+ system_prompt=system_prompt)
1470
+ # validate prompt template
1471
+ if "{{context}}" not in self.prompt_template or "{{frame}}" not in self.prompt_template:
1472
+ raise ValueError("prompt_template must contain both {{context}} and {{frame}} placeholders.")
1587
1473
 
1588
- def _get_ROI(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
1589
- text:str, buffer_size:int=100) -> str:
1474
+ def _get_context(self, frame:LLMInformationExtractionFrame, text:str, context_size:int=256) -> str:
1590
1475
  """
1591
- This method returns the Region of Interest (ROI) that covers the two frames. Leaves a buffer_size of characters before and after.
1592
- The returned text has the two frames inline annotated with <entity_1>, <entity_2>.
1476
+ This method returns the context that covers the frame. Leaves a context_size of characters before and after.
1477
+ The returned text has the frame inline annotated with <entity>.
1593
1478
 
1594
1479
  Parameters:
1595
1480
  -----------
1596
- frame_1 : LLMInformationExtractionFrame
1481
+ frame : LLMInformationExtractionFrame
1597
1482
  a frame
1598
- frame_2 : LLMInformationExtractionFrame
1599
- the other frame
1600
1483
  text : str
1601
1484
  the entire document text
1602
- buffer_size : int, Optional
1603
- the number of characters before and after the two frames in the ROI text.
1485
+ context_size : int, Optional
1486
+ the number of characters before and after the frame in the context text.
1604
1487
 
1605
1488
  Return : str
1606
- the ROI text with the two frames inline annotated with <entity_1>, <entity_2>.
1489
+ the context text with the frame inline annotated with <entity>.
1607
1490
  """
1608
- left_frame, right_frame = sorted([frame_1, frame_2], key=lambda f: f.start)
1609
- left_frame_name = "entity_1" if left_frame == frame_1 else "entity_2"
1610
- right_frame_name = "entity_1" if right_frame == frame_1 else "entity_2"
1611
-
1612
- start = max(left_frame.start - buffer_size, 0)
1613
- end = min(right_frame.end + buffer_size, len(text))
1614
- roi = text[start:end]
1491
+ start = max(frame.start - context_size, 0)
1492
+ end = min(frame.end + context_size, len(text))
1493
+ context = text[start:end]
1615
1494
 
1616
- roi_annotated = roi[0:left_frame.start - start] + \
1617
- f'<{left_frame_name}>' + \
1618
- roi[left_frame.start - start:left_frame.end - start] + \
1619
- f"</{left_frame_name}>" + \
1620
- roi[left_frame.end - start:right_frame.start - start] + \
1621
- f'<{right_frame_name}>' + \
1622
- roi[right_frame.start - start:right_frame.end - start] + \
1623
- f"</{right_frame_name}>" + \
1624
- roi[right_frame.end - start:end - start]
1495
+ context_annotated = context[0:frame.start - start] + \
1496
+ f"<entity> " + \
1497
+ context[frame.start - start:frame.end - start] + \
1498
+ f" </entity>" + \
1499
+ context[frame.end - start:end - start]
1625
1500
 
1626
1501
  if start > 0:
1627
- roi_annotated = "..." + roi_annotated
1502
+ context_annotated = "..." + context_annotated
1628
1503
  if end < len(text):
1629
- roi_annotated = roi_annotated + "..."
1630
- return roi_annotated
1504
+ context_annotated = context_annotated + "..."
1505
+ return context_annotated
1631
1506
 
1632
-
1633
- @abc.abstractmethod
1634
- def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1635
- temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1507
+ def _extract_from_frame(self, frame:LLMInformationExtractionFrame, text:str,
1508
+ context_size:int=256, verbose:bool=False, return_messages_log:bool=False) -> Dict[str, Any]:
1636
1509
  """
1637
- This method considers all combinations of two frames.
1510
+ This method extracts attributes from a single frame.
1638
1511
 
1639
1512
  Parameters:
1640
1513
  -----------
1641
- doc : LLMInformationExtractionDocument
1642
- a document with frames.
1643
- buffer_size : int, Optional
1644
- the number of characters before and after the two frames in the ROI text.
1645
- max_new_tokens : str, Optional
1646
- the max number of new tokens LLM should generate.
1647
- temperature : float, Optional
1648
- the temperature for token sampling.
1649
- stream : bool, Optional
1650
- if True, LLM generated text will be printed in terminal in real-time.
1514
+ frame : LLMInformationExtractionFrame
1515
+ a frame to extract attributes from.
1516
+ text : str
1517
+ the entire document text.
1518
+ context_size : int, Optional
1519
+ the number of characters before and after the frame in the context text.
1520
+ verbose : bool, Optional
1521
+ if True, LLM generated text will be printed in terminal in real-time.
1651
1522
  return_messages_log : bool, Optional
1652
1523
  if True, a list of messages will be returned.
1653
1524
 
1654
- Return : List[Dict]
1655
- a list of dict with {"frame_1", "frame_2"} for all relations.
1525
+ Return : Dict[str, Any]
1526
+ a dictionary of attributes extracted from the frame.
1527
+ If return_messages_log is True, a list of messages will be returned as well.
1656
1528
  """
1657
- return NotImplemented
1658
-
1529
+ # construct chat messages
1530
+ messages = []
1531
+ if self.system_prompt:
1532
+ messages.append({'role': 'system', 'content': self.system_prompt})
1659
1533
 
1660
- class BinaryRelationExtractor(RelationExtractor):
1661
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_func: Callable,
1662
- system_prompt:str=None, **kwrs):
1663
- """
1664
- This class extracts binary (yes/no) relations between two entities.
1665
- Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
1534
+ context = self._get_context(frame, text, context_size)
1535
+ messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
1666
1536
 
1667
- Parameters
1668
- ----------
1669
- inference_engine : InferenceEngine
1670
- the LLM inferencing engine object. Must implements the chat() method.
1671
- prompt_template : str
1672
- prompt template with "{{<placeholder name>}}" placeholder.
1673
- possible_relation_func : Callable, Optional
1674
- a function that inputs 2 frames and returns a bool indicating possible relations between them.
1675
- system_prompt : str, Optional
1676
- system prompt.
1677
- """
1678
- super().__init__(inference_engine=inference_engine,
1679
- prompt_template=prompt_template,
1680
- system_prompt=system_prompt,
1681
- **kwrs)
1682
-
1683
- if possible_relation_func:
1684
- # Check if possible_relation_func is a function
1685
- if not callable(possible_relation_func):
1686
- raise TypeError(f"Expect possible_relation_func as a function, received {type(possible_relation_func)} instead.")
1537
+ if verbose:
1538
+ print(f"\n\n{Fore.GREEN}Frame: {frame.frame_id}{Style.RESET_ALL}\n{frame.to_dict()}\n")
1539
+ if context != "":
1540
+ print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
1687
1541
 
1688
- sig = inspect.signature(possible_relation_func)
1689
- # Check if frame_1, frame_2 are in input parameters
1690
- if len(sig.parameters) != 2:
1691
- raise ValueError("The possible_relation_func must have exactly frame_1 and frame_2 as parameters.")
1692
- if "frame_1" not in sig.parameters.keys():
1693
- raise ValueError("The possible_relation_func is missing frame_1 as a parameter.")
1694
- if "frame_2" not in sig.parameters.keys():
1695
- raise ValueError("The possible_relation_func is missing frame_2 as a parameter.")
1696
- # Check if output is a bool
1697
- if sig.return_annotation != bool:
1698
- raise ValueError(f"Expect possible_relation_func to output a bool, current type hint suggests {sig.return_annotation} instead.")
1699
-
1700
- self.possible_relation_func = possible_relation_func
1701
-
1702
-
1703
- def _post_process(self, rel_json:str) -> bool:
1704
- if len(rel_json) > 0:
1705
- if "Relation" in rel_json[0]:
1706
- rel = rel_json[0]["Relation"]
1707
- if isinstance(rel, bool):
1708
- return rel
1709
- elif isinstance(rel, str) and rel in {"True", "False"}:
1710
- return eval(rel)
1711
- else:
1712
- warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
1713
- 'Following default, relation = False.', RuntimeWarning)
1714
- else:
1715
- warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
1716
- else:
1717
- warnings.warn('Extractor did not output a JSON list. Following default, relation = False.', RuntimeWarning)
1718
- return False
1719
-
1542
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1720
1543
 
1721
- def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1722
- temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1544
+ get_text = self.inference_engine.chat(
1545
+ messages=messages,
1546
+ verbose=verbose,
1547
+ stream=False
1548
+ )
1549
+ if return_messages_log:
1550
+ messages.append({"role": "assistant", "content": get_text})
1551
+
1552
+ attribute_list = self._extract_json(gen_text=get_text)
1553
+ if isinstance(attribute_list, list) and len(attribute_list) > 0:
1554
+ attributes = attribute_list[0]
1555
+ if return_messages_log:
1556
+ return attributes, messages
1557
+ return attributes
1558
+
1559
+
1560
+ def extract(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256, verbose:bool=False,
1561
+ return_messages_log:bool=False, inplace:bool=True) -> Union[None, List[LLMInformationExtractionFrame]]:
1723
1562
  """
1724
- This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1725
- Outputs pairs that are related.
1563
+ This method extracts attributes from the document.
1726
1564
 
1727
1565
  Parameters:
1728
1566
  -----------
1729
- doc : LLMInformationExtractionDocument
1730
- a document with frames.
1731
- buffer_size : int, Optional
1732
- the number of characters before and after the two frames in the ROI text.
1733
- max_new_tokens : str, Optional
1734
- the max number of new tokens LLM should generate.
1735
- temperature : float, Optional
1736
- the temperature for token sampling.
1737
- stream : bool, Optional
1567
+ frames : List[LLMInformationExtractionFrame]
1568
+ a list of frames to extract attributes from.
1569
+ text : str
1570
+ the entire document text.
1571
+ context_size : int, Optional
1572
+ the number of characters before and after the frame in the context text.
1573
+ verbose : bool, Optional
1738
1574
  if True, LLM generated text will be printed in terminal in real-time.
1739
1575
  return_messages_log : bool, Optional
1740
1576
  if True, a list of messages will be returned.
1577
+ inplace : bool, Optional
1578
+ if True, the method will modify the frames in-place.
1579
+
1580
+ Return : Union[None, List[LLMInformationExtractionFrame]]
1581
+ if inplace is True, the method will modify the frames in-place.
1582
+ if inplace is False, the method will return a list of frames with attributes extracted.
1583
+ """
1584
+ for frame in frames:
1585
+ if not isinstance(frame, LLMInformationExtractionFrame):
1586
+ raise TypeError(f"Expect frame as LLMInformationExtractionFrame, received {type(frame)} instead.")
1587
+ if not isinstance(text, str):
1588
+ raise TypeError(f"Expect text as str, received {type(text)} instead.")
1589
+
1590
+ new_frames = []
1591
+ messages_log = [] if return_messages_log else None
1741
1592
 
1742
- Return : List[Dict]
1743
- a list of dict with {"frame_1_id", "frame_2_id"}.
1744
- """
1745
- pairs = itertools.combinations(doc.frames, 2)
1593
+ for frame in frames:
1594
+ if return_messages_log:
1595
+ attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
1596
+ verbose=verbose, return_messages_log=return_messages_log)
1597
+ messages_log.append(messages)
1598
+ else:
1599
+ attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
1600
+ verbose=verbose, return_messages_log=return_messages_log)
1601
+
1602
+ if inplace:
1603
+ frame.attr.update(attr)
1604
+ else:
1605
+ new_frame = frame.copy()
1606
+ new_frame.attr.update(attr)
1607
+ new_frames.append(new_frame)
1746
1608
 
1747
- if return_messages_log:
1748
- messages_log = []
1609
+ if inplace:
1610
+ return messages_log if return_messages_log else None
1611
+ else:
1612
+ return (new_frames, messages_log) if return_messages_log else new_frames
1749
1613
 
1750
- output = []
1751
- for frame_1, frame_2 in pairs:
1752
- pos_rel = self.possible_relation_func(frame_1, frame_2)
1753
1614
 
1754
- if pos_rel:
1755
- roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1756
- if stream:
1757
- print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1758
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1615
+ async def extract_async(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
1616
+ concurrent_batch_size:int=32, inplace:bool=True, return_messages_log:bool=False) -> Union[None, List[LLMInformationExtractionFrame]]:
1617
+ """
1618
+ This method extracts attributes from the document asynchronously.
1619
+
1620
+ Parameters:
1621
+ -----------
1622
+ frames : List[LLMInformationExtractionFrame]
1623
+ a list of frames to extract attributes from.
1624
+ text : str
1625
+ the entire document text.
1626
+ context_size : int, Optional
1627
+ the number of characters before and after the frame in the context text.
1628
+ concurrent_batch_size : int, Optional
1629
+ the batch size for concurrent processing.
1630
+ inplace : bool, Optional
1631
+ if True, the method will modify the frames in-place.
1632
+ return_messages_log : bool, Optional
1633
+ if True, a list of messages will be returned.
1634
+
1635
+ Return : Union[None, List[LLMInformationExtractionFrame]]
1636
+ if inplace is True, the method will modify the frames in-place.
1637
+ if inplace is False, the method will return a list of frames with attributes extracted.
1638
+ """
1639
+ # validation
1640
+ for frame in frames:
1641
+ if not isinstance(frame, LLMInformationExtractionFrame):
1642
+ raise TypeError(f"Expect frame as LLMInformationExtractionFrame, received {type(frame)} instead.")
1643
+ if not isinstance(text, str):
1644
+ raise TypeError(f"Expect text as str, received {type(text)} instead.")
1645
+
1646
+ # async helper
1647
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
1648
+
1649
+ async def semaphore_helper(frame:LLMInformationExtractionFrame, text:str, context_size:int) -> dict:
1650
+ async with semaphore:
1759
1651
  messages = []
1760
1652
  if self.system_prompt:
1761
1653
  messages.append({'role': 'system', 'content': self.system_prompt})
1762
1654
 
1763
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1764
- "frame_1": str(frame_1.to_dict()),
1765
- "frame_2": str(frame_2.to_dict())}
1766
- )})
1767
-
1768
- gen_text = self.inference_engine.chat(
1769
- messages=messages,
1770
- max_new_tokens=max_new_tokens,
1771
- temperature=temperature,
1772
- stream=stream,
1773
- **kwrs
1774
- )
1775
- rel_json = self._extract_json(gen_text)
1776
- if self._post_process(rel_json):
1777
- output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
1655
+ context = self._get_context(frame, text, context_size)
1656
+ messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
1778
1657
 
1658
+ gen_text = await self.inference_engine.chat_async(messages=messages)
1659
+
1779
1660
  if return_messages_log:
1780
1661
  messages.append({"role": "assistant", "content": gen_text})
1781
- messages_log.append(messages)
1782
1662
 
1783
- if return_messages_log:
1784
- return output, messages_log
1785
- return output
1786
-
1787
-
1788
- async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1789
- temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1663
+ attribute_list = self._extract_json(gen_text=gen_text)
1664
+ attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
1665
+ return {"frame": frame, "attributes": attributes, "messages": messages}
1666
+
1667
+ # create tasks
1668
+ tasks = [asyncio.create_task(semaphore_helper(frame, text, context_size)) for frame in frames]
1669
+ results = await asyncio.gather(*tasks)
1670
+
1671
+ # process results
1672
+ new_frames = []
1673
+ messages_log = [] if return_messages_log else None
1674
+
1675
+ for result in results:
1676
+ if return_messages_log:
1677
+ messages_log.append(result["messages"])
1678
+
1679
+ if inplace:
1680
+ result["frame"].attr.update(result["attributes"])
1681
+ else:
1682
+ new_frame = result["frame"].copy()
1683
+ new_frame.attr.update(result["attributes"])
1684
+ new_frames.append(new_frame)
1685
+
1686
+ # output
1687
+ if inplace:
1688
+ return messages_log if return_messages_log else None
1689
+ else:
1690
+ return (new_frames, messages_log) if return_messages_log else new_frames
1691
+
1692
+ def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
1693
+ concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
1694
+ return_messages_log:bool=False, inplace:bool=True) -> Union[None, List[LLMInformationExtractionFrame]]:
1790
1695
  """
1791
- This is the asynchronous version of the extract() method.
1696
+ This method extracts attributes from the document.
1792
1697
 
1793
1698
  Parameters:
1794
1699
  -----------
1795
- doc : LLMInformationExtractionDocument
1796
- a document with frames.
1797
- buffer_size : int, Optional
1798
- the number of characters before and after the two frames in the ROI text.
1799
- max_new_tokens : str, Optional
1800
- the max number of new tokens LLM should generate.
1801
- temperature : float, Optional
1802
- the temperature for token sampling.
1700
+ frames : List[LLMInformationExtractionFrame]
1701
+ a list of frames to extract attributes from.
1702
+ text : str
1703
+ the entire document text.
1704
+ context_size : int, Optional
1705
+ the number of characters before and after the frame in the context text.
1706
+ concurrent : bool, Optional
1707
+ if True, the method will run in concurrent mode with batch size concurrent_batch_size.
1803
1708
  concurrent_batch_size : int, Optional
1804
- the number of frame pairs to process in concurrent.
1709
+ the batch size for concurrent processing.
1710
+ verbose : bool, Optional
1711
+ if True, LLM generated text will be printed in terminal in real-time.
1805
1712
  return_messages_log : bool, Optional
1806
1713
  if True, a list of messages will be returned.
1807
-
1808
- Return : List[Dict]
1809
- a list of dict with {"frame_1", "frame_2"}.
1810
- """
1811
- # Check if self.inference_engine.chat_async() is implemented
1812
- if not hasattr(self.inference_engine, 'chat_async'):
1813
- raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1714
+ inplace : bool, Optional
1715
+ if True, the method will modify the frames in-place.
1814
1716
 
1815
- pairs = itertools.combinations(doc.frames, 2)
1816
- if return_messages_log:
1817
- messages_log = []
1818
-
1819
- n_frames = len(doc.frames)
1820
- num_pairs = (n_frames * (n_frames-1)) // 2
1821
- output = []
1822
- for i in range(0, num_pairs, concurrent_batch_size):
1823
- rel_pair_list = []
1824
- tasks = []
1825
- batch = list(itertools.islice(pairs, concurrent_batch_size))
1826
- batch_messages = []
1827
- for frame_1, frame_2 in batch:
1828
- pos_rel = self.possible_relation_func(frame_1, frame_2)
1829
-
1830
- if pos_rel:
1831
- rel_pair_list.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
1832
- roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1833
- messages = []
1834
- if self.system_prompt:
1835
- messages.append({'role': 'system', 'content': self.system_prompt})
1836
-
1837
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1838
- "frame_1": str(frame_1.to_dict()),
1839
- "frame_2": str(frame_2.to_dict())}
1840
- )})
1841
-
1842
- task = asyncio.create_task(
1843
- self.inference_engine.chat_async(
1844
- messages=messages,
1845
- max_new_tokens=max_new_tokens,
1846
- temperature=temperature,
1847
- **kwrs
1848
- )
1849
- )
1850
- tasks.append(task)
1851
- batch_messages.append(messages)
1717
+ Return : Union[None, List[LLMInformationExtractionFrame]]
1718
+ if inplace is True, the method will modify the frames in-place.
1719
+ if inplace is False, the method will return a list of frames with attributes extracted.
1720
+ """
1721
+ if concurrent:
1722
+ if verbose:
1723
+ warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
1852
1724
 
1853
- responses = await asyncio.gather(*tasks)
1725
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
1854
1726
 
1855
- for d, response, messages in zip(rel_pair_list, responses, batch_messages):
1856
- if return_messages_log:
1857
- messages.append({"role": "assistant", "content": response})
1858
- messages_log.append(messages)
1727
+ return asyncio.run(self.extract_async(frames=frames, text=text, context_size=context_size,
1728
+ concurrent_batch_size=concurrent_batch_size,
1729
+ inplace=inplace, return_messages_log=return_messages_log))
1730
+ else:
1731
+ return self.extract(frames=frames, text=text, context_size=context_size,
1732
+ verbose=verbose, return_messages_log=return_messages_log, inplace=inplace)
1859
1733
 
1860
- rel_json = self._extract_json(response)
1861
- if self._post_process(rel_json):
1862
- output.append(d)
1863
-
1864
- if return_messages_log:
1865
- return output, messages_log
1866
- return output
1867
1734
 
1735
+ class RelationExtractor(Extractor):
1736
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
1737
+ """
1738
+ This is the abstract class for relation extraction.
1739
+ Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
1868
1740
 
1869
- def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1870
- temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
1871
- stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1741
+ Parameters
1742
+ ----------
1743
+ inference_engine : InferenceEngine
1744
+ the LLM inferencing engine object. Must implements the chat() method.
1745
+ prompt_template : str
1746
+ prompt template with "{{<placeholder name>}}" placeholder.
1747
+ system_prompt : str, Optional
1748
+ system prompt.
1872
1749
  """
1873
- This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1750
+ super().__init__(inference_engine=inference_engine,
1751
+ prompt_template=prompt_template,
1752
+ system_prompt=system_prompt)
1753
+
1754
+ def _get_ROI(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
1755
+ text:str, buffer_size:int=128) -> str:
1756
+ """
1757
+ This method returns the Region of Interest (ROI) that covers the two frames. Leaves a buffer_size of characters before and after.
1758
+ The returned text has the two frames inline annotated with <entity_1>, <entity_2>.
1874
1759
 
1875
1760
  Parameters:
1876
1761
  -----------
1877
- doc : LLMInformationExtractionDocument
1878
- a document with frames.
1762
+ frame_1 : LLMInformationExtractionFrame
1763
+ a frame
1764
+ frame_2 : LLMInformationExtractionFrame
1765
+ the other frame
1766
+ text : str
1767
+ the entire document text
1879
1768
  buffer_size : int, Optional
1880
1769
  the number of characters before and after the two frames in the ROI text.
1881
- max_new_tokens : str, Optional
1882
- the max number of new tokens LLM should generate.
1883
- temperature : float, Optional
1884
- the temperature for token sampling.
1885
- concurrent: bool, Optional
1886
- if True, the extraction will be done in concurrent.
1887
- concurrent_batch_size : int, Optional
1888
- the number of frame pairs to process in concurrent.
1889
- stream : bool, Optional
1890
- if True, LLM generated text will be printed in terminal in real-time.
1891
- return_messages_log : bool, Optional
1892
- if True, a list of messages will be returned.
1893
1770
 
1894
- Return : List[Dict]
1895
- a list of dict with {"frame_1", "frame_2"} for all relations.
1771
+ Return : str
1772
+ the ROI text with the two frames inline annotated with <entity_1>, <entity_2>.
1896
1773
  """
1774
+ left_frame, right_frame = sorted([frame_1, frame_2], key=lambda f: f.start)
1775
+ left_frame_name = "entity_1" if left_frame.frame_id == frame_1.frame_id else "entity_2"
1776
+ right_frame_name = "entity_1" if right_frame.frame_id == frame_1.frame_id else "entity_2"
1777
+
1778
+ start = max(left_frame.start - buffer_size, 0)
1779
+ end = min(right_frame.end + buffer_size, len(text))
1780
+ roi = text[start:end]
1781
+
1782
+ roi_annotated = roi[0:left_frame.start - start] + \
1783
+ f"<{left_frame_name}> " + \
1784
+ roi[left_frame.start - start:left_frame.end - start] + \
1785
+ f" </{left_frame_name}>" + \
1786
+ roi[left_frame.end - start:right_frame.start - start] + \
1787
+ f"<{right_frame_name}> " + \
1788
+ roi[right_frame.start - start:right_frame.end - start] + \
1789
+ f" </{right_frame_name}>" + \
1790
+ roi[right_frame.end - start:end - start]
1791
+
1792
+ if start > 0:
1793
+ roi_annotated = "..." + roi_annotated
1794
+ if end < len(text):
1795
+ roi_annotated = roi_annotated + "..."
1796
+ return roi_annotated
1797
+
1798
+ @abc.abstractmethod
1799
+ def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
1800
+ text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
1801
+ """Checks if a relation is possible and constructs the task payload."""
1802
+ raise NotImplementedError
1803
+
1804
+ @abc.abstractmethod
1805
+ def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
1806
+ """Processes the LLM output for a single pair and returns the final relation dictionary."""
1807
+ raise NotImplementedError
1808
+
1809
+ def _extract(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, verbose: bool = False,
1810
+ return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
1811
+ pairs = itertools.combinations(doc.frames, 2)
1812
+ relations = []
1813
+ messages_log = [] if return_messages_log else None
1814
+
1815
+ for frame_1, frame_2 in pairs:
1816
+ task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
1817
+ if task_payload:
1818
+ if verbose:
1819
+ print(f"\n\n{Fore.GREEN}Evaluating pair:{Style.RESET_ALL} ({frame_1.frame_id}, {frame_2.frame_id})")
1820
+ print(f"{Fore.YELLOW}ROI Text:{Style.RESET_ALL}\n{task_payload['roi_text']}\n")
1821
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1822
+
1823
+ gen_text = self.inference_engine.chat(
1824
+ messages=task_payload['messages'],
1825
+ verbose=verbose
1826
+ )
1827
+ relation = self._post_process_result(gen_text, task_payload)
1828
+ if relation:
1829
+ relations.append(relation)
1830
+
1831
+ if return_messages_log:
1832
+ task_payload['messages'].append({"role": "assistant", "content": gen_text})
1833
+ messages_log.append(task_payload['messages'])
1834
+
1835
+ return (relations, messages_log) if return_messages_log else relations
1836
+
1837
+ async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
1838
+ pairs = list(itertools.combinations(doc.frames, 2))
1839
+ tasks_input = [self._get_task_if_possible(f1, f2, doc.text, buffer_size) for f1, f2 in pairs]
1840
+ # Filter out impossible pairs
1841
+ tasks_input = [task for task in tasks_input if task is not None]
1842
+
1843
+ relations = []
1844
+ messages_log = [] if return_messages_log else None
1845
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
1846
+
1847
+ async def semaphore_helper(task_payload: Dict):
1848
+ async with semaphore:
1849
+ gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
1850
+ return gen_text, task_payload
1851
+
1852
+ tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
1853
+ results = await asyncio.gather(*tasks)
1854
+
1855
+ for gen_text, task_payload in results:
1856
+ relation = self._post_process_result(gen_text, task_payload)
1857
+ if relation:
1858
+ relations.append(relation)
1859
+
1860
+ if return_messages_log:
1861
+ task_payload['messages'].append({"role": "assistant", "content": gen_text})
1862
+ messages_log.append(task_payload['messages'])
1863
+
1864
+ return (relations, messages_log) if return_messages_log else relations
1865
+
1866
+ def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
1897
1867
  if not doc.has_frame():
1898
1868
  raise ValueError("Input document must have frames.")
1899
-
1900
1869
  if doc.has_duplicate_frame_ids():
1901
1870
  raise ValueError("All frame_ids in the input document must be unique.")
1902
1871
 
1903
1872
  if concurrent:
1904
- if stream:
1905
- warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
1906
-
1907
- nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
1908
- return asyncio.run(self.extract_async(doc=doc,
1909
- buffer_size=buffer_size,
1910
- max_new_tokens=max_new_tokens,
1911
- temperature=temperature,
1912
- concurrent_batch_size=concurrent_batch_size,
1913
- return_messages_log=return_messages_log,
1914
- **kwrs)
1915
- )
1873
+ if verbose:
1874
+ warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
1875
+ nest_asyncio.apply()
1876
+ return asyncio.run(self._extract_async(doc, buffer_size, concurrent_batch_size, return_messages_log))
1916
1877
  else:
1917
- return self.extract(doc=doc,
1918
- buffer_size=buffer_size,
1919
- max_new_tokens=max_new_tokens,
1920
- temperature=temperature,
1921
- stream=stream,
1922
- return_messages_log=return_messages_log,
1923
- **kwrs)
1878
+ return self._extract(doc, buffer_size, verbose, return_messages_log)
1879
+
1880
+
1881
+ class BinaryRelationExtractor(RelationExtractor):
1882
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_func: Callable,
1883
+ system_prompt:str=None):
1884
+ """
1885
+ This class extracts binary (yes/no) relations between two entities.
1886
+ Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
1887
+
1888
+ Parameters
1889
+ ----------
1890
+ inference_engine : InferenceEngine
1891
+ the LLM inferencing engine object. Must implements the chat() method.
1892
+ prompt_template : str
1893
+ prompt template with "{{<placeholder name>}}" placeholder.
1894
+ possible_relation_func : Callable, Optional
1895
+ a function that inputs 2 frames and returns a bool indicating possible relations between them.
1896
+ system_prompt : str, Optional
1897
+ system prompt.
1898
+ """
1899
+ super().__init__(inference_engine, prompt_template, system_prompt)
1900
+ if not callable(possible_relation_func):
1901
+ raise TypeError(f"Expect possible_relation_func as a function, received {type(possible_relation_func)} instead.")
1902
+
1903
+ sig = inspect.signature(possible_relation_func)
1904
+ if len(sig.parameters) != 2:
1905
+ raise ValueError("The possible_relation_func must have exactly two parameters.")
1906
+
1907
+ if sig.return_annotation not in {bool, inspect.Signature.empty}:
1908
+ warnings.warn(f"Expected possible_relation_func return annotation to be bool, but got {sig.return_annotation}.")
1909
+
1910
+ self.possible_relation_func = possible_relation_func
1911
+
1912
+ def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
1913
+ text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
1914
+ if self.possible_relation_func(frame_1, frame_2):
1915
+ roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size)
1916
+ messages = []
1917
+ if self.system_prompt:
1918
+ messages.append({'role': 'system', 'content': self.system_prompt})
1919
+
1920
+ messages.append({'role': 'user', 'content': self._get_user_prompt(
1921
+ text_content={"roi_text": roi_text, "frame_1": str(frame_1.to_dict()), "frame_2": str(frame_2.to_dict())}
1922
+ )})
1923
+ return {"frame_1": frame_1, "frame_2": frame_2, "messages": messages, "roi_text": roi_text}
1924
+ return None
1925
+
1926
+ def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
1927
+ rel_json = self._extract_json(gen_text)
1928
+ if len(rel_json) > 0 and "Relation" in rel_json[0]:
1929
+ rel = rel_json[0]["Relation"]
1930
+ if (isinstance(rel, bool) and rel) or (isinstance(rel, str) and rel.lower() == 'true'):
1931
+ return {'frame_1_id': pair_data['frame_1'].frame_id, 'frame_2_id': pair_data['frame_2'].frame_id}
1932
+ return None
1924
1933
 
1925
1934
 
1926
1935
  class MultiClassRelationExtractor(RelationExtractor):
1927
1936
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_types_func: Callable,
1928
- system_prompt:str=None, **kwrs):
1937
+ system_prompt:str=None):
1929
1938
  """
1930
1939
  This class extracts relations with relation types.
1931
1940
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -1944,8 +1953,7 @@ class MultiClassRelationExtractor(RelationExtractor):
1944
1953
  """
1945
1954
  super().__init__(inference_engine=inference_engine,
1946
1955
  prompt_template=prompt_template,
1947
- system_prompt=system_prompt,
1948
- **kwrs)
1956
+ system_prompt=system_prompt)
1949
1957
 
1950
1958
  if possible_relation_types_func:
1951
1959
  # Check if possible_relation_types_func is a function
@@ -1967,235 +1975,25 @@ class MultiClassRelationExtractor(RelationExtractor):
1967
1975
  self.possible_relation_types_func = possible_relation_types_func
1968
1976
 
1969
1977
 
1970
- def _post_process(self, rel_json:List[Dict], pos_rel_types:List[str]) -> Union[str, None]:
1971
- """
1972
- This method post-processes the extracted relation JSON.
1973
-
1974
- Parameters:
1975
- -----------
1976
- rel_json : List[Dict]
1977
- the extracted relation JSON.
1978
- pos_rel_types : List[str]
1979
- possible relation types by the possible_relation_types_func.
1980
-
1981
- Return : Union[str, None]
1982
- the relation type (str) or None for no relation.
1983
- """
1984
- if len(rel_json) > 0:
1985
- if "RelationType" in rel_json[0]:
1986
- if rel_json[0]["RelationType"] in pos_rel_types:
1987
- return rel_json[0]["RelationType"]
1988
- else:
1989
- warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
1990
- else:
1991
- warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
1978
+ def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
1979
+ text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
1980
+ pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1981
+ if pos_rel_types:
1982
+ roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size)
1983
+ messages = []
1984
+ if self.system_prompt:
1985
+ messages.append({'role': 'system', 'content': self.system_prompt})
1986
+ messages.append({'role': 'user', 'content': self._get_user_prompt(
1987
+ text_content={"roi_text": roi_text, "frame_1": str(frame_1.to_dict()), "frame_2": str(frame_2.to_dict()), "pos_rel_types": str(pos_rel_types)}
1988
+ )})
1989
+ return {"frame_1": frame_1, "frame_2": frame_2, "messages": messages, "pos_rel_types": pos_rel_types, "roi_text": roi_text}
1992
1990
  return None
1993
-
1994
-
1995
- def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1996
- temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1997
- """
1998
- This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1999
-
2000
- Parameters:
2001
- -----------
2002
- doc : LLMInformationExtractionDocument
2003
- a document with frames.
2004
- buffer_size : int, Optional
2005
- the number of characters before and after the two frames in the ROI text.
2006
- max_new_tokens : str, Optional
2007
- the max number of new tokens LLM should generate.
2008
- temperature : float, Optional
2009
- the temperature for token sampling.
2010
- stream : bool, Optional
2011
- if True, LLM generated text will be printed in terminal in real-time.
2012
- return_messages_log : bool, Optional
2013
- if True, a list of messages will be returned.
2014
-
2015
- Return : List[Dict]
2016
- a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
2017
- """
2018
- pairs = itertools.combinations(doc.frames, 2)
2019
-
2020
- if return_messages_log:
2021
- messages_log = []
2022
-
2023
- output = []
2024
- for frame_1, frame_2 in pairs:
2025
- pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
2026
-
2027
- if pos_rel_types:
2028
- roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
2029
- if stream:
2030
- print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
2031
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
2032
- messages = []
2033
- if self.system_prompt:
2034
- messages.append({'role': 'system', 'content': self.system_prompt})
2035
-
2036
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
2037
- "frame_1": str(frame_1.to_dict()),
2038
- "frame_2": str(frame_2.to_dict()),
2039
- "pos_rel_types":str(pos_rel_types)}
2040
- )})
2041
-
2042
- gen_text = self.inference_engine.chat(
2043
- messages=messages,
2044
- max_new_tokens=max_new_tokens,
2045
- temperature=temperature,
2046
- stream=stream,
2047
- **kwrs
2048
- )
2049
-
2050
- if return_messages_log:
2051
- messages.append({"role": "assistant", "content": gen_text})
2052
- messages_log.append(messages)
2053
-
2054
- rel_json = self._extract_json(gen_text)
2055
- rel = self._post_process(rel_json, pos_rel_types)
2056
- if rel:
2057
- output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id, 'relation':rel})
2058
-
2059
- if return_messages_log:
2060
- return output, messages_log
2061
- return output
2062
-
2063
-
2064
- async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
2065
- temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
2066
- """
2067
- This is the asynchronous version of the extract() method.
2068
-
2069
- Parameters:
2070
- -----------
2071
- doc : LLMInformationExtractionDocument
2072
- a document with frames.
2073
- buffer_size : int, Optional
2074
- the number of characters before and after the two frames in the ROI text.
2075
- max_new_tokens : str, Optional
2076
- the max number of new tokens LLM should generate.
2077
- temperature : float, Optional
2078
- the temperature for token sampling.
2079
- concurrent_batch_size : int, Optional
2080
- the number of frame pairs to process in concurrent.
2081
- return_messages_log : bool, Optional
2082
- if True, a list of messages will be returned.
2083
-
2084
- Return : List[Dict]
2085
- a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
2086
- """
2087
- # Check if self.inference_engine.chat_async() is implemented
2088
- if not hasattr(self.inference_engine, 'chat_async'):
2089
- raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
2090
-
2091
- pairs = itertools.combinations(doc.frames, 2)
2092
- if return_messages_log:
2093
- messages_log = []
2094
-
2095
- n_frames = len(doc.frames)
2096
- num_pairs = (n_frames * (n_frames-1)) // 2
2097
- output = []
2098
- for i in range(0, num_pairs, concurrent_batch_size):
2099
- rel_pair_list = []
2100
- tasks = []
2101
- batch = list(itertools.islice(pairs, concurrent_batch_size))
2102
- batch_messages = []
2103
- for frame_1, frame_2 in batch:
2104
- pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
2105
-
2106
- if pos_rel_types:
2107
- rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'pos_rel_types':pos_rel_types})
2108
- roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
2109
- messages = []
2110
- if self.system_prompt:
2111
- messages.append({'role': 'system', 'content': self.system_prompt})
2112
-
2113
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
2114
- "frame_1": str(frame_1.to_dict()),
2115
- "frame_2": str(frame_2.to_dict()),
2116
- "pos_rel_types":str(pos_rel_types)}
2117
- )})
2118
- task = asyncio.create_task(
2119
- self.inference_engine.chat_async(
2120
- messages=messages,
2121
- max_new_tokens=max_new_tokens,
2122
- temperature=temperature,
2123
- **kwrs
2124
- )
2125
- )
2126
- tasks.append(task)
2127
- batch_messages.append(messages)
2128
-
2129
- responses = await asyncio.gather(*tasks)
2130
-
2131
- for d, response, messages in zip(rel_pair_list, responses, batch_messages):
2132
- if return_messages_log:
2133
- messages.append({"role": "assistant", "content": response})
2134
- messages_log.append(messages)
2135
-
2136
- rel_json = self._extract_json(response)
2137
- rel = self._post_process(rel_json, d['pos_rel_types'])
2138
- if rel:
2139
- output.append({'frame_1_id':d['frame_1'], 'frame_2_id':d['frame_2'], 'relation':rel})
2140
-
2141
- if return_messages_log:
2142
- return output, messages_log
2143
- return output
2144
-
2145
1991
 
2146
- def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
2147
- temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
2148
- stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
2149
- """
2150
- This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
2151
-
2152
- Parameters:
2153
- -----------
2154
- doc : LLMInformationExtractionDocument
2155
- a document with frames.
2156
- buffer_size : int, Optional
2157
- the number of characters before and after the two frames in the ROI text.
2158
- max_new_tokens : str, Optional
2159
- the max number of new tokens LLM should generate.
2160
- temperature : float, Optional
2161
- the temperature for token sampling.
2162
- concurrent: bool, Optional
2163
- if True, the extraction will be done in concurrent.
2164
- concurrent_batch_size : int, Optional
2165
- the number of frame pairs to process in concurrent.
2166
- stream : bool, Optional
2167
- if True, LLM generated text will be printed in terminal in real-time.
2168
- return_messages_log : bool, Optional
2169
- if True, a list of messages will be returned.
2170
-
2171
- Return : List[Dict]
2172
- a list of dict with {"frame_1", "frame_2", "relation"} for all relations.
2173
- """
2174
- if not doc.has_frame():
2175
- raise ValueError("Input document must have frames.")
2176
-
2177
- if doc.has_duplicate_frame_ids():
2178
- raise ValueError("All frame_ids in the input document must be unique.")
2179
-
2180
- if concurrent:
2181
- if stream:
2182
- warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
2183
-
2184
- nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
2185
- return asyncio.run(self.extract_async(doc=doc,
2186
- buffer_size=buffer_size,
2187
- max_new_tokens=max_new_tokens,
2188
- temperature=temperature,
2189
- concurrent_batch_size=concurrent_batch_size,
2190
- return_messages_log=return_messages_log,
2191
- **kwrs)
2192
- )
2193
- else:
2194
- return self.extract(doc=doc,
2195
- buffer_size=buffer_size,
2196
- max_new_tokens=max_new_tokens,
2197
- temperature=temperature,
2198
- stream=stream,
2199
- return_messages_log=return_messages_log,
2200
- **kwrs)
2201
-
1992
+ def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
1993
+ rel_json = self._extract_json(gen_text)
1994
+ pos_rel_types = pair_data['pos_rel_types']
1995
+ if len(rel_json) > 0 and "RelationType" in rel_json[0]:
1996
+ rel_type = rel_json[0]["RelationType"]
1997
+ if rel_type in pos_rel_types:
1998
+ return {'frame_1_id': pair_data['frame_1'].frame_id, 'frame_2_id': pair_data['frame_2'].frame_id, 'relation': rel_type}
1999
+ return None