llm-ie 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -1,11 +1,14 @@
1
1
  import abc
2
2
  import re
3
+ import copy
3
4
  import json
4
5
  import json_repair
5
6
  import inspect
6
7
  import importlib.resources
7
8
  import warnings
8
9
  import itertools
10
+ import asyncio
11
+ import nest_asyncio
9
12
  from typing import Set, List, Dict, Tuple, Union, Callable
10
13
  from llm_ie.data_types import LLMInformationExtractionFrame, LLMInformationExtractionDocument
11
14
  from llm_ie.engines import InferenceEngine
@@ -19,7 +22,7 @@ class Extractor:
19
22
  This is the abstract class for (frame and relation) extractors.
20
23
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
21
24
 
22
- Parameters
25
+ Parameters:
23
26
  ----------
24
27
  inference_engine : InferenceEngine
25
28
  the LLM inferencing engine object. Must implements the chat() method.
@@ -38,16 +41,20 @@ class Extractor:
38
41
  """
39
42
  This method returns the pre-defined prompt guideline for the extractor from the package asset.
40
43
  """
44
+ # Check if the prompt guide is available
41
45
  file_path = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{cls.__name__}_prompt_guide.txt")
42
- with open(file_path, 'r', encoding="utf-8") as f:
43
- return f.read()
44
-
46
+ try:
47
+ with open(file_path, 'r', encoding="utf-8") as f:
48
+ return f.read()
49
+ except FileNotFoundError:
50
+ warnings.warn(f"Prompt guide for {cls.__name__} is not available. Is it a customed extractor?", UserWarning)
51
+ return None
45
52
 
46
53
  def _get_user_prompt(self, text_content:Union[str, Dict[str,str]]) -> str:
47
54
  """
48
55
  This method applies text_content to prompt_template and returns a prompt.
49
56
 
50
- Parameters
57
+ Parameters:
51
58
  ----------
52
59
  text_content : Union[str, Dict[str,str]]
53
60
  the input text content to put in prompt template.
@@ -133,7 +140,7 @@ class FrameExtractor(Extractor):
133
140
  This is the abstract class for frame extraction.
134
141
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
135
142
 
136
- Parameters
143
+ Parameters:
137
144
  ----------
138
145
  inference_engine : InferenceEngine
139
146
  the LLM inferencing engine object. Must implements the chat() method.
@@ -175,7 +182,7 @@ class FrameExtractor(Extractor):
175
182
  the substring must start with the same word token as the pattern. This is due to the observation that
176
183
  LLM often generate the first few words consistently.
177
184
 
178
- Parameters
185
+ Parameters:
179
186
  ----------
180
187
  text : str
181
188
  the input text.
@@ -219,7 +226,7 @@ class FrameExtractor(Extractor):
219
226
  outputs a list of spans (2-tuple) for each entity.
220
227
  Entities that are not found in the text will be None from output.
221
228
 
222
- Parameters
229
+ Parameters:
223
230
  ----------
224
231
  text : str
225
232
  text that contains entities
@@ -241,7 +248,10 @@ class FrameExtractor(Extractor):
241
248
 
242
249
  # Match entities
243
250
  entity_spans = []
244
- for entity in entities:
251
+ for entity in entities:
252
+ if not isinstance(entity, str):
253
+ entity_spans.append(None)
254
+ continue
245
255
  if not case_sensitive:
246
256
  entity = entity.lower()
247
257
 
@@ -322,7 +332,7 @@ class BasicFrameExtractor(FrameExtractor):
322
332
  Input system prompt (optional), prompt template (with instruction, few-shot examples),
323
333
  and specify a LLM.
324
334
 
325
- Parameters
335
+ Parameters:
326
336
  ----------
327
337
  inference_engine : InferenceEngine
328
338
  the LLM inferencing engine object. Must implements the chat() method.
@@ -555,18 +565,18 @@ class SentenceFrameExtractor(FrameExtractor):
555
565
  from nltk.tokenize.punkt import PunktSentenceTokenizer
556
566
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
557
567
  """
558
- This class performs sentence-based information extraction.
559
- A simulated chat follows this process:
568
+ This class performs sentence-by-sentence information extraction.
569
+ The process is as follows:
560
570
  1. system prompt (optional)
561
- 2. user instructions (schema, background, full text, few-shot example...)
562
- 3. user input first sentence
563
- 4. assistant extract outputs
571
+ 2. user prompt with instructions (schema, background, full text, few-shot example...)
572
+ 3. feed a sentence (start with first sentence)
573
+ 4. LLM extract entities and attributes from the sentence
564
574
  5. repeat #3 and #4
