llm-ie 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +9 -0
- llm_ie/engines.py +151 -9
- llm_ie/extractors.py +545 -151
- llm_ie/prompt_editor.py +17 -2
- {llm_ie-0.3.5.dist-info → llm_ie-0.4.0.dist-info}/METADATA +341 -103
- {llm_ie-0.3.5.dist-info → llm_ie-0.4.0.dist-info}/RECORD +7 -7
- {llm_ie-0.3.5.dist-info → llm_ie-0.4.0.dist-info}/WHEEL +0 -0
llm_ie/extractors.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import re
|
|
3
|
+
import copy
|
|
3
4
|
import json
|
|
4
5
|
import json_repair
|
|
5
6
|
import inspect
|
|
6
7
|
import importlib.resources
|
|
7
8
|
import warnings
|
|
8
9
|
import itertools
|
|
10
|
+
import asyncio
|
|
11
|
+
import nest_asyncio
|
|
9
12
|
from typing import Set, List, Dict, Tuple, Union, Callable
|
|
10
13
|
from llm_ie.data_types import LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
11
14
|
from llm_ie.engines import InferenceEngine
|
|
@@ -19,7 +22,7 @@ class Extractor:
|
|
|
19
22
|
This is the abstract class for (frame and relation) extractors.
|
|
20
23
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
21
24
|
|
|
22
|
-
Parameters
|
|
25
|
+
Parameters:
|
|
23
26
|
----------
|
|
24
27
|
inference_engine : InferenceEngine
|
|
25
28
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -38,16 +41,20 @@ class Extractor:
|
|
|
38
41
|
"""
|
|
39
42
|
This method returns the pre-defined prompt guideline for the extractor from the package asset.
|
|
40
43
|
"""
|
|
44
|
+
# Check if the prompt guide is available
|
|
41
45
|
file_path = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{cls.__name__}_prompt_guide.txt")
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
46
|
+
try:
|
|
47
|
+
with open(file_path, 'r', encoding="utf-8") as f:
|
|
48
|
+
return f.read()
|
|
49
|
+
except FileNotFoundError:
|
|
50
|
+
warnings.warn(f"Prompt guide for {cls.__name__} is not available. Is it a customed extractor?", UserWarning)
|
|
51
|
+
return None
|
|
45
52
|
|
|
46
53
|
def _get_user_prompt(self, text_content:Union[str, Dict[str,str]]) -> str:
|
|
47
54
|
"""
|
|
48
55
|
This method applies text_content to prompt_template and returns a prompt.
|
|
49
56
|
|
|
50
|
-
Parameters
|
|
57
|
+
Parameters:
|
|
51
58
|
----------
|
|
52
59
|
text_content : Union[str, Dict[str,str]]
|
|
53
60
|
the input text content to put in prompt template.
|
|
@@ -133,7 +140,7 @@ class FrameExtractor(Extractor):
|
|
|
133
140
|
This is the abstract class for frame extraction.
|
|
134
141
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
135
142
|
|
|
136
|
-
Parameters
|
|
143
|
+
Parameters:
|
|
137
144
|
----------
|
|
138
145
|
inference_engine : InferenceEngine
|
|
139
146
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -175,7 +182,7 @@ class FrameExtractor(Extractor):
|
|
|
175
182
|
the substring must start with the same word token as the pattern. This is due to the observation that
|
|
176
183
|
LLM often generate the first few words consistently.
|
|
177
184
|
|
|
178
|
-
Parameters
|
|
185
|
+
Parameters:
|
|
179
186
|
----------
|
|
180
187
|
text : str
|
|
181
188
|
the input text.
|
|
@@ -219,7 +226,7 @@ class FrameExtractor(Extractor):
|
|
|
219
226
|
outputs a list of spans (2-tuple) for each entity.
|
|
220
227
|
Entities that are not found in the text will be None from output.
|
|
221
228
|
|
|
222
|
-
Parameters
|
|
229
|
+
Parameters:
|
|
223
230
|
----------
|
|
224
231
|
text : str
|
|
225
232
|
text that contains entities
|
|
@@ -241,7 +248,10 @@ class FrameExtractor(Extractor):
|
|
|
241
248
|
|
|
242
249
|
# Match entities
|
|
243
250
|
entity_spans = []
|
|
244
|
-
for entity in entities:
|
|
251
|
+
for entity in entities:
|
|
252
|
+
if not isinstance(entity, str):
|
|
253
|
+
entity_spans.append(None)
|
|
254
|
+
continue
|
|
245
255
|
if not case_sensitive:
|
|
246
256
|
entity = entity.lower()
|
|
247
257
|
|
|
@@ -322,7 +332,7 @@ class BasicFrameExtractor(FrameExtractor):
|
|
|
322
332
|
Input system prompt (optional), prompt template (with instruction, few-shot examples),
|
|
323
333
|
and specify a LLM.
|
|
324
334
|
|
|
325
|
-
Parameters
|
|
335
|
+
Parameters:
|
|
326
336
|
----------
|
|
327
337
|
inference_engine : InferenceEngine
|
|
328
338
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -555,18 +565,18 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
555
565
|
from nltk.tokenize.punkt import PunktSentenceTokenizer
|
|
556
566
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
|
|
557
567
|
"""
|
|
558
|
-
This class performs sentence-
|
|
559
|
-
|
|
568
|
+
This class performs sentence-by-sentence information extraction.
|
|
569
|
+
The process is as follows:
|
|
560
570
|
1. system prompt (optional)
|
|
561
|
-
2. user instructions (schema, background, full text, few-shot example...)
|
|
562
|
-
3.
|
|
563
|
-
4.
|
|
571
|
+
2. user prompt with instructions (schema, background, full text, few-shot example...)
|
|
572
|
+
3. feed a sentence (start with first sentence)
|
|
573
|
+
4. LLM extract entities and attributes from the sentence
|
|
564
574
|
5. repeat #3 and #4
|
|
565
575
|
|
|
566
576
|
Input system prompt (optional), prompt template (with user instructions),
|
|
567
577
|
and specify a LLM.
|
|
568
578
|
|
|
569
|
-
Parameters
|
|
579
|
+
Parameters:
|
|
570
580
|
----------
|
|
571
581
|
inference_engine : InferenceEngine
|
|
572
582
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -583,7 +593,7 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
583
593
|
This method sentence tokenize the input text into a list of sentences
|
|
584
594
|
as dict of {start, end, sentence_text}
|
|
585
595
|
|
|
586
|
-
Parameters
|
|
596
|
+
Parameters:
|
|
587
597
|
----------
|
|
588
598
|
text : str
|
|
589
599
|
text to sentence tokenize.
|
|
@@ -674,10 +684,80 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
674
684
|
return output
|
|
675
685
|
|
|
676
686
|
|
|
687
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
|
|
688
|
+
document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
|
|
689
|
+
"""
|
|
690
|
+
The asynchronous version of the extract() method.
|
|
691
|
+
|
|
692
|
+
Parameters:
|
|
693
|
+
----------
|
|
694
|
+
text_content : Union[str, Dict[str,str]]
|
|
695
|
+
the input text content to put in prompt template.
|
|
696
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
697
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
698
|
+
max_new_tokens : str, Optional
|
|
699
|
+
the max number of new tokens LLM should generate.
|
|
700
|
+
document_key : str, Optional
|
|
701
|
+
specify the key in text_content where document text is.
|
|
702
|
+
If text_content is str, this parameter will be ignored.
|
|
703
|
+
temperature : float, Optional
|
|
704
|
+
the temperature for token sampling.
|
|
705
|
+
concurrent_batch_size : int, Optional
|
|
706
|
+
the number of sentences to process in concurrent.
|
|
707
|
+
"""
|
|
708
|
+
# Check if self.inference_engine.chat_async() is implemented
|
|
709
|
+
if not hasattr(self.inference_engine, 'chat_async'):
|
|
710
|
+
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
711
|
+
|
|
712
|
+
# define output
|
|
713
|
+
output = []
|
|
714
|
+
# sentence tokenization
|
|
715
|
+
if isinstance(text_content, str):
|
|
716
|
+
sentences = self._get_sentences(text_content)
|
|
717
|
+
elif isinstance(text_content, dict):
|
|
718
|
+
sentences = self._get_sentences(text_content[document_key])
|
|
719
|
+
# construct chat messages
|
|
720
|
+
base_messages = []
|
|
721
|
+
if self.system_prompt:
|
|
722
|
+
base_messages.append({'role': 'system', 'content': self.system_prompt})
|
|
723
|
+
|
|
724
|
+
base_messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
|
|
725
|
+
base_messages.append({'role': 'assistant', 'content': 'Sure, please start with the first sentence.'})
|
|
726
|
+
|
|
727
|
+
# generate sentence by sentence
|
|
728
|
+
tasks = []
|
|
729
|
+
for i in range(0, len(sentences), concurrent_batch_size):
|
|
730
|
+
batch = sentences[i:i + concurrent_batch_size]
|
|
731
|
+
for sent in batch:
|
|
732
|
+
messages = copy.deepcopy(base_messages)
|
|
733
|
+
messages.append({'role': 'user', 'content': sent['sentence_text']})
|
|
734
|
+
task = asyncio.create_task(
|
|
735
|
+
self.inference_engine.chat_async(
|
|
736
|
+
messages=messages,
|
|
737
|
+
max_new_tokens=max_new_tokens,
|
|
738
|
+
temperature=temperature,
|
|
739
|
+
**kwrs
|
|
740
|
+
)
|
|
741
|
+
)
|
|
742
|
+
tasks.append(task)
|
|
743
|
+
|
|
744
|
+
# Wait until the batch is done, collect results and move on to next batch
|
|
745
|
+
responses = await asyncio.gather(*tasks)
|
|
746
|
+
|
|
747
|
+
# Collect outputs
|
|
748
|
+
for gen_text, sent in zip(responses, sentences):
|
|
749
|
+
output.append({'sentence_start': sent['start'],
|
|
750
|
+
'sentence_end': sent['end'],
|
|
751
|
+
'sentence_text': sent['sentence_text'],
|
|
752
|
+
'gen_text': gen_text})
|
|
753
|
+
return output
|
|
754
|
+
|
|
755
|
+
|
|
677
756
|
def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=512,
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
757
|
+
document_key:str=None, multi_turn:bool=False, temperature:float=0.0, stream:bool=False,
|
|
758
|
+
concurrent:bool=False, concurrent_batch_size:int=32,
|
|
759
|
+
case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
760
|
+
**kwrs) -> List[LLMInformationExtractionFrame]:
|
|
681
761
|
"""
|
|
682
762
|
This method inputs a text and outputs a list of LLMInformationExtractionFrame
|
|
683
763
|
It use the extract() method and post-process outputs into frames.
|
|
@@ -705,6 +785,10 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
705
785
|
the temperature for token sampling.
|
|
706
786
|
stream : bool, Optional
|
|
707
787
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
788
|
+
concurrent : bool, Optional
|
|
789
|
+
if True, the sentences will be extracted in concurrent.
|
|
790
|
+
concurrent_batch_size : int, Optional
|
|
791
|
+
the number of sentences to process in concurrent. Only used when `concurrent` is True.
|
|
708
792
|
case_sensitive : bool, Optional
|
|
709
793
|
if True, entity text matching will be case-sensitive.
|
|
710
794
|
fuzzy_match : bool, Optional
|
|
@@ -718,15 +802,30 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
718
802
|
Return : str
|
|
719
803
|
a list of frames.
|
|
720
804
|
"""
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
805
|
+
if concurrent:
|
|
806
|
+
if stream:
|
|
807
|
+
warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
|
|
808
|
+
if multi_turn:
|
|
809
|
+
warnings.warn("multi_turn=True is not supported in concurrent mode.", RuntimeWarning)
|
|
810
|
+
|
|
811
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
812
|
+
llm_output_sentences = asyncio.run(self.extract_async(text_content=text_content,
|
|
813
|
+
max_new_tokens=max_new_tokens,
|
|
814
|
+
document_key=document_key,
|
|
815
|
+
temperature=temperature,
|
|
816
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
817
|
+
**kwrs)
|
|
818
|
+
)
|
|
819
|
+
else:
|
|
820
|
+
llm_output_sentences = self.extract(text_content=text_content,
|
|
821
|
+
max_new_tokens=max_new_tokens,
|
|
822
|
+
document_key=document_key,
|
|
823
|
+
multi_turn=multi_turn,
|
|
824
|
+
temperature=temperature,
|
|
825
|
+
stream=stream,
|
|
826
|
+
**kwrs)
|
|
728
827
|
frame_list = []
|
|
729
|
-
for sent in
|
|
828
|
+
for sent in llm_output_sentences:
|
|
730
829
|
entity_json = []
|
|
731
830
|
for entity in self._extract_json(gen_text=sent['gen_text']):
|
|
732
831
|
if entity_key in entity:
|
|
@@ -891,6 +990,121 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
|
|
|
891
990
|
'sentence_text': sent['sentence_text'],
|
|
892
991
|
'gen_text': gen_text})
|
|
893
992
|
return output
|
|
993
|
+
|
|
994
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
|
|
995
|
+
document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
|
|
996
|
+
"""
|
|
997
|
+
The asynchronous version of the extract() method.
|
|
998
|
+
|
|
999
|
+
Parameters:
|
|
1000
|
+
----------
|
|
1001
|
+
text_content : Union[str, Dict[str,str]]
|
|
1002
|
+
the input text content to put in prompt template.
|
|
1003
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
1004
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
1005
|
+
max_new_tokens : str, Optional
|
|
1006
|
+
the max number of new tokens LLM should generate.
|
|
1007
|
+
document_key : str, Optional
|
|
1008
|
+
specify the key in text_content where document text is.
|
|
1009
|
+
If text_content is str, this parameter will be ignored.
|
|
1010
|
+
temperature : float, Optional
|
|
1011
|
+
the temperature for token sampling.
|
|
1012
|
+
concurrent_batch_size : int, Optional
|
|
1013
|
+
the number of sentences to process in concurrent.
|
|
1014
|
+
|
|
1015
|
+
Return : str
|
|
1016
|
+
the output from LLM. Need post-processing.
|
|
1017
|
+
"""
|
|
1018
|
+
# Check if self.inference_engine.chat_async() is implemented
|
|
1019
|
+
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1020
|
+
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1021
|
+
|
|
1022
|
+
# define output
|
|
1023
|
+
output = []
|
|
1024
|
+
# sentence tokenization
|
|
1025
|
+
if isinstance(text_content, str):
|
|
1026
|
+
sentences = self._get_sentences(text_content)
|
|
1027
|
+
elif isinstance(text_content, dict):
|
|
1028
|
+
sentences = self._get_sentences(text_content[document_key])
|
|
1029
|
+
# construct chat messages
|
|
1030
|
+
base_messages = []
|
|
1031
|
+
if self.system_prompt:
|
|
1032
|
+
base_messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1033
|
+
|
|
1034
|
+
base_messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
|
|
1035
|
+
base_messages.append({'role': 'assistant', 'content': 'Sure, please start with the first sentence.'})
|
|
1036
|
+
|
|
1037
|
+
# generate initial outputs sentence by sentence
|
|
1038
|
+
initials = []
|
|
1039
|
+
tasks = []
|
|
1040
|
+
message_list = []
|
|
1041
|
+
for i in range(0, len(sentences), concurrent_batch_size):
|
|
1042
|
+
batch = sentences[i:i + concurrent_batch_size]
|
|
1043
|
+
for sent in batch:
|
|
1044
|
+
messages = copy.deepcopy(base_messages)
|
|
1045
|
+
messages.append({'role': 'user', 'content': sent['sentence_text']})
|
|
1046
|
+
message_list.append(messages)
|
|
1047
|
+
task = asyncio.create_task(
|
|
1048
|
+
self.inference_engine.chat_async(
|
|
1049
|
+
messages=messages,
|
|
1050
|
+
max_new_tokens=max_new_tokens,
|
|
1051
|
+
temperature=temperature,
|
|
1052
|
+
**kwrs
|
|
1053
|
+
)
|
|
1054
|
+
)
|
|
1055
|
+
tasks.append(task)
|
|
1056
|
+
|
|
1057
|
+
# Wait until the batch is done, collect results and move on to next batch
|
|
1058
|
+
responses = await asyncio.gather(*tasks)
|
|
1059
|
+
# Collect initials
|
|
1060
|
+
for gen_text, sent, message in zip(responses, sentences, message_list):
|
|
1061
|
+
initials.append({'sentence_start': sent['start'],
|
|
1062
|
+
'sentence_end': sent['end'],
|
|
1063
|
+
'sentence_text': sent['sentence_text'],
|
|
1064
|
+
'gen_text': gen_text,
|
|
1065
|
+
'messages': message})
|
|
1066
|
+
|
|
1067
|
+
# Review
|
|
1068
|
+
reviews = []
|
|
1069
|
+
tasks = []
|
|
1070
|
+
for i in range(0, len(initials), concurrent_batch_size):
|
|
1071
|
+
batch = initials[i:i + concurrent_batch_size]
|
|
1072
|
+
for init in batch:
|
|
1073
|
+
messages = init["messages"]
|
|
1074
|
+
initial = init["gen_text"]
|
|
1075
|
+
messages.append({'role': 'assistant', 'content': initial})
|
|
1076
|
+
messages.append({'role': 'user', 'content': self.review_prompt})
|
|
1077
|
+
task = asyncio.create_task(
|
|
1078
|
+
self.inference_engine.chat_async(
|
|
1079
|
+
messages=messages,
|
|
1080
|
+
max_new_tokens=max_new_tokens,
|
|
1081
|
+
temperature=temperature,
|
|
1082
|
+
**kwrs
|
|
1083
|
+
)
|
|
1084
|
+
)
|
|
1085
|
+
tasks.append(task)
|
|
1086
|
+
|
|
1087
|
+
responses = await asyncio.gather(*tasks)
|
|
1088
|
+
|
|
1089
|
+
# Collect reviews
|
|
1090
|
+
for gen_text, sent in zip(responses, sentences):
|
|
1091
|
+
reviews.append({'sentence_start': sent['start'],
|
|
1092
|
+
'sentence_end': sent['end'],
|
|
1093
|
+
'sentence_text': sent['sentence_text'],
|
|
1094
|
+
'gen_text': gen_text})
|
|
1095
|
+
|
|
1096
|
+
for init, rev in zip(initials, reviews):
|
|
1097
|
+
if self.review_mode == "revision":
|
|
1098
|
+
gen_text = rev['gen_text']
|
|
1099
|
+
elif self.review_mode == "addition":
|
|
1100
|
+
gen_text = init['gen_text'] + '\n' + rev['gen_text']
|
|
1101
|
+
|
|
1102
|
+
# add to output
|
|
1103
|
+
output.append({'sentence_start': init['sentence_start'],
|
|
1104
|
+
'sentence_end': init['sentence_end'],
|
|
1105
|
+
'sentence_text': init['sentence_text'],
|
|
1106
|
+
'gen_text': gen_text})
|
|
1107
|
+
return output
|
|
894
1108
|
|
|
895
1109
|
|
|
896
1110
|
class SentenceCoTFrameExtractor(SentenceFrameExtractor):
|
|
@@ -1130,19 +1344,34 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1130
1344
|
self.possible_relation_func = possible_relation_func
|
|
1131
1345
|
|
|
1132
1346
|
|
|
1133
|
-
def
|
|
1134
|
-
|
|
1347
|
+
def _post_process(self, rel_json:str) -> bool:
|
|
1348
|
+
if len(rel_json) > 0:
|
|
1349
|
+
if "Relation" in rel_json[0]:
|
|
1350
|
+
rel = rel_json[0]["Relation"]
|
|
1351
|
+
if isinstance(rel, bool):
|
|
1352
|
+
return rel
|
|
1353
|
+
elif isinstance(rel, str) and rel in {"True", "False"}:
|
|
1354
|
+
return eval(rel)
|
|
1355
|
+
else:
|
|
1356
|
+
warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
|
|
1357
|
+
'Following default, relation = False.', RuntimeWarning)
|
|
1358
|
+
else:
|
|
1359
|
+
warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
|
|
1360
|
+
else:
|
|
1361
|
+
warnings.warn('Extractor did not output a JSON list. Following default, relation = False.', RuntimeWarning)
|
|
1362
|
+
return False
|
|
1363
|
+
|
|
1364
|
+
|
|
1365
|
+
def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1366
|
+
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1135
1367
|
"""
|
|
1136
|
-
This method
|
|
1368
|
+
This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
|
|
1369
|
+
Outputs pairs that are related.
|
|
1137
1370
|
|
|
1138
1371
|
Parameters:
|
|
1139
1372
|
-----------
|
|
1140
|
-
|
|
1141
|
-
a
|
|
1142
|
-
frame_2 : LLMInformationExtractionFrame
|
|
1143
|
-
the other frame
|
|
1144
|
-
text : str
|
|
1145
|
-
the entire document text
|
|
1373
|
+
doc : LLMInformationExtractionDocument
|
|
1374
|
+
a document with frames.
|
|
1146
1375
|
buffer_size : int, Optional
|
|
1147
1376
|
the number of characters before and after the two frames in the ROI text.
|
|
1148
1377
|
max_new_tokens : str, Optional
|
|
@@ -1152,51 +1381,111 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1152
1381
|
stream : bool, Optional
|
|
1153
1382
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1154
1383
|
|
|
1155
|
-
Return :
|
|
1156
|
-
a
|
|
1384
|
+
Return : List[Dict]
|
|
1385
|
+
a list of dict with {"frame_1_id", "frame_2_id"}.
|
|
1157
1386
|
"""
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
messages = []
|
|
1164
|
-
if self.system_prompt:
|
|
1165
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1387
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1388
|
+
output = []
|
|
1389
|
+
for frame_1, frame_2 in pairs:
|
|
1390
|
+
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1166
1391
|
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
|
|
1192
|
-
else:
|
|
1193
|
-
warnings.warn("Extractor did not output a JSON. Following default, relation = False.", RuntimeWarning)
|
|
1392
|
+
if pos_rel:
|
|
1393
|
+
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1394
|
+
if stream:
|
|
1395
|
+
print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
|
|
1396
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1397
|
+
messages = []
|
|
1398
|
+
if self.system_prompt:
|
|
1399
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1400
|
+
|
|
1401
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1402
|
+
"frame_1": str(frame_1.to_dict()),
|
|
1403
|
+
"frame_2": str(frame_2.to_dict())}
|
|
1404
|
+
)})
|
|
1405
|
+
|
|
1406
|
+
gen_text = self.inference_engine.chat(
|
|
1407
|
+
messages=messages,
|
|
1408
|
+
max_new_tokens=max_new_tokens,
|
|
1409
|
+
temperature=temperature,
|
|
1410
|
+
stream=stream,
|
|
1411
|
+
**kwrs
|
|
1412
|
+
)
|
|
1413
|
+
rel_json = self._extract_json(gen_text)
|
|
1414
|
+
if self._post_process(rel_json):
|
|
1415
|
+
output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id})
|
|
1194
1416
|
|
|
1195
|
-
return
|
|
1417
|
+
return output
|
|
1418
|
+
|
|
1196
1419
|
|
|
1420
|
+
async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1421
|
+
temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
|
|
1422
|
+
"""
|
|
1423
|
+
This is the asynchronous version of the extract() method.
|
|
1424
|
+
|
|
1425
|
+
Parameters:
|
|
1426
|
+
-----------
|
|
1427
|
+
doc : LLMInformationExtractionDocument
|
|
1428
|
+
a document with frames.
|
|
1429
|
+
buffer_size : int, Optional
|
|
1430
|
+
the number of characters before and after the two frames in the ROI text.
|
|
1431
|
+
max_new_tokens : str, Optional
|
|
1432
|
+
the max number of new tokens LLM should generate.
|
|
1433
|
+
temperature : float, Optional
|
|
1434
|
+
the temperature for token sampling.
|
|
1435
|
+
concurrent_batch_size : int, Optional
|
|
1436
|
+
the number of frame pairs to process in concurrent.
|
|
1437
|
+
|
|
1438
|
+
Return : List[Dict]
|
|
1439
|
+
a list of dict with {"frame_1", "frame_2"}.
|
|
1440
|
+
"""
|
|
1441
|
+
# Check if self.inference_engine.chat_async() is implemented
|
|
1442
|
+
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1443
|
+
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1444
|
+
|
|
1445
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1446
|
+
n_frames = len(doc.frames)
|
|
1447
|
+
num_pairs = (n_frames * (n_frames-1)) // 2
|
|
1448
|
+
rel_pair_list = []
|
|
1449
|
+
tasks = []
|
|
1450
|
+
for i in range(0, num_pairs, concurrent_batch_size):
|
|
1451
|
+
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
1452
|
+
for frame_1, frame_2 in batch:
|
|
1453
|
+
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1197
1454
|
|
|
1455
|
+
if pos_rel:
|
|
1456
|
+
rel_pair_list.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
|
|
1457
|
+
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1458
|
+
messages = []
|
|
1459
|
+
if self.system_prompt:
|
|
1460
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1461
|
+
|
|
1462
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1463
|
+
"frame_1": str(frame_1.to_dict()),
|
|
1464
|
+
"frame_2": str(frame_2.to_dict())}
|
|
1465
|
+
)})
|
|
1466
|
+
task = asyncio.create_task(
|
|
1467
|
+
self.inference_engine.chat_async(
|
|
1468
|
+
messages=messages,
|
|
1469
|
+
max_new_tokens=max_new_tokens,
|
|
1470
|
+
temperature=temperature,
|
|
1471
|
+
**kwrs
|
|
1472
|
+
)
|
|
1473
|
+
)
|
|
1474
|
+
tasks.append(task)
|
|
1475
|
+
|
|
1476
|
+
responses = await asyncio.gather(*tasks)
|
|
1477
|
+
|
|
1478
|
+
output = []
|
|
1479
|
+
for d, response in zip(rel_pair_list, responses):
|
|
1480
|
+
rel_json = self._extract_json(response)
|
|
1481
|
+
if self._post_process(rel_json):
|
|
1482
|
+
output.append(d)
|
|
1483
|
+
|
|
1484
|
+
return output
|
|
1485
|
+
|
|
1486
|
+
|
|
1198
1487
|
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1199
|
-
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1488
|
+
temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1200
1489
|
"""
|
|
1201
1490
|
This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
|
|
1202
1491
|
|
|
@@ -1210,6 +1499,10 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1210
1499
|
the max number of new tokens LLM should generate.
|
|
1211
1500
|
temperature : float, Optional
|
|
1212
1501
|
the temperature for token sampling.
|
|
1502
|
+
concurrent: bool, Optional
|
|
1503
|
+
if True, the extraction will be done in concurrent.
|
|
1504
|
+
concurrent_batch_size : int, Optional
|
|
1505
|
+
the number of frame pairs to process in concurrent.
|
|
1213
1506
|
stream : bool, Optional
|
|
1214
1507
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1215
1508
|
|
|
@@ -1222,19 +1515,26 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1222
1515
|
if doc.has_duplicate_frame_ids():
|
|
1223
1516
|
raise ValueError("All frame_ids in the input document must be unique.")
|
|
1224
1517
|
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1518
|
+
if concurrent:
|
|
1519
|
+
if stream:
|
|
1520
|
+
warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1521
|
+
|
|
1522
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
1523
|
+
return asyncio.run(self.extract_async(doc=doc,
|
|
1524
|
+
buffer_size=buffer_size,
|
|
1525
|
+
max_new_tokens=max_new_tokens,
|
|
1526
|
+
temperature=temperature,
|
|
1527
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
1528
|
+
**kwrs)
|
|
1529
|
+
)
|
|
1530
|
+
else:
|
|
1531
|
+
return self.extract(doc=doc,
|
|
1532
|
+
buffer_size=buffer_size,
|
|
1533
|
+
max_new_tokens=max_new_tokens,
|
|
1534
|
+
temperature=temperature,
|
|
1535
|
+
stream=stream,
|
|
1536
|
+
**kwrs)
|
|
1537
|
+
|
|
1238
1538
|
|
|
1239
1539
|
class MultiClassRelationExtractor(RelationExtractor):
|
|
1240
1540
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_types_func: Callable,
|
|
@@ -1279,22 +1579,41 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1279
1579
|
|
|
1280
1580
|
self.possible_relation_types_func = possible_relation_types_func
|
|
1281
1581
|
|
|
1282
|
-
|
|
1283
|
-
def
|
|
1284
|
-
pos_rel_types:List[str], text:str, buffer_size:int=100, max_new_tokens:int=128, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
|
|
1582
|
+
|
|
1583
|
+
def _post_process(self, rel_json:List[Dict], pos_rel_types:List[str]) -> Union[str, None]:
|
|
1285
1584
|
"""
|
|
1286
|
-
This method
|
|
1585
|
+
This method post-processes the extracted relation JSON.
|
|
1287
1586
|
|
|
1288
1587
|
Parameters:
|
|
1289
1588
|
-----------
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
frame_2 : LLMInformationExtractionFrame
|
|
1293
|
-
the other frame
|
|
1589
|
+
rel_json : List[Dict]
|
|
1590
|
+
the extracted relation JSON.
|
|
1294
1591
|
pos_rel_types : List[str]
|
|
1295
|
-
possible relation types.
|
|
1296
|
-
|
|
1297
|
-
|
|
1592
|
+
possible relation types by the possible_relation_types_func.
|
|
1593
|
+
|
|
1594
|
+
Return : Union[str, None]
|
|
1595
|
+
the relation type (str) or None for no relation.
|
|
1596
|
+
"""
|
|
1597
|
+
if len(rel_json) > 0:
|
|
1598
|
+
if "RelationType" in rel_json[0]:
|
|
1599
|
+
if rel_json[0]["RelationType"] in pos_rel_types:
|
|
1600
|
+
return rel_json[0]["RelationType"]
|
|
1601
|
+
else:
|
|
1602
|
+
warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
|
|
1603
|
+
else:
|
|
1604
|
+
warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
|
|
1605
|
+
return None
|
|
1606
|
+
|
|
1607
|
+
|
|
1608
|
+
def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1609
|
+
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1610
|
+
"""
|
|
1611
|
+
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
|
|
1612
|
+
|
|
1613
|
+
Parameters:
|
|
1614
|
+
-----------
|
|
1615
|
+
doc : LLMInformationExtractionDocument
|
|
1616
|
+
a document with frames.
|
|
1298
1617
|
buffer_size : int, Optional
|
|
1299
1618
|
the number of characters before and after the two frames in the ROI text.
|
|
1300
1619
|
max_new_tokens : str, Optional
|
|
@@ -1304,54 +1623,117 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1304
1623
|
stream : bool, Optional
|
|
1305
1624
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1306
1625
|
|
|
1307
|
-
Return :
|
|
1308
|
-
a relation
|
|
1626
|
+
Return : List[Dict]
|
|
1627
|
+
a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
|
|
1309
1628
|
"""
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1629
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1630
|
+
output = []
|
|
1631
|
+
for frame_1, frame_2 in pairs:
|
|
1632
|
+
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1314
1633
|
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1634
|
+
if pos_rel_types:
|
|
1635
|
+
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1636
|
+
if stream:
|
|
1637
|
+
print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
|
|
1638
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1639
|
+
messages = []
|
|
1640
|
+
if self.system_prompt:
|
|
1641
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1642
|
+
|
|
1643
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1644
|
+
"frame_1": str(frame_1.to_dict()),
|
|
1645
|
+
"frame_2": str(frame_2.to_dict()),
|
|
1646
|
+
"pos_rel_types":str(pos_rel_types)}
|
|
1647
|
+
)})
|
|
1648
|
+
|
|
1649
|
+
gen_text = self.inference_engine.chat(
|
|
1650
|
+
messages=messages,
|
|
1651
|
+
max_new_tokens=max_new_tokens,
|
|
1652
|
+
temperature=temperature,
|
|
1653
|
+
stream=stream,
|
|
1654
|
+
**kwrs
|
|
1655
|
+
)
|
|
1656
|
+
rel_json = self._extract_json(gen_text)
|
|
1657
|
+
rel = self._post_process(rel_json, pos_rel_types)
|
|
1658
|
+
if rel:
|
|
1659
|
+
output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'relation':rel})
|
|
1318
1660
|
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1661
|
+
return output
|
|
1662
|
+
|
|
1663
|
+
|
|
1664
|
+
async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1665
|
+
temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
|
|
1666
|
+
"""
|
|
1667
|
+
This is the asynchronous version of the extract() method.
|
|
1668
|
+
|
|
1669
|
+
Parameters:
|
|
1670
|
+
-----------
|
|
1671
|
+
doc : LLMInformationExtractionDocument
|
|
1672
|
+
a document with frames.
|
|
1673
|
+
buffer_size : int, Optional
|
|
1674
|
+
the number of characters before and after the two frames in the ROI text.
|
|
1675
|
+
max_new_tokens : str, Optional
|
|
1676
|
+
the max number of new tokens LLM should generate.
|
|
1677
|
+
temperature : float, Optional
|
|
1678
|
+
the temperature for token sampling.
|
|
1679
|
+
concurrent_batch_size : int, Optional
|
|
1680
|
+
the number of frame pairs to process in concurrent.
|
|
1681
|
+
|
|
1682
|
+
Return : List[Dict]
|
|
1683
|
+
a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
|
|
1684
|
+
"""
|
|
1685
|
+
# Check if self.inference_engine.chat_async() is implemented
|
|
1686
|
+
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1687
|
+
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1343
1688
|
|
|
1344
|
-
|
|
1345
|
-
|
|
1689
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1690
|
+
n_frames = len(doc.frames)
|
|
1691
|
+
num_pairs = (n_frames * (n_frames-1)) // 2
|
|
1692
|
+
rel_pair_list = []
|
|
1693
|
+
tasks = []
|
|
1694
|
+
for i in range(0, num_pairs, concurrent_batch_size):
|
|
1695
|
+
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
1696
|
+
for frame_1, frame_2 in batch:
|
|
1697
|
+
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1698
|
+
|
|
1699
|
+
if pos_rel_types:
|
|
1700
|
+
rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'pos_rel_types':pos_rel_types})
|
|
1701
|
+
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1702
|
+
messages = []
|
|
1703
|
+
if self.system_prompt:
|
|
1704
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1705
|
+
|
|
1706
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1707
|
+
"frame_1": str(frame_1.to_dict()),
|
|
1708
|
+
"frame_2": str(frame_2.to_dict()),
|
|
1709
|
+
"pos_rel_types":str(pos_rel_types)}
|
|
1710
|
+
)})
|
|
1711
|
+
task = asyncio.create_task(
|
|
1712
|
+
self.inference_engine.chat_async(
|
|
1713
|
+
messages=messages,
|
|
1714
|
+
max_new_tokens=max_new_tokens,
|
|
1715
|
+
temperature=temperature,
|
|
1716
|
+
**kwrs
|
|
1717
|
+
)
|
|
1718
|
+
)
|
|
1719
|
+
tasks.append(task)
|
|
1720
|
+
|
|
1721
|
+
responses = await asyncio.gather(*tasks)
|
|
1722
|
+
|
|
1723
|
+
output = []
|
|
1724
|
+
for d, response in zip(rel_pair_list, responses):
|
|
1725
|
+
rel_json = self._extract_json(response)
|
|
1726
|
+
rel = self._post_process(rel_json, d['pos_rel_types'])
|
|
1727
|
+
if rel:
|
|
1728
|
+
output.append({'frame_1':d['frame_1'], 'frame_2':d['frame_2'], 'relation':rel})
|
|
1346
1729
|
|
|
1347
|
-
return
|
|
1730
|
+
return output
|
|
1348
1731
|
|
|
1349
1732
|
|
|
1350
1733
|
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1351
|
-
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1734
|
+
temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1352
1735
|
"""
|
|
1353
|
-
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs
|
|
1354
|
-
and to provide possible relation types between two frames.
|
|
1736
|
+
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
|
|
1355
1737
|
|
|
1356
1738
|
Parameters:
|
|
1357
1739
|
-----------
|
|
@@ -1363,6 +1745,10 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1363
1745
|
the max number of new tokens LLM should generate.
|
|
1364
1746
|
temperature : float, Optional
|
|
1365
1747
|
the temperature for token sampling.
|
|
1748
|
+
concurrent: bool, Optional
|
|
1749
|
+
if True, the extraction will be done in concurrent.
|
|
1750
|
+
concurrent_batch_size : int, Optional
|
|
1751
|
+
the number of frame pairs to process in concurrent.
|
|
1366
1752
|
stream : bool, Optional
|
|
1367
1753
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1368
1754
|
|
|
@@ -1375,15 +1761,23 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1375
1761
|
if doc.has_duplicate_frame_ids():
|
|
1376
1762
|
raise ValueError("All frame_ids in the input document must be unique.")
|
|
1377
1763
|
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1764
|
+
if concurrent:
|
|
1765
|
+
if stream:
|
|
1766
|
+
warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1767
|
+
|
|
1768
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
1769
|
+
return asyncio.run(self.extract_async(doc=doc,
|
|
1770
|
+
buffer_size=buffer_size,
|
|
1771
|
+
max_new_tokens=max_new_tokens,
|
|
1772
|
+
temperature=temperature,
|
|
1773
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
1774
|
+
**kwrs)
|
|
1775
|
+
)
|
|
1776
|
+
else:
|
|
1777
|
+
return self.extract(doc=doc,
|
|
1778
|
+
buffer_size=buffer_size,
|
|
1779
|
+
max_new_tokens=max_new_tokens,
|
|
1780
|
+
temperature=temperature,
|
|
1781
|
+
stream=stream,
|
|
1782
|
+
**kwrs)
|
|
1783
|
+
|