llm-ie 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -1,10 +1,14 @@
1
1
  import abc
2
2
  import re
3
+ import copy
3
4
  import json
5
+ import json_repair
4
6
  import inspect
5
7
  import importlib.resources
6
8
  import warnings
7
9
  import itertools
10
+ import asyncio
11
+ import nest_asyncio
8
12
  from typing import Set, List, Dict, Tuple, Union, Callable
9
13
  from llm_ie.data_types import LLMInformationExtractionFrame, LLMInformationExtractionDocument
10
14
  from llm_ie.engines import InferenceEngine
@@ -18,7 +22,7 @@ class Extractor:
18
22
  This is the abstract class for (frame and relation) extractors.
19
23
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
20
24
 
21
- Parameters
25
+ Parameters:
22
26
  ----------
23
27
  inference_engine : InferenceEngine
24
28
  the LLM inferencing engine object. Must implements the chat() method.
@@ -37,16 +41,20 @@ class Extractor:
37
41
  """
38
42
  This method returns the pre-defined prompt guideline for the extractor from the package asset.
39
43
  """
44
+ # Check if the prompt guide is available
40
45
  file_path = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{cls.__name__}_prompt_guide.txt")
41
- with open(file_path, 'r', encoding="utf-8") as f:
42
- return f.read()
43
-
46
+ try:
47
+ with open(file_path, 'r', encoding="utf-8") as f:
48
+ return f.read()
49
+ except FileNotFoundError:
50
+ warnings.warn(f"Prompt guide for {cls.__name__} is not available. Is it a customed extractor?", UserWarning)
51
+ return None
44
52
 
45
53
  def _get_user_prompt(self, text_content:Union[str, Dict[str,str]]) -> str:
46
54
  """
47
55
  This method applies text_content to prompt_template and returns a prompt.
48
56
 
49
- Parameters
57
+ Parameters:
50
58
  ----------
51
59
  text_content : Union[str, Dict[str,str]]
52
60
  the input text content to put in prompt template.
@@ -117,7 +125,12 @@ class Extractor:
117
125
  dict_obj = json.loads(dict_str)
118
126
  out.append(dict_obj)
119
127
  except json.JSONDecodeError:
120
- warnings.warn(f'Post-processing failed:\n{dict_str}', RuntimeWarning)
128
+ dict_obj = json_repair.repair_json(dict_str, skip_json_loads=True, return_objects=True)
129
+ if dict_obj:
130
+ warnings.warn(f'JSONDecodeError detected, fixed with repair_json:\n{dict_str}', RuntimeWarning)
131
+ out.append(dict_obj)
132
+ else:
133
+ warnings.warn(f'JSONDecodeError could not be fixed:\n{dict_str}', RuntimeWarning)
121
134
  return out
122
135
 
123
136
 
@@ -127,7 +140,7 @@ class FrameExtractor(Extractor):
127
140
  This is the abstract class for frame extraction.
128
141
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
129
142
 
130
- Parameters
143
+ Parameters:
131
144
  ----------
132
145
  inference_engine : InferenceEngine
133
146
  the LLM inferencing engine object. Must implements the chat() method.
@@ -169,7 +182,7 @@ class FrameExtractor(Extractor):
169
182
  the substring must start with the same word token as the pattern. This is due to the observation that
170
183
  LLM often generate the first few words consistently.
171
184
 
172
- Parameters
185
+ Parameters:
173
186
  ----------
174
187
  text : str
175
188
  the input text.
@@ -213,7 +226,7 @@ class FrameExtractor(Extractor):
213
226
  outputs a list of spans (2-tuple) for each entity.
214
227
  Entities that are not found in the text will be None from output.
215
228
 
216
- Parameters
229
+ Parameters:
217
230
  ----------
218
231
  text : str
219
232
  text that contains entities
@@ -235,7 +248,10 @@ class FrameExtractor(Extractor):
235
248
 
236
249
  # Match entities
237
250
  entity_spans = []
238
- for entity in entities:
251
+ for entity in entities:
252
+ if not isinstance(entity, str):
253
+ entity_spans.append(None)
254
+ continue
239
255
  if not case_sensitive:
240
256
  entity = entity.lower()
241
257
 
@@ -316,7 +332,7 @@ class BasicFrameExtractor(FrameExtractor):
316
332
  Input system prompt (optional), prompt template (with instruction, few-shot examples),
317
333
  and specify a LLM.
318
334
 
319
- Parameters
335
+ Parameters:
320
336
  ----------
321
337
  inference_engine : InferenceEngine
322
338
  the LLM inferencing engine object. Must implements the chat() method.
@@ -549,18 +565,18 @@ class SentenceFrameExtractor(FrameExtractor):
549
565
  from nltk.tokenize.punkt import PunktSentenceTokenizer
550
566
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
551
567
  """