565
575
 
566
576
  Input system prompt (optional), prompt template (with user instructions),
567
577
  and specify a LLM.
568
578
 
569
- Parameters
579
+ Parameters:
570
580
  ----------
571
581
  inference_engine : InferenceEngine
572
582
  the LLM inferencing engine object. Must implements the chat() method.
@@ -583,7 +593,7 @@ class SentenceFrameExtractor(FrameExtractor):
583
593
  This method sentence tokenize the input text into a list of sentences
584
594
  as dict of {start, end, sentence_text}
585
595
 
586
- Parameters
596
+ Parameters:
587
597
  ----------
588
598
  text : str
589
599
  text to sentence tokenize.
@@ -674,10 +684,80 @@ class SentenceFrameExtractor(FrameExtractor):
674
684
  return output
675
685
 
676
686
 
687
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
688
+ document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
689
+ """
690
+ The asynchronous version of the extract() method.
691
+
692
+ Parameters:
693
+ ----------
694
+ text_content : Union[str, Dict[str,str]]
695
+ the input text content to put in prompt template.
696
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
697
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
698
+ max_new_tokens : str, Optional
699
+ the max number of new tokens LLM should generate.
700
+ document_key : str, Optional
701
+ specify the key in text_content where document text is.
702
+ If text_content is str, this parameter will be ignored.
703
+ temperature : float, Optional
704
+ the temperature for token sampling.
705
+ concurrent_batch_size : int, Optional
706
+ the number of sentences to process in concurrent.
707
+ """
708
+ # Check if self.inference_engine.chat_async() is implemented
709
+ if not hasattr(self.inference_engine, 'chat_async'):
710
+ raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
711
+
712
+ # define output
713
+ output = []
714
+ # sentence tokenization
715
+ if isinstance(text_content, str):
716
+ sentences = self._get_sentences(text_content)
717
+ elif isinstance(text_content, dict):
718
+ sentences = self._get_sentences(text_content[document_key])
719
+ # construct chat messages
720
+ base_messages = []
721
+ if self.system_prompt:
722
+ base_messages.append({'role': 'system', 'content': self.system_prompt})
723
+
724
+ base_messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
725
+ base_messages.append({'role': 'assistant', 'content': 'Sure, please start with the first sentence.'})
726
+
727
+ # generate sentence by sentence
728
+ tasks = []
729
+ for i in range(0, len(sentences), concurrent_batch_size):
730
+ batch = sentences[i:i + concurrent_batch_size]
731
+ for sent in batch:
732
+ messages = copy.deepcopy(base_messages)
733
+ messages.append({'role': 'user', 'content': sent['sentence_text']})
734
+ task = asyncio.create_task(
735
+ self.inference_engine.chat_async(
736
+ messages=messages,
737
+ max_new_tokens=max_new_tokens,
738
+ temperature=temperature,
739
+ **kwrs
740
+ )
741
+ )
742
+ tasks.append(task)
743
+
744
+ # Wait until the batch is done, collect results and move on to next batch
745
+ responses = await asyncio.gather(*tasks)
746
+
747
+ # Collect outputs
748
+ for gen_text, sent in zip(responses, sentences):
749
+ output.append({'sentence_start': sent['start'],
750
+ 'sentence_end': sent['end'],
751
+ 'sentence_text': sent['sentence_text'],
752
+ 'gen_text': gen_text})
753
+ return output
754
+
755
+
677
756
  def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=512,
678
- document_key:str=None, multi_turn:bool=False, temperature:float=0.0, stream:bool=False,
679
- case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
680
- **kwrs) -> List[LLMInformationExtractionFrame]:
757
+ document_key:str=None, multi_turn:bool=False, temperature:float=0.0, stream:bool=False,
758
+ concurrent:bool=False, concurrent_batch_size:int=32,
759
+ case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
760
+ **kwrs) -> List[LLMInformationExtractionFrame]:
681
761
  """
682
762
  This method inputs a text and outputs a list of LLMInformationExtractionFrame
683
763
  It use the extract() method and post-process outputs into frames.
@@ -705,6 +785,10 @@ class SentenceFrameExtractor(FrameExtractor):
705
785
  the temperature for token sampling.
706
786
  stream : bool, Optional
707
787
  if True, LLM generated text will be printed in terminal in real-time.
788
+ concurrent : bool, Optional
789
+ if True, the sentences will be extracted in concurrent.
790
+ concurrent_batch_size : int, Optional
791
+ the number of sentences to process in concurrent. Only used when `concurrent` is True.
708
792
  case_sensitive : bool, Optional
709
793
  if True, entity text matching will be case-sensitive.
710
794
  fuzzy_match : bool, Optional
@@ -718,15 +802,30 @@ class SentenceFrameExtractor(FrameExtractor):
718
802
  Return : str
719
803
  a list of frames.
720
804
  """
