llm-ie 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +4 -2
- llm_ie/asset/default_prompts/BasicReviewFrameExtractor_addition_review_prompt.txt +3 -0
- llm_ie/asset/default_prompts/BasicReviewFrameExtractor_revision_review_prompt.txt +2 -0
- llm_ie/asset/default_prompts/ReviewFrameExtractor_addition_review_prompt.txt +2 -1
- llm_ie/asset/default_prompts/ReviewFrameExtractor_revision_review_prompt.txt +2 -1
- llm_ie/asset/prompt_guide/BasicFrameExtractor_prompt_guide.txt +104 -86
- llm_ie/asset/prompt_guide/BasicReviewFrameExtractor_prompt_guide.txt +163 -0
- llm_ie/asset/prompt_guide/DirectFrameExtractor_prompt_guide.txt +163 -0
- llm_ie/asset/prompt_guide/ReviewFrameExtractor_prompt_guide.txt +103 -85
- llm_ie/asset/prompt_guide/SentenceFrameExtractor_prompt_guide.txt +103 -86
- llm_ie/asset/prompt_guide/SentenceReviewFrameExtractor_prompt_guide.txt +103 -86
- llm_ie/chunkers.py +191 -0
- llm_ie/data_types.py +75 -1
- llm_ie/engines.py +274 -183
- llm_ie/extractors.py +1062 -727
- llm_ie/prompt_editor.py +39 -6
- llm_ie-1.0.0.dist-info/METADATA +18 -0
- llm_ie-1.0.0.dist-info/RECORD +27 -0
- llm_ie/asset/prompt_guide/SentenceCoTFrameExtractor_prompt_guide.txt +0 -217
- llm_ie-0.4.6.dist-info/METADATA +0 -1215
- llm_ie-0.4.6.dist-info/RECORD +0 -23
- {llm_ie-0.4.6.dist-info → llm_ie-1.0.0.dist-info}/WHEEL +0 -0
llm_ie/extractors.py
CHANGED
|
@@ -8,10 +8,12 @@ import warnings
|
|
|
8
8
|
import itertools
|
|
9
9
|
import asyncio
|
|
10
10
|
import nest_asyncio
|
|
11
|
-
from typing import Set, List, Dict, Tuple, Union, Callable
|
|
12
|
-
from llm_ie.data_types import LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
11
|
+
from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional
|
|
12
|
+
from llm_ie.data_types import FrameExtractionUnit, FrameExtractionUnitResult, LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
13
|
+
from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
|
|
14
|
+
from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
|
|
13
15
|
from llm_ie.engines import InferenceEngine
|
|
14
|
-
from colorama import Fore, Style
|
|
16
|
+
from colorama import Fore, Style
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
class Extractor:
|
|
@@ -38,15 +40,46 @@ class Extractor:
|
|
|
38
40
|
def get_prompt_guide(cls) -> str:
|
|
39
41
|
"""
|
|
40
42
|
This method returns the pre-defined prompt guideline for the extractor from the package asset.
|
|
43
|
+
It searches for a guide specific to the current class first, if not found, it will search
|
|
44
|
+
for the guide in its ancestors by traversing the class's method resolution order (MRO).
|
|
41
45
|
"""
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
46
|
+
original_class_name = cls.__name__
|
|
47
|
+
|
|
48
|
+
for current_class_in_mro in cls.__mro__:
|
|
49
|
+
if current_class_in_mro is object:
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
current_class_name = current_class_in_mro.__name__
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
file_path_obj = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{current_class_name}_prompt_guide.txt")
|
|
56
|
+
|
|
57
|
+
with open(file_path_obj, 'r', encoding="utf-8") as f:
|
|
58
|
+
prompt_content = f.read()
|
|
59
|
+
# If the guide was found for an ancestor, not the original class, issue a warning.
|
|
60
|
+
if cls is not current_class_in_mro:
|
|
61
|
+
warnings.warn(
|
|
62
|
+
f"Prompt guide for '{original_class_name}' not found. "
|
|
63
|
+
f"Using guide from ancestor: '{current_class_name}_prompt_guide.txt'.",
|
|
64
|
+
UserWarning
|
|
65
|
+
)
|
|
66
|
+
return prompt_content
|
|
67
|
+
except FileNotFoundError:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
except Exception as e:
|
|
71
|
+
warnings.warn(
|
|
72
|
+
f"Error attempting to read prompt guide for '{current_class_name}' "
|
|
73
|
+
f"from '{str(file_path_obj)}': {e}. Trying next in MRO.",
|
|
74
|
+
UserWarning
|
|
75
|
+
)
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
# If the loop completes, no prompt guide was found for the original class or any of its ancestors.
|
|
79
|
+
raise FileNotFoundError(
|
|
80
|
+
f"Prompt guide for '{original_class_name}' not found in the package asset. "
|
|
81
|
+
f"Is it a custom extractor?"
|
|
82
|
+
)
|
|
50
83
|
|
|
51
84
|
def _get_user_prompt(self, text_content:Union[str, Dict[str,str]]) -> str:
|
|
52
85
|
"""
|
|
@@ -138,7 +171,8 @@ class Extractor:
|
|
|
138
171
|
|
|
139
172
|
class FrameExtractor(Extractor):
|
|
140
173
|
from nltk.tokenize import RegexpTokenizer
|
|
141
|
-
def __init__(self, inference_engine:InferenceEngine,
|
|
174
|
+
def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
|
|
175
|
+
prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None, **kwrs):
|
|
142
176
|
"""
|
|
143
177
|
This is the abstract class for frame extraction.
|
|
144
178
|
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
@@ -147,15 +181,25 @@ class FrameExtractor(Extractor):
|
|
|
147
181
|
----------
|
|
148
182
|
inference_engine : InferenceEngine
|
|
149
183
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
184
|
+
unit_chunker : UnitChunker
|
|
185
|
+
the unit chunker object that determines how to chunk the document text into units.
|
|
150
186
|
prompt_template : str
|
|
151
187
|
prompt template with "{{<placeholder name>}}" placeholder.
|
|
152
188
|
system_prompt : str, Optional
|
|
153
189
|
system prompt.
|
|
190
|
+
context_chunker : ContextChunker
|
|
191
|
+
the context chunker object that determines how to get context for each unit.
|
|
154
192
|
"""
|
|
155
193
|
super().__init__(inference_engine=inference_engine,
|
|
156
194
|
prompt_template=prompt_template,
|
|
157
195
|
system_prompt=system_prompt,
|
|
158
196
|
**kwrs)
|
|
197
|
+
|
|
198
|
+
self.unit_chunker = unit_chunker
|
|
199
|
+
if context_chunker is None:
|
|
200
|
+
self.context_chunker = NoContextChunker()
|
|
201
|
+
else:
|
|
202
|
+
self.context_chunker = context_chunker
|
|
159
203
|
|
|
160
204
|
self.tokenizer = self.RegexpTokenizer(r'\w+|[^\w\s]')
|
|
161
205
|
|
|
@@ -288,7 +332,7 @@ class FrameExtractor(Extractor):
|
|
|
288
332
|
return entity_spans
|
|
289
333
|
|
|
290
334
|
@abc.abstractmethod
|
|
291
|
-
def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, **kwrs) -> str:
|
|
335
|
+
def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, return_messages_log:bool=False, **kwrs) -> str:
|
|
292
336
|
"""
|
|
293
337
|
This method inputs text content and outputs a string generated by LLM
|
|
294
338
|
|
|
@@ -300,6 +344,8 @@ class FrameExtractor(Extractor):
|
|
|
300
344
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
301
345
|
max_new_tokens : str, Optional
|
|
302
346
|
the max number of new tokens LLM can generate.
|
|
347
|
+
return_messages_log : bool, Optional
|
|
348
|
+
if True, a list of messages will be returned.
|
|
303
349
|
|
|
304
350
|
Return : str
|
|
305
351
|
the output from LLM. Need post-processing.
|
|
@@ -309,7 +355,7 @@ class FrameExtractor(Extractor):
|
|
|
309
355
|
|
|
310
356
|
@abc.abstractmethod
|
|
311
357
|
def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
|
|
312
|
-
document_key:str=None, **kwrs) -> List[LLMInformationExtractionFrame]:
|
|
358
|
+
document_key:str=None, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
|
|
313
359
|
"""
|
|
314
360
|
This method inputs text content and outputs a list of LLMInformationExtractionFrame
|
|
315
361
|
It use the extract() method and post-process outputs into frames.
|
|
@@ -327,6 +373,8 @@ class FrameExtractor(Extractor):
|
|
|
327
373
|
document_key : str, Optional
|
|
328
374
|
specify the key in text_content where document text is.
|
|
329
375
|
If text_content is str, this parameter will be ignored.
|
|
376
|
+
return_messages_log : bool, Optional
|
|
377
|
+
if True, a list of messages will be returned.
|
|
330
378
|
|
|
331
379
|
Return : str
|
|
332
380
|
a list of frames.
|
|
@@ -334,332 +382,38 @@ class FrameExtractor(Extractor):
|
|
|
334
382
|
return NotImplemented
|
|
335
383
|
|
|
336
384
|
|
|
337
|
-
class
|
|
338
|
-
def __init__(self, inference_engine:InferenceEngine,
|
|
385
|
+
class DirectFrameExtractor(FrameExtractor):
|
|
386
|
+
def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
|
|
387
|
+
prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None, **kwrs):
|
|
339
388
|
"""
|
|
340
|
-
This class
|
|
341
|
-
Input system prompt (optional), prompt template (with instruction, few-shot examples)
|
|
342
|
-
and specify a LLM.
|
|
389
|
+
This class is for general unit-context frame extraction.
|
|
390
|
+
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
343
391
|
|
|
344
392
|
Parameters:
|
|
345
393
|
----------
|
|
346
394
|
inference_engine : InferenceEngine
|
|
347
395
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
396
|
+
unit_chunker : UnitChunker
|
|
397
|
+
the unit chunker object that determines how to chunk the document text into units.
|
|
348
398
|
prompt_template : str
|
|
349
399
|
prompt template with "{{<placeholder name>}}" placeholder.
|
|
350
400
|
system_prompt : str, Optional
|
|
351
401
|
system prompt.
|
|
402
|
+
context_chunker : ContextChunker
|
|
403
|
+
the context chunker object that determines how to get context for each unit.
|
|
352
404
|
"""
|
|
353
|
-
super().__init__(inference_engine=inference_engine,
|
|
354
|
-
|
|
355
|
-
|
|
405
|
+
super().__init__(inference_engine=inference_engine,
|
|
406
|
+
unit_chunker=unit_chunker,
|
|
407
|
+
prompt_template=prompt_template,
|
|
408
|
+
system_prompt=system_prompt,
|
|
409
|
+
context_chunker=context_chunker,
|
|
356
410
|
**kwrs)
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
|
|
360
|
-
temperature:float=0.0, stream:bool=False, **kwrs) -> str:
|
|
361
|
-
"""
|
|
362
|
-
This method inputs a text and outputs a string generated by LLM.
|
|
363
|
-
|
|
364
|
-
Parameters:
|
|
365
|
-
----------
|
|
366
|
-
text_content : Union[str, Dict[str,str]]
|
|
367
|
-
the input text content to put in prompt template.
|
|
368
|
-
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
369
|
-
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
370
|
-
max_new_tokens : str, Optional
|
|
371
|
-
the max number of new tokens LLM can generate.
|
|
372
|
-
temperature : float, Optional
|
|
373
|
-
the temperature for token sampling.
|
|
374
|
-
stream : bool, Optional
|
|
375
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
376
|
-
|
|
377
|
-
Return : str
|
|
378
|
-
the output from LLM. Need post-processing.
|
|
379
|
-
"""
|
|
380
|
-
messages = []
|
|
381
|
-
if self.system_prompt:
|
|
382
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
383
|
-
|
|
384
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
|
|
385
|
-
response = self.inference_engine.chat(
|
|
386
|
-
messages=messages,
|
|
387
|
-
max_new_tokens=max_new_tokens,
|
|
388
|
-
temperature=temperature,
|
|
389
|
-
stream=stream,
|
|
390
|
-
**kwrs
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
return response
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
|
|
397
|
-
temperature:float=0.0, document_key:str=None, stream:bool=False,
|
|
398
|
-
case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2,
|
|
399
|
-
fuzzy_score_cutoff:float=0.8, allow_overlap_entities:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
|
|
400
|
-
"""
|
|
401
|
-
This method inputs a text and outputs a list of LLMInformationExtractionFrame
|
|
402
|
-
It use the extract() method and post-process outputs into frames.
|
|
403
|
-
|
|
404
|
-
Parameters:
|
|
405
|
-
----------
|
|
406
|
-
text_content : Union[str, Dict[str,str]]
|
|
407
|
-
the input text content to put in prompt template.
|
|
408
|
-
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
409
|
-
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
410
|
-
entity_key : str
|
|
411
|
-
the key (in ouptut JSON) for entity text. Any extraction that does not include entity key will be dropped.
|
|
412
|
-
max_new_tokens : str, Optional
|
|
413
|
-
the max number of new tokens LLM should generate.
|
|
414
|
-
temperature : float, Optional
|
|
415
|
-
the temperature for token sampling.
|
|
416
|
-
document_key : str, Optional
|
|
417
|
-
specify the key in text_content where document text is.
|
|
418
|
-
If text_content is str, this parameter will be ignored.
|
|
419
|
-
stream : bool, Optional
|
|
420
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
421
|
-
case_sensitive : bool, Optional
|
|
422
|
-
if True, entity text matching will be case-sensitive.
|
|
423
|
-
fuzzy_match : bool, Optional
|
|
424
|
-
if True, fuzzy matching will be applied to find entity text.
|
|
425
|
-
fuzzy_buffer_size : float, Optional
|
|
426
|
-
the buffer size for fuzzy matching. Default is 20% of entity text length.
|
|
427
|
-
fuzzy_score_cutoff : float, Optional
|
|
428
|
-
the Jaccard score cutoff for fuzzy matching.
|
|
429
|
-
Matched entity text must have a score higher than this value or a None will be returned.
|
|
430
|
-
allow_overlap_entities : bool, Optional
|
|
431
|
-
if True, entities can overlap in the text.
|
|
432
|
-
Note that this can cause multiple frames to be generated on the same entity span if they have same entity text.
|
|
433
|
-
|
|
434
|
-
Return : str
|
|
435
|
-
a list of frames.
|
|
436
|
-
"""
|
|
437
|
-
if isinstance(text_content, str):
|
|
438
|
-
text = text_content
|
|
439
|
-
elif isinstance(text_content, dict):
|
|
440
|
-
if document_key is None:
|
|
441
|
-
raise ValueError("document_key must be provided when text_content is dict.")
|
|
442
|
-
text = text_content[document_key]
|
|
443
|
-
|
|
444
|
-
frame_list = []
|
|
445
|
-
gen_text = self.extract(text_content=text_content,
|
|
446
|
-
max_new_tokens=max_new_tokens,
|
|
447
|
-
temperature=temperature,
|
|
448
|
-
stream=stream,
|
|
449
|
-
**kwrs)
|
|
450
|
-
|
|
451
|
-
entity_json = []
|
|
452
|
-
for entity in self._extract_json(gen_text=gen_text):
|
|
453
|
-
if entity_key in entity:
|
|
454
|
-
entity_json.append(entity)
|
|
455
|
-
else:
|
|
456
|
-
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{entity_key}"). This frame will be dropped.', RuntimeWarning)
|
|
457
|
-
|
|
458
|
-
spans = self._find_entity_spans(text=text,
|
|
459
|
-
entities=[e[entity_key] for e in entity_json],
|
|
460
|
-
case_sensitive=case_sensitive,
|
|
461
|
-
fuzzy_match=fuzzy_match,
|
|
462
|
-
fuzzy_buffer_size=fuzzy_buffer_size,
|
|
463
|
-
fuzzy_score_cutoff=fuzzy_score_cutoff,
|
|
464
|
-
allow_overlap_entities=allow_overlap_entities)
|
|
465
|
-
|
|
466
|
-
for i, (ent, span) in enumerate(zip(entity_json, spans)):
|
|
467
|
-
if span is not None:
|
|
468
|
-
start, end = span
|
|
469
|
-
frame = LLMInformationExtractionFrame(frame_id=f"{i}",
|
|
470
|
-
start=start,
|
|
471
|
-
end=end,
|
|
472
|
-
entity_text=text[start:end],
|
|
473
|
-
attr={k: v for k, v in ent.items() if k != entity_key and v != ""})
|
|
474
|
-
frame_list.append(frame)
|
|
475
|
-
return frame_list
|
|
476
|
-
|
|
477
411
|
|
|
478
|
-
class ReviewFrameExtractor(BasicFrameExtractor):
|
|
479
|
-
def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
|
|
480
|
-
review_mode:str, review_prompt:str=None,system_prompt:str=None, **kwrs):
|
|
481
|
-
"""
|
|
482
|
-
This class add a review step after the BasicFrameExtractor.
|
|
483
|
-
The Review process asks LLM to review its output and:
|
|
484
|
-
1. add more frames while keep current. This is efficient for boosting recall.
|
|
485
|
-
2. or, regenerate frames (add new and delete existing).
|
|
486
|
-
Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
|
|
487
|
-
|
|
488
|
-
Parameters:
|
|
489
|
-
----------
|
|
490
|
-
inference_engine : InferenceEngine
|
|
491
|
-
the LLM inferencing engine object. Must implements the chat() method.
|
|
492
|
-
prompt_template : str
|
|
493
|
-
prompt template with "{{<placeholder name>}}" placeholder.
|
|
494
|
-
review_prompt : str: Optional
|
|
495
|
-
the prompt text that ask LLM to review. Specify addition or revision in the instruction.
|
|
496
|
-
if not provided, a default review prompt will be used.
|
|
497
|
-
review_mode : str
|
|
498
|
-
review mode. Must be one of {"addition", "revision"}
|
|
499
|
-
addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
|
|
500
|
-
system_prompt : str, Optional
|
|
501
|
-
system prompt.
|
|
502
|
-
"""
|
|
503
|
-
super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
|
|
504
|
-
system_prompt=system_prompt, **kwrs)
|
|
505
|
-
if review_mode not in {"addition", "revision"}:
|
|
506
|
-
raise ValueError('review_mode must be one of {"addition", "revision"}.')
|
|
507
|
-
self.review_mode = review_mode
|
|
508
|
-
|
|
509
|
-
if review_prompt:
|
|
510
|
-
self.review_prompt = review_prompt
|
|
511
|
-
else:
|
|
512
|
-
file_path = importlib.resources.files('llm_ie.asset.default_prompts').\
|
|
513
|
-
joinpath(f"{self.__class__.__name__}_{self.review_mode}_review_prompt.txt")
|
|
514
|
-
with open(file_path, 'r', encoding="utf-8") as f:
|
|
515
|
-
self.review_prompt = f.read()
|
|
516
|
-
|
|
517
|
-
warnings.warn(f'Custom review prompt not provided. The default review prompt is used:\n"{self.review_prompt}"', UserWarning)
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
def extract(self, text_content:Union[str, Dict[str,str]],
|
|
521
|
-
max_new_tokens:int=4096, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
|
|
522
|
-
"""
|
|
523
|
-
This method inputs a text and outputs a string generated by LLM.
|
|
524
|
-
|
|
525
|
-
Parameters:
|
|
526
|
-
----------
|
|
527
|
-
text_content : Union[str, Dict[str,str]]
|
|
528
|
-
the input text content to put in prompt template.
|
|
529
|
-
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
530
|
-
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
531
|
-
max_new_tokens : str, Optional
|
|
532
|
-
the max number of new tokens LLM can generate.
|
|
533
|
-
temperature : float, Optional
|
|
534
|
-
the temperature for token sampling.
|
|
535
|
-
stream : bool, Optional
|
|
536
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
537
|
-
|
|
538
|
-
Return : str
|
|
539
|
-
the output from LLM. Need post-processing.
|
|
540
|
-
"""
|
|
541
|
-
messages = []
|
|
542
|
-
if self.system_prompt:
|
|
543
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
544
|
-
|
|
545
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
|
|
546
|
-
# Initial output
|
|
547
|
-
if stream:
|
|
548
|
-
print(f"{Fore.BLUE}Initial Output:{Style.RESET_ALL}")
|
|
549
|
-
|
|
550
|
-
initial = self.inference_engine.chat(
|
|
551
|
-
messages=messages,
|
|
552
|
-
max_new_tokens=max_new_tokens,
|
|
553
|
-
temperature=temperature,
|
|
554
|
-
stream=stream,
|
|
555
|
-
**kwrs
|
|
556
|
-
)
|
|
557
|
-
|
|
558
|
-
# Review
|
|
559
|
-
messages.append({'role': 'assistant', 'content': initial})
|
|
560
|
-
messages.append({'role': 'user', 'content': self.review_prompt})
|
|
561
|
-
|
|
562
|
-
if stream:
|
|
563
|
-
print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
|
|
564
|
-
review = self.inference_engine.chat(
|
|
565
|
-
messages=messages,
|
|
566
|
-
max_new_tokens=max_new_tokens,
|
|
567
|
-
temperature=temperature,
|
|
568
|
-
stream=stream,
|
|
569
|
-
**kwrs
|
|
570
|
-
)
|
|
571
|
-
|
|
572
|
-
# Output
|
|
573
|
-
if self.review_mode == "revision":
|
|
574
|
-
return review
|
|
575
|
-
elif self.review_mode == "addition":
|
|
576
|
-
return initial + '\n' + review
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
class SentenceFrameExtractor(FrameExtractor):
|
|
580
|
-
from nltk.tokenize.punkt import PunktSentenceTokenizer
|
|
581
|
-
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
|
|
582
|
-
context_sentences:Union[str, int]="all", **kwrs):
|
|
583
|
-
"""
|
|
584
|
-
This class performs sentence-by-sentence information extraction.
|
|
585
|
-
The process is as follows:
|
|
586
|
-
1. system prompt (optional)
|
|
587
|
-
2. user prompt with instructions (schema, background, full text, few-shot example...)
|
|
588
|
-
3. feed a sentence (start with first sentence)
|
|
589
|
-
4. LLM extract entities and attributes from the sentence
|
|
590
|
-
5. repeat #3 and #4
|
|
591
|
-
|
|
592
|
-
Input system prompt (optional), prompt template (with user instructions),
|
|
593
|
-
and specify a LLM.
|
|
594
|
-
|
|
595
|
-
Parameters:
|
|
596
|
-
----------
|
|
597
|
-
inference_engine : InferenceEngine
|
|
598
|
-
the LLM inferencing engine object. Must implements the chat() method.
|
|
599
|
-
prompt_template : str
|
|
600
|
-
prompt template with "{{<placeholder name>}}" placeholder.
|
|
601
|
-
system_prompt : str, Optional
|
|
602
|
-
system prompt.
|
|
603
|
-
context_sentences : Union[str, int], Optional
|
|
604
|
-
number of sentences before and after the given sentence to provide additional context.
|
|
605
|
-
if "all", the full text will be provided in the prompt as context.
|
|
606
|
-
if 0, no additional context will be provided.
|
|
607
|
-
This is good for tasks that does not require context beyond the given sentence.
|
|
608
|
-
if > 0, the number of sentences before and after the given sentence to provide as context.
|
|
609
|
-
This is good for tasks that require context beyond the given sentence.
|
|
610
|
-
"""
|
|
611
|
-
super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
|
|
612
|
-
system_prompt=system_prompt, **kwrs)
|
|
613
|
-
|
|
614
|
-
if not isinstance(context_sentences, int) and context_sentences != "all":
|
|
615
|
-
raise ValueError('context_sentences must be an integer (>= 0) or "all".')
|
|
616
|
-
|
|
617
|
-
if isinstance(context_sentences, int) and context_sentences < 0:
|
|
618
|
-
raise ValueError("context_sentences must be a positive integer.")
|
|
619
|
-
|
|
620
|
-
self.context_sentences =context_sentences
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
def _get_sentences(self, text:str) -> List[Dict[str,str]]:
|
|
624
|
-
"""
|
|
625
|
-
This method sentence tokenize the input text into a list of sentences
|
|
626
|
-
as dict of {start, end, sentence_text}
|
|
627
|
-
|
|
628
|
-
Parameters:
|
|
629
|
-
----------
|
|
630
|
-
text : str
|
|
631
|
-
text to sentence tokenize.
|
|
632
|
-
|
|
633
|
-
Returns : List[Dict[str,str]]
|
|
634
|
-
a list of sentences as dict with keys: {"sentence_text", "start", "end"}.
|
|
635
|
-
"""
|
|
636
|
-
sentences = []
|
|
637
|
-
for start, end in self.PunktSentenceTokenizer().span_tokenize(text):
|
|
638
|
-
sentences.append({"sentence_text": text[start:end],
|
|
639
|
-
"start": start,
|
|
640
|
-
"end": end})
|
|
641
|
-
return sentences
|
|
642
|
-
|
|
643
412
|
|
|
644
|
-
def
|
|
645
|
-
|
|
646
|
-
This function returns the context sentences for the current sentence of interest (i).
|
|
647
|
-
"""
|
|
648
|
-
if self.context_sentences == "all":
|
|
649
|
-
context = text_content if isinstance(text_content, str) else text_content[document_key]
|
|
650
|
-
elif self.context_sentences == 0:
|
|
651
|
-
context = ""
|
|
652
|
-
else:
|
|
653
|
-
start = max(0, i - self.context_sentences)
|
|
654
|
-
end = min(i + 1 + self.context_sentences, len(sentences))
|
|
655
|
-
context = " ".join([s['sentence_text'] for s in sentences[start:end]])
|
|
656
|
-
return context
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
|
|
660
|
-
document_key:str=None, temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict[str,str]]:
|
|
413
|
+
def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
|
|
414
|
+
document_key:str=None, temperature:float=0.0, verbose:bool=False, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
|
|
661
415
|
"""
|
|
662
|
-
This method inputs a text and outputs a list of outputs per
|
|
416
|
+
This method inputs a text and outputs a list of outputs per unit.
|
|
663
417
|
|
|
664
418
|
Parameters:
|
|
665
419
|
----------
|
|
@@ -667,77 +421,211 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
667
421
|
the input text content to put in prompt template.
|
|
668
422
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
669
423
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
670
|
-
max_new_tokens :
|
|
424
|
+
max_new_tokens : int, Optional
|
|
671
425
|
the max number of new tokens LLM should generate.
|
|
672
426
|
document_key : str, Optional
|
|
673
427
|
specify the key in text_content where document text is.
|
|
674
428
|
If text_content is str, this parameter will be ignored.
|
|
675
429
|
temperature : float, Optional
|
|
676
430
|
the temperature for token sampling.
|
|
677
|
-
|
|
431
|
+
verbose : bool, Optional
|
|
678
432
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
433
|
+
return_messages_log : bool, Optional
|
|
434
|
+
if True, a list of messages will be returned.
|
|
679
435
|
|
|
680
|
-
Return :
|
|
681
|
-
the output from LLM.
|
|
436
|
+
Return : List[FrameExtractionUnitResult]
|
|
437
|
+
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
682
438
|
"""
|
|
683
439
|
# define output
|
|
684
440
|
output = []
|
|
685
|
-
#
|
|
441
|
+
# unit chunking
|
|
686
442
|
if isinstance(text_content, str):
|
|
687
|
-
|
|
443
|
+
doc_text = text_content
|
|
444
|
+
|
|
688
445
|
elif isinstance(text_content, dict):
|
|
689
446
|
if document_key is None:
|
|
690
447
|
raise ValueError("document_key must be provided when text_content is dict.")
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
448
|
+
doc_text = text_content[document_key]
|
|
449
|
+
|
|
450
|
+
units = self.unit_chunker.chunk(doc_text)
|
|
451
|
+
# context chunker init
|
|
452
|
+
self.context_chunker.fit(doc_text, units)
|
|
453
|
+
# messages log
|
|
454
|
+
if return_messages_log:
|
|
455
|
+
messages_log = []
|
|
456
|
+
|
|
457
|
+
# generate unit by unit
|
|
458
|
+
for i, unit in enumerate(units):
|
|
695
459
|
# construct chat messages
|
|
696
460
|
messages = []
|
|
697
461
|
if self.system_prompt:
|
|
698
462
|
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
699
463
|
|
|
700
|
-
context = self.
|
|
464
|
+
context = self.context_chunker.chunk(unit)
|
|
701
465
|
|
|
702
|
-
if
|
|
703
|
-
# no context, just place
|
|
704
|
-
|
|
466
|
+
if context == "":
|
|
467
|
+
# no context, just place unit in user prompt
|
|
468
|
+
if isinstance(text_content, str):
|
|
469
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
470
|
+
else:
|
|
471
|
+
unit_content = text_content.copy()
|
|
472
|
+
unit_content[document_key] = unit.text
|
|
473
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
705
474
|
else:
|
|
706
|
-
# insert context
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
475
|
+
# insert context to user prompt
|
|
476
|
+
if isinstance(text_content, str):
|
|
477
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
478
|
+
else:
|
|
479
|
+
context_content = text_content.copy()
|
|
480
|
+
context_content[document_key] = context
|
|
481
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
482
|
+
# simulate conversation where assistant confirms
|
|
483
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
484
|
+
# place unit of interest
|
|
485
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
486
|
+
|
|
487
|
+
if verbose:
|
|
488
|
+
print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
489
|
+
if context != "":
|
|
716
490
|
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
717
491
|
|
|
718
492
|
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
719
493
|
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
494
|
+
response_stream = self.inference_engine.chat(
|
|
495
|
+
messages=messages,
|
|
496
|
+
max_new_tokens=max_new_tokens,
|
|
497
|
+
temperature=temperature,
|
|
498
|
+
stream=True,
|
|
499
|
+
**kwrs
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
gen_text = ""
|
|
503
|
+
for chunk in response_stream:
|
|
504
|
+
gen_text += chunk
|
|
505
|
+
print(chunk, end='', flush=True)
|
|
506
|
+
|
|
507
|
+
else:
|
|
508
|
+
gen_text = self.inference_engine.chat(
|
|
509
|
+
messages=messages,
|
|
510
|
+
max_new_tokens=max_new_tokens,
|
|
511
|
+
temperature=temperature,
|
|
512
|
+
stream=False,
|
|
513
|
+
**kwrs
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
if return_messages_log:
|
|
517
|
+
messages.append({"role": "assistant", "content": gen_text})
|
|
518
|
+
messages_log.append(messages)
|
|
727
519
|
|
|
728
520
|
# add to output
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
521
|
+
result = FrameExtractionUnitResult(
|
|
522
|
+
start=unit.start,
|
|
523
|
+
end=unit.end,
|
|
524
|
+
text=unit.text,
|
|
525
|
+
gen_text=gen_text)
|
|
526
|
+
output.append(result)
|
|
733
527
|
|
|
528
|
+
if return_messages_log:
|
|
529
|
+
return output, messages_log
|
|
530
|
+
|
|
734
531
|
return output
|
|
735
532
|
|
|
533
|
+
def stream(self, text_content: Union[str, Dict[str, str]], max_new_tokens: int = 2048, document_key: str = None,
|
|
534
|
+
temperature: float = 0.0, **kwrs) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
|
|
535
|
+
"""
|
|
536
|
+
Streams LLM responses per unit with structured event types,
|
|
537
|
+
and returns collected data for post-processing.
|
|
538
|
+
|
|
539
|
+
Yields:
|
|
540
|
+
-------
|
|
541
|
+
Dict[str, Any]: (type, data)
|
|
542
|
+
- {"type": "info", "data": str_message}: General informational messages.
|
|
543
|
+
- {"type": "unit", "data": dict_unit_info}: Signals start of a new unit. dict_unit_info contains {'id', 'text', 'start', 'end'}
|
|
544
|
+
- {"type": "context", "data": str_context}: Context string for the current unit.
|
|
545
|
+
- {"type": "llm_chunk", "data": str_chunk}: A raw chunk from the LLM.
|
|
736
546
|
|
|
737
|
-
|
|
738
|
-
|
|
547
|
+
Returns:
|
|
548
|
+
--------
|
|
549
|
+
List[FrameExtractionUnitResult]:
|
|
550
|
+
A list of FrameExtractionUnitResult objects, each containing the
|
|
551
|
+
original unit details and the fully accumulated 'gen_text' from the LLM.
|
|
739
552
|
"""
|
|
740
|
-
|
|
553
|
+
collected_results: List[FrameExtractionUnitResult] = []
|
|
554
|
+
|
|
555
|
+
if isinstance(text_content, str):
|
|
556
|
+
doc_text = text_content
|
|
557
|
+
elif isinstance(text_content, dict):
|
|
558
|
+
if document_key is None:
|
|
559
|
+
raise ValueError("document_key must be provided when text_content is dict.")
|
|
560
|
+
if document_key not in text_content:
|
|
561
|
+
raise ValueError(f"document_key '{document_key}' not found in text_content.")
|
|
562
|
+
doc_text = text_content[document_key]
|
|
563
|
+
else:
|
|
564
|
+
raise TypeError("text_content must be a string or a dictionary.")
|
|
565
|
+
|
|
566
|
+
units: List[FrameExtractionUnit] = self.unit_chunker.chunk(doc_text)
|
|
567
|
+
self.context_chunker.fit(doc_text, units)
|
|
568
|
+
|
|
569
|
+
yield {"type": "info", "data": f"Starting LLM processing for {len(units)} units."}
|
|
570
|
+
|
|
571
|
+
for i, unit in enumerate(units):
|
|
572
|
+
unit_info_payload = {"id": i, "text": unit.text, "start": unit.start, "end": unit.end}
|
|
573
|
+
yield {"type": "unit", "data": unit_info_payload}
|
|
574
|
+
|
|
575
|
+
messages = []
|
|
576
|
+
if self.system_prompt:
|
|
577
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
578
|
+
|
|
579
|
+
context_str = self.context_chunker.chunk(unit)
|
|
580
|
+
|
|
581
|
+
# Construct prompt input based on whether text_content was str or dict
|
|
582
|
+
if context_str:
|
|
583
|
+
yield {"type": "context", "data": context_str}
|
|
584
|
+
prompt_input_for_context = context_str
|
|
585
|
+
if isinstance(text_content, dict):
|
|
586
|
+
context_content_dict = text_content.copy()
|
|
587
|
+
context_content_dict[document_key] = context_str
|
|
588
|
+
prompt_input_for_context = context_content_dict
|
|
589
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_context)})
|
|
590
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
591
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
592
|
+
else: # No context
|
|
593
|
+
prompt_input_for_unit = unit.text
|
|
594
|
+
if isinstance(text_content, dict):
|
|
595
|
+
unit_content_dict = text_content.copy()
|
|
596
|
+
unit_content_dict[document_key] = unit.text
|
|
597
|
+
prompt_input_for_unit = unit_content_dict
|
|
598
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_unit)})
|
|
599
|
+
|
|
600
|
+
current_gen_text = ""
|
|
601
|
+
|
|
602
|
+
response_stream = self.inference_engine.chat(
|
|
603
|
+
messages=messages,
|
|
604
|
+
max_new_tokens=max_new_tokens,
|
|
605
|
+
temperature=temperature,
|
|
606
|
+
stream=True,
|
|
607
|
+
**kwrs
|
|
608
|
+
)
|
|
609
|
+
for chunk in response_stream:
|
|
610
|
+
yield {"type": "llm_chunk", "data": chunk}
|
|
611
|
+
current_gen_text += chunk
|
|
612
|
+
|
|
613
|
+
# Store the result for this unit
|
|
614
|
+
result_for_unit = FrameExtractionUnitResult(
|
|
615
|
+
start=unit.start,
|
|
616
|
+
end=unit.end,
|
|
617
|
+
text=unit.text,
|
|
618
|
+
gen_text=current_gen_text
|
|
619
|
+
)
|
|
620
|
+
collected_results.append(result_for_unit)
|
|
621
|
+
|
|
622
|
+
yield {"type": "info", "data": "All units processed by LLM."}
|
|
623
|
+
return collected_results
|
|
624
|
+
|
|
625
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None, temperature:float=0.0,
|
|
626
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
|
|
627
|
+
"""
|
|
628
|
+
This is the asynchronous version of the extract() method.
|
|
741
629
|
|
|
742
630
|
Parameters:
|
|
743
631
|
----------
|
|
@@ -745,7 +633,7 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
745
633
|
the input text content to put in prompt template.
|
|
746
634
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
747
635
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
748
|
-
max_new_tokens :
|
|
636
|
+
max_new_tokens : int, Optional
|
|
749
637
|
the max number of new tokens LLM should generate.
|
|
750
638
|
document_key : str, Optional
|
|
751
639
|
specify the key in text_content where document text is.
|
|
@@ -753,73 +641,129 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
753
641
|
temperature : float, Optional
|
|
754
642
|
the temperature for token sampling.
|
|
755
643
|
concurrent_batch_size : int, Optional
|
|
756
|
-
the
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
if not hasattr(self.inference_engine, 'chat_async'):
|
|
760
|
-
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
644
|
+
the batch size for concurrent processing.
|
|
645
|
+
return_messages_log : bool, Optional
|
|
646
|
+
if True, a list of messages will be returned.
|
|
761
647
|
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
648
|
+
Return : List[FrameExtractionUnitResult]
|
|
649
|
+
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
650
|
+
"""
|
|
765
651
|
if isinstance(text_content, str):
|
|
766
|
-
|
|
652
|
+
doc_text = text_content
|
|
767
653
|
elif isinstance(text_content, dict):
|
|
768
654
|
if document_key is None:
|
|
769
655
|
raise ValueError("document_key must be provided when text_content is dict.")
|
|
770
|
-
|
|
656
|
+
if document_key not in text_content:
|
|
657
|
+
raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
|
|
658
|
+
doc_text = text_content[document_key]
|
|
659
|
+
else:
|
|
660
|
+
raise TypeError("text_content must be a string or a dictionary.")
|
|
771
661
|
|
|
772
|
-
|
|
773
|
-
for i in range(0, len(sentences), concurrent_batch_size):
|
|
774
|
-
tasks = []
|
|
775
|
-
batch = sentences[i:i + concurrent_batch_size]
|
|
776
|
-
for j, sent in enumerate(batch):
|
|
777
|
-
# construct chat messages
|
|
778
|
-
messages = []
|
|
779
|
-
if self.system_prompt:
|
|
780
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
662
|
+
units = self.unit_chunker.chunk(doc_text)
|
|
781
663
|
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
664
|
+
# context chunker init
|
|
665
|
+
self.context_chunker.fit(doc_text, units)
|
|
666
|
+
|
|
667
|
+
# Prepare inputs for all units first
|
|
668
|
+
tasks_input = []
|
|
669
|
+
for i, unit in enumerate(units):
|
|
670
|
+
# construct chat messages
|
|
671
|
+
messages = []
|
|
672
|
+
if self.system_prompt:
|
|
673
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
674
|
+
|
|
675
|
+
context = self.context_chunker.chunk(unit)
|
|
676
|
+
|
|
677
|
+
if context == "":
|
|
678
|
+
# no context, just place unit in user prompt
|
|
679
|
+
if isinstance(text_content, str):
|
|
680
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
787
681
|
else:
|
|
788
|
-
|
|
682
|
+
unit_content = text_content.copy()
|
|
683
|
+
unit_content[document_key] = unit.text
|
|
684
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
685
|
+
else:
|
|
686
|
+
# insert context to user prompt
|
|
687
|
+
if isinstance(text_content, str):
|
|
789
688
|
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
messages.append({'role': 'user', 'content':
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
689
|
+
else:
|
|
690
|
+
context_content = text_content.copy()
|
|
691
|
+
context_content[document_key] = context
|
|
692
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
693
|
+
# simulate conversation where assistant confirms
|
|
694
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
695
|
+
# place unit of interest
|
|
696
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
697
|
+
|
|
698
|
+
# Store unit and messages together for the task
|
|
699
|
+
tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
|
|
700
|
+
|
|
701
|
+
# Process units concurrently with asyncio.Semaphore
|
|
702
|
+
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
703
|
+
|
|
704
|
+
async def semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
|
|
705
|
+
unit = task_data["unit"]
|
|
706
|
+
messages = task_data["messages"]
|
|
707
|
+
original_index = task_data["original_index"]
|
|
708
|
+
|
|
709
|
+
async with semaphore:
|
|
710
|
+
gen_text = await self.inference_engine.chat_async(
|
|
711
|
+
messages=messages,
|
|
712
|
+
max_new_tokens=max_new_tokens,
|
|
713
|
+
temperature=temperature,
|
|
714
|
+
**kwrs
|
|
803
715
|
)
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
716
|
+
return {"original_index": original_index, "unit": unit, "gen_text": gen_text, "messages": messages}
|
|
717
|
+
|
|
718
|
+
# Create and gather tasks
|
|
719
|
+
tasks = []
|
|
720
|
+
for task_inp in tasks_input:
|
|
721
|
+
task = asyncio.create_task(semaphore_helper(
|
|
722
|
+
task_inp,
|
|
723
|
+
max_new_tokens=max_new_tokens,
|
|
724
|
+
temperature=temperature,
|
|
725
|
+
**kwrs
|
|
726
|
+
))
|
|
727
|
+
tasks.append(task)
|
|
728
|
+
|
|
729
|
+
results_raw = await asyncio.gather(*tasks)
|
|
730
|
+
|
|
731
|
+
# Sort results back into original order using the index stored
|
|
732
|
+
results_raw.sort(key=lambda x: x["original_index"])
|
|
733
|
+
|
|
734
|
+
# Restructure the results
|
|
735
|
+
output: List[FrameExtractionUnitResult] = []
|
|
736
|
+
messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
|
|
737
|
+
|
|
738
|
+
for result_data in results_raw:
|
|
739
|
+
unit = result_data["unit"]
|
|
740
|
+
gen_text = result_data["gen_text"]
|
|
741
|
+
|
|
742
|
+
# Create result object
|
|
743
|
+
result = FrameExtractionUnitResult(
|
|
744
|
+
start=unit.start,
|
|
745
|
+
end=unit.end,
|
|
746
|
+
text=unit.text,
|
|
747
|
+
gen_text=gen_text
|
|
748
|
+
)
|
|
749
|
+
output.append(result)
|
|
750
|
+
|
|
751
|
+
# Append to messages log if requested
|
|
752
|
+
if return_messages_log:
|
|
753
|
+
final_messages = result_data["messages"] + [{"role": "assistant", "content": gen_text}]
|
|
754
|
+
messages_log.append(final_messages)
|
|
755
|
+
|
|
756
|
+
if return_messages_log:
|
|
757
|
+
return output, messages_log
|
|
758
|
+
else:
|
|
759
|
+
return output
|
|
808
760
|
|
|
809
|
-
# Collect outputs
|
|
810
|
-
for gen_text, sent in zip(responses, batch):
|
|
811
|
-
output.append({'sentence_start': sent['start'],
|
|
812
|
-
'sentence_end': sent['end'],
|
|
813
|
-
'sentence_text': sent['sentence_text'],
|
|
814
|
-
'gen_text': gen_text})
|
|
815
|
-
return output
|
|
816
|
-
|
|
817
761
|
|
|
818
|
-
def extract_frames(self, text_content:Union[str, Dict[str,str]],
|
|
819
|
-
document_key:str=None, temperature:float=0.0,
|
|
762
|
+
def extract_frames(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
|
|
763
|
+
document_key:str=None, temperature:float=0.0, verbose:bool=False,
|
|
820
764
|
concurrent:bool=False, concurrent_batch_size:int=32,
|
|
821
765
|
case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
|
|
822
|
-
allow_overlap_entities:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
|
|
766
|
+
allow_overlap_entities:bool=False, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
|
|
823
767
|
"""
|
|
824
768
|
This method inputs a text and outputs a list of LLMInformationExtractionFrame
|
|
825
769
|
It use the extract() method and post-process outputs into frames.
|
|
@@ -830,8 +774,6 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
830
774
|
the input text content to put in prompt template.
|
|
831
775
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
832
776
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
833
|
-
entity_key : str
|
|
834
|
-
the key (in ouptut JSON) for entity text.
|
|
835
777
|
max_new_tokens : str, Optional
|
|
836
778
|
the max number of new tokens LLM should generate.
|
|
837
779
|
document_key : str, Optional
|
|
@@ -839,7 +781,7 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
839
781
|
If text_content is str, this parameter will be ignored.
|
|
840
782
|
temperature : float, Optional
|
|
841
783
|
the temperature for token sampling.
|
|
842
|
-
|
|
784
|
+
verbose : bool, Optional
|
|
843
785
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
844
786
|
concurrent : bool, Optional
|
|
845
787
|
if True, the sentences will be extracted in concurrent.
|
|
@@ -857,40 +799,48 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
857
799
|
allow_overlap_entities : bool, Optional
|
|
858
800
|
if True, entities can overlap in the text.
|
|
859
801
|
Note that this can cause multiple frames to be generated on the same entity span if they have same entity text.
|
|
802
|
+
return_messages_log : bool, Optional
|
|
803
|
+
if True, a list of messages will be returned.
|
|
860
804
|
|
|
861
805
|
Return : str
|
|
862
806
|
a list of frames.
|
|
863
807
|
"""
|
|
808
|
+
ENTITY_KEY = "entity_text"
|
|
864
809
|
if concurrent:
|
|
865
|
-
if
|
|
866
|
-
warnings.warn("
|
|
810
|
+
if verbose:
|
|
811
|
+
warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
|
|
867
812
|
|
|
868
813
|
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
814
|
+
extraction_results = asyncio.run(self.extract_async(text_content=text_content,
|
|
815
|
+
max_new_tokens=max_new_tokens,
|
|
816
|
+
document_key=document_key,
|
|
817
|
+
temperature=temperature,
|
|
818
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
819
|
+
return_messages_log=return_messages_log,
|
|
820
|
+
**kwrs)
|
|
821
|
+
)
|
|
876
822
|
else:
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
823
|
+
extraction_results = self.extract(text_content=text_content,
|
|
824
|
+
max_new_tokens=max_new_tokens,
|
|
825
|
+
document_key=document_key,
|
|
826
|
+
temperature=temperature,
|
|
827
|
+
verbose=verbose,
|
|
828
|
+
return_messages_log=return_messages_log,
|
|
829
|
+
**kwrs)
|
|
830
|
+
|
|
831
|
+
llm_output_results, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
832
|
+
|
|
883
833
|
frame_list = []
|
|
884
|
-
for
|
|
834
|
+
for res in llm_output_results:
|
|
885
835
|
entity_json = []
|
|
886
|
-
for entity in self._extract_json(gen_text=
|
|
887
|
-
if
|
|
836
|
+
for entity in self._extract_json(gen_text=res.gen_text):
|
|
837
|
+
if ENTITY_KEY in entity:
|
|
888
838
|
entity_json.append(entity)
|
|
889
839
|
else:
|
|
890
|
-
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{
|
|
840
|
+
warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
|
|
891
841
|
|
|
892
|
-
spans = self._find_entity_spans(text=
|
|
893
|
-
entities=[e[
|
|
842
|
+
spans = self._find_entity_spans(text=res.text,
|
|
843
|
+
entities=[e[ENTITY_KEY] for e in entity_json],
|
|
894
844
|
case_sensitive=case_sensitive,
|
|
895
845
|
fuzzy_match=fuzzy_match,
|
|
896
846
|
fuzzy_buffer_size=fuzzy_buffer_size,
|
|
@@ -899,31 +849,41 @@ class SentenceFrameExtractor(FrameExtractor):
|
|
|
899
849
|
for ent, span in zip(entity_json, spans):
|
|
900
850
|
if span is not None:
|
|
901
851
|
start, end = span
|
|
902
|
-
entity_text =
|
|
903
|
-
start +=
|
|
904
|
-
end +=
|
|
852
|
+
entity_text = res.text[start:end]
|
|
853
|
+
start += res.start
|
|
854
|
+
end += res.start
|
|
855
|
+
attr = {}
|
|
856
|
+
if "attr" in ent and ent["attr"] is not None:
|
|
857
|
+
attr = ent["attr"]
|
|
858
|
+
|
|
905
859
|
frame = LLMInformationExtractionFrame(frame_id=f"{len(frame_list)}",
|
|
906
860
|
start=start,
|
|
907
861
|
end=end,
|
|
908
862
|
entity_text=entity_text,
|
|
909
|
-
attr=
|
|
863
|
+
attr=attr)
|
|
910
864
|
frame_list.append(frame)
|
|
911
|
-
return frame_list
|
|
912
865
|
|
|
866
|
+
if return_messages_log:
|
|
867
|
+
return frame_list, messages_log
|
|
868
|
+
return frame_list
|
|
869
|
+
|
|
913
870
|
|
|
914
|
-
class
|
|
915
|
-
def __init__(self,
|
|
916
|
-
review_mode:str, review_prompt:str=None, system_prompt:str=None,
|
|
917
|
-
context_sentences:Union[str, int]="all", **kwrs):
|
|
871
|
+
class ReviewFrameExtractor(DirectFrameExtractor):
|
|
872
|
+
def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker,
|
|
873
|
+
inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None, **kwrs):
|
|
918
874
|
"""
|
|
919
|
-
This class
|
|
920
|
-
|
|
921
|
-
1. add more frames while
|
|
875
|
+
This class add a review step after the DirectFrameExtractor.
|
|
876
|
+
The Review process asks LLM to review its output and:
|
|
877
|
+
1. add more frames while keep current. This is efficient for boosting recall.
|
|
922
878
|
2. or, regenerate frames (add new and delete existing).
|
|
923
879
|
Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
|
|
924
880
|
|
|
925
881
|
Parameters:
|
|
926
882
|
----------
|
|
883
|
+
unit_chunker : UnitChunker
|
|
884
|
+
the unit chunker object that determines how to chunk the document text into units.
|
|
885
|
+
context_chunker : ContextChunker
|
|
886
|
+
the context chunker object that determines how to get context for each unit.
|
|
927
887
|
inference_engine : InferenceEngine
|
|
928
888
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
929
889
|
prompt_template : str
|
|
@@ -936,36 +896,215 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
|
|
|
936
896
|
addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
|
|
937
897
|
system_prompt : str, Optional
|
|
938
898
|
system prompt.
|
|
939
|
-
context_sentences : Union[str, int], Optional
|
|
940
|
-
number of sentences before and after the given sentence to provide additional context.
|
|
941
|
-
if "all", the full text will be provided in the prompt as context.
|
|
942
|
-
if 0, no additional context will be provided.
|
|
943
|
-
This is good for tasks that does not require context beyond the given sentence.
|
|
944
|
-
if > 0, the number of sentences before and after the given sentence to provide as context.
|
|
945
|
-
This is good for tasks that require context beyond the given sentence.
|
|
946
899
|
"""
|
|
947
|
-
super().__init__(inference_engine=inference_engine,
|
|
948
|
-
|
|
949
|
-
|
|
900
|
+
super().__init__(inference_engine=inference_engine,
|
|
901
|
+
unit_chunker=unit_chunker,
|
|
902
|
+
prompt_template=prompt_template,
|
|
903
|
+
system_prompt=system_prompt,
|
|
904
|
+
context_chunker=context_chunker,
|
|
905
|
+
**kwrs)
|
|
906
|
+
# check review mode
|
|
950
907
|
if review_mode not in {"addition", "revision"}:
|
|
951
908
|
raise ValueError('review_mode must be one of {"addition", "revision"}.')
|
|
952
909
|
self.review_mode = review_mode
|
|
910
|
+
# assign review prompt
|
|
911
|
+
if review_prompt:
|
|
912
|
+
self.review_prompt = review_prompt
|
|
913
|
+
else:
|
|
914
|
+
self.review_prompt = None
|
|
915
|
+
original_class_name = self.__class__.__name__
|
|
916
|
+
|
|
917
|
+
current_class_name = original_class_name
|
|
918
|
+
for current_class_in_mro in self.__class__.__mro__:
|
|
919
|
+
if current_class_in_mro is object:
|
|
920
|
+
continue
|
|
921
|
+
|
|
922
|
+
current_class_name = current_class_in_mro.__name__
|
|
923
|
+
try:
|
|
924
|
+
file_path = importlib.resources.files('llm_ie.asset.default_prompts').\
|
|
925
|
+
joinpath(f"{self.__class__.__name__}_{self.review_mode}_review_prompt.txt")
|
|
926
|
+
with open(file_path, 'r', encoding="utf-8") as f:
|
|
927
|
+
self.review_prompt = f.read()
|
|
928
|
+
except FileNotFoundError:
|
|
929
|
+
pass
|
|
930
|
+
|
|
931
|
+
except Exception as e:
|
|
932
|
+
warnings.warn(
|
|
933
|
+
f"Error attempting to read default review prompt for '{current_class_name}' "
|
|
934
|
+
f"from '{str(file_path)}': {e}. Trying next in MRO.",
|
|
935
|
+
UserWarning
|
|
936
|
+
)
|
|
937
|
+
continue
|
|
938
|
+
|
|
939
|
+
if self.review_prompt is None:
|
|
940
|
+
raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
|
|
941
|
+
|
|
942
|
+
def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None,
|
|
943
|
+
temperature:float=0.0, verbose:bool=False, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
|
|
944
|
+
"""
|
|
945
|
+
This method inputs a text and outputs a list of outputs per unit.
|
|
946
|
+
|
|
947
|
+
Parameters:
|
|
948
|
+
----------
|
|
949
|
+
text_content : Union[str, Dict[str,str]]
|
|
950
|
+
the input text content to put in prompt template.
|
|
951
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
952
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
953
|
+
max_new_tokens : int, Optional
|
|
954
|
+
the max number of new tokens LLM should generate.
|
|
955
|
+
document_key : str, Optional
|
|
956
|
+
specify the key in text_content where document text is.
|
|
957
|
+
If text_content is str, this parameter will be ignored.
|
|
958
|
+
temperature : float, Optional
|
|
959
|
+
the temperature for token sampling.
|
|
960
|
+
verbose : bool, Optional
|
|
961
|
+
if True, LLM generated text will be printed in terminal in real-time.
|
|
962
|
+
return_messages_log : bool, Optional
|
|
963
|
+
if True, a list of messages will be returned.
|
|
964
|
+
|
|
965
|
+
Return : List[FrameExtractionUnitResult]
|
|
966
|
+
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
967
|
+
"""
|
|
968
|
+
# define output
|
|
969
|
+
output = []
|
|
970
|
+
# unit chunking
|
|
971
|
+
if isinstance(text_content, str):
|
|
972
|
+
doc_text = text_content
|
|
973
|
+
|
|
974
|
+
elif isinstance(text_content, dict):
|
|
975
|
+
if document_key is None:
|
|
976
|
+
raise ValueError("document_key must be provided when text_content is dict.")
|
|
977
|
+
doc_text = text_content[document_key]
|
|
978
|
+
|
|
979
|
+
units = self.unit_chunker.chunk(doc_text)
|
|
980
|
+
# context chunker init
|
|
981
|
+
self.context_chunker.fit(doc_text, units)
|
|
982
|
+
# messages log
|
|
983
|
+
if return_messages_log:
|
|
984
|
+
messages_log = []
|
|
985
|
+
|
|
986
|
+
# generate unit by unit
|
|
987
|
+
for i, unit in enumerate(units):
|
|
988
|
+
# <--- Initial generation step --->
|
|
989
|
+
# construct chat messages
|
|
990
|
+
messages = []
|
|
991
|
+
if self.system_prompt:
|
|
992
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
993
|
+
|
|
994
|
+
context = self.context_chunker.chunk(unit)
|
|
995
|
+
|
|
996
|
+
if context == "":
|
|
997
|
+
# no context, just place unit in user prompt
|
|
998
|
+
if isinstance(text_content, str):
|
|
999
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
1000
|
+
else:
|
|
1001
|
+
unit_content = text_content.copy()
|
|
1002
|
+
unit_content[document_key] = unit.text
|
|
1003
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
1004
|
+
else:
|
|
1005
|
+
# insert context to user prompt
|
|
1006
|
+
if isinstance(text_content, str):
|
|
1007
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
1008
|
+
else:
|
|
1009
|
+
context_content = text_content.copy()
|
|
1010
|
+
context_content[document_key] = context
|
|
1011
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
1012
|
+
# simulate conversation where assistant confirms
|
|
1013
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
1014
|
+
# place unit of interest
|
|
1015
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
1016
|
+
|
|
1017
|
+
if verbose:
|
|
1018
|
+
print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
1019
|
+
if context != "":
|
|
1020
|
+
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
1021
|
+
|
|
1022
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1023
|
+
|
|
1024
|
+
response_stream = self.inference_engine.chat(
|
|
1025
|
+
messages=messages,
|
|
1026
|
+
max_new_tokens=max_new_tokens,
|
|
1027
|
+
temperature=temperature,
|
|
1028
|
+
stream=True,
|
|
1029
|
+
**kwrs
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
initial = ""
|
|
1033
|
+
for chunk in response_stream:
|
|
1034
|
+
initial += chunk
|
|
1035
|
+
print(chunk, end='', flush=True)
|
|
1036
|
+
|
|
1037
|
+
else:
|
|
1038
|
+
initial = self.inference_engine.chat(
|
|
1039
|
+
messages=messages,
|
|
1040
|
+
max_new_tokens=max_new_tokens,
|
|
1041
|
+
temperature=temperature,
|
|
1042
|
+
stream=False,
|
|
1043
|
+
**kwrs
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
if return_messages_log:
|
|
1047
|
+
messages.append({"role": "assistant", "content": initial})
|
|
1048
|
+
messages_log.append(messages)
|
|
1049
|
+
|
|
1050
|
+
# <--- Review step --->
|
|
1051
|
+
if verbose:
|
|
1052
|
+
print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
|
|
1053
|
+
|
|
1054
|
+
messages.append({'role': 'assistant', 'content': initial})
|
|
1055
|
+
messages.append({'role': 'user', 'content': self.review_prompt})
|
|
1056
|
+
|
|
1057
|
+
if verbose:
|
|
1058
|
+
response_stream = self.inference_engine.chat(
|
|
1059
|
+
messages=messages,
|
|
1060
|
+
max_new_tokens=max_new_tokens,
|
|
1061
|
+
temperature=temperature,
|
|
1062
|
+
stream=True,
|
|
1063
|
+
**kwrs
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
review = ""
|
|
1067
|
+
for chunk in response_stream:
|
|
1068
|
+
review += chunk
|
|
1069
|
+
print(chunk, end='', flush=True)
|
|
1070
|
+
|
|
1071
|
+
else:
|
|
1072
|
+
review = self.inference_engine.chat(
|
|
1073
|
+
messages=messages,
|
|
1074
|
+
max_new_tokens=max_new_tokens,
|
|
1075
|
+
temperature=temperature,
|
|
1076
|
+
stream=False,
|
|
1077
|
+
**kwrs
|
|
1078
|
+
)
|
|
953
1079
|
|
|
954
|
-
|
|
955
|
-
self.
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
1080
|
+
# Output
|
|
1081
|
+
if self.review_mode == "revision":
|
|
1082
|
+
gen_text = review
|
|
1083
|
+
elif self.review_mode == "addition":
|
|
1084
|
+
gen_text = initial + '\n' + review
|
|
1085
|
+
|
|
1086
|
+
if return_messages_log:
|
|
1087
|
+
messages.append({"role": "assistant", "content": review})
|
|
1088
|
+
messages_log.append(messages)
|
|
961
1089
|
|
|
962
|
-
|
|
1090
|
+
# add to output
|
|
1091
|
+
result = FrameExtractionUnitResult(
|
|
1092
|
+
start=unit.start,
|
|
1093
|
+
end=unit.end,
|
|
1094
|
+
text=unit.text,
|
|
1095
|
+
gen_text=gen_text)
|
|
1096
|
+
output.append(result)
|
|
1097
|
+
|
|
1098
|
+
if return_messages_log:
|
|
1099
|
+
return output, messages_log
|
|
1100
|
+
|
|
1101
|
+
return output
|
|
963
1102
|
|
|
964
1103
|
|
|
965
|
-
def
|
|
966
|
-
|
|
1104
|
+
def stream(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
|
|
1105
|
+
document_key:str=None, temperature:float=0.0, **kwrs) -> Generator[str, None, None]:
|
|
967
1106
|
"""
|
|
968
|
-
This method inputs a text and outputs a list of outputs per
|
|
1107
|
+
This method inputs a text and outputs a list of outputs per unit.
|
|
969
1108
|
|
|
970
1109
|
Parameters:
|
|
971
1110
|
----------
|
|
@@ -973,234 +1112,371 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
|
|
|
973
1112
|
the input text content to put in prompt template.
|
|
974
1113
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
975
1114
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
976
|
-
max_new_tokens :
|
|
1115
|
+
max_new_tokens : int, Optional
|
|
977
1116
|
the max number of new tokens LLM should generate.
|
|
978
1117
|
document_key : str, Optional
|
|
979
1118
|
specify the key in text_content where document text is.
|
|
980
1119
|
If text_content is str, this parameter will be ignored.
|
|
981
1120
|
temperature : float, Optional
|
|
982
1121
|
the temperature for token sampling.
|
|
983
|
-
stream : bool, Optional
|
|
984
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
985
1122
|
|
|
986
|
-
Return :
|
|
987
|
-
the output from LLM.
|
|
1123
|
+
Return : List[FrameExtractionUnitResult]
|
|
1124
|
+
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
988
1125
|
"""
|
|
989
|
-
#
|
|
990
|
-
output = []
|
|
991
|
-
# sentence tokenization
|
|
1126
|
+
# unit chunking
|
|
992
1127
|
if isinstance(text_content, str):
|
|
993
|
-
|
|
1128
|
+
doc_text = text_content
|
|
1129
|
+
|
|
994
1130
|
elif isinstance(text_content, dict):
|
|
995
1131
|
if document_key is None:
|
|
996
1132
|
raise ValueError("document_key must be provided when text_content is dict.")
|
|
997
|
-
|
|
1133
|
+
doc_text = text_content[document_key]
|
|
998
1134
|
|
|
999
|
-
|
|
1000
|
-
|
|
1135
|
+
units = self.unit_chunker.chunk(doc_text)
|
|
1136
|
+
# context chunker init
|
|
1137
|
+
self.context_chunker.fit(doc_text, units)
|
|
1138
|
+
|
|
1139
|
+
# generate unit by unit
|
|
1140
|
+
for i, unit in enumerate(units):
|
|
1141
|
+
# <--- Initial generation step --->
|
|
1001
1142
|
# construct chat messages
|
|
1002
1143
|
messages = []
|
|
1003
1144
|
if self.system_prompt:
|
|
1004
1145
|
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1005
1146
|
|
|
1006
|
-
context = self.
|
|
1147
|
+
context = self.context_chunker.chunk(unit)
|
|
1007
1148
|
|
|
1008
|
-
if
|
|
1009
|
-
# no context, just place
|
|
1010
|
-
|
|
1149
|
+
if context == "":
|
|
1150
|
+
# no context, just place unit in user prompt
|
|
1151
|
+
if isinstance(text_content, str):
|
|
1152
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
1153
|
+
else:
|
|
1154
|
+
unit_content = text_content.copy()
|
|
1155
|
+
unit_content[document_key] = unit.text
|
|
1156
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
1011
1157
|
else:
|
|
1012
|
-
# insert context
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1158
|
+
# insert context to user prompt
|
|
1159
|
+
if isinstance(text_content, str):
|
|
1160
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
1161
|
+
else:
|
|
1162
|
+
context_content = text_content.copy()
|
|
1163
|
+
context_content[document_key] = context
|
|
1164
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
1165
|
+
# simulate conversation where assistant confirms
|
|
1166
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
1167
|
+
# place unit of interest
|
|
1168
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
1169
|
+
|
|
1170
|
+
|
|
1171
|
+
yield f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n"
|
|
1172
|
+
if context != "":
|
|
1173
|
+
yield f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n"
|
|
1174
|
+
|
|
1175
|
+
yield f"{Fore.BLUE}Extraction:{Style.RESET_ALL}\n"
|
|
1024
1176
|
|
|
1025
|
-
|
|
1177
|
+
response_stream = self.inference_engine.chat(
|
|
1026
1178
|
messages=messages,
|
|
1027
1179
|
max_new_tokens=max_new_tokens,
|
|
1028
1180
|
temperature=temperature,
|
|
1029
|
-
stream=
|
|
1181
|
+
stream=True,
|
|
1030
1182
|
**kwrs
|
|
1031
1183
|
)
|
|
1032
1184
|
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1185
|
+
initial = ""
|
|
1186
|
+
for chunk in response_stream:
|
|
1187
|
+
initial += chunk
|
|
1188
|
+
yield chunk
|
|
1189
|
+
|
|
1190
|
+
# <--- Review step --->
|
|
1191
|
+
yield f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}"
|
|
1192
|
+
|
|
1036
1193
|
messages.append({'role': 'assistant', 'content': initial})
|
|
1037
1194
|
messages.append({'role': 'user', 'content': self.review_prompt})
|
|
1038
1195
|
|
|
1039
|
-
|
|
1196
|
+
response_stream = self.inference_engine.chat(
|
|
1040
1197
|
messages=messages,
|
|
1041
1198
|
max_new_tokens=max_new_tokens,
|
|
1042
1199
|
temperature=temperature,
|
|
1043
|
-
stream=
|
|
1200
|
+
stream=True,
|
|
1044
1201
|
**kwrs
|
|
1045
1202
|
)
|
|
1046
1203
|
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
gen_text = review
|
|
1050
|
-
elif self.review_mode == "addition":
|
|
1051
|
-
gen_text = initial + '\n' + review
|
|
1204
|
+
for chunk in response_stream:
|
|
1205
|
+
yield chunk
|
|
1052
1206
|
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
'sentence_end': sent['end'],
|
|
1056
|
-
'sentence_text': sent['sentence_text'],
|
|
1057
|
-
'gen_text': gen_text})
|
|
1058
|
-
return output
|
|
1059
|
-
|
|
1060
|
-
async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
|
|
1061
|
-
document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
|
|
1207
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None, temperature:float=0.0,
|
|
1208
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
|
|
1062
1209
|
"""
|
|
1063
|
-
|
|
1210
|
+
This is the asynchronous version of the extract() method with the review step.
|
|
1064
1211
|
|
|
1065
1212
|
Parameters:
|
|
1066
1213
|
----------
|
|
1067
1214
|
text_content : Union[str, Dict[str,str]]
|
|
1068
|
-
the input text content to put in prompt template.
|
|
1215
|
+
the input text content to put in prompt template.
|
|
1069
1216
|
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
1070
1217
|
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
1071
|
-
max_new_tokens :
|
|
1072
|
-
the max number of new tokens LLM should generate.
|
|
1218
|
+
max_new_tokens : int, Optional
|
|
1219
|
+
the max number of new tokens LLM should generate.
|
|
1073
1220
|
document_key : str, Optional
|
|
1074
|
-
specify the key in text_content where document text is.
|
|
1221
|
+
specify the key in text_content where document text is.
|
|
1075
1222
|
If text_content is str, this parameter will be ignored.
|
|
1076
1223
|
temperature : float, Optional
|
|
1077
1224
|
the temperature for token sampling.
|
|
1078
1225
|
concurrent_batch_size : int, Optional
|
|
1079
|
-
the
|
|
1226
|
+
the batch size for concurrent processing.
|
|
1227
|
+
return_messages_log : bool, Optional
|
|
1228
|
+
if True, a list of messages will be returned, including review steps.
|
|
1080
1229
|
|
|
1081
|
-
Return :
|
|
1082
|
-
the output from LLM.
|
|
1230
|
+
Return : List[FrameExtractionUnitResult]
|
|
1231
|
+
the output from LLM for each unit after review. Contains the start, end, text, and generated text.
|
|
1083
1232
|
"""
|
|
1084
|
-
# Check if self.inference_engine.chat_async() is implemented
|
|
1085
|
-
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1086
|
-
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1087
|
-
|
|
1088
|
-
# define output
|
|
1089
|
-
output = []
|
|
1090
|
-
# sentence tokenization
|
|
1091
1233
|
if isinstance(text_content, str):
|
|
1092
|
-
|
|
1234
|
+
doc_text = text_content
|
|
1093
1235
|
elif isinstance(text_content, dict):
|
|
1094
1236
|
if document_key is None:
|
|
1095
1237
|
raise ValueError("document_key must be provided when text_content is dict.")
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
init_tasks = []
|
|
1102
|
-
review_tasks = []
|
|
1103
|
-
batch = sentences[i:i + concurrent_batch_size]
|
|
1104
|
-
for j, sent in enumerate(batch):
|
|
1105
|
-
# construct chat messages
|
|
1106
|
-
messages = []
|
|
1107
|
-
if self.system_prompt:
|
|
1108
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1238
|
+
if document_key not in text_content:
|
|
1239
|
+
raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
|
|
1240
|
+
doc_text = text_content[document_key]
|
|
1241
|
+
else:
|
|
1242
|
+
raise TypeError("text_content must be a string or a dictionary.")
|
|
1109
1243
|
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1244
|
+
units = self.unit_chunker.chunk(doc_text)
|
|
1245
|
+
|
|
1246
|
+
# context chunker init
|
|
1247
|
+
self.context_chunker.fit(doc_text, units)
|
|
1248
|
+
|
|
1249
|
+
# <--- Initial generation step --->
|
|
1250
|
+
initial_tasks_input = []
|
|
1251
|
+
for i, unit in enumerate(units):
|
|
1252
|
+
# construct chat messages for initial generation
|
|
1253
|
+
messages = []
|
|
1254
|
+
if self.system_prompt:
|
|
1255
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1256
|
+
|
|
1257
|
+
context = self.context_chunker.chunk(unit)
|
|
1258
|
+
|
|
1259
|
+
if context == "":
|
|
1260
|
+
# no context, just place unit in user prompt
|
|
1261
|
+
if isinstance(text_content, str):
|
|
1262
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
1115
1263
|
else:
|
|
1116
|
-
|
|
1264
|
+
unit_content = text_content.copy()
|
|
1265
|
+
unit_content[document_key] = unit.text
|
|
1266
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
1267
|
+
else:
|
|
1268
|
+
# insert context to user prompt
|
|
1269
|
+
if isinstance(text_content, str):
|
|
1117
1270
|
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
messages.append({'role': 'user', 'content':
|
|
1271
|
+
else:
|
|
1272
|
+
context_content = text_content.copy()
|
|
1273
|
+
context_content[document_key] = context
|
|
1274
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
1275
|
+
# simulate conversation where assistant confirms
|
|
1276
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
1277
|
+
# place unit of interest
|
|
1278
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
1279
|
+
|
|
1280
|
+
# Store unit and messages together for the initial task
|
|
1281
|
+
initial_tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
|
|
1282
|
+
|
|
1283
|
+
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1284
|
+
|
|
1285
|
+
async def initial_semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
|
|
1286
|
+
unit = task_data["unit"]
|
|
1287
|
+
messages = task_data["messages"]
|
|
1288
|
+
original_index = task_data["original_index"]
|
|
1289
|
+
|
|
1290
|
+
async with semaphore:
|
|
1291
|
+
gen_text = await self.inference_engine.chat_async(
|
|
1292
|
+
messages=messages,
|
|
1293
|
+
max_new_tokens=max_new_tokens,
|
|
1294
|
+
temperature=temperature,
|
|
1295
|
+
**kwrs
|
|
1296
|
+
)
|
|
1297
|
+
# Return initial generation result along with the messages used and the unit
|
|
1298
|
+
return {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text, "initial_messages": messages}
|
|
1299
|
+
|
|
1300
|
+
# Create and gather initial generation tasks
|
|
1301
|
+
initial_tasks = [
|
|
1302
|
+
asyncio.create_task(initial_semaphore_helper(
|
|
1303
|
+
task_inp,
|
|
1304
|
+
max_new_tokens=max_new_tokens,
|
|
1305
|
+
temperature=temperature,
|
|
1306
|
+
**kwrs
|
|
1307
|
+
))
|
|
1308
|
+
for task_inp in initial_tasks_input
|
|
1309
|
+
]
|
|
1310
|
+
|
|
1311
|
+
initial_results_raw = await asyncio.gather(*initial_tasks)
|
|
1312
|
+
|
|
1313
|
+
# Sort initial results back into original order
|
|
1314
|
+
initial_results_raw.sort(key=lambda x: x["original_index"])
|
|
1315
|
+
|
|
1316
|
+
# <--- Review step --->
|
|
1317
|
+
review_tasks_input = []
|
|
1318
|
+
for result_data in initial_results_raw:
|
|
1319
|
+
# Prepare messages for the review step
|
|
1320
|
+
initial_messages = result_data["initial_messages"]
|
|
1321
|
+
initial_gen_text = result_data["initial_gen_text"]
|
|
1322
|
+
review_messages = initial_messages + [
|
|
1323
|
+
{'role': 'assistant', 'content': initial_gen_text},
|
|
1324
|
+
{'role': 'user', 'content': self.review_prompt}
|
|
1325
|
+
]
|
|
1326
|
+
# Store data needed for review task
|
|
1327
|
+
review_tasks_input.append({
|
|
1328
|
+
"unit": result_data["unit"],
|
|
1329
|
+
"initial_gen_text": initial_gen_text,
|
|
1330
|
+
"messages": review_messages,
|
|
1331
|
+
"original_index": result_data["original_index"],
|
|
1332
|
+
"full_initial_log": initial_messages + [{'role': 'assistant', 'content': initial_gen_text}] if return_messages_log else None # Log up to initial generation
|
|
1333
|
+
})
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
async def review_semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
|
|
1337
|
+
messages = task_data["messages"]
|
|
1338
|
+
original_index = task_data["original_index"]
|
|
1339
|
+
|
|
1340
|
+
async with semaphore:
|
|
1341
|
+
review_gen_text = await self.inference_engine.chat_async(
|
|
1342
|
+
messages=messages,
|
|
1343
|
+
max_new_tokens=max_new_tokens,
|
|
1344
|
+
temperature=temperature,
|
|
1345
|
+
**kwrs
|
|
1346
|
+
)
|
|
1347
|
+
# Combine initial and review results
|
|
1348
|
+
task_data["review_gen_text"] = review_gen_text
|
|
1349
|
+
if return_messages_log:
|
|
1350
|
+
# Log for the review call itself
|
|
1351
|
+
task_data["full_review_log"] = messages + [{'role': 'assistant', 'content': review_gen_text}]
|
|
1352
|
+
return task_data # Return the augmented dictionary
|
|
1353
|
+
|
|
1354
|
+
# Create and gather review tasks
|
|
1355
|
+
review_tasks = [
|
|
1356
|
+
asyncio.create_task(review_semaphore_helper(
|
|
1357
|
+
task_inp,
|
|
1358
|
+
max_new_tokens=max_new_tokens,
|
|
1359
|
+
temperature=temperature,
|
|
1360
|
+
**kwrs
|
|
1361
|
+
))
|
|
1362
|
+
for task_inp in review_tasks_input
|
|
1363
|
+
]
|
|
1364
|
+
|
|
1365
|
+
final_results_raw = await asyncio.gather(*review_tasks)
|
|
1366
|
+
|
|
1367
|
+
# Sort final results back into original order (although gather might preserve order for tasks added sequentially)
|
|
1368
|
+
final_results_raw.sort(key=lambda x: x["original_index"])
|
|
1369
|
+
|
|
1370
|
+
# <--- Process final results --->
|
|
1371
|
+
output: List[FrameExtractionUnitResult] = []
|
|
1372
|
+
messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
|
|
1373
|
+
|
|
1374
|
+
for result_data in final_results_raw:
|
|
1375
|
+
unit = result_data["unit"]
|
|
1376
|
+
initial_gen = result_data["initial_gen_text"]
|
|
1377
|
+
review_gen = result_data["review_gen_text"]
|
|
1378
|
+
|
|
1379
|
+
# Combine based on review mode
|
|
1380
|
+
if self.review_mode == "revision":
|
|
1381
|
+
final_gen_text = review_gen
|
|
1382
|
+
elif self.review_mode == "addition":
|
|
1383
|
+
final_gen_text = initial_gen + '\n' + review_gen
|
|
1384
|
+
else: # Should not happen due to init check
|
|
1385
|
+
final_gen_text = review_gen # Default to revision if mode is somehow invalid
|
|
1386
|
+
|
|
1387
|
+
# Create final result object
|
|
1388
|
+
result = FrameExtractionUnitResult(
|
|
1389
|
+
start=unit.start,
|
|
1390
|
+
end=unit.end,
|
|
1391
|
+
text=unit.text,
|
|
1392
|
+
gen_text=final_gen_text # Use the combined/reviewed text
|
|
1393
|
+
)
|
|
1394
|
+
output.append(result)
|
|
1395
|
+
|
|
1396
|
+
# Append full conversation log if requested
|
|
1397
|
+
if return_messages_log:
|
|
1398
|
+
full_log_for_unit = result_data.get("full_initial_log", []) + [{'role': 'user', 'content': self.review_prompt}] + [{'role': 'assistant', 'content': review_gen}]
|
|
1399
|
+
messages_log.append(full_log_for_unit)
|
|
1400
|
+
|
|
1401
|
+
if return_messages_log:
|
|
1402
|
+
return output, messages_log
|
|
1403
|
+
else:
|
|
1404
|
+
return output
|
|
1122
1405
|
|
|
1123
|
-
messages_list.append(messages)
|
|
1124
1406
|
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
'sentence_end': init['sentence_end'],
|
|
1181
|
-
'sentence_text': init['sentence_text'],
|
|
1182
|
-
'gen_text': gen_text})
|
|
1183
|
-
return output
|
|
1407
|
+
class BasicFrameExtractor(DirectFrameExtractor):
|
|
1408
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
|
|
1409
|
+
"""
|
|
1410
|
+
This class diretly prompt LLM for frame extraction.
|
|
1411
|
+
Input system prompt (optional), prompt template (with instruction, few-shot examples),
|
|
1412
|
+
and specify a LLM.
|
|
1413
|
+
|
|
1414
|
+
Parameters:
|
|
1415
|
+
----------
|
|
1416
|
+
inference_engine : InferenceEngine
|
|
1417
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
1418
|
+
prompt_template : str
|
|
1419
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1420
|
+
system_prompt : str, Optional
|
|
1421
|
+
system prompt.
|
|
1422
|
+
"""
|
|
1423
|
+
super().__init__(inference_engine=inference_engine,
|
|
1424
|
+
unit_chunker=WholeDocumentUnitChunker(),
|
|
1425
|
+
prompt_template=prompt_template,
|
|
1426
|
+
system_prompt=system_prompt,
|
|
1427
|
+
context_chunker=NoContextChunker(),
|
|
1428
|
+
**kwrs)
|
|
1429
|
+
|
|
1430
|
+
class BasicReviewFrameExtractor(ReviewFrameExtractor):
|
|
1431
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None, **kwrs):
|
|
1432
|
+
"""
|
|
1433
|
+
This class add a review step after the BasicFrameExtractor.
|
|
1434
|
+
The Review process asks LLM to review its output and:
|
|
1435
|
+
1. add more frames while keep current. This is efficient for boosting recall.
|
|
1436
|
+
2. or, regenerate frames (add new and delete existing).
|
|
1437
|
+
Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
|
|
1438
|
+
|
|
1439
|
+
Parameters:
|
|
1440
|
+
----------
|
|
1441
|
+
inference_engine : InferenceEngine
|
|
1442
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
1443
|
+
prompt_template : str
|
|
1444
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1445
|
+
review_prompt : str: Optional
|
|
1446
|
+
the prompt text that ask LLM to review. Specify addition or revision in the instruction.
|
|
1447
|
+
if not provided, a default review prompt will be used.
|
|
1448
|
+
review_mode : str
|
|
1449
|
+
review mode. Must be one of {"addition", "revision"}
|
|
1450
|
+
addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
|
|
1451
|
+
system_prompt : str, Optional
|
|
1452
|
+
system prompt.
|
|
1453
|
+
"""
|
|
1454
|
+
super().__init__(inference_engine=inference_engine,
|
|
1455
|
+
unit_chunker=WholeDocumentUnitChunker(),
|
|
1456
|
+
prompt_template=prompt_template,
|
|
1457
|
+
review_mode=review_mode,
|
|
1458
|
+
review_prompt=review_prompt,
|
|
1459
|
+
system_prompt=system_prompt,
|
|
1460
|
+
context_chunker=NoContextChunker(),
|
|
1461
|
+
**kwrs)
|
|
1184
1462
|
|
|
1185
1463
|
|
|
1186
|
-
class
|
|
1187
|
-
|
|
1188
|
-
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
|
|
1464
|
+
class SentenceFrameExtractor(DirectFrameExtractor):
|
|
1465
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
|
|
1189
1466
|
context_sentences:Union[str, int]="all", **kwrs):
|
|
1190
1467
|
"""
|
|
1191
|
-
This class performs sentence-
|
|
1192
|
-
|
|
1468
|
+
This class performs sentence-by-sentence information extraction.
|
|
1469
|
+
The process is as follows:
|
|
1193
1470
|
1. system prompt (optional)
|
|
1194
|
-
2. user instructions (schema, background, full text, few-shot example...)
|
|
1195
|
-
3.
|
|
1196
|
-
4.
|
|
1197
|
-
5.
|
|
1198
|
-
6. repeat #3, #4, #5
|
|
1471
|
+
2. user prompt with instructions (schema, background, full text, few-shot example...)
|
|
1472
|
+
3. feed a sentence (start with first sentence)
|
|
1473
|
+
4. LLM extract entities and attributes from the sentence
|
|
1474
|
+
5. iterate to the next sentence and repeat steps 3-4 until all sentences are processed.
|
|
1199
1475
|
|
|
1200
1476
|
Input system prompt (optional), prompt template (with user instructions),
|
|
1201
1477
|
and specify a LLM.
|
|
1202
1478
|
|
|
1203
|
-
Parameters
|
|
1479
|
+
Parameters:
|
|
1204
1480
|
----------
|
|
1205
1481
|
inference_engine : InferenceEngine
|
|
1206
1482
|
the LLM inferencing engine object. Must implements the chat() method.
|
|
@@ -1216,82 +1492,77 @@ class SentenceCoTFrameExtractor(SentenceFrameExtractor):
|
|
|
1216
1492
|
if > 0, the number of sentences before and after the given sentence to provide as context.
|
|
1217
1493
|
This is good for tasks that require context beyond the given sentence.
|
|
1218
1494
|
"""
|
|
1219
|
-
|
|
1220
|
-
|
|
1495
|
+
if not isinstance(context_sentences, int) and context_sentences != "all":
|
|
1496
|
+
raise ValueError('context_sentences must be an integer (>= 0) or "all".')
|
|
1497
|
+
|
|
1498
|
+
if isinstance(context_sentences, int) and context_sentences < 0:
|
|
1499
|
+
raise ValueError("context_sentences must be a positive integer.")
|
|
1500
|
+
|
|
1501
|
+
if isinstance(context_sentences, int):
|
|
1502
|
+
context_chunker = SlideWindowContextChunker(window_size=context_sentences)
|
|
1503
|
+
elif context_sentences == "all":
|
|
1504
|
+
context_chunker = WholeDocumentContextChunker()
|
|
1505
|
+
|
|
1506
|
+
super().__init__(inference_engine=inference_engine,
|
|
1507
|
+
unit_chunker=SentenceUnitChunker(),
|
|
1508
|
+
prompt_template=prompt_template,
|
|
1509
|
+
system_prompt=system_prompt,
|
|
1510
|
+
context_chunker=context_chunker,
|
|
1511
|
+
**kwrs)
|
|
1221
1512
|
|
|
1222
1513
|
|
|
1223
|
-
|
|
1224
|
-
|
|
1514
|
+
class SentenceReviewFrameExtractor(ReviewFrameExtractor):
|
|
1515
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
|
|
1516
|
+
review_mode:str, review_prompt:str=None, system_prompt:str=None,
|
|
1517
|
+
context_sentences:Union[str, int]="all", **kwrs):
|
|
1225
1518
|
"""
|
|
1226
|
-
This
|
|
1519
|
+
This class adds a review step after the SentenceFrameExtractor.
|
|
1520
|
+
For each sentence, the review process asks LLM to review its output and:
|
|
1521
|
+
1. add more frames while keeping current. This is efficient for boosting recall.
|
|
1522
|
+
2. or, regenerate frames (add new and delete existing).
|
|
1523
|
+
Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
|
|
1227
1524
|
|
|
1228
1525
|
Parameters:
|
|
1229
1526
|
----------
|
|
1230
|
-
|
|
1231
|
-
the
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
the
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1527
|
+
inference_engine : InferenceEngine
|
|
1528
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
1529
|
+
prompt_template : str
|
|
1530
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1531
|
+
review_prompt : str: Optional
|
|
1532
|
+
the prompt text that ask LLM to review. Specify addition or revision in the instruction.
|
|
1533
|
+
if not provided, a default review prompt will be used.
|
|
1534
|
+
review_mode : str
|
|
1535
|
+
review mode. Must be one of {"addition", "revision"}
|
|
1536
|
+
addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
|
|
1537
|
+
system_prompt : str, Optional
|
|
1538
|
+
system prompt.
|
|
1539
|
+
context_sentences : Union[str, int], Optional
|
|
1540
|
+
number of sentences before and after the given sentence to provide additional context.
|
|
1541
|
+
if "all", the full text will be provided in the prompt as context.
|
|
1542
|
+
if 0, no additional context will be provided.
|
|
1543
|
+
This is good for tasks that does not require context beyond the given sentence.
|
|
1544
|
+
if > 0, the number of sentences before and after the given sentence to provide as context.
|
|
1545
|
+
This is good for tasks that require context beyond the given sentence.
|
|
1246
1546
|
"""
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
if isinstance(
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
# construct chat messages
|
|
1258
|
-
messages = []
|
|
1259
|
-
if self.system_prompt:
|
|
1260
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1261
|
-
|
|
1262
|
-
context = self._get_context_sentences(text_content, i, sentences, document_key)
|
|
1263
|
-
|
|
1264
|
-
if self.context_sentences == 0:
|
|
1265
|
-
# no context, just place sentence of interest
|
|
1266
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
|
|
1267
|
-
else:
|
|
1268
|
-
# insert context
|
|
1269
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
1270
|
-
# simulate conversation
|
|
1271
|
-
messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
|
|
1272
|
-
# place sentence of interest
|
|
1273
|
-
messages.append({'role': 'user', 'content': sent['sentence_text']})
|
|
1274
|
-
|
|
1275
|
-
if stream:
|
|
1276
|
-
print(f"\n\n{Fore.GREEN}Sentence: {Style.RESET_ALL}\n{sent['sentence_text']}\n")
|
|
1277
|
-
if isinstance(self.context_sentences, int) and self.context_sentences > 0:
|
|
1278
|
-
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
1279
|
-
print(f"{Fore.BLUE}CoT:{Style.RESET_ALL}")
|
|
1280
|
-
|
|
1281
|
-
gen_text = self.inference_engine.chat(
|
|
1282
|
-
messages=messages,
|
|
1283
|
-
max_new_tokens=max_new_tokens,
|
|
1284
|
-
temperature=temperature,
|
|
1285
|
-
stream=stream,
|
|
1286
|
-
**kwrs
|
|
1287
|
-
)
|
|
1547
|
+
if not isinstance(context_sentences, int) and context_sentences != "all":
|
|
1548
|
+
raise ValueError('context_sentences must be an integer (>= 0) or "all".')
|
|
1549
|
+
|
|
1550
|
+
if isinstance(context_sentences, int) and context_sentences < 0:
|
|
1551
|
+
raise ValueError("context_sentences must be a positive integer.")
|
|
1552
|
+
|
|
1553
|
+
if isinstance(context_sentences, int):
|
|
1554
|
+
context_chunker = SlideWindowContextChunker(window_size=context_sentences)
|
|
1555
|
+
elif context_sentences == "all":
|
|
1556
|
+
context_chunker = WholeDocumentContextChunker()
|
|
1288
1557
|
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1558
|
+
super().__init__(inference_engine=inference_engine,
|
|
1559
|
+
unit_chunker=SentenceUnitChunker(),
|
|
1560
|
+
prompt_template=prompt_template,
|
|
1561
|
+
review_mode=review_mode,
|
|
1562
|
+
review_prompt=review_prompt,
|
|
1563
|
+
system_prompt=system_prompt,
|
|
1564
|
+
context_chunker=context_chunker,
|
|
1565
|
+
**kwrs)
|
|
1295
1566
|
|
|
1296
1567
|
|
|
1297
1568
|
class RelationExtractor(Extractor):
|
|
@@ -1361,7 +1632,7 @@ class RelationExtractor(Extractor):
|
|
|
1361
1632
|
|
|
1362
1633
|
@abc.abstractmethod
|
|
1363
1634
|
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1364
|
-
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1635
|
+
temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1365
1636
|
"""
|
|
1366
1637
|
This method considers all combinations of two frames.
|
|
1367
1638
|
|
|
@@ -1377,6 +1648,8 @@ class RelationExtractor(Extractor):
|
|
|
1377
1648
|
the temperature for token sampling.
|
|
1378
1649
|
stream : bool, Optional
|
|
1379
1650
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1651
|
+
return_messages_log : bool, Optional
|
|
1652
|
+
if True, a list of messages will be returned.
|
|
1380
1653
|
|
|
1381
1654
|
Return : List[Dict]
|
|
1382
1655
|
a list of dict with {"frame_1", "frame_2"} for all relations.
|
|
@@ -1446,7 +1719,7 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1446
1719
|
|
|
1447
1720
|
|
|
1448
1721
|
def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1449
|
-
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1722
|
+
temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1450
1723
|
"""
|
|
1451
1724
|
This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
|
|
1452
1725
|
Outputs pairs that are related.
|
|
@@ -1463,11 +1736,17 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1463
1736
|
the temperature for token sampling.
|
|
1464
1737
|
stream : bool, Optional
|
|
1465
1738
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1739
|
+
return_messages_log : bool, Optional
|
|
1740
|
+
if True, a list of messages will be returned.
|
|
1466
1741
|
|
|
1467
1742
|
Return : List[Dict]
|
|
1468
1743
|
a list of dict with {"frame_1_id", "frame_2_id"}.
|
|
1469
1744
|
"""
|
|
1470
1745
|
pairs = itertools.combinations(doc.frames, 2)
|
|
1746
|
+
|
|
1747
|
+
if return_messages_log:
|
|
1748
|
+
messages_log = []
|
|
1749
|
+
|
|
1471
1750
|
output = []
|
|
1472
1751
|
for frame_1, frame_2 in pairs:
|
|
1473
1752
|
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
@@ -1495,13 +1774,19 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1495
1774
|
)
|
|
1496
1775
|
rel_json = self._extract_json(gen_text)
|
|
1497
1776
|
if self._post_process(rel_json):
|
|
1498
|
-
output.append({'
|
|
1777
|
+
output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
|
|
1499
1778
|
|
|
1779
|
+
if return_messages_log:
|
|
1780
|
+
messages.append({"role": "assistant", "content": gen_text})
|
|
1781
|
+
messages_log.append(messages)
|
|
1782
|
+
|
|
1783
|
+
if return_messages_log:
|
|
1784
|
+
return output, messages_log
|
|
1500
1785
|
return output
|
|
1501
1786
|
|
|
1502
1787
|
|
|
1503
1788
|
async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1504
|
-
temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
|
|
1789
|
+
temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1505
1790
|
"""
|
|
1506
1791
|
This is the asynchronous version of the extract() method.
|
|
1507
1792
|
|
|
@@ -1517,6 +1802,8 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1517
1802
|
the temperature for token sampling.
|
|
1518
1803
|
concurrent_batch_size : int, Optional
|
|
1519
1804
|
the number of frame pairs to process in concurrent.
|
|
1805
|
+
return_messages_log : bool, Optional
|
|
1806
|
+
if True, a list of messages will be returned.
|
|
1520
1807
|
|
|
1521
1808
|
Return : List[Dict]
|
|
1522
1809
|
a list of dict with {"frame_1", "frame_2"}.
|
|
@@ -1526,12 +1813,17 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1526
1813
|
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1527
1814
|
|
|
1528
1815
|
pairs = itertools.combinations(doc.frames, 2)
|
|
1816
|
+
if return_messages_log:
|
|
1817
|
+
messages_log = []
|
|
1818
|
+
|
|
1529
1819
|
n_frames = len(doc.frames)
|
|
1530
1820
|
num_pairs = (n_frames * (n_frames-1)) // 2
|
|
1531
|
-
|
|
1532
|
-
tasks = []
|
|
1821
|
+
output = []
|
|
1533
1822
|
for i in range(0, num_pairs, concurrent_batch_size):
|
|
1823
|
+
rel_pair_list = []
|
|
1824
|
+
tasks = []
|
|
1534
1825
|
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
1826
|
+
batch_messages = []
|
|
1535
1827
|
for frame_1, frame_2 in batch:
|
|
1536
1828
|
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1537
1829
|
|
|
@@ -1546,6 +1838,7 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1546
1838
|
"frame_1": str(frame_1.to_dict()),
|
|
1547
1839
|
"frame_2": str(frame_2.to_dict())}
|
|
1548
1840
|
)})
|
|
1841
|
+
|
|
1549
1842
|
task = asyncio.create_task(
|
|
1550
1843
|
self.inference_engine.chat_async(
|
|
1551
1844
|
messages=messages,
|
|
@@ -1555,20 +1848,27 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1555
1848
|
)
|
|
1556
1849
|
)
|
|
1557
1850
|
tasks.append(task)
|
|
1851
|
+
batch_messages.append(messages)
|
|
1558
1852
|
|
|
1559
1853
|
responses = await asyncio.gather(*tasks)
|
|
1560
1854
|
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1855
|
+
for d, response, messages in zip(rel_pair_list, responses, batch_messages):
|
|
1856
|
+
if return_messages_log:
|
|
1857
|
+
messages.append({"role": "assistant", "content": response})
|
|
1858
|
+
messages_log.append(messages)
|
|
1859
|
+
|
|
1860
|
+
rel_json = self._extract_json(response)
|
|
1861
|
+
if self._post_process(rel_json):
|
|
1862
|
+
output.append(d)
|
|
1566
1863
|
|
|
1864
|
+
if return_messages_log:
|
|
1865
|
+
return output, messages_log
|
|
1567
1866
|
return output
|
|
1568
1867
|
|
|
1569
1868
|
|
|
1570
1869
|
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1571
|
-
temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
|
|
1870
|
+
temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
|
|
1871
|
+
stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1572
1872
|
"""
|
|
1573
1873
|
This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
|
|
1574
1874
|
|
|
@@ -1588,6 +1888,8 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1588
1888
|
the number of frame pairs to process in concurrent.
|
|
1589
1889
|
stream : bool, Optional
|
|
1590
1890
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1891
|
+
return_messages_log : bool, Optional
|
|
1892
|
+
if True, a list of messages will be returned.
|
|
1591
1893
|
|
|
1592
1894
|
Return : List[Dict]
|
|
1593
1895
|
a list of dict with {"frame_1", "frame_2"} for all relations.
|
|
@@ -1608,6 +1910,7 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1608
1910
|
max_new_tokens=max_new_tokens,
|
|
1609
1911
|
temperature=temperature,
|
|
1610
1912
|
concurrent_batch_size=concurrent_batch_size,
|
|
1913
|
+
return_messages_log=return_messages_log,
|
|
1611
1914
|
**kwrs)
|
|
1612
1915
|
)
|
|
1613
1916
|
else:
|
|
@@ -1616,6 +1919,7 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
1616
1919
|
max_new_tokens=max_new_tokens,
|
|
1617
1920
|
temperature=temperature,
|
|
1618
1921
|
stream=stream,
|
|
1922
|
+
return_messages_log=return_messages_log,
|
|
1619
1923
|
**kwrs)
|
|
1620
1924
|
|
|
1621
1925
|
|
|
@@ -1689,7 +1993,7 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1689
1993
|
|
|
1690
1994
|
|
|
1691
1995
|
def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1692
|
-
temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
|
|
1996
|
+
temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1693
1997
|
"""
|
|
1694
1998
|
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
|
|
1695
1999
|
|
|
@@ -1705,11 +2009,17 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1705
2009
|
the temperature for token sampling.
|
|
1706
2010
|
stream : bool, Optional
|
|
1707
2011
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
2012
|
+
return_messages_log : bool, Optional
|
|
2013
|
+
if True, a list of messages will be returned.
|
|
1708
2014
|
|
|
1709
2015
|
Return : List[Dict]
|
|
1710
|
-
a list of dict with {"
|
|
2016
|
+
a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
|
|
1711
2017
|
"""
|
|
1712
2018
|
pairs = itertools.combinations(doc.frames, 2)
|
|
2019
|
+
|
|
2020
|
+
if return_messages_log:
|
|
2021
|
+
messages_log = []
|
|
2022
|
+
|
|
1713
2023
|
output = []
|
|
1714
2024
|
for frame_1, frame_2 in pairs:
|
|
1715
2025
|
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
@@ -1736,16 +2046,23 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1736
2046
|
stream=stream,
|
|
1737
2047
|
**kwrs
|
|
1738
2048
|
)
|
|
2049
|
+
|
|
2050
|
+
if return_messages_log:
|
|
2051
|
+
messages.append({"role": "assistant", "content": gen_text})
|
|
2052
|
+
messages_log.append(messages)
|
|
2053
|
+
|
|
1739
2054
|
rel_json = self._extract_json(gen_text)
|
|
1740
2055
|
rel = self._post_process(rel_json, pos_rel_types)
|
|
1741
2056
|
if rel:
|
|
1742
|
-
output.append({'
|
|
2057
|
+
output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id, 'relation':rel})
|
|
1743
2058
|
|
|
2059
|
+
if return_messages_log:
|
|
2060
|
+
return output, messages_log
|
|
1744
2061
|
return output
|
|
1745
2062
|
|
|
1746
2063
|
|
|
1747
2064
|
async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1748
|
-
temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
|
|
2065
|
+
temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1749
2066
|
"""
|
|
1750
2067
|
This is the asynchronous version of the extract() method.
|
|
1751
2068
|
|
|
@@ -1761,21 +2078,28 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1761
2078
|
the temperature for token sampling.
|
|
1762
2079
|
concurrent_batch_size : int, Optional
|
|
1763
2080
|
the number of frame pairs to process in concurrent.
|
|
2081
|
+
return_messages_log : bool, Optional
|
|
2082
|
+
if True, a list of messages will be returned.
|
|
1764
2083
|
|
|
1765
2084
|
Return : List[Dict]
|
|
1766
|
-
a list of dict with {"
|
|
2085
|
+
a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
|
|
1767
2086
|
"""
|
|
1768
2087
|
# Check if self.inference_engine.chat_async() is implemented
|
|
1769
2088
|
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1770
2089
|
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1771
2090
|
|
|
1772
2091
|
pairs = itertools.combinations(doc.frames, 2)
|
|
2092
|
+
if return_messages_log:
|
|
2093
|
+
messages_log = []
|
|
2094
|
+
|
|
1773
2095
|
n_frames = len(doc.frames)
|
|
1774
2096
|
num_pairs = (n_frames * (n_frames-1)) // 2
|
|
1775
|
-
|
|
1776
|
-
tasks = []
|
|
2097
|
+
output = []
|
|
1777
2098
|
for i in range(0, num_pairs, concurrent_batch_size):
|
|
2099
|
+
rel_pair_list = []
|
|
2100
|
+
tasks = []
|
|
1778
2101
|
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
2102
|
+
batch_messages = []
|
|
1779
2103
|
for frame_1, frame_2 in batch:
|
|
1780
2104
|
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1781
2105
|
|
|
@@ -1800,21 +2124,28 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1800
2124
|
)
|
|
1801
2125
|
)
|
|
1802
2126
|
tasks.append(task)
|
|
2127
|
+
batch_messages.append(messages)
|
|
1803
2128
|
|
|
1804
2129
|
responses = await asyncio.gather(*tasks)
|
|
1805
2130
|
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
2131
|
+
for d, response, messages in zip(rel_pair_list, responses, batch_messages):
|
|
2132
|
+
if return_messages_log:
|
|
2133
|
+
messages.append({"role": "assistant", "content": response})
|
|
2134
|
+
messages_log.append(messages)
|
|
2135
|
+
|
|
2136
|
+
rel_json = self._extract_json(response)
|
|
2137
|
+
rel = self._post_process(rel_json, d['pos_rel_types'])
|
|
2138
|
+
if rel:
|
|
2139
|
+
output.append({'frame_1_id':d['frame_1'], 'frame_2_id':d['frame_2'], 'relation':rel})
|
|
1812
2140
|
|
|
2141
|
+
if return_messages_log:
|
|
2142
|
+
return output, messages_log
|
|
1813
2143
|
return output
|
|
1814
2144
|
|
|
1815
2145
|
|
|
1816
2146
|
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1817
|
-
temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
|
|
2147
|
+
temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
|
|
2148
|
+
stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1818
2149
|
"""
|
|
1819
2150
|
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
|
|
1820
2151
|
|
|
@@ -1834,6 +2165,8 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1834
2165
|
the number of frame pairs to process in concurrent.
|
|
1835
2166
|
stream : bool, Optional
|
|
1836
2167
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
2168
|
+
return_messages_log : bool, Optional
|
|
2169
|
+
if True, a list of messages will be returned.
|
|
1837
2170
|
|
|
1838
2171
|
Return : List[Dict]
|
|
1839
2172
|
a list of dict with {"frame_1", "frame_2", "relation"} for all relations.
|
|
@@ -1854,6 +2187,7 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1854
2187
|
max_new_tokens=max_new_tokens,
|
|
1855
2188
|
temperature=temperature,
|
|
1856
2189
|
concurrent_batch_size=concurrent_batch_size,
|
|
2190
|
+
return_messages_log=return_messages_log,
|
|
1857
2191
|
**kwrs)
|
|
1858
2192
|
)
|
|
1859
2193
|
else:
|
|
@@ -1862,5 +2196,6 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1862
2196
|
max_new_tokens=max_new_tokens,
|
|
1863
2197
|
temperature=temperature,
|
|
1864
2198
|
stream=stream,
|
|
2199
|
+
return_messages_log=return_messages_log,
|
|
1865
2200
|
**kwrs)
|
|
1866
2201
|
|