552
- This class performs sentence-based information extraction.
553
- A simulated chat follows this process:
568
+ This class performs sentence-by-sentence information extraction.
569
+ The process is as follows:
554
570
  1. system prompt (optional)
555
- 2. user instructions (schema, background, full text, few-shot example...)
556
- 3. user input first sentence
557
- 4. assistant extract outputs
571
+ 2. user prompt with instructions (schema, background, full text, few-shot example...)
572
+ 3. feed a sentence (start with first sentence)
573
+ 4. LLM extract entities and attributes from the sentence
558
574
  5. repeat #3 and #4
559
575
 
560
576
  Input system prompt (optional), prompt template (with user instructions),
561
577
  and specify a LLM.
562
578
 
563
- Parameters
579
+ Parameters:
564
580
  ----------
565
581
  inference_engine : InferenceEngine
566
582
  the LLM inferencing engine object. Must implements the chat() method.
@@ -577,7 +593,7 @@ class SentenceFrameExtractor(FrameExtractor):
577
593
  This method sentence tokenize the input text into a list of sentences
578
594
  as dict of {start, end, sentence_text}
579
595
 
580
- Parameters
596
+ Parameters:
581
597
  ----------
582
598
  text : str
583
599
  text to sentence tokenize.
@@ -668,10 +684,80 @@ class SentenceFrameExtractor(FrameExtractor):
668
684
  return output
669
685
 
670
686
 
687
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
688
+ document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
689
+ """
690
+ The asynchronous version of the extract() method.
691
+
692
+ Parameters:
693
+ ----------
694
+ text_content : Union[str, Dict[str,str]]
695
+ the input text content to put in prompt template.
696
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
697
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
698
+ max_new_tokens : str, Optional
699
+ the max number of new tokens LLM should generate.
700
+ document_key : str, Optional
701
+ specify the key in text_content where document text is.
702
+ If text_content is str, this parameter will be ignored.
703
+ temperature : float, Optional
704
+ the temperature for token sampling.
705
+ concurrent_batch_size : int, Optional
706
+ the number of sentences to process in concurrent.
707
+ """
708
+ # Check if self.inference_engine.chat_async() is implemented
709
+ if not hasattr(self.inference_engine, 'chat_async'):
710
+ raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
711
+
712
+ # define output
713
+ output = []
714
+ # sentence tokenization
715
+ if isinstance(text_content, str):
716
+ sentences = self._get_sentences(text_content)
717
+ elif isinstance(text_content, dict):
718
+ sentences = self._get_sentences(text_content[document_key])
719
+ # construct chat messages
720
+ base_messages = []
721
+ if self.system_prompt:
722
+ base_messages.append({'role': 'system', 'content': self.system_prompt})
723
+
724
+ base_messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
725
+ base_messages.append({'role': 'assistant', 'content': 'Sure, please start with the first sentence.'})
726
+
727
+ # generate sentence by sentence
728
+ tasks = []
729
+ for i in range(0, len(sentences), concurrent_batch_size):
730
+ batch = sentences[i:i + concurrent_batch_size]
731
+ for sent in batch:
732
+ messages = copy.deepcopy(base_messages)
733
+ messages.append({'role': 'user', 'content': sent['sentence_text']})
734
+ task = asyncio.create_task(
735
+ self.inference_engine.chat_async(
736
+ messages=messages,
737
+ max_new_tokens=max_new_tokens,
738
+ temperature=temperature,
739
+ **kwrs
740
+ )
741
+ )
742
+ tasks.append(task)
743
+
744
+ # Wait until the batch is done, collect results and move on to next batch
745
+ responses = await asyncio.gather(*tasks)
746
+
747
+ # Collect outputs
748
+ for gen_text, sent in zip(responses, sentences):
749
+ output.append({'sentence_start': sent['start'],
750
+ 'sentence_end': sent['end'],
751
+ 'sentence_text': sent['sentence_text'],
752
+ 'gen_text': gen_text})
753
+ return output
754
+
755
+
671
756
  def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=512,