721
- llm_output_sentence = self.extract(text_content=text_content,
722
- max_new_tokens=max_new_tokens,
723
- document_key=document_key,
724
- multi_turn=multi_turn,
725
- temperature=temperature,
726
- stream=stream,
727
- **kwrs)
805
+ if concurrent:
806
+ if stream:
807
+ warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
808
+ if multi_turn:
809
+ warnings.warn("multi_turn=True is not supported in concurrent mode.", RuntimeWarning)
810
+
811
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
812
+ llm_output_sentences = asyncio.run(self.extract_async(text_content=text_content,
813
+ max_new_tokens=max_new_tokens,
814
+ document_key=document_key,
815
+ temperature=temperature,
816
+ concurrent_batch_size=concurrent_batch_size,
817
+ **kwrs)
818
+ )
819
+ else:
820
+ llm_output_sentences = self.extract(text_content=text_content,
821
+ max_new_tokens=max_new_tokens,
822
+ document_key=document_key,
823
+ multi_turn=multi_turn,
824
+ temperature=temperature,
825
+ stream=stream,
826
+ **kwrs)
728
827
  frame_list = []
729
- for sent in llm_output_sentence:
828
+ for sent in llm_output_sentences:
730
829
  entity_json = []
731
830
  for entity in self._extract_json(gen_text=sent['gen_text']):
732
831
  if entity_key in entity:
@@ -891,6 +990,121 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
891
990
  'sentence_text': sent['sentence_text'],
892
991
  'gen_text': gen_text})
893
992
  return output
