llm-ie 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +9 -0
- llm_ie/engines.py +151 -9
- llm_ie/extractors.py +552 -152
- llm_ie/prompt_editor.py +17 -2
- {llm_ie-0.3.4.dist-info → llm_ie-0.4.0.dist-info}/METADATA +342 -103
- {llm_ie-0.3.4.dist-info → llm_ie-0.4.0.dist-info}/RECORD +7 -7
- {llm_ie-0.3.4.dist-info → llm_ie-0.4.0.dist-info}/WHEEL +0 -0
llm_ie/extractors.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import re
|
|
3
|
+
import copy
|
|
3
4
|
import json
|
|
5
|
+
import json_repair
|
|
4
6
|
import inspect
|
|
5
7
|
import importlib.resources
|
|
6
8
|
import warnings
|
|
7
9
|
import itertools
|
|
10
|
+
import asyncio
|
|
11
|
+
import nest_asyncio
|
|
8
12
|
from typing import Set, List, Dict, Tuple, Union, Callable
|
|
9
13
|
from llm_ie.data_types import LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
10
14
|
from llm_ie.engines import InferenceEngine
|
|
@@ -18,7 +22,7 @@ class Extractor:
|
|
|
18
22
|
This is the abstract class for (frame and relation) extractors.
|
|
19
23
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
20
24
|
|
|
21
|
-
Parameters
|
|
25
|
+
Parameters:
|
|
22
26
|
----------
|
|
23
27
|
inference_engine : InferenceEngine
|
|
24
28
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -37,16 +41,20 @@ class Extractor:
|
|
|
37
41
|
"""
|
|
38
42
|
This method returns the pre-defined prompt guideline for the extractor from the package asset.
|
|
39
43
|
"""
|
|
44
|
+
# Check if the prompt guide is available
|
|
40
45
|
file_path = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{cls.__name__}_prompt_guide.txt")
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
46
|
+
try:
|
|
47
|
+
with open(file_path, 'r', encoding="utf-8") as f:
|
|
48
|
+
return f.read()
|
|
49
|
+
except FileNotFoundError:
|
|
50
|
+
warnings.warn(f"Prompt guide for {cls.__name__} is not available. Is it a customed extractor?", UserWarning)
|
|
51
|
+
return None
|
|
44
52
|
|
|
45
53
|
def _get_user_prompt(self, text_content:Union[str, Dict[str,str]]) -> str:
|
|
46
54
|
"""
|
|
47
55
|
This method applies text_content to prompt_template and returns a prompt.
|
|
48
56
|
|
|
49
|
-
Parameters
|
|
57
|
+
Parameters:
|
|
50
58
|
----------
|
|
51
59
|
text_content : Union[str, Dict[str,str]]
|
|
52
60
|
the input text content to put in prompt template.
|
|
@@ -117,7 +125,12 @@ class Extractor:
|
|
|
117
125
|
dict_obj = json.loads(dict_str)
|
|
118
126
|
out.append(dict_obj)
|
|
119
127
|
except json.JSONDecodeError:
|
|
120
|
-
|
|
128
|
+
dict_obj = json_repair.repair_json(dict_str, skip_json_loads=True, return_objects=True)
|
|
129
|
+
if dict_obj:
|
|
130
|
+
warnings.warn(f'JSONDecodeError detected, fixed with repair_json:\n{dict_str}', RuntimeWarning)
|
|
131
|
+
out.append(dict_obj)
|
|
132
|
+
else:
|
|
133
|
+
warnings.warn(f'JSONDecodeError could not be fixed:\n{dict_str}', RuntimeWarning)
|
|
121
134
|
return out
|
|
122
135
|
|
|
123
136
|
|
|
@@ -127,7 +140,7 @@ class FrameExtractor(Extractor):
|
|
|
127
140
|
This is the abstract class for frame extraction.
|
|
128
141
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
129
142
|
|
|
130
|
-
Parameters
|
|
143
|
+
Parameters:
|
|
131
144
|
----------
|
|
132
145
|
inference_engine : InferenceEngine
|
|
133
146
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -169,7 +182,7 @@ class FrameExtractor(Extractor):
|
|
|
169
182
|
the substring must start with the same word token as the pattern. This is due to the observation that
|
|
170
183
|
LLM often generate the first few words consistently.
|
|
171
184
|
|
|
172
|
-
Parameters
|
|
185
|
+
Parameters:
|
|
173
186
|
----------
|
|
174
187
|
text : str
|
|
175
188
|
the input text.
|
|
@@ -213,7 +226,7 @@ class FrameExtractor(Extractor):
|
|
|
213
226
|
outputs a list of spans (2-tuple) for each entity.
|
|
214
227
|
Entities that are not found in the text will be None from output.
|
|
215
228
|
|
|
216
|
-
Parameters
|
|
229
|
+
Parameters:
|
|
217
230
|
----------
|
|
218
231
|
text : str
|
|
219
232
|
text that contains entities
|
|
@@ -235,7 +248,10 @@ class FrameExtractor(Extractor):
|
|
|
235
248
|
|
|
236
249
|
# Match entities
|
|
237
250
|
entity_spans = []
|
|
238
|
-
for entity in entities:
|
|
251
|
+
for entity in entities:
|
|
252
|
+
if not isinstance(entity, str):
|
|
253
|
+
entity_spans.append(None)
|
|
254
|
+
continue
|
|
239
255
|
if not case_sensitive:
|
|
240
256
|
entity = entity.lower()
|
|
241
257
|
|
|
@@ -316,7 +332,7 @@ class BasicFrameExtractor(FrameExtractor):
|
|
|
316
332
|
Input system prompt (optional), prompt template (with instruction, few-shot examples),
|
|
317
333
|
and specify a LLM.
|
|
318
334
|
|
|
319
|
-
Parameters
|
|
335
|
+
Parameters:
|
|
320
336
|
----------
|
|
321
337
|
inference_engine : InferenceEngine
|
|
322
338
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -549,18 +565,18 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
549
565
|
from nltk.tokenize.punkt import PunktSentenceTokenizer
|
|
550
566
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
|
|
551
567
|
"""
|
|
552
|
-
This class performs sentence-
|
|
553
|
-
|
|
568
|
+
This class performs sentence-by-sentence information extraction.
|
|
569
|
+
The process is as follows:
|
|
554
570
|
1. system prompt (optional)
|
|
555
|
-
2. user instructions (schema, background, full text, few-shot example...)
|
|
556
|
-
3.
|
|
557
|
-
4.
|
|
571
|
+
2. user prompt with instructions (schema, background, full text, few-shot example...)
|
|
572
|
+
3. feed a sentence (start with first sentence)
|
|
573
|
+
4. LLM extract entities and attributes from the sentence
|
|
558
574
|
5. repeat #3 and #4
|
|
559
575
|
|
|
560
576
|
Input system prompt (optional), prompt template (with user instructions),
|
|
561
577
|
and specify a LLM.
|
|
562
578
|
|
|
563
|
-
Parameters
|
|
579
|
+
Parameters:
|
|
564
580
|
----------
|
|
565
581
|
inference_engine : InferenceEngine
|
|
566
582
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -577,7 +593,7 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
577
593
|
This method sentence tokenize the input text into a list of sentences
|
|
578
594
|
as dict of {start, end, sentence_text}
|
|
579
595
|
|
|
580
|
-
Parameters
|
|
596
|
+
Parameters:
|
|
581
597
|
----------
|
|
582
598
|
text : str
|
|
583
599
|
text to sentence tokenize.
|
|
@@ -668,10 +684,80 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
668
684
|
return output
|
|
669
685
|
|
|
670
686
|
|
|
687
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
|
|
688
|
+
document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
|
|
689
|
+
"""
|
|
690
|
+
The asynchronous version of the extract() method.
|
|
691
|
+
|
|
692
|
+
Parameters:
|
|
693
|
+
----------
|
|
694
|
+
text_content : Union[str, Dict[str,str]]
|
|
695
|
+
the input text content to put in prompt template.
|
|
696
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
697
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
698
|
+
max_new_tokens : str, Optional
|
|
699
|
+
the max number of new tokens LLM should generate.
|
|
700
|
+
document_key : str, Optional
|
|
701
|
+
specify the key in text_content where document text is.
|
|
702
|
+
If text_content is str, this parameter will be ignored.
|
|
703
|
+
temperature : float, Optional
|
|
704
|
+
the temperature for token sampling.
|
|
705
|
+
concurrent_batch_size : int, Optional
|
|
706
|
+
the number of sentences to process in concurrent.
|
|
707
|
+
"""
|
|
708
|
+
# Check if self.inference_engine.chat_async() is implemented
|
|
709
|
+
if not hasattr(self.inference_engine, 'chat_async'):
|
|
710
|
+
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
711
|
+
|
|
712
|
+
# define output
|
|
713
|
+
output = []
|
|
714
|
+
# sentence tokenization
|
|
715
|
+
if isinstance(text_content, str):
|
|
716
|
+
sentences = self._get_sentences(text_content)
|
|
717
|
+
elif isinstance(text_content, dict):
|
|
718
|
+
sentences = self._get_sentences(text_content[document_key])
|
|
719
|
+
# construct chat messages
|
|
720
|
+
base_messages = []
|
|
721
|
+
if self.system_prompt:
|
|
722
|
+
base_messages.append({'role': 'system', 'content': self.system_prompt})
|
|
723
|
+
|
|
724
|
+
base_messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
|
|
725
|
+
base_messages.append({'role': 'assistant', 'content': 'Sure, please start with the first sentence.'})
|
|
726
|
+
|
|
727
|
+
# generate sentence by sentence
|
|
728
|
+
tasks = []
|
|
729
|
+
for i in range(0, len(sentences), concurrent_batch_size):
|
|
730
|
+
batch = sentences[i:i + concurrent_batch_size]
|
|
731
|
+
for sent in batch:
|
|
732
|
+
messages = copy.deepcopy(base_messages)
|
|
733
|
+
messages.append({'role': 'user', 'content': sent['sentence_text']})
|
|
734
|
+
task = asyncio.create_task(
|
|
735
|
+
self.inference_engine.chat_async(
|
|
736
|
+
messages=messages,
|
|
737
|
+
max_new_tokens=max_new_tokens,
|
|
738
|
+
temperature=temperature,
|
|
739
|
+
**kwrs
|
|
740
|
+
)
|
|
741
|
+
)
|
|
742
|
+
tasks.append(task)
|
|
743
|
+
|
|
744
|
+
# Wait until the batch is done, collect results and move on to next batch
|
|
745
|
+
responses = await asyncio.gather(*tasks)
|
|
746
|
+
|
|
747
|
+
# Collect outputs
|
|
748
|
+
for gen_text, sent in zip(responses, sentences):
|
|
749
|
+
output.append({'sentence_start': sent['start'],
|
|
750
|
+
'sentence_end': sent['end'],
|
|
751
|
+
'sentence_text': sent['sentence_text'],
|
|
752
|
+
'gen_text': gen_text})
|
|
753
|
+
return output
|
|
754
|
+
|
|
755
|
+
|
|
671
756
|
def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=512,
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
757
|
+
document_key:str=None, multi_turn:bool=False, temperature:float=0.0, stream:bool=False,
|
|
758
|
+
concurrent:bool=False, concurrent_batch_size:int=32,
|
|
759
|
+
case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
760
|
+
**kwrs) -> List[LLMInformationExtractionFrame]:
|
|
675
761
|
"""
|
|
676
762
|
This method inputs a text and outputs a list of LLMInformationExtractionFrame
|
|
677
763
|
It use the extract() method and post-process outputs into frames.
|
|
@@ -699,6 +785,10 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
699
785
|
the temperature for token sampling.
|
|
700
786
|
stream : bool, Optional
|
|
701
787
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
788
|
+
concurrent : bool, Optional
|
|
789
|
+
if True, the sentences will be extracted in concurrent.
|
|
790
|
+
concurrent_batch_size : int, Optional
|
|
791
|
+
the number of sentences to process in concurrent. Only used when `concurrent` is True.
|
|
702
792
|
case_sensitive : bool, Optional
|
|
703
793
|
if True, entity text matching will be case-sensitive.
|
|
704
794
|
fuzzy_match : bool, Optional
|
|
@@ -712,15 +802,30 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
712
802
|
Return : str
|
|
713
803
|
a list of frames.
|
|
714
804
|
"""
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
805
|
+
if concurrent:
|
|
806
|
+
if stream:
|
|
807
|
+
warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
|
|
808
|
+
if multi_turn:
|
|
809
|
+
warnings.warn("multi_turn=True is not supported in concurrent mode.", RuntimeWarning)
|
|
810
|
+
|
|
811
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
812
|
+
llm_output_sentences = asyncio.run(self.extract_async(text_content=text_content,
|
|
813
|
+
max_new_tokens=max_new_tokens,
|
|
814
|
+
document_key=document_key,
|
|
815
|
+
temperature=temperature,
|
|
816
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
817
|
+
**kwrs)
|
|
818
|
+
)
|
|
819
|
+
else:
|
|
820
|
+
llm_output_sentences = self.extract(text_content=text_content,
|
|
821
|
+
max_new_tokens=max_new_tokens,
|
|
822
|
+
document_key=document_key,
|
|
823
|
+
multi_turn=multi_turn,
|
|
824
|
+
temperature=temperature,
|
|
825
|
+
stream=stream,
|
|
826
|
+
**kwrs)
|
|
722
827
|
frame_list = []
|
|
723
|
-
for sent in
|
|
828
|
+
for sent in llm_output_sentences:
|
|
724
829
|
entity_json = []
|
|
725
830
|
for entity in self._extract_json(gen_text=sent['gen_text']):
|
|
726
831
|
if entity_key in entity:
|
|
@@ -885,6 +990,121 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
|
|
|
885
990
|
'sentence_text': sent['sentence_text'],
|
|
886
991
|
'gen_text': gen_text})
|
|
887
992
|
return output
|
|
993
|
+
|
|
994
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
|
|
995
|
+
document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
|
|
996
|
+
"""
|
|
997
|
+
The asynchronous version of the extract() method.
|
|
998
|
+
|
|
999
|
+
Parameters:
|
|
1000
|
+
----------
|
|
1001
|
+
text_content : Union[str, Dict[str,str]]
|
|
1002
|
+
the input text content to put in prompt template.
|
|
1003
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
1004
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
1005
|
+
max_new_tokens : str, Optional
|
|
1006
|
+
the max number of new tokens LLM should generate.
|
|
1007
|
+
document_key : str, Optional
|
|
1008
|
+
specify the key in text_content where document text is.
|
|
1009
|
+
If text_content is str, this parameter will be ignored.
|
|
1010
|
+
temperature : float, Optional
|
|
1011
|
+
the temperature for token sampling.
|
|
1012
|
+
concurrent_batch_size : int, Optional
|
|
1013
|
+
the number of sentences to process in concurrent.
|
|
1014
|
+
|
|
1015
|
+
Return : str
|
|
1016
|
+
the output from LLM. Need post-processing.
|
|
1017
|
+
"""
|
|
1018
|
+
# Check if self.inference_engine.chat_async() is implemented
|
|
1019
|
+
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1020
|
+
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1021
|
+
|
|
1022
|
+
# define output
|
|
1023
|
+
output = []
|
|
1024
|
+
# sentence tokenization
|
|
1025
|
+
if isinstance(text_content, str):
|
|
1026
|
+
sentences = self._get_sentences(text_content)
|
|
1027
|
+
elif isinstance(text_content, dict):
|
|
1028
|
+
sentences = self._get_sentences(text_content[document_key])
|
|
1029
|
+
# construct chat messages
|
|
1030
|
+
base_messages = []
|
|
1031
|
+
if self.system_prompt:
|
|
1032
|
+
base_messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1033
|
+
|
|
1034
|
+
base_messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
|
|
1035
|
+
base_messages.append({'role': 'assistant', 'content': 'Sure, please start with the first sentence.'})
|
|
1036
|
+
|
|
1037
|
+
# generate initial outputs sentence by sentence
|
|
1038
|
+
initials = []
|
|
1039
|
+
tasks = []
|
|
1040
|
+
message_list = []
|
|
1041
|
+
for i in range(0, len(sentences), concurrent_batch_size):
|
|
1042
|
+
batch = sentences[i:i + concurrent_batch_size]
|
|
1043
|
+
for sent in batch:
|
|
1044
|
+
messages = copy.deepcopy(base_messages)
|
|
1045
|
+
messages.append({'role': 'user', 'content': sent['sentence_text']})
|
|
1046
|
+
message_list.append(messages)
|
|
1047
|
+
task = asyncio.create_task(
|
|
1048
|
+
self.inference_engine.chat_async(
|
|
1049
|
+
messages=messages,
|
|
1050
|
+
max_new_tokens=max_new_tokens,
|
|
1051
|
+
temperature=temperature,
|
|
1052
|
+
**kwrs
|
|
1053
|
+
)
|
|
1054
|
+
)
|
|
1055
|
+
tasks.append(task)
|
|
1056
|
+
|
|
1057
|
+
# Wait until the batch is done, collect results and move on to next batch
|
|
1058
|
+
responses = await asyncio.gather(*tasks)
|
|
1059
|
+
# Collect initials
|
|
1060
|
+
for gen_text, sent, message in zip(responses, sentences, message_list):
|
|
1061
|
+
initials.append({'sentence_start': sent['start'],
|
|
1062
|
+
'sentence_end': sent['end'],
|
|
1063
|
+
'sentence_text': sent['sentence_text'],
|
|
1064
|
+
'gen_text': gen_text,
|
|
1065
|
+
'messages': message})
|
|
1066
|
+
|
|
1067
|
+
# Review
|
|
1068
|
+
reviews = []
|
|
1069
|
+
tasks = []
|
|
1070
|
+
for i in range(0, len(initials), concurrent_batch_size):
|
|
1071
|
+
batch = initials[i:i + concurrent_batch_size]
|
|
1072
|
+
for init in batch:
|
|
1073
|
+
messages = init["messages"]
|
|
1074
|
+
initial = init["gen_text"]
|
|
1075
|
+
messages.append({'role': 'assistant', 'content': initial})
|
|
1076
|
+
messages.append({'role': 'user', 'content': self.review_prompt})
|
|
1077
|
+
task = asyncio.create_task(
|
|
1078
|
+
self.inference_engine.chat_async(
|
|
1079
|
+
messages=messages,
|
|
1080
|
+
max_new_tokens=max_new_tokens,
|
|
1081
|
+
temperature=temperature,
|
|
1082
|
+
**kwrs
|
|
1083
|
+
)
|
|
1084
|
+
)
|
|
1085
|
+
tasks.append(task)
|
|
1086
|
+
|
|
1087
|
+
responses = await asyncio.gather(*tasks)
|
|
1088
|
+
|
|
1089
|
+
# Collect reviews
|
|
1090
|
+
for gen_text, sent in zip(responses, sentences):
|
|
1091
|
+
reviews.append({'sentence_start': sent['start'],
|
|
1092
|
+
'sentence_end': sent['end'],
|
|
1093
|
+
'sentence_text': sent['sentence_text'],
|
|
1094
|
+
'gen_text': gen_text})
|
|
1095
|
+
|
|
1096
|
+
for init, rev in zip(initials, reviews):
|
|
1097
|
+
if self.review_mode == "revision":
|
|
1098
|
+
gen_text = rev['gen_text']
|
|
1099
|
+
elif self.review_mode == "addition":
|
|
1100
|
+
gen_text = init['gen_text'] + '\n' + rev['gen_text']
|
|
1101
|
+
|
|
1102
|
+
# add to output
|
|
1103
|
+
output.append({'sentence_start': init['sentence_start'],
|
|
1104
|
+
'sentence_end': init['sentence_end'],
|
|
1105
|
+
'sentence_text': init['sentence_text'],
|
|
1106
|
+
'gen_text': gen_text})
|
|
1107
|
+
return output
|
|
888
1108
|
|
|
889
1109
|
|
|
890
1110
|
class SentenceCoTFrameExtractor(SentenceFrameExtractor):
|
|
@@ -1124,19 +1344,34 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1124
1344
|
self.possible_relation_func = possible_relation_func
|
|
1125
1345
|
|
|
1126
1346
|
|
|
1127
|
-
def
|
|
1128
|
-
|
|
1347
|
+
def _post_process(self, rel_json:str) -> bool:
|
|
1348
|
+
if len(rel_json) > 0:
|
|
1349
|
+
if "Relation" in rel_json[0]:
|
|
1350
|
+
rel = rel_json[0]["Relation"]
|
|
1351
|
+
if isinstance(rel, bool):
|
|
1352
|
+
return rel
|
|
1353
|
+
elif isinstance(rel, str) and rel in {"True", "False"}:
|
|
1354
|
+
return eval(rel)
|
|
1355
|
+
else:
|
|
1356
|
+
warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
|
|
1357
|
+
'Following default, relation = False.', RuntimeWarning)
|
|
1358
|
+
else:
|
|
1359
|
+
warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
|
|
1360
|
+
else:
|
|
1361
|
+
warnings.warn('Extractor did not output a JSON list. Following default, relation = False.', RuntimeWarning)
|
|
1362
|
+
return False
|
|
1363
|
+
|
|
1364
|
+
|
|
1365
|
+
def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1366
|
+
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1129
1367
|
"""
|
|
1130
|
-
This method
|
|
1368
|
+
This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
|
|
1369
|
+
Outputs pairs that are related.
|
|
1131
1370
|
|
|
1132
1371
|
Parameters:
|
|
1133
1372
|
-----------
|
|
1134
|
-
|
|
1135
|
-
a
|
|
1136
|
-
frame_2 : LLMInformationExtractionFrame
|
|
1137
|
-
the other frame
|
|
1138
|
-
text : str
|
|
1139
|
-
the entire document text
|
|
1373
|
+
doc : LLMInformationExtractionDocument
|
|
1374
|
+
a document with frames.
|
|
1140
1375
|
buffer_size : int, Optional
|
|
1141
1376
|
the number of characters before and after the two frames in the ROI text.
|
|
1142
1377
|
max_new_tokens : str, Optional
|
|
@@ -1146,51 +1381,111 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1146
1381
|
stream : bool, Optional
|
|
1147
1382
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1148
1383
|
|
|
1149
|
-
Return :
|
|
1150
|
-
a
|
|
1384
|
+
Return : List[Dict]
|
|
1385
|
+
a list of dict with {"frame_1_id", "frame_2_id"}.
|
|
1151
1386
|
"""
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
messages = []
|
|
1158
|
-
if self.system_prompt:
|
|
1159
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1387
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1388
|
+
output = []
|
|
1389
|
+
for frame_1, frame_2 in pairs:
|
|
1390
|
+
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1160
1391
|
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
|
|
1186
|
-
else:
|
|
1187
|
-
warnings.warn("Extractor did not output a JSON. Following default, relation = False.", RuntimeWarning)
|
|
1392
|
+
if pos_rel:
|
|
1393
|
+
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1394
|
+
if stream:
|
|
1395
|
+
print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
|
|
1396
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1397
|
+
messages = []
|
|
1398
|
+
if self.system_prompt:
|
|
1399
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1400
|
+
|
|
1401
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1402
|
+
"frame_1": str(frame_1.to_dict()),
|
|
1403
|
+
"frame_2": str(frame_2.to_dict())}
|
|
1404
|
+
)})
|
|
1405
|
+
|
|
1406
|
+
gen_text = self.inference_engine.chat(
|
|
1407
|
+
messages=messages,
|
|
1408
|
+
max_new_tokens=max_new_tokens,
|
|
1409
|
+
temperature=temperature,
|
|
1410
|
+
stream=stream,
|
|
1411
|
+
**kwrs
|
|
1412
|
+
)
|
|
1413
|
+
rel_json = self._extract_json(gen_text)
|
|
1414
|
+
if self._post_process(rel_json):
|
|
1415
|
+
output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id})
|
|
1188
1416
|
|
|
1189
|
-
return
|
|
1417
|
+
return output
|
|
1418
|
+
|
|
1190
1419
|
|
|
1420
|
+
async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1421
|
+
temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
|
|
1422
|
+
"""
|
|
1423
|
+
This is the asynchronous version of the extract() method.
|
|
1424
|
+
|
|
1425
|
+
Parameters:
|
|
1426
|
+
-----------
|
|
1427
|
+
doc : LLMInformationExtractionDocument
|
|
1428
|
+
a document with frames.
|
|
1429
|
+
buffer_size : int, Optional
|
|
1430
|
+
the number of characters before and after the two frames in the ROI text.
|
|
1431
|
+
max_new_tokens : str, Optional
|
|
1432
|
+
the max number of new tokens LLM should generate.
|
|
1433
|
+
temperature : float, Optional
|
|
1434
|
+
the temperature for token sampling.
|
|
1435
|
+
concurrent_batch_size : int, Optional
|
|
1436
|
+
the number of frame pairs to process in concurrent.
|
|
1437
|
+
|
|
1438
|
+
Return : List[Dict]
|
|
1439
|
+
a list of dict with {"frame_1", "frame_2"}.
|
|
1440
|
+
"""
|
|
1441
|
+
# Check if self.inference_engine.chat_async() is implemented
|
|
1442
|
+
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1443
|
+
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1444
|
+
|
|
1445
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1446
|
+
n_frames = len(doc.frames)
|
|
1447
|
+
num_pairs = (n_frames * (n_frames-1)) // 2
|
|
1448
|
+
rel_pair_list = []
|
|
1449
|
+
tasks = []
|
|
1450
|
+
for i in range(0, num_pairs, concurrent_batch_size):
|
|
1451
|
+
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
1452
|
+
for frame_1, frame_2 in batch:
|
|
1453
|
+
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1191
1454
|
|
|
1455
|
+
if pos_rel:
|
|
1456
|
+
rel_pair_list.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
|
|
1457
|
+
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1458
|
+
messages = []
|
|
1459
|
+
if self.system_prompt:
|
|
1460
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1461
|
+
|
|
1462
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1463
|
+
"frame_1": str(frame_1.to_dict()),
|
|
1464
|
+
"frame_2": str(frame_2.to_dict())}
|
|
1465
|
+
)})
|
|
1466
|
+
task = asyncio.create_task(
|
|
1467
|
+
self.inference_engine.chat_async(
|
|
1468
|
+
messages=messages,
|
|
1469
|
+
max_new_tokens=max_new_tokens,
|
|
1470
|
+
temperature=temperature,
|
|
1471
|
+
**kwrs
|
|
1472
|
+
)
|
|
1473
|
+
)
|
|
1474
|
+
tasks.append(task)
|
|
1475
|
+
|
|
1476
|
+
responses = await asyncio.gather(*tasks)
|
|
1477
|
+
|
|
1478
|
+
output = []
|
|
1479
|
+
for d, response in zip(rel_pair_list, responses):
|
|
1480
|
+
rel_json = self._extract_json(response)
|
|
1481
|
+
if self._post_process(rel_json):
|
|
1482
|
+
output.append(d)
|
|
1483
|
+
|
|
1484
|
+
return output
|
|
1485
|
+
|
|
1486
|
+
|
|
1192
1487
|
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1193
|
-
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1488
|
+
temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1194
1489
|
"""
|
|
1195
1490
|
This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
|
|
1196
1491
|
|
|
@@ -1204,6 +1499,10 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1204
1499
|
the max number of new tokens LLM should generate.
|
|
1205
1500
|
temperature : float, Optional
|
|
1206
1501
|
the temperature for token sampling.
|
|
1502
|
+
concurrent: bool, Optional
|
|
1503
|
+
if True, the extraction will be done in concurrent.
|
|
1504
|
+
concurrent_batch_size : int, Optional
|
|
1505
|
+
the number of frame pairs to process in concurrent.
|
|
1207
1506
|
stream : bool, Optional
|
|
1208
1507
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1209
1508
|
|
|
@@ -1216,19 +1515,26 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1216
1515
|
if doc.has_duplicate_frame_ids():
|
|
1217
1516
|
raise ValueError("All frame_ids in the input document must be unique.")
|
|
1218
1517
|
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1518
|
+
if concurrent:
|
|
1519
|
+
if stream:
|
|
1520
|
+
warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1521
|
+
|
|
1522
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
1523
|
+
return asyncio.run(self.extract_async(doc=doc,
|
|
1524
|
+
buffer_size=buffer_size,
|
|
1525
|
+
max_new_tokens=max_new_tokens,
|
|
1526
|
+
temperature=temperature,
|
|
1527
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
1528
|
+
**kwrs)
|
|
1529
|
+
)
|
|
1530
|
+
else:
|
|
1531
|
+
return self.extract(doc=doc,
|
|
1532
|
+
buffer_size=buffer_size,
|
|
1533
|
+
max_new_tokens=max_new_tokens,
|
|
1534
|
+
temperature=temperature,
|
|
1535
|
+
stream=stream,
|
|
1536
|
+
**kwrs)
|
|
1537
|
+
|
|
1232
1538
|
|
|
1233
1539
|
class MultiClassRelationExtractor(RelationExtractor):
|
|
1234
1540
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_types_func: Callable,
|
|
@@ -1273,22 +1579,41 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1273
1579
|
|
|
1274
1580
|
self.possible_relation_types_func = possible_relation_types_func
|
|
1275
1581
|
|
|
1276
|
-
|
|
1277
|
-
def
|
|
1278
|
-
pos_rel_types:List[str], text:str, buffer_size:int=100, max_new_tokens:int=128, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
|
|
1582
|
+
|
|
1583
|
+
def _post_process(self, rel_json:List[Dict], pos_rel_types:List[str]) -> Union[str, None]:
|
|
1279
1584
|
"""
|
|
1280
|
-
This method
|
|
1585
|
+
This method post-processes the extracted relation JSON.
|
|
1281
1586
|
|
|
1282
1587
|
Parameters:
|
|
1283
1588
|
-----------
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
frame_2 : LLMInformationExtractionFrame
|
|
1287
|
-
the other frame
|
|
1589
|
+
rel_json : List[Dict]
|
|
1590
|
+
the extracted relation JSON.
|
|
1288
1591
|
pos_rel_types : List[str]
|
|
1289
|
-
possible relation types.
|
|
1290
|
-
|
|
1291
|
-
|
|
1592
|
+
possible relation types by the possible_relation_types_func.
|
|
1593
|
+
|
|
1594
|
+
Return : Union[str, None]
|
|
1595
|
+
the relation type (str) or None for no relation.
|
|
1596
|
+
"""
|
|
1597
|
+
if len(rel_json) > 0:
|
|
1598
|
+
if "RelationType" in rel_json[0]:
|
|
1599
|
+
if rel_json[0]["RelationType"] in pos_rel_types:
|
|
1600
|
+
return rel_json[0]["RelationType"]
|
|
1601
|
+
else:
|
|
1602
|
+
warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
|
|
1603
|
+
else:
|
|
1604
|
+
warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
|
|
1605
|
+
return None
|
|
1606
|
+
|
|
1607
|
+
|
|
1608
|
+
def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1609
|
+
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1610
|
+
"""
|
|
1611
|
+
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
|
|
1612
|
+
|
|
1613
|
+
Parameters:
|
|
1614
|
+
-----------
|
|
1615
|
+
doc : LLMInformationExtractionDocument
|
|
1616
|
+
a document with frames.
|
|
1292
1617
|
buffer_size : int, Optional
|
|
1293
1618
|
the number of characters before and after the two frames in the ROI text.
|
|
1294
1619
|
max_new_tokens : str, Optional
|
|
@@ -1298,54 +1623,117 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1298
1623
|
stream : bool, Optional
|
|
1299
1624
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1300
1625
|
|
|
1301
|
-
Return :
|
|
1302
|
-
a relation
|
|
1626
|
+
Return : List[Dict]
|
|
1627
|
+
a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
|
|
1303
1628
|
"""
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1629
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1630
|
+
output = []
|
|
1631
|
+
for frame_1, frame_2 in pairs:
|
|
1632
|
+
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1308
1633
|
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1634
|
+
if pos_rel_types:
|
|
1635
|
+
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1636
|
+
if stream:
|
|
1637
|
+
print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
|
|
1638
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1639
|
+
messages = []
|
|
1640
|
+
if self.system_prompt:
|
|
1641
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1642
|
+
|
|
1643
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1644
|
+
"frame_1": str(frame_1.to_dict()),
|
|
1645
|
+
"frame_2": str(frame_2.to_dict()),
|
|
1646
|
+
"pos_rel_types":str(pos_rel_types)}
|
|
1647
|
+
)})
|
|
1648
|
+
|
|
1649
|
+
gen_text = self.inference_engine.chat(
|
|
1650
|
+
messages=messages,
|
|
1651
|
+
max_new_tokens=max_new_tokens,
|
|
1652
|
+
temperature=temperature,
|
|
1653
|
+
stream=stream,
|
|
1654
|
+
**kwrs
|
|
1655
|
+
)
|
|
1656
|
+
rel_json = self._extract_json(gen_text)
|
|
1657
|
+
rel = self._post_process(rel_json, pos_rel_types)
|
|
1658
|
+
if rel:
|
|
1659
|
+
output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'relation':rel})
|
|
1312
1660
|
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1661
|
+
return output
|
|
1662
|
+
|
|
1663
|
+
|
|
1664
|
+
async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1665
|
+
temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
|
|
1666
|
+
"""
|
|
1667
|
+
This is the asynchronous version of the extract() method.
|
|
1668
|
+
|
|
1669
|
+
Parameters:
|
|
1670
|
+
-----------
|
|
1671
|
+
doc : LLMInformationExtractionDocument
|
|
1672
|
+
a document with frames.
|
|
1673
|
+
buffer_size : int, Optional
|
|
1674
|
+
the number of characters before and after the two frames in the ROI text.
|
|
1675
|
+
max_new_tokens : str, Optional
|
|
1676
|
+
the max number of new tokens LLM should generate.
|
|
1677
|
+
temperature : float, Optional
|
|
1678
|
+
the temperature for token sampling.
|
|
1679
|
+
concurrent_batch_size : int, Optional
|
|
1680
|
+
the number of frame pairs to process in concurrent.
|
|
1681
|
+
|
|
1682
|
+
Return : List[Dict]
|
|
1683
|
+
a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
|
|
1684
|
+
"""
|
|
1685
|
+
# Check if self.inference_engine.chat_async() is implemented
|
|
1686
|
+
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1687
|
+
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1337
1688
|
|
|
1338
|
-
|
|
1339
|
-
|
|
1689
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1690
|
+
n_frames = len(doc.frames)
|
|
1691
|
+
num_pairs = (n_frames * (n_frames-1)) // 2
|
|
1692
|
+
rel_pair_list = []
|
|
1693
|
+
tasks = []
|
|
1694
|
+
for i in range(0, num_pairs, concurrent_batch_size):
|
|
1695
|
+
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
1696
|
+
for frame_1, frame_2 in batch:
|
|
1697
|
+
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1698
|
+
|
|
1699
|
+
if pos_rel_types:
|
|
1700
|
+
rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'pos_rel_types':pos_rel_types})
|
|
1701
|
+
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1702
|
+
messages = []
|
|
1703
|
+
if self.system_prompt:
|
|
1704
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1705
|
+
|
|
1706
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1707
|
+
"frame_1": str(frame_1.to_dict()),
|
|
1708
|
+
"frame_2": str(frame_2.to_dict()),
|
|
1709
|
+
"pos_rel_types":str(pos_rel_types)}
|
|
1710
|
+
)})
|
|
1711
|
+
task = asyncio.create_task(
|
|
1712
|
+
self.inference_engine.chat_async(
|
|
1713
|
+
messages=messages,
|
|
1714
|
+
max_new_tokens=max_new_tokens,
|
|
1715
|
+
temperature=temperature,
|
|
1716
|
+
**kwrs
|
|
1717
|
+
)
|
|
1718
|
+
)
|
|
1719
|
+
tasks.append(task)
|
|
1720
|
+
|
|
1721
|
+
responses = await asyncio.gather(*tasks)
|
|
1722
|
+
|
|
1723
|
+
output = []
|
|
1724
|
+
for d, response in zip(rel_pair_list, responses):
|
|
1725
|
+
rel_json = self._extract_json(response)
|
|
1726
|
+
rel = self._post_process(rel_json, d['pos_rel_types'])
|
|
1727
|
+
if rel:
|
|
1728
|
+
output.append({'frame_1':d['frame_1'], 'frame_2':d['frame_2'], 'relation':rel})
|
|
1340
1729
|
|
|
1341
|
-
return
|
|
1730
|
+
return output
|
|
1342
1731
|
|
|
1343
1732
|
|
|
1344
1733
|
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1345
|
-
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1734
|
+
temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1346
1735
|
"""
|
|
1347
|
-
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs
|
|
1348
|
-
and to provide possible relation types between two frames.
|
|
1736
|
+
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
|
|
1349
1737
|
|
|
1350
1738
|
Parameters:
|
|
1351
1739
|
-----------
|
|
@@ -1357,6 +1745,10 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1357
1745
|
the max number of new tokens LLM should generate.
|
|
1358
1746
|
temperature : float, Optional
|
|
1359
1747
|
the temperature for token sampling.
|
|
1748
|
+
concurrent: bool, Optional
|
|
1749
|
+
if True, the extraction will be done in concurrent.
|
|
1750
|
+
concurrent_batch_size : int, Optional
|
|
1751
|
+
the number of frame pairs to process in concurrent.
|
|
1360
1752
|
stream : bool, Optional
|
|
1361
1753
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1362
1754
|
|
|
@@ -1369,15 +1761,23 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1369
1761
|
if doc.has_duplicate_frame_ids():
|
|
1370
1762
|
raise ValueError("All frame_ids in the input document must be unique.")
|
|
1371
1763
|
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1764
|
+
if concurrent:
|
|
1765
|
+
if stream:
|
|
1766
|
+
warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1767
|
+
|
|
1768
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
1769
|
+
return asyncio.run(self.extract_async(doc=doc,
|
|
1770
|
+
buffer_size=buffer_size,
|
|
1771
|
+
max_new_tokens=max_new_tokens,
|
|
1772
|
+
temperature=temperature,
|
|
1773
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
1774
|
+
**kwrs)
|
|
1775
|
+
)
|
|
1776
|
+
else:
|
|
1777
|
+
return self.extract(doc=doc,
|
|
1778
|
+
buffer_size=buffer_size,
|
|
1779
|
+
max_new_tokens=max_new_tokens,
|
|
1780
|
+
temperature=temperature,
|
|
1781
|
+
stream=stream,
|
|
1782
|
+
**kwrs)
|
|
1783
|
+
|