672
- document_key:str=None, multi_turn:bool=False, temperature:float=0.0, stream:bool=False,
673
- case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
674
- **kwrs) -> List[LLMInformationExtractionFrame]:
757
+ document_key:str=None, multi_turn:bool=False, temperature:float=0.0, stream:bool=False,
758
+ concurrent:bool=False, concurrent_batch_size:int=32,
759
+ case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
760
+ **kwrs) -> List[LLMInformationExtractionFrame]:
675
761
  """
676
762
  This method inputs a text and outputs a list of LLMInformationExtractionFrame
677
763
  It use the extract() method and post-process outputs into frames.
@@ -699,6 +785,10 @@ class SentenceFrameExtractor(FrameExtractor):
699
785
  the temperature for token sampling.
700
786
  stream : bool, Optional
701
787
  if True, LLM generated text will be printed in terminal in real-time.
788
+ concurrent : bool, Optional
789
+ if True, the sentences will be extracted in concurrent.
790
+ concurrent_batch_size : int, Optional
791
+ the number of sentences to process in concurrent. Only used when `concurrent` is True.
702
792
  case_sensitive : bool, Optional
703
793
  if True, entity text matching will be case-sensitive.
704
794
  fuzzy_match : bool, Optional
@@ -712,15 +802,30 @@ class SentenceFrameExtractor(FrameExtractor):
712
802
  Return : str
713
803
  a list of frames.
714
804
  """
715
- llm_output_sentence = self.extract(text_content=text_content,
716
- max_new_tokens=max_new_tokens,
717
- document_key=document_key,
718
- multi_turn=multi_turn,
719
- temperature=temperature,
720
- stream=stream,
721
- **kwrs)
805
+ if concurrent:
806
+ if stream:
807
+ warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
808
+ if multi_turn:
809
+ warnings.warn("multi_turn=True is not supported in concurrent mode.", RuntimeWarning)
810
+
811
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
812
+ llm_output_sentences = asyncio.run(self.extract_async(text_content=text_content,
813
+ max_new_tokens=max_new_tokens,
814
+ document_key=document_key,
815
+ temperature=temperature,
816
+ concurrent_batch_size=concurrent_batch_size,
817
+ **kwrs)
818
+ )
819
+ else:
820
+ llm_output_sentences = self.extract(text_content=text_content,
821
+ max_new_tokens=max_new_tokens,
822
+ document_key=document_key,
823
+ multi_turn=multi_turn,
824
+ temperature=temperature,
825
+ stream=stream,
826
+ **kwrs)
722
827
  frame_list = []
723
- for sent in llm_output_sentence:
828
+ for sent in llm_output_sentences:
724
829
  entity_json = []
725
830
  for entity in self._extract_json(gen_text=sent['gen_text']):
726
831
  if entity_key in entity:
@@ -885,6 +990,121 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
885
990
  'sentence_text': sent['sentence_text'],
886
991
  'gen_text': gen_text})
887
992
  return output