993
+
994
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
995
+ document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
996
+ """
997
+ The asynchronous version of the extract() method.
998
+
999
+ Parameters:
1000
+ ----------
1001
+ text_content : Union[str, Dict[str,str]]
1002
+ the input text content to put in prompt template.
1003
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1004
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1005
+ max_new_tokens : str, Optional
1006
+ the max number of new tokens LLM should generate.
1007
+ document_key : str, Optional
1008
+ specify the key in text_content where document text is.
1009
+ If text_content is str, this parameter will be ignored.
1010
+ temperature : float, Optional
1011
+ the temperature for token sampling.
1012
+ concurrent_batch_size : int, Optional
1013
+ the number of sentences to process in concurrent.
1014
+
1015
+ Return : str
1016
+ the output from LLM. Need post-processing.
1017
+ """
1018
+ # Check if self.inference_engine.chat_async() is implemented
1019
+ if not hasattr(self.inference_engine, 'chat_async'):
1020
+ raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1021
+
1022
+ # define output
1023
+ output = []
1024
+ # sentence tokenization
1025
+ if isinstance(text_content, str):
1026
+ sentences = self._get_sentences(text_content)
1027
+ elif isinstance(text_content, dict):
1028
+ sentences = self._get_sentences(text_content[document_key])
1029
+ # construct chat messages
1030
+ base_messages = []
1031
+ if self.system_prompt:
1032
+ base_messages.append({'role': 'system', 'content': self.system_prompt})
1033
+
1034
+ base_messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
1035
+ base_messages.append({'role': 'assistant', 'content': 'Sure, please start with the first sentence.'})
1036
+
1037
+ # generate initial outputs sentence by sentence
1038
+ initials = []
1039
+ tasks = []
1040
+ message_list = []
1041
+ for i in range(0, len(sentences), concurrent_batch_size):
1042
+ batch = sentences[i:i + concurrent_batch_size]
1043
+ for sent in batch:
1044
+ messages = copy.deepcopy(base_messages)
1045
+ messages.append({'role': 'user', 'content': sent['sentence_text']})
1046
+ message_list.append(messages)
1047
+ task = asyncio.create_task(
1048
+ self.inference_engine.chat_async(
1049
+ messages=messages,
1050
+ max_new_tokens=max_new_tokens,
1051
+ temperature=temperature,
1052
+ **kwrs
1053
+ )
1054
+ )
1055
+ tasks.append(task)
1056
+
1057
+ # Wait until the batch is done, collect results and move on to next batch
1058
+ responses = await asyncio.gather(*tasks)
1059
+ # Collect initials
1060
+ for gen_text, sent, message in zip(responses, sentences, message_list):
1061
+ initials.append({'sentence_start': sent['start'],
1062
+ 'sentence_end': sent['end'],
1063
+ 'sentence_text': sent['sentence_text'],
1064
+ 'gen_text': gen_text,
1065
+ 'messages': message})
1066
+
1067
+ # Review
1068
+ reviews = []
1069
+ tasks = []
1070
+ for i in range(0, len(initials), concurrent_batch_size):
1071
+ batch = initials[i:i + concurrent_batch_size]
1072
+ for init in batch:
1073
+ messages = init["messages"]
1074
+ initial = init["gen_text"]
1075
+ messages.append({'role': 'assistant', 'content': initial})
1076
+ messages.append({'role': 'user', 'content': self.review_prompt})
1077
+ task = asyncio.create_task(
1078
+ self.inference_engine.chat_async(
1079
+ messages=messages,
1080
+ max_new_tokens=max_new_tokens,
1081
+ temperature=temperature,
1082
+ **kwrs
1083
+ )
1084
+ )
1085
+ tasks.append(task)
1086
+
1087
+ responses = await asyncio.gather(*tasks)
1088
+
1089
+ # Collect reviews
1090
+ for gen_text, sent in zip(responses, sentences):
1091
+ reviews.append({'sentence_start': sent['start'],
1092
+ 'sentence_end': sent['end'],
1093
+ 'sentence_text': sent['sentence_text'],
1094
+ 'gen_text': gen_text})
1095
+
1096
+ for init, rev in zip(initials, reviews):
1097
+ if self.review_mode == "revision":
1098
+ gen_text = rev['gen_text']
1099
+ elif self.review_mode == "addition":
1100
+ gen_text = init['gen_text'] + '\n' + rev['gen_text']
1101
+
1102
+ # add to output
1103
+ output.append({'sentence_start': init['sentence_start'],
1104
+ 'sentence_end': init['sentence_end'],
1105
+ 'sentence_text': init['sentence_text'],
1106
+ 'gen_text': gen_text})
1107
+ return output
894
1108
 
895
1109
 
896
1110
  class SentenceCoTFrameExtractor(SentenceFrameExtractor):
@@ -1130,19 +1344,34 @@ class BinaryRelationExtractor(RelationExtractor):
1130
1344
  self.possible_relation_func = possible_relation_func
1131
1345
 
1132
1346
 
1133
- def _extract_relation(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
1134
- text:str, buffer_size:int=100, max_new_tokens:int=128, temperature:float=0.0, stream:bool=False, **kwrs) -> bool:
1347
+ def _post_process(self, rel_json:str) -> bool:
1348
+ if len(rel_json) > 0:
1349
+ if "Relation" in rel_json[0]:
1350
+ rel = rel_json[0]["Relation"]
1351
+ if isinstance(rel, bool):
1352
+ return rel
1353
+ elif isinstance(rel, str) and rel in {"True", "False"}:
1354
+ return eval(rel)
1355
+ else:
1356
+ warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
1357
+ 'Following default, relation = False.', RuntimeWarning)
1358
+ else:
1359
+ warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
1360
+ else:
1361
+ warnings.warn('Extractor did not output a JSON list. Following default, relation = False.', RuntimeWarning)
1362
+ return False
1363
+
1364
+
1365
+ def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1366
+ temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1135
1367
  """
1136
- This method inputs two frames and a ROI text, extracts the binary relation.
1368
+ This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1369
+ Outputs pairs that are related.
1137
1370
 
1138
1371
  Parameters:
1139
1372
  -----------
1140
- frame_1 : LLMInformationExtractionFrame
1141
- a frame
1142
- frame_2 : LLMInformationExtractionFrame
1143
- the other frame
1144
- text : str
1145
- the entire document text
1373
+ doc : LLMInformationExtractionDocument
1374
+ a document with frames.
1146
1375
  buffer_size : int, Optional
1147
1376
  the number of characters before and after the two frames in the ROI text.
1148
1377
  max_new_tokens : str, Optional
@@ -1152,51 +1381,111 @@ class BinaryRelationExtractor(RelationExtractor):
1152
1381
  stream : bool, Optional
1153
1382
  if True, LLM generated text will be printed in terminal in real-time.
1154
1383
 
1155
- Return : bool
1156
- a relation indicator
1384
+ Return : List[Dict]
1385
+ a list of dict with {"frame_1_id", "frame_2_id"}.
1157
1386
  """