993
+
994
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
995
+ document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
996
+ """
997
+ The asynchronous version of the extract() method.
998
+
999
+ Parameters:
1000
+ ----------
1001
+ text_content : Union[str, Dict[str,str]]
1002
+ the input text content to put in prompt template.
1003
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1004
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1005
+ max_new_tokens : str, Optional
1006
+ the max number of new tokens LLM should generate.
1007
+ document_key : str, Optional
1008
+ specify the key in text_content where document text is.
1009
+ If text_content is str, this parameter will be ignored.
1010
+ temperature : float, Optional
1011
+ the temperature for token sampling.
1012
+ concurrent_batch_size : int, Optional
1013
+ the number of sentences to process in concurrent.
1014
+
1015
+ Return : str
1016
+ the output from LLM. Need post-processing.
1017
+ """
1018
+ # Check if self.inference_engine.chat_async() is implemented
1019
+ if not hasattr(self.inference_engine, 'chat_async'):
1020
+ raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1021
+
1022
+ # define output
1023
+ output = []
1024
+ # sentence tokenization
1025
+ if isinstance(text_content, str):
1026
+ sentences = self._get_sentences(text_content)
1027
+ elif isinstance(text_content, dict):
1028
+ sentences = self._get_sentences(text_content[document_key])
1029
+ # construct chat messages
1030
+ base_messages = []
1031
+ if self.system_prompt:
1032
+ base_messages.append({'role': 'system', 'content': self.system_prompt})
1033
+
1034
+ base_messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
1035
+ base_messages.append({'role': 'assistant', 'content': 'Sure, please start with the first sentence.'})
1036
+
1037
+ # generate initial outputs sentence by sentence
1038
+ initials = []
1039
+ tasks = []
1040
+ message_list = []
1041
+ for i in range(0, len(sentences), concurrent_batch_size):
1042
+ batch = sentences[i:i + concurrent_batch_size]
1043
+ for sent in batch:
1044
+ messages = copy.deepcopy(base_messages)
1045
+ messages.append({'role': 'user', 'content': sent['sentence_text']})
1046
+ message_list.append(messages)
1047
+ task = asyncio.create_task(
1048
+ self.inference_engine.chat_async(
1049
+ messages=messages,
1050
+ max_new_tokens=max_new_tokens,
1051
+ temperature=temperature,
1052
+ **kwrs
1053
+ )
1054
+ )
1055
+ tasks.append(task)
1056
+
1057
+ # Wait until the batch is done, collect results and move on to next batch
1058
+ responses = await asyncio.gather(*tasks)
1059
+ # Collect initials
1060
+ for gen_text, sent, message in zip(responses, sentences, message_list):
1061
+ initials.append({'sentence_start': sent['start'],
1062
+ 'sentence_end': sent['end'],
1063
+ 'sentence_text': sent['sentence_text'],
1064
+ 'gen_text': gen_text,
1065
+ 'messages': message})
1066
+
1067
+ # Review
1068
+ reviews = []
1069
+ tasks = []
1070
+ for i in range(0, len(initials), concurrent_batch_size):
1071
+ batch = initials[i:i + concurrent_batch_size]
1072
+ for init in batch:
1073
+ messages = init["messages"]
1074
+ initial = init["gen_text"]
1075
+ messages.append({'role': 'assistant', 'content': initial})
1076
+ messages.append({'role': 'user', 'content': self.review_prompt})
1077
+ task = asyncio.create_task(
1078
+ self.inference_engine.chat_async(
1079
+ messages=messages,
1080
+ max_new_tokens=max_new_tokens,
1081
+ temperature=temperature,
1082
+ **kwrs
1083
+ )
1084
+ )
1085
+ tasks.append(task)
1086
+
1087
+ responses = await asyncio.gather(*tasks)
1088
+
1089
+ # Collect reviews
1090
+ for gen_text, sent in zip(responses, sentences):
1091
+ reviews.append({'sentence_start': sent['start'],
1092
+ 'sentence_end': sent['end'],
1093
+ 'sentence_text': sent['sentence_text'],
1094
+ 'gen_text': gen_text})
1095
+
1096
+ for init, rev in zip(initials, reviews):
1097
+ if self.review_mode == "revision":
1098
+ gen_text = rev['gen_text']
1099
+ elif self.review_mode == "addition":
1100
+ gen_text = init['gen_text'] + '\n' + rev['gen_text']
1101
+
1102
+ # add to output
1103
+ output.append({'sentence_start': init['sentence_start'],
1104
+ 'sentence_end': init['sentence_end'],
1105
+ 'sentence_text': init['sentence_text'],
1106
+ 'gen_text': gen_text})
1107
+ return output
888
1108
 
889
1109
 
890
1110
  class SentenceCoTFrameExtractor(SentenceFrameExtractor):
@@ -1124,19 +1344,34 @@ class BinaryRelationExtractor(RelationExtractor):
1124
1344
  self.possible_relation_func = possible_relation_func
1125
1345
 
1126
1346
 
1127
- def _extract_relation(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
1128
- text:str, buffer_size:int=100, max_new_tokens:int=128, temperature:float=0.0, stream:bool=False, **kwrs) -> bool:
1347
+ def _post_process(self, rel_json:str) -> bool:
1348
+ if len(rel_json) > 0:
1349
+ if "Relation" in rel_json[0]:
1350
+ rel = rel_json[0]["Relation"]
1351
+ if isinstance(rel, bool):
1352
+ return rel
1353
+ elif isinstance(rel, str) and rel in {"True", "False"}:
1354
+ return eval(rel)
1355
+ else:
1356
+ warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
1357
+ 'Following default, relation = False.', RuntimeWarning)
1358
+ else:
1359
+ warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
1360
+ else:
1361
+ warnings.warn('Extractor did not output a JSON list. Following default, relation = False.', RuntimeWarning)
1362
+ return False
1363
+
1364
+
1365
+ def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1366
+ temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1129
1367
  """
1130
- This method inputs two frames and a ROI text, extracts the binary relation.
1368
+ This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1369
+ Outputs pairs that are related.
1131
1370
 
1132
1371
  Parameters:
1133
1372
  -----------
1134
- frame_1 : LLMInformationExtractionFrame
1135
- a frame
1136
- frame_2 : LLMInformationExtractionFrame
1137
- the other frame
1138
- text : str
1139
- the entire document text
1373
+ doc : LLMInformationExtractionDocument
1374
+ a document with frames.
1140
1375
  buffer_size : int, Optional
1141
1376
  the number of characters before and after the two frames in the ROI text.
1142
1377
  max_new_tokens : str, Optional
@@ -1146,51 +1381,111 @@ class BinaryRelationExtractor(RelationExtractor):
1146
1381
  stream : bool, Optional
1147
1382
  if True, LLM generated text will be printed in terminal in real-time.
1148
1383
 
1149
- Return : bool
1150
- a relation indicator
1384
+ Return : List[Dict]
1385
+ a list of dict with {"frame_1_id", "frame_2_id"}.
1151
1386
  """
1152
- roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size=buffer_size)
1153
- if stream:
1154
- print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1155
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1156
-
1157
- messages = []
1158
- if self.system_prompt:
1159
- messages.append({'role': 'system', 'content': self.system_prompt})
1387
+ pairs = itertools.combinations(doc.frames, 2)
1388
+ output = []
1389
+ for frame_1, frame_2 in pairs:
1390
+ pos_rel = self.possible_relation_func(frame_1, frame_2)
1160
1391
 
1161
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1162
- "frame_1": str(frame_1.to_dict()),
1163
- "frame_2": str(frame_2.to_dict())}
1164
- )})
1165
- response = self.inference_engine.chat(
1166
- messages=messages,
1167
- max_new_tokens=max_new_tokens,
1168
- temperature=temperature,
1169
- stream=stream,
1170
- **kwrs
1171
- )
1172
-
1173
- rel_json = self._extract_json(response)
1174
- if len(rel_json) > 0:
1175
- if "Relation" in rel_json[0]:
1176
- rel = rel_json[0]["Relation"]
1177
- if isinstance(rel, bool):
1178
- return rel
1179
- elif isinstance(rel, str) and rel in {"True", "False"}:
1180
- return eval(rel)
1181
- else:
1182
- warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
1183
- 'Following default, relation = False.', RuntimeWarning)
1184
- else:
1185
- warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
1186
- else:
1187
- warnings.warn("Extractor did not output a JSON. Following default, relation = False.", RuntimeWarning)
1392
+ if pos_rel:
1393
+ roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1394
+ if stream:
1395
+ print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1396
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1397
+ messages = []
1398
+ if self.system_prompt:
1399
+ messages.append({'role': 'system', 'content': self.system_prompt})
1400
+
1401
+ messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1402
+ "frame_1": str(frame_1.to_dict()),
1403
+ "frame_2": str(frame_2.to_dict())}
1404
+ )})
1405
+
1406
+ gen_text = self.inference_engine.chat(
1407
+ messages=messages,
1408
+ max_new_tokens=max_new_tokens,
1409
+ temperature=temperature,
1410
+ stream=stream,
1411
+ **kwrs
1412
+ )
1413
+ rel_json = self._extract_json(gen_text)
1414
+ if self._post_process(rel_json):
1415
+ output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id})
1188
1416
 
1189
- return False
1417
+ return output
1418
+
1190
1419
 
1420
+ async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1421
+ temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
1422
+ """
1423
+ This is the asynchronous version of the extract() method.
1424
+
1425
+ Parameters:
1426
+ -----------
1427
+ doc : LLMInformationExtractionDocument
1428
+ a document with frames.
1429
+ buffer_size : int, Optional
1430
+ the number of characters before and after the two frames in the ROI text.
1431
+ max_new_tokens : str, Optional
1432
+ the max number of new tokens LLM should generate.
1433
+ temperature : float, Optional
1434
+ the temperature for token sampling.
1435
+ concurrent_batch_size : int, Optional
1436
+ the number of frame pairs to process in concurrent.
1437
+
1438
+ Return : List[Dict]
1439
+ a list of dict with {"frame_1", "frame_2"}.
1440
+ """
1441
+ # Check if self.inference_engine.chat_async() is implemented
1442
+ if not hasattr(self.inference_engine, 'chat_async'):
1443
+ raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1444
+
1445
+ pairs = itertools.combinations(doc.frames, 2)
1446
+ n_frames = len(doc.frames)
1447
+ num_pairs = (n_frames * (n_frames-1)) // 2
1448
+ rel_pair_list = []
1449
+ tasks = []
1450
+ for i in range(0, num_pairs, concurrent_batch_size):
1451
+ batch = list(itertools.islice(pairs, concurrent_batch_size))
1452
+ for frame_1, frame_2 in batch:
1453
+ pos_rel = self.possible_relation_func(frame_1, frame_2)
1191
1454
 
1455
+ if pos_rel:
1456
+ rel_pair_list.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
1457
+ roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1458
+ messages = []
1459
+ if self.system_prompt:
1460
+ messages.append({'role': 'system', 'content': self.system_prompt})
1461
+
1462
+ messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1463
+ "frame_1": str(frame_1.to_dict()),
1464
+ "frame_2": str(frame_2.to_dict())}
1465
+ )})
1466
+ task = asyncio.create_task(
1467
+ self.inference_engine.chat_async(
1468
+ messages=messages,
1469
+ max_new_tokens=max_new_tokens,
1470
+ temperature=temperature,
1471
+ **kwrs
1472
+ )
1473
+ )
1474
+ tasks.append(task)
1475
+
1476
+ responses = await asyncio.gather(*tasks)
1477
+
1478
+ output = []
1479
+ for d, response in zip(rel_pair_list, responses):
1480
+ rel_json = self._extract_json(response)
1481
+ if self._post_process(rel_json):
1482
+ output.append(d)
1483
+
1484
+ return output
1485
+
1486
+
1192
1487
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1193
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1488
+ temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
1194
1489
  """
1195
1490
  This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1196
1491
 
@@ -1204,6 +1499,10 @@ class BinaryRelationExtractor(RelationExtractor):
1204
1499
  the max number of new tokens LLM should generate.
1205
1500
  temperature : float, Optional
1206
1501
  the temperature for token sampling.
1502
+ concurrent: bool, Optional
1503
+ if True, the extraction will be done in concurrent.
1504
+ concurrent_batch_size : int, Optional
1505
+ the number of frame pairs to process in concurrent.
1207
1506
  stream : bool, Optional
1208
1507
  if True, LLM generated text will be printed in terminal in real-time.
1209
1508
 
@@ -1216,19 +1515,26 @@ class BinaryRelationExtractor(RelationExtractor):
1216
1515
  if doc.has_duplicate_frame_ids():
1217
1516
  raise ValueError("All frame_ids in the input document must be unique.")
1218
1517
 
1219
- pairs = itertools.combinations(doc.frames, 2)
1220
- rel_pair_list = []
1221
- for frame_1, frame_2 in pairs:
1222
- pos_rel = self.possible_relation_func(frame_1, frame_2)
1223
- if pos_rel:
1224
- rel = self._extract_relation(frame_1=frame_1, frame_2=frame_2, text=doc.text, buffer_size=buffer_size,
1225
- max_new_tokens=max_new_tokens, temperature=temperature, stream=stream, **kwrs)
1226
- if rel:
1227
- rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id})
1228
-
1229
- return rel_pair_list
1230
-
1231
-
1518
+ if concurrent:
1519
+ if stream:
1520
+ warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
1521
+
1522
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
1523
+ return asyncio.run(self.extract_async(doc=doc,
1524
+ buffer_size=buffer_size,
1525
+ max_new_tokens=max_new_tokens,
1526
+ temperature=temperature,
1527
+ concurrent_batch_size=concurrent_batch_size,
1528
+ **kwrs)
1529
+ )
1530
+ else:
1531
+ return self.extract(doc=doc,
1532
+ buffer_size=buffer_size,
1533
+ max_new_tokens=max_new_tokens,
1534
+ temperature=temperature,
1535
+ stream=stream,
1536
+ **kwrs)
1537
+
1232
1538
 
1233
1539
  class MultiClassRelationExtractor(RelationExtractor):
1234
1540
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_types_func: Callable,
@@ -1273,22 +1579,41 @@ class MultiClassRelationExtractor(RelationExtractor):
1273
1579
 
1274
1580
  self.possible_relation_types_func = possible_relation_types_func
1275
1581
 
1276
-
1277
- def _extract_relation(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
1278
- pos_rel_types:List[str], text:str, buffer_size:int=100, max_new_tokens:int=128, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
1582
+
1583
+ def _post_process(self, rel_json:List[Dict], pos_rel_types:List[str]) -> Union[str, None]:
1279
1584
  """
1280
- This method inputs two frames and a ROI text, extracts the relation.
1585
+ This method post-processes the extracted relation JSON.
1281
1586
 
1282
1587
  Parameters:
1283
1588
  -----------
1284
- frame_1 : LLMInformationExtractionFrame
1285
- a frame
1286
- frame_2 : LLMInformationExtractionFrame
1287
- the other frame
1589
+ rel_json : List[Dict]
1590
+ the extracted relation JSON.
1288
1591
  pos_rel_types : List[str]