1158
- roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size=buffer_size)
1159
- if stream:
1160
- print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1161
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1162
-
1163
- messages = []
1164
- if self.system_prompt:
1165
- messages.append({'role': 'system', 'content': self.system_prompt})
1387
+ pairs = itertools.combinations(doc.frames, 2)
1388
+ output = []
1389
+ for frame_1, frame_2 in pairs:
1390
+ pos_rel = self.possible_relation_func(frame_1, frame_2)
1166
1391
 
1167
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1168
- "frame_1": str(frame_1.to_dict()),
1169
- "frame_2": str(frame_2.to_dict())}
1170
- )})
1171
- response = self.inference_engine.chat(
1172
- messages=messages,
1173
- max_new_tokens=max_new_tokens,
1174
- temperature=temperature,
1175
- stream=stream,
1176
- **kwrs
1177
- )
1178
-
1179
- rel_json = self._extract_json(response)
1180
- if len(rel_json) > 0:
1181
- if "Relation" in rel_json[0]:
1182
- rel = rel_json[0]["Relation"]
1183
- if isinstance(rel, bool):
1184
- return rel
1185
- elif isinstance(rel, str) and rel in {"True", "False"}:
1186
- return eval(rel)
1187
- else:
1188
- warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
1189
- 'Following default, relation = False.', RuntimeWarning)
1190
- else:
1191
- warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
1192
- else:
1193
- warnings.warn("Extractor did not output a JSON. Following default, relation = False.", RuntimeWarning)
1392
+ if pos_rel:
1393
+ roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1394
+ if stream:
1395
+ print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1396
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1397
+ messages = []
1398
+ if self.system_prompt:
1399
+ messages.append({'role': 'system', 'content': self.system_prompt})
1400
+
1401
+ messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1402
+ "frame_1": str(frame_1.to_dict()),
1403
+ "frame_2": str(frame_2.to_dict())}
1404
+ )})
1405
+
1406
+ gen_text = self.inference_engine.chat(
1407
+ messages=messages,
1408
+ max_new_tokens=max_new_tokens,
1409
+ temperature=temperature,
1410
+ stream=stream,
1411
+ **kwrs
1412
+ )
1413
+ rel_json = self._extract_json(gen_text)
1414
+ if self._post_process(rel_json):
1415
+ output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id})
1194
1416
 
1195
- return False
1417
+ return output
1418
+
1196
1419
 
1420
+ async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1421
+ temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
1422
+ """
1423
+ This is the asynchronous version of the extract() method.
1424
+
1425
+ Parameters:
1426
+ -----------
1427
+ doc : LLMInformationExtractionDocument
1428
+ a document with frames.
1429
+ buffer_size : int, Optional
1430
+ the number of characters before and after the two frames in the ROI text.
1431
+ max_new_tokens : str, Optional
1432
+ the max number of new tokens LLM should generate.
1433
+ temperature : float, Optional
1434
+ the temperature for token sampling.
1435
+ concurrent_batch_size : int, Optional
1436
+ the number of frame pairs to process in concurrent.
1437
+
1438
+ Return : List[Dict]
1439
+ a list of dict with {"frame_1", "frame_2"}.
1440
+ """
1441
+ # Check if self.inference_engine.chat_async() is implemented
1442
+ if not hasattr(self.inference_engine, 'chat_async'):
1443
+ raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1444
+
1445
+ pairs = itertools.combinations(doc.frames, 2)
1446
+ n_frames = len(doc.frames)
1447
+ num_pairs = (n_frames * (n_frames-1)) // 2
1448
+ rel_pair_list = []
1449
+ tasks = []
1450
+ for i in range(0, num_pairs, concurrent_batch_size):
1451
+ batch = list(itertools.islice(pairs, concurrent_batch_size))
1452
+ for frame_1, frame_2 in batch:
1453
+ pos_rel = self.possible_relation_func(frame_1, frame_2)
1197
1454
 
1455
+ if pos_rel:
1456
+ rel_pair_list.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
1457
+ roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1458
+ messages = []
1459
+ if self.system_prompt:
1460
+ messages.append({'role': 'system', 'content': self.system_prompt})
1461
+
1462
+ messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1463
+ "frame_1": str(frame_1.to_dict()),
1464
+ "frame_2": str(frame_2.to_dict())}
1465
+ )})
1466
+ task = asyncio.create_task(
1467
+ self.inference_engine.chat_async(
1468
+ messages=messages,
1469
+ max_new_tokens=max_new_tokens,
1470
+ temperature=temperature,
1471
+ **kwrs
1472
+ )
1473
+ )
1474
+ tasks.append(task)
1475
+
1476
+ responses = await asyncio.gather(*tasks)
1477
+
1478
+ output = []
1479
+ for d, response in zip(rel_pair_list, responses):
1480
+ rel_json = self._extract_json(response)
1481
+ if self._post_process(rel_json):
1482
+ output.append(d)
1483
+
1484
+ return output
1485
+
1486
+
1198
1487
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1199
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1488
+ temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
1200
1489
  """
1201
1490
  This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1202
1491
 
@@ -1210,6 +1499,10 @@ class BinaryRelationExtractor(RelationExtractor):
1210
1499
  the max number of new tokens LLM should generate.
1211
1500
  temperature : float, Optional
1212
1501
  the temperature for token sampling.
1502
+ concurrent: bool, Optional
1503
+ if True, the extraction will be done in concurrent.
1504
+ concurrent_batch_size : int, Optional
1505
+ the number of frame pairs to process in concurrent.
1213
1506
  stream : bool, Optional
1214
1507
  if True, LLM generated text will be printed in terminal in real-time.
1215
1508
 
@@ -1222,19 +1515,26 @@ class BinaryRelationExtractor(RelationExtractor):
1222
1515
  if doc.has_duplicate_frame_ids():
1223
1516
  raise ValueError("All frame_ids in the input document must be unique.")
1224
1517
 
1225
- pairs = itertools.combinations(doc.frames, 2)
1226
- rel_pair_list = []
1227
- for frame_1, frame_2 in pairs:
1228
- pos_rel = self.possible_relation_func(frame_1, frame_2)
1229
- if pos_rel:
1230
- rel = self._extract_relation(frame_1=frame_1, frame_2=frame_2, text=doc.text, buffer_size=buffer_size,
1231
- max_new_tokens=max_new_tokens, temperature=temperature, stream=stream, **kwrs)
1232
- if rel:
1233
- rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id})
1234
-
1235
- return rel_pair_list
1236
-
1237
-
1518
+ if concurrent:
1519
+ if stream:
1520
+ warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
1521
+
1522
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
1523
+ return asyncio.run(self.extract_async(doc=doc,
1524
+ buffer_size=buffer_size,
1525
+ max_new_tokens=max_new_tokens,
1526
+ temperature=temperature,
1527
+ concurrent_batch_size=concurrent_batch_size,
1528
+ **kwrs)
1529
+ )
1530
+ else:
1531
+ return self.extract(doc=doc,
1532
+ buffer_size=buffer_size,
1533
+ max_new_tokens=max_new_tokens,
1534
+ temperature=temperature,
1535
+ stream=stream,
1536
+ **kwrs)
1537
+
1238
1538
 
1239
1539
  class MultiClassRelationExtractor(RelationExtractor):
1240
1540
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_types_func: Callable,
@@ -1279,22 +1579,41 @@ class MultiClassRelationExtractor(RelationExtractor):
1279
1579
 
1280
1580
  self.possible_relation_types_func = possible_relation_types_func
1281
1581
 
1282
-
1283
- def _extract_relation(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
1284
- pos_rel_types:List[str], text:str, buffer_size:int=100, max_new_tokens:int=128, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
1582
+
1583
+ def _post_process(self, rel_json:List[Dict], pos_rel_types:List[str]) -> Union[str, None]:
1285
1584
  """
1286
- This method inputs two frames and a ROI text, extracts the relation.
1585
+ This method post-processes the extracted relation JSON.
1287
1586
 
1288
1587
  Parameters:
1289
1588
  -----------
1290
- frame_1 : LLMInformationExtractionFrame
1291
- a frame
1292
- frame_2 : LLMInformationExtractionFrame
1293
- the other frame
1589
+ rel_json : List[Dict]
1590
+ the extracted relation JSON.
1294
1591
  pos_rel_types : List[str]
1295
- possible relation types.
1296
- text : str
1297
- the entire document text
1592
+ possible relation types by the possible_relation_types_func.
1593
+
1594
+ Return : Union[str, None]
1595
+ the relation type (str) or None for no relation.
1596
+ """
1597
+ if len(rel_json) > 0:
1598
+ if "RelationType" in rel_json[0]:
1599
+ if rel_json[0]["RelationType"] in pos_rel_types:
1600
+ return rel_json[0]["RelationType"]
1601
+ else:
1602
+ warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
1603
+ else:
1604
+ warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
1605
+ return None
1606
+
1607
+
1608
+ def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1609
+ temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1610
+ """
1611
+ This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1612
+
1613
+ Parameters:
1614
+ -----------
1615
+ doc : LLMInformationExtractionDocument
1616
+ a document with frames.
1298
1617
  buffer_size : int, Optional
1299
1618
  the number of characters before and after the two frames in the ROI text.
1300
1619
  max_new_tokens : str, Optional
@@ -1304,54 +1623,117 @@ class MultiClassRelationExtractor(RelationExtractor):
1304
1623
  stream : bool, Optional
1305
1624
  if True, LLM generated text will be printed in terminal in real-time.
1306
1625
 
1307
- Return : str
1308
- a relation type
1626
+ Return : List[Dict]
1627
+ a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
1309
1628
  """
1310
- roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size=buffer_size)
1311
- if stream:
1312
- print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1313
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1629
+ pairs = itertools.combinations(doc.frames, 2)
1630
+ output = []
1631
+ for frame_1, frame_2 in pairs:
1632
+ pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1314
1633
 
1315
- messages = []
1316
- if self.system_prompt:
1317
- messages.append({'role': 'system', 'content': self.system_prompt})
1634
+ if pos_rel_types:
1635
+ roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1636
+ if stream:
1637
+ print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1638
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1639
+ messages = []
1640
+ if self.system_prompt:
1641
+ messages.append({'role': 'system', 'content': self.system_prompt})
1642
+
1643
+ messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1644
+ "frame_1": str(frame_1.to_dict()),
1645
+ "frame_2": str(frame_2.to_dict()),
1646
+ "pos_rel_types":str(pos_rel_types)}
1647
+ )})
1648
+
1649
+ gen_text = self.inference_engine.chat(
1650
+ messages=messages,
1651
+ max_new_tokens=max_new_tokens,
1652
+ temperature=temperature,
1653
+ stream=stream,
1654
+ **kwrs
1655
+ )
1656
+ rel_json = self._extract_json(gen_text)
1657
+ rel = self._post_process(rel_json, pos_rel_types)
1658
+ if rel:
1659
+ output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'relation':rel})
1318
1660
 
1319
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1320
- "frame_1": str(frame_1.to_dict()),
1321
- "frame_2": str(frame_2.to_dict()),
1322
- "pos_rel_types":str(pos_rel_types)})})
1323
- response = self.inference_engine.chat(
1324
- messages=messages,
1325
- max_new_tokens=max_new_tokens,
1326
- temperature=temperature,
1327
- stream=stream,
1328
- **kwrs
1329
- )
1330
-
1331
- rel_json = self._extract_json(response)
1332
- if len(rel_json) > 0:
1333
- if "RelationType" in rel_json[0]:
1334
- rel = rel_json[0]["RelationType"]
1335
- if rel in pos_rel_types or rel == "No Relation":
1336
- return rel_json[0]["RelationType"]
1337
- else:
1338
- warnings.warn(f'Extracted relation type "{rel}", which is not in the return of possible_relation_types_func: {pos_rel_types}.'+ \
1339
- 'Following default, relation = "No Relation".', RuntimeWarning)
1340
-
1341
- else:
1342
- warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
1661
+ return output
1662
+
1663
+
1664
+ async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1665
+ temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
1666
+ """
1667
+ This is the asynchronous version of the extract() method.
1668
+
1669
+ Parameters:
1670
+ -----------
1671
+ doc : LLMInformationExtractionDocument
1672
+ a document with frames.
1673
+ buffer_size : int, Optional
1674
+ the number of characters before and after the two frames in the ROI text.
1675
+ max_new_tokens : str, Optional
1676
+ the max number of new tokens LLM should generate.
1677
+ temperature : float, Optional
1678
+ the temperature for token sampling.
1679
+ concurrent_batch_size : int, Optional
1680
+ the number of frame pairs to process in concurrent.
1681
+
1682
+ Return : List[Dict]
1683
+ a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
1684
+ """
1685
+ # Check if self.inference_engine.chat_async() is implemented
1686
+ if not hasattr(self.inference_engine, 'chat_async'):
1687
+ raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1343
1688
 
1344
- else:
1345
- warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
1689
+ pairs = itertools.combinations(doc.frames, 2)
1690
+ n_frames = len(doc.frames)
1691
+ num_pairs = (n_frames * (n_frames-1)) // 2
1692
+ rel_pair_list = []
1693
+ tasks = []
1694
+ for i in range(0, num_pairs, concurrent_batch_size):
1695
+ batch = list(itertools.islice(pairs, concurrent_batch_size))
1696
+ for frame_1, frame_2 in batch:
1697
+ pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1698
+
1699
+ if pos_rel_types:
1700
+ rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'pos_rel_types':pos_rel_types})
1701
+ roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1702
+ messages = []
1703
+ if self.system_prompt:
1704
+ messages.append({'role': 'system', 'content': self.system_prompt})
1705
+
1706
+ messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
1707
+ "frame_1": str(frame_1.to_dict()),
1708
+ "frame_2": str(frame_2.to_dict()),
1709
+ "pos_rel_types":str(pos_rel_types)}
1710
+ )})
1711
+ task = asyncio.create_task(
1712
+ self.inference_engine.chat_async(
1713
+ messages=messages,
1714
+ max_new_tokens=max_new_tokens,
1715
+ temperature=temperature,
1716
+ **kwrs
1717
+ )
1718
+ )
1719
+ tasks.append(task)
1720
+
1721
+ responses = await asyncio.gather(*tasks)
1722
+
1723
+ output = []
1724
+ for d, response in zip(rel_pair_list, responses):
1725
+ rel_json = self._extract_json(response)
1726
+ rel = self._post_process(rel_json, d['pos_rel_types'])
1727
+ if rel:
1728
+ output.append({'frame_1':d['frame_1'], 'frame_2':d['frame_2'], 'relation':rel})
1346
1729
 
1347
- return "No Relation"
1730
+ return output
1348
1731
 
1349
1732
 
1350
1733
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1351
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1734
+ temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
1352
1735
  """
1353
- This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs
1354
- and to provide possible relation types between two frames.
1736
+ This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1355
1737
 
1356
1738
  Parameters:
1357
1739
  -----------
@@ -1363,6 +1745,10 @@ class MultiClassRelationExtractor(RelationExtractor):
1363
1745
  the max number of new tokens LLM should generate.
1364
1746
  temperature : float, Optional
1365
1747
  the temperature for token sampling.
1748
+ concurrent: bool, Optional
1749
+ if True, the extraction will be done in concurrent.
1750
+ concurrent_batch_size : int, Optional
1751
+ the number of frame pairs to process in concurrent.
1366
1752
  stream : bool, Optional
1367
1753
  if True, LLM generated text will be printed in terminal in real-time.
1368
1754
 
@@ -1375,15 +1761,23 @@ class MultiClassRelationExtractor(RelationExtractor):
1375
1761
  if doc.has_duplicate_frame_ids():
1376
1762
  raise ValueError("All frame_ids in the input document must be unique.")
1377
1763
 
1378
- pairs = itertools.combinations(doc.frames, 2)
1379
- rel_pair_list = []
1380
- for frame_1, frame_2 in pairs:
1381
- pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1382
- if pos_rel_types:
1383
- rel = self._extract_relation(frame_1=frame_1, frame_2=frame_2, pos_rel_types=pos_rel_types, text=doc.text,
1384
- buffer_size=buffer_size, max_new_tokens=max_new_tokens, temperature=temperature, stream=stream, **kwrs)
1385
-
1386
- if rel != "No Relation":
1387
- rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, "relation":rel})
1388
-
1389
- return rel_pair_list
1764
+ if concurrent:
1765
+ if stream:
1766
+ warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
1767
+
1768
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
1769
+ return asyncio.run(self.extract_async(doc=doc,
1770
+ buffer_size=buffer_size,
1771
+ max_new_tokens=max_new_tokens,
1772
+ temperature=temperature,
1773
+ concurrent_batch_size=concurrent_batch_size,
1774
+ **kwrs)
1775
+ )
1776
+ else:
1777
+ return self.extract(doc=doc,
1778
+ buffer_size=buffer_size,
1779
+ max_new_tokens=max_new_tokens,
1780
+ temperature=temperature,
1781
+ stream=stream,
1782
+ **kwrs)
1783
+