1289
- possible relation types.
1290
- text : str
1291
- the entire document text
1592
+ possible relation types by the possible_relation_types_func.
1593
+
1594
+ Return : Union[str, None]
1595
+ the relation type (str) or None for no relation.
1596
+ """
1597
+ if len(rel_json) > 0:
1598
+ if "RelationType" in rel_json[0]:
1599
+ if rel_json[0]["RelationType"] in pos_rel_types:
1600
+ return rel_json[0]["RelationType"]
1601
+ else:
1602
+ warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
1603
+ else:
1604
+ warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
1605
+ return None
1606
+
1607
+
1608
+ def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1609
+ temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1610
+ """
1611
+ This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1612
+
1613
+ Parameters:
1614
+ -----------
1615
+ doc : LLMInformationExtractionDocument
1616
+ a document with frames.
1292
1617
  buffer_size : int, Optional
1293
1618
  the number of characters before and after the two frames in the ROI text.
1294
1619
  max_new_tokens : str, Optional
@@ -1298,54 +1623,117 @@ class MultiClassRelationExtractor(RelationExtractor):
1298
1623
  stream : bool, Optional
1299
1624
  if True, LLM generated text will be printed in terminal in real-time.
1300
1625
 
1301
- Return : str
1302
- a relation type
1626
+ Return : List[Dict]
1627
+ a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
1303
1628
  """
1304
- roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size=buffer_size)
1305
- if stream:
1306
- print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1307
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1629
+ pairs = itertools.combinations(doc.frames, 2)
1630
+ output = []
1631
+ for frame_1, frame_2 in pairs:
1632
+ pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1308
1633
 
1309
- messages = []
1310
- if self.system_prompt:
1311
- messages.append({'role': 'system', 'content': self.system_prompt})
1634
+ if pos_rel_types:
1635
+ roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1636
+ if stream:
1637
+ print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1638
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1639
+ messages = []
1640
+ if self.system_prompt:
1641
+ messages.append({'role': 'system', 'content': self.system_prompt})
1642
+
1643
+ messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1644
+ "frame_1": str(frame_1.to_dict()),
1645
+ "frame_2": str(frame_2.to_dict()),
1646
+ "pos_rel_types":str(pos_rel_types)}
1647
+ )})
1648
+
1649
+ gen_text = self.inference_engine.chat(
1650
+ messages=messages,
1651
+ max_new_tokens=max_new_tokens,
1652
+ temperature=temperature,
1653
+ stream=stream,
1654
+ **kwrs
1655
+ )
1656
+ rel_json = self._extract_json(gen_text)
1657
+ rel = self._post_process(rel_json, pos_rel_types)
1658
+ if rel:
1659
+ output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'relation':rel})
1312
1660
 
1313
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1314
- "frame_1": str(frame_1.to_dict()),
1315
- "frame_2": str(frame_2.to_dict()),
1316
- "pos_rel_types":str(pos_rel_types)})})
1317
- response = self.inference_engine.chat(
1318
- messages=messages,
1319
- max_new_tokens=max_new_tokens,
1320
- temperature=temperature,
1321
- stream=stream,
1322
- **kwrs
1323
- )
1324
-
1325
- rel_json = self._extract_json(response)
1326
- if len(rel_json) > 0:
1327
- if "RelationType" in rel_json[0]:
1328
- rel = rel_json[0]["RelationType"]
1329
- if rel in pos_rel_types or rel == "No Relation":
1330
- return rel_json[0]["RelationType"]
1331
- else:
1332
- warnings.warn(f'Extracted relation type "{rel}", which is not in the return of possible_relation_types_func: {pos_rel_types}.'+ \
1333
- 'Following default, relation = "No Relation".', RuntimeWarning)
1334
-
1335
- else:
1336
- warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
1661
+ return output
1662
+
1663
+
1664
+ async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1665
+ temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
1666
+ """
1667
+ This is the asynchronous version of the extract() method.
1668
+
1669
+ Parameters:
1670
+ -----------
1671
+ doc : LLMInformationExtractionDocument
1672
+ a document with frames.
1673
+ buffer_size : int, Optional
1674
+ the number of characters before and after the two frames in the ROI text.
1675
+ max_new_tokens : str, Optional
1676
+ the max number of new tokens LLM should generate.
1677
+ temperature : float, Optional
1678
+ the temperature for token sampling.
1679
+ concurrent_batch_size : int, Optional
1680
+ the number of frame pairs to process in concurrent.
1681
+
1682
+ Return : List[Dict]
1683
+ a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
1684
+ """
1685
+ # Check if self.inference_engine.chat_async() is implemented
1686
+ if not hasattr(self.inference_engine, 'chat_async'):
1687
+ raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1337
1688
 
1338
- else:
1339
- warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
1689
+ pairs = itertools.combinations(doc.frames, 2)
1690
+ n_frames = len(doc.frames)
1691
+ num_pairs = (n_frames * (n_frames-1)) // 2
1692
+ rel_pair_list = []
1693
+ tasks = []
1694
+ for i in range(0, num_pairs, concurrent_batch_size):
1695
+ batch = list(itertools.islice(pairs, concurrent_batch_size))
1696
+ for frame_1, frame_2 in batch:
1697
+ pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1698
+
1699
+ if pos_rel_types:
1700
+ rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'pos_rel_types':pos_rel_types})
1701
+ roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1702
+ messages = []
1703
+ if self.system_prompt:
1704
+ messages.append({'role': 'system', 'content': self.system_prompt})
1705
+
1706
+ messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1707
+ "frame_1": str(frame_1.to_dict()),
1708
+ "frame_2": str(frame_2.to_dict()),
1709
+ "pos_rel_types":str(pos_rel_types)}
1710
+ )})
1711
+ task = asyncio.create_task(
1712
+ self.inference_engine.chat_async(
1713
+ messages=messages,
1714
+ max_new_tokens=max_new_tokens,
1715
+ temperature=temperature,
1716
+ **kwrs
1717
+ )
1718
+ )
1719
+ tasks.append(task)
1720
+
1721
+ responses = await asyncio.gather(*tasks)
1722
+
1723
+ output = []
1724
+ for d, response in zip(rel_pair_list, responses):
1725
+ rel_json = self._extract_json(response)
1726
+ rel = self._post_process(rel_json, d['pos_rel_types'])
1727
+ if rel:
1728
+ output.append({'frame_1':d['frame_1'], 'frame_2':d['frame_2'], 'relation':rel})
1340
1729
 
1341
- return "No Relation"
1730
+ return output
1342
1731
 
1343
1732
 
1344
1733
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1345
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1734
+ temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
1346
1735
  """
1347
- This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs
1348
- and to provide possible relation types between two frames.
1736
+ This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1349
1737
 
1350
1738
  Parameters:
1351
1739
  -----------
@@ -1357,6 +1745,10 @@ class MultiClassRelationExtractor(RelationExtractor):
1357
1745
  the max number of new tokens LLM should generate.
1358
1746
  temperature : float, Optional
1359
1747
  the temperature for token sampling.
1748
+ concurrent: bool, Optional
1749
+ if True, the extraction will be done in concurrent.
1750
+ concurrent_batch_size : int, Optional
1751
+ the number of frame pairs to process in concurrent.
1360
1752
  stream : bool, Optional
1361
1753
  if True, LLM generated text will be printed in terminal in real-time.
1362
1754
 
@@ -1369,15 +1761,23 @@ class MultiClassRelationExtractor(RelationExtractor):
1369
1761
  if doc.has_duplicate_frame_ids():
1370
1762
  raise ValueError("All frame_ids in the input document must be unique.")
1371
1763
 
1372
- pairs = itertools.combinations(doc.frames, 2)
1373
- rel_pair_list = []
1374
- for frame_1, frame_2 in pairs:
1375
- pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1376
- if pos_rel_types:
1377
- rel = self._extract_relation(frame_1=frame_1, frame_2=frame_2, pos_rel_types=pos_rel_types, text=doc.text,
1378
- buffer_size=buffer_size, max_new_tokens=max_new_tokens, temperature=temperature, stream=stream, **kwrs)
1379
-
1380
- if rel != "No Relation":
1381
- rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, "relation":rel})
1382
-
1383
- return rel_pair_list
1764
+ if concurrent:
1765
+ if stream:
1766
+ warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
1767
+
1768
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
1769
+ return asyncio.run(self.extract_async(doc=doc,
1770
+ buffer_size=buffer_size,
1771
+ max_new_tokens=max_new_tokens,
1772
+ temperature=temperature,
1773
+ concurrent_batch_size=concurrent_batch_size,
1774
+ **kwrs)
1775
+ )
1776
+ else:
1777
+ return self.extract(doc=doc,
1778
+ buffer_size=buffer_size,
1779
+ max_new_tokens=max_new_tokens,
1780
+ temperature=temperature,
1781
+ stream=stream,
1782
+ **kwrs)
1783
+