llm-ie 0.4.7__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -8,14 +8,16 @@ import warnings
8
8
  import itertools
9
9
  import asyncio
10
10
  import nest_asyncio
11
- from typing import Set, List, Dict, Tuple, Union, Callable
12
- from llm_ie.data_types import LLMInformationExtractionFrame, LLMInformationExtractionDocument
11
+ from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional
12
+ from llm_ie.data_types import FrameExtractionUnit, FrameExtractionUnitResult, LLMInformationExtractionFrame, LLMInformationExtractionDocument
13
+ from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
14
+ from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
13
15
  from llm_ie.engines import InferenceEngine
14
- from colorama import Fore, Style
16
+ from colorama import Fore, Style
15
17
 
16
18
 
17
19
  class Extractor:
18
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
20
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
19
21
  """
20
22
  This is the abstract class for (frame and relation) extractors.
21
23
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -38,15 +40,46 @@ class Extractor:
38
40
  def get_prompt_guide(cls) -> str:
39
41
  """
40
42
  This method returns the pre-defined prompt guideline for the extractor from the package asset.
43
+ It searches for a guide specific to the current class first, if not found, it will search
44
+ for the guide in its ancestors by traversing the class's method resolution order (MRO).
41
45
  """
42
- # Check if the prompt guide is available
43
- file_path = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{cls.__name__}_prompt_guide.txt")
44
- try:
45
- with open(file_path, 'r', encoding="utf-8") as f:
46
- return f.read()
47
- except FileNotFoundError:
48
- warnings.warn(f"Prompt guide for {cls.__name__} is not available. Is it a customed extractor?", UserWarning)
49
- return None
46
+ original_class_name = cls.__name__
47
+
48
+ for current_class_in_mro in cls.__mro__:
49
+ if current_class_in_mro is object:
50
+ continue
51
+
52
+ current_class_name = current_class_in_mro.__name__
53
+
54
+ try:
55
+ file_path_obj = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{current_class_name}_prompt_guide.txt")
56
+
57
+ with open(file_path_obj, 'r', encoding="utf-8") as f:
58
+ prompt_content = f.read()
59
+ # If the guide was found for an ancestor, not the original class, issue a warning.
60
+ if cls is not current_class_in_mro:
61
+ warnings.warn(
62
+ f"Prompt guide for '{original_class_name}' not found. "
63
+ f"Using guide from ancestor: '{current_class_name}_prompt_guide.txt'.",
64
+ UserWarning
65
+ )
66
+ return prompt_content
67
+ except FileNotFoundError:
68
+ pass
69
+
70
+ except Exception as e:
71
+ warnings.warn(
72
+ f"Error attempting to read prompt guide for '{current_class_name}' "
73
+ f"from '{str(file_path_obj)}': {e}. Trying next in MRO.",
74
+ UserWarning
75
+ )
76
+ continue
77
+
78
+ # If the loop completes, no prompt guide was found for the original class or any of its ancestors.
79
+ raise FileNotFoundError(
80
+ f"Prompt guide for '{original_class_name}' not found in the package asset. "
81
+ f"Is it a custom extractor?"
82
+ )
50
83
 
51
84
  def _get_user_prompt(self, text_content:Union[str, Dict[str,str]]) -> str:
52
85
  """
@@ -138,7 +171,8 @@ class Extractor:
138
171
 
139
172
  class FrameExtractor(Extractor):
140
173
  from nltk.tokenize import RegexpTokenizer
141
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
174
+ def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
175
+ prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None):
142
176
  """
143
177
  This is the abstract class for frame extraction.
144
178
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -147,15 +181,24 @@ class FrameExtractor(Extractor):
147
181
  ----------
148
182
  inference_engine : InferenceEngine
149
183
  the LLM inferencing engine object. Must implements the chat() method.
184
+ unit_chunker : UnitChunker
185
+ the unit chunker object that determines how to chunk the document text into units.
150
186
  prompt_template : str
151
187
  prompt template with "{{<placeholder name>}}" placeholder.
152
188
  system_prompt : str, Optional
153
189
  system prompt.
190
+ context_chunker : ContextChunker
191
+ the context chunker object that determines how to get context for each unit.
154
192
  """
155
193
  super().__init__(inference_engine=inference_engine,
156
194
  prompt_template=prompt_template,
157
- system_prompt=system_prompt,
158
- **kwrs)
195
+ system_prompt=system_prompt)
196
+
197
+ self.unit_chunker = unit_chunker
198
+ if context_chunker is None:
199
+ self.context_chunker = NoContextChunker()
200
+ else:
201
+ self.context_chunker = context_chunker
159
202
 
160
203
  self.tokenizer = self.RegexpTokenizer(r'\w+|[^\w\s]')
161
204
 
@@ -288,7 +331,7 @@ class FrameExtractor(Extractor):
288
331
  return entity_spans
289
332
 
290
333
  @abc.abstractmethod
291
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, return_messages_log:bool=False, **kwrs) -> str:
334
+ def extract(self, text_content:Union[str, Dict[str,str]], return_messages_log:bool=False, **kwrs) -> str:
292
335
  """
293
336
  This method inputs text content and outputs a string generated by LLM
294
337
 
@@ -298,8 +341,6 @@ class FrameExtractor(Extractor):
298
341
  the input text content to put in prompt template.
299
342
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
300
343
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
301
- max_new_tokens : str, Optional
302
- the max number of new tokens LLM can generate.
303
344
  return_messages_log : bool, Optional
304
345
  if True, a list of messages will be returned.
305
346
 
@@ -310,7 +351,7 @@ class FrameExtractor(Extractor):
310
351
 
311
352
 
312
353
  @abc.abstractmethod
313
- def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
354
+ def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str,
314
355
  document_key:str=None, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
315
356
  """
316
357
  This method inputs text content and outputs a list of LLMInformationExtractionFrame
@@ -324,8 +365,6 @@ class FrameExtractor(Extractor):
324
365
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
325
366
  entity_key : str
326
367
  the key (in ouptut JSON) for entity text. Any extraction that does not include entity key will be dropped.
327
- max_new_tokens : str, Optional
328
- the max number of new tokens LLM should generate.
329
368
  document_key : str, Optional
330
369
  specify the key in text_content where document text is.
331
370
  If text_content is str, this parameter will be ignored.
@@ -338,209 +377,37 @@ class FrameExtractor(Extractor):
338
377
  return NotImplemented
339
378
 
340
379
 
341
- class BasicFrameExtractor(FrameExtractor):
342
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
343
- """
344
- This class diretly prompt LLM for frame extraction.
345
- Input system prompt (optional), prompt template (with instruction, few-shot examples),
346
- and specify a LLM.
347
-
348
- Parameters:
349
- ----------
350
- inference_engine : InferenceEngine
351
- the LLM inferencing engine object. Must implements the chat() method.
352
- prompt_template : str
353
- prompt template with "{{<placeholder name>}}" placeholder.
354
- system_prompt : str, Optional
355
- system prompt.
356
- """
357
- super().__init__(inference_engine=inference_engine,
358
- prompt_template=prompt_template,
359
- system_prompt=system_prompt,
360
- **kwrs)
361
-
362
-
363
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
364
- temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> str:
365
- """
366
- This method inputs a text and outputs a string generated by LLM.
367
-
368
- Parameters:
369
- ----------
370
- text_content : Union[str, Dict[str,str]]
371
- the input text content to put in prompt template.
372
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
373
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
374
- max_new_tokens : str, Optional
375
- the max number of new tokens LLM can generate.
376
- temperature : float, Optional
377
- the temperature for token sampling.
378
- stream : bool, Optional
379
- if True, LLM generated text will be printed in terminal in real-time.
380
- return_messages_log : bool, Optional
381
- if True, a list of messages will be returned.
382
-
383
- Return : str
384
- the output from LLM. Need post-processing.
385
- """
386
- messages = []
387
- if self.system_prompt:
388
- messages.append({'role': 'system', 'content': self.system_prompt})
389
-
390
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
391
- response = self.inference_engine.chat(
392
- messages=messages,
393
- max_new_tokens=max_new_tokens,
394
- temperature=temperature,
395
- stream=stream,
396
- **kwrs
397
- )
398
-
399
- if return_messages_log:
400
- messages.append({"role": "assistant", "content": response})
401
- messages_log = [messages]
402
- return response, messages_log
403
-
404
- return response
405
-
406
-
407
- def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
408
- temperature:float=0.0, document_key:str=None, stream:bool=False,
409
- case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2,
410
- fuzzy_score_cutoff:float=0.8, allow_overlap_entities:bool=False,
411
- return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
412
- """
413
- This method inputs a text and outputs a list of LLMInformationExtractionFrame
414
- It use the extract() method and post-process outputs into frames.
415
-
416
- Parameters:
417
- ----------
418
- text_content : Union[str, Dict[str,str]]
419
- the input text content to put in prompt template.
420
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
421
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
422
- entity_key : str
423
- the key (in ouptut JSON) for entity text. Any extraction that does not include entity key will be dropped.
424
- max_new_tokens : str, Optional
425
- the max number of new tokens LLM should generate.
426
- temperature : float, Optional
427
- the temperature for token sampling.
428
- document_key : str, Optional
429
- specify the key in text_content where document text is.
430
- If text_content is str, this parameter will be ignored.
431
- stream : bool, Optional
432
- if True, LLM generated text will be printed in terminal in real-time.
433
- case_sensitive : bool, Optional
434
- if True, entity text matching will be case-sensitive.
435
- fuzzy_match : bool, Optional
436
- if True, fuzzy matching will be applied to find entity text.
437
- fuzzy_buffer_size : float, Optional
438
- the buffer size for fuzzy matching. Default is 20% of entity text length.
439
- fuzzy_score_cutoff : float, Optional
440
- the Jaccard score cutoff for fuzzy matching.
441
- Matched entity text must have a score higher than this value or a None will be returned.
442
- allow_overlap_entities : bool, Optional
443
- if True, entities can overlap in the text.
444
- Note that this can cause multiple frames to be generated on the same entity span if they have same entity text.
445
- return_messages_log : bool, Optional
446
- if True, a list of messages will be returned.
447
-
448
- Return : str
449
- a list of frames.
450
- """
451
- if isinstance(text_content, str):
452
- text = text_content
453
- elif isinstance(text_content, dict):
454
- if document_key is None:
455
- raise ValueError("document_key must be provided when text_content is dict.")
456
- text = text_content[document_key]
457
-
458
- frame_list = []
459
- extraction_results = self.extract(text_content=text_content,
460
- max_new_tokens=max_new_tokens,
461
- temperature=temperature,
462
- stream=stream,
463
- return_messages_log=return_messages_log,
464
- **kwrs)
465
- gen_text, messages_log = extraction_results if return_messages_log else (extraction_results, None)
466
-
467
- entity_json = []
468
- for entity in self._extract_json(gen_text=gen_text):
469
- if entity_key in entity:
470
- entity_json.append(entity)
471
- else:
472
- warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{entity_key}"). This frame will be dropped.', RuntimeWarning)
473
-
474
- spans = self._find_entity_spans(text=text,
475
- entities=[e[entity_key] for e in entity_json],
476
- case_sensitive=case_sensitive,
477
- fuzzy_match=fuzzy_match,
478
- fuzzy_buffer_size=fuzzy_buffer_size,
479
- fuzzy_score_cutoff=fuzzy_score_cutoff,
480
- allow_overlap_entities=allow_overlap_entities)
481
-
482
- for i, (ent, span) in enumerate(zip(entity_json, spans)):
483
- if span is not None:
484
- start, end = span
485
- frame = LLMInformationExtractionFrame(frame_id=f"{i}",
486
- start=start,
487
- end=end,
488
- entity_text=text[start:end],
489
- attr={k: v for k, v in ent.items() if k != entity_key and v != ""})
490
- frame_list.append(frame)
491
-
492
- if return_messages_log:
493
- return frame_list, messages_log
494
-
495
- return frame_list
496
-
497
-
498
- class ReviewFrameExtractor(BasicFrameExtractor):
499
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
500
- review_mode:str, review_prompt:str=None,system_prompt:str=None, **kwrs):
380
+ class DirectFrameExtractor(FrameExtractor):
381
+ def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
382
+ prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None):
501
383
  """
502
- This class add a review step after the BasicFrameExtractor.
503
- The Review process asks LLM to review its output and:
504
- 1. add more frames while keep current. This is efficient for boosting recall.
505
- 2. or, regenerate frames (add new and delete existing).
506
- Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
384
+ This class is for general unit-context frame extraction.
385
+ Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
507
386
 
508
387
  Parameters:
509
388
  ----------
510
389
  inference_engine : InferenceEngine
511
390
  the LLM inferencing engine object. Must implements the chat() method.
391
+ unit_chunker : UnitChunker
392
+ the unit chunker object that determines how to chunk the document text into units.
512
393
  prompt_template : str
513
394
  prompt template with "{{<placeholder name>}}" placeholder.
514
- review_prompt : str: Optional
515
- the prompt text that ask LLM to review. Specify addition or revision in the instruction.
516
- if not provided, a default review prompt will be used.
517
- review_mode : str
518
- review mode. Must be one of {"addition", "revision"}
519
- addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
520
395
  system_prompt : str, Optional
521
396
  system prompt.
397
+ context_chunker : ContextChunker
398
+ the context chunker object that determines how to get context for each unit.
522
399
  """
523
- super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
524
- system_prompt=system_prompt, **kwrs)
525
- if review_mode not in {"addition", "revision"}:
526
- raise ValueError('review_mode must be one of {"addition", "revision"}.')
527
- self.review_mode = review_mode
528
-
529
- if review_prompt:
530
- self.review_prompt = review_prompt
531
- else:
532
- file_path = importlib.resources.files('llm_ie.asset.default_prompts').\
533
- joinpath(f"{self.__class__.__name__}_{self.review_mode}_review_prompt.txt")
534
- with open(file_path, 'r', encoding="utf-8") as f:
535
- self.review_prompt = f.read()
536
-
537
- warnings.warn(f'Custom review prompt not provided. The default review prompt is used:\n"{self.review_prompt}"', UserWarning)
400
+ super().__init__(inference_engine=inference_engine,
401
+ unit_chunker=unit_chunker,
402
+ prompt_template=prompt_template,
403
+ system_prompt=system_prompt,
404
+ context_chunker=context_chunker)
538
405
 
539
406
 
540
407
  def extract(self, text_content:Union[str, Dict[str,str]],
541
- max_new_tokens:int=4096, temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> str:
408
+ document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
542
409
  """
543
- This method inputs a text and outputs a string generated by LLM.
410
+ This method inputs a text and outputs a list of outputs per unit.
544
411
 
545
412
  Parameters:
546
413
  ----------
@@ -548,249 +415,190 @@ class ReviewFrameExtractor(BasicFrameExtractor):
548
415
  the input text content to put in prompt template.
549
416
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
550
417
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
551
- max_new_tokens : str, Optional
552
- the max number of new tokens LLM can generate.
553
- temperature : float, Optional
554
- the temperature for token sampling.
555
- stream : bool, Optional
556
- if True, LLM generated text will be printed in terminal in real-time.
557
- return_messages_log : bool, Optional
558
- if True, a list of messages will be returned.
559
-
560
- Return : str
561
- the output from LLM. Need post-processing.
562
- """
563
- messages = []
564
- if self.system_prompt:
565
- messages.append({'role': 'system', 'content': self.system_prompt})
566
-
567
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
568
- # Initial output
569
- if stream:
570
- print(f"{Fore.BLUE}Initial Output:{Style.RESET_ALL}")
571
-
572
- initial = self.inference_engine.chat(
573
- messages=messages,
574
- max_new_tokens=max_new_tokens,
575
- temperature=temperature,
576
- stream=stream,
577
- **kwrs
578
- )
579
-
580
- # Review
581
- messages.append({'role': 'assistant', 'content': initial})
582
- messages.append({'role': 'user', 'content': self.review_prompt})
583
-
584
- if stream:
585
- print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
586
- review = self.inference_engine.chat(
587
- messages=messages,
588
- max_new_tokens=max_new_tokens,
589
- temperature=temperature,
590
- stream=stream,
591
- **kwrs
592
- )
593
-
594
- # Output
595
- output_text = ""
596
- if self.review_mode == "revision":
597
- output_text = review
598
- elif self.review_mode == "addition":
599
- output_text = initial + '\n' + review
600
-
601
- if return_messages_log:
602
- messages.append({"role": "assistant", "content": review})
603
- messages_log = [messages]
604
- return output_text, messages_log
605
-
606
- return output_text
607
-
608
-
609
- class SentenceFrameExtractor(FrameExtractor):
610
- from nltk.tokenize.punkt import PunktSentenceTokenizer
611
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
612
- context_sentences:Union[str, int]="all", **kwrs):
613
- """
614
- This class performs sentence-by-sentence information extraction.
615
- The process is as follows:
616
- 1. system prompt (optional)
617
- 2. user prompt with instructions (schema, background, full text, few-shot example...)
618
- 3. feed a sentence (start with first sentence)
619
- 4. LLM extract entities and attributes from the sentence
620
- 5. repeat #3 and #4
621
-
622
- Input system prompt (optional), prompt template (with user instructions),
623
- and specify a LLM.
624
-
625
- Parameters:
626
- ----------
627
- inference_engine : InferenceEngine
628
- the LLM inferencing engine object. Must implements the chat() method.
629
- prompt_template : str
630
- prompt template with "{{<placeholder name>}}" placeholder.
631
- system_prompt : str, Optional
632
- system prompt.
633
- context_sentences : Union[str, int], Optional
634
- number of sentences before and after the given sentence to provide additional context.
635
- if "all", the full text will be provided in the prompt as context.
636
- if 0, no additional context will be provided.
637
- This is good for tasks that does not require context beyond the given sentence.
638
- if > 0, the number of sentences before and after the given sentence to provide as context.
639
- This is good for tasks that require context beyond the given sentence.
640
- """
641
- super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
642
- system_prompt=system_prompt, **kwrs)
643
-
644
- if not isinstance(context_sentences, int) and context_sentences != "all":
645
- raise ValueError('context_sentences must be an integer (>= 0) or "all".')
646
-
647
- if isinstance(context_sentences, int) and context_sentences < 0:
648
- raise ValueError("context_sentences must be a positive integer.")
649
-
650
- self.context_sentences =context_sentences
651
-
652
-
653
- def _get_sentences(self, text:str) -> List[Dict[str,str]]:
654
- """
655
- This method sentence tokenize the input text into a list of sentences
656
- as dict of {start, end, sentence_text}
657
-
658
- Parameters:
659
- ----------
660
- text : str
661
- text to sentence tokenize.
662
-
663
- Returns : List[Dict[str,str]]
664
- a list of sentences as dict with keys: {"sentence_text", "start", "end"}.
665
- """
666
- sentences = []
667
- for start, end in self.PunktSentenceTokenizer().span_tokenize(text):
668
- sentences.append({"sentence_text": text[start:end],
669
- "start": start,
670
- "end": end})
671
- return sentences
672
-
673
-
674
- def _get_context_sentences(self, text_content, i:int, sentences:List[Dict[str, str]], document_key:str=None) -> str:
675
- """
676
- This function returns the context sentences for the current sentence of interest (i).
677
- """
678
- if self.context_sentences == "all":
679
- context = text_content if isinstance(text_content, str) else text_content[document_key]
680
- elif self.context_sentences == 0:
681
- context = ""
682
- else:
683
- start = max(0, i - self.context_sentences)
684
- end = min(i + 1 + self.context_sentences, len(sentences))
685
- context = " ".join([s['sentence_text'] for s in sentences[start:end]])
686
- return context
687
-
688
-
689
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
690
- document_key:str=None, temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
691
- """
692
- This method inputs a text and outputs a list of outputs per sentence.
693
-
694
- Parameters:
695
- ----------
696
- text_content : Union[str, Dict[str,str]]
697
- the input text content to put in prompt template.
698
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
699
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
700
- max_new_tokens : str, Optional
701
- the max number of new tokens LLM should generate.
702
418
  document_key : str, Optional
703
419
  specify the key in text_content where document text is.
704
420
  If text_content is str, this parameter will be ignored.
705
- temperature : float, Optional
706
- the temperature for token sampling.
707
- stream : bool, Optional
421
+ verbose : bool, Optional
708
422
  if True, LLM generated text will be printed in terminal in real-time.
709
423
  return_messages_log : bool, Optional
710
424
  if True, a list of messages will be returned.
711
425
 
712
- Return : str
713
- the output from LLM. Need post-processing.
426
+ Return : List[FrameExtractionUnitResult]
427
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
714
428
  """
715
429
  # define output
716
430
  output = []
717
- # sentence tokenization
431
+ # unit chunking
718
432
  if isinstance(text_content, str):
719
- sentences = self._get_sentences(text_content)
433
+ doc_text = text_content
434
+
720
435
  elif isinstance(text_content, dict):
721
436
  if document_key is None:
722
437
  raise ValueError("document_key must be provided when text_content is dict.")
723
- sentences = self._get_sentences(text_content[document_key])
724
-
438
+ doc_text = text_content[document_key]
439
+
440
+ units = self.unit_chunker.chunk(doc_text)
441
+ # context chunker init
442
+ self.context_chunker.fit(doc_text, units)
443
+ # messages log
725
444
  if return_messages_log:
726
445
  messages_log = []
727
446
 
728
- # generate sentence by sentence
729
- for i, sent in enumerate(sentences):
447
+ # generate unit by unit
448
+ for i, unit in enumerate(units):
730
449
  # construct chat messages
731
450
  messages = []
732
451
  if self.system_prompt:
733
452
  messages.append({'role': 'system', 'content': self.system_prompt})
734
453
 
735
- context = self._get_context_sentences(text_content, i, sentences, document_key)
454
+ context = self.context_chunker.chunk(unit)
736
455
 
737
- if self.context_sentences == 0:
738
- # no context, just place sentence of interest
456
+ if context == "":
457
+ # no context, just place unit in user prompt
739
458
  if isinstance(text_content, str):
740
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
459
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
741
460
  else:
742
- sentence_content = text_content.copy()
743
- sentence_content[document_key] = sent['sentence_text']
744
- messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
461
+ unit_content = text_content.copy()
462
+ unit_content[document_key] = unit.text
463
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
745
464
  else:
746
- # insert context
465
+ # insert context to user prompt
747
466
  if isinstance(text_content, str):
748
467
  messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
749
468
  else:
750
469
  context_content = text_content.copy()
751
470
  context_content[document_key] = context
752
471
  messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
753
- # simulate conversation
754
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
755
- # place sentence of interest
756
- messages.append({'role': 'user', 'content': sent['sentence_text']})
757
-
758
- if stream:
759
- print(f"\n\n{Fore.GREEN}Sentence {i}:{Style.RESET_ALL}\n{sent['sentence_text']}\n")
760
- if isinstance(self.context_sentences, int) and self.context_sentences > 0:
472
+ # simulate conversation where assistant confirms
473
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
474
+ # place unit of interest
475
+ messages.append({'role': 'user', 'content': unit.text})
476
+
477
+ if verbose:
478
+ print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
479
+ if context != "":
761
480
  print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
762
481
 
763
482
  print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
764
483
 
484
+
765
485
  gen_text = self.inference_engine.chat(
766
486
  messages=messages,
767
- max_new_tokens=max_new_tokens,
768
- temperature=temperature,
769
- stream=stream,
770
- **kwrs
487
+ verbose=verbose,
488
+ stream=False
771
489
  )
772
-
490
+
773
491
  if return_messages_log:
774
492
  messages.append({"role": "assistant", "content": gen_text})
775
493
  messages_log.append(messages)
776
494
 
777
495
  # add to output
778
- output.append({'sentence_start': sent['start'],
779
- 'sentence_end': sent['end'],
780
- 'sentence_text': sent['sentence_text'],
781
- 'gen_text': gen_text})
496
+ result = FrameExtractionUnitResult(
497
+ start=unit.start,
498
+ end=unit.end,
499
+ text=unit.text,
500
+ gen_text=gen_text)
501
+ output.append(result)
782
502
 
783
503
  if return_messages_log:
784
504
  return output, messages_log
785
505
 
786
506
  return output
787
507
 
508
+ def stream(self, text_content: Union[str, Dict[str, str]],
509
+ document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
510
+ """
511
+ Streams LLM responses per unit with structured event types,
512
+ and returns collected data for post-processing.
513
+
514
+ Yields:
515
+ -------
516
+ Dict[str, Any]: (type, data)
517
+ - {"type": "info", "data": str_message}: General informational messages.
518
+ - {"type": "unit", "data": dict_unit_info}: Signals start of a new unit. dict_unit_info contains {'id', 'text', 'start', 'end'}
519
+ - {"type": "context", "data": str_context}: Context string for the current unit.
520
+ - {"type": "reasoning", "data": str_chunk}: A reasoning model thinking chunk from the LLM.
521
+ - {"type": "response", "data": str_chunk}: A response/answer chunk from the LLM.
522
+
523
+ Returns:
524
+ --------
525
+ List[FrameExtractionUnitResult]:
526
+ A list of FrameExtractionUnitResult objects, each containing the
527
+ original unit details and the fully accumulated 'gen_text' from the LLM.
528
+ """
529
+ collected_results: List[FrameExtractionUnitResult] = []
530
+
531
+ if isinstance(text_content, str):
532
+ doc_text = text_content
533
+ elif isinstance(text_content, dict):
534
+ if document_key is None:
535
+ raise ValueError("document_key must be provided when text_content is dict.")
536
+ if document_key not in text_content:
537
+ raise ValueError(f"document_key '{document_key}' not found in text_content.")
538
+ doc_text = text_content[document_key]
539
+ else:
540
+ raise TypeError("text_content must be a string or a dictionary.")
541
+
542
+ units: List[FrameExtractionUnit] = self.unit_chunker.chunk(doc_text)
543
+ self.context_chunker.fit(doc_text, units)
544
+
545
+ yield {"type": "info", "data": f"Starting LLM processing for {len(units)} units."}
788
546
 
789
- async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
790
- document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32,
791
- return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
547
+ for i, unit in enumerate(units):
548
+ unit_info_payload = {"id": i, "text": unit.text, "start": unit.start, "end": unit.end}
549
+ yield {"type": "unit", "data": unit_info_payload}
550
+
551
+ messages = []
552
+ if self.system_prompt:
553
+ messages.append({'role': 'system', 'content': self.system_prompt})
554
+
555
+ context_str = self.context_chunker.chunk(unit)
556
+
557
+ # Construct prompt input based on whether text_content was str or dict
558
+ if context_str:
559
+ yield {"type": "context", "data": context_str}
560
+ prompt_input_for_context = context_str
561
+ if isinstance(text_content, dict):
562
+ context_content_dict = text_content.copy()
563
+ context_content_dict[document_key] = context_str
564
+ prompt_input_for_context = context_content_dict
565
+ messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_context)})
566
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
567
+ messages.append({'role': 'user', 'content': unit.text})
568
+ else: # No context
569
+ prompt_input_for_unit = unit.text
570
+ if isinstance(text_content, dict):
571
+ unit_content_dict = text_content.copy()
572
+ unit_content_dict[document_key] = unit.text
573
+ prompt_input_for_unit = unit_content_dict
574
+ messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_unit)})
575
+
576
+ current_gen_text = ""
577
+
578
+ response_stream = self.inference_engine.chat(
579
+ messages=messages,
580
+ stream=True
581
+ )
582
+ for chunk in response_stream:
583
+ yield chunk
584
+ current_gen_text += chunk
585
+
586
+ # Store the result for this unit
587
+ result_for_unit = FrameExtractionUnitResult(
588
+ start=unit.start,
589
+ end=unit.end,
590
+ text=unit.text,
591
+ gen_text=current_gen_text
592
+ )
593
+ collected_results.append(result_for_unit)
594
+
595
+ yield {"type": "info", "data": "All units processed by LLM."}
596
+ return collected_results
597
+
598
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
599
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
792
600
  """
793
- The asynchronous version of the extract() method.
601
+ This is the asynchronous version of the extract() method.
794
602
 
795
603
  Parameters:
796
604
  ----------
@@ -798,109 +606,126 @@ class SentenceFrameExtractor(FrameExtractor):
798
606
  the input text content to put in prompt template.
799
607
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
800
608
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
801
- max_new_tokens : str, Optional
802
- the max number of new tokens LLM should generate.
803
609
  document_key : str, Optional
804
610
  specify the key in text_content where document text is.
805
611
  If text_content is str, this parameter will be ignored.
806
- temperature : float, Optional
807
- the temperature for token sampling.
808
612
  concurrent_batch_size : int, Optional
809
- the number of sentences to process in concurrent.
613
+ the batch size for concurrent processing.
810
614
  return_messages_log : bool, Optional
811
615
  if True, a list of messages will be returned.
812
616
 
813
- Return : str
814
- the output from LLM. Need post-processing.
617
+ Return : List[FrameExtractionUnitResult]
618
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
815
619
  """
816
- # Check if self.inference_engine.chat_async() is implemented
817
- if not hasattr(self.inference_engine, 'chat_async'):
818
- raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
819
-
820
- # define output
821
- output = []
822
- # sentence tokenization
823
620
  if isinstance(text_content, str):
824
- sentences = self._get_sentences(text_content)
621
+ doc_text = text_content
825
622
  elif isinstance(text_content, dict):
826
623
  if document_key is None:
827
624
  raise ValueError("document_key must be provided when text_content is dict.")
828
- sentences = self._get_sentences(text_content[document_key])
625
+ if document_key not in text_content:
626
+ raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
627
+ doc_text = text_content[document_key]
628
+ else:
629
+ raise TypeError("text_content must be a string or a dictionary.")
829
630
 
830
- if return_messages_log:
831
- messages_log = []
631
+ units = self.unit_chunker.chunk(doc_text)
832
632
 
833
- # generate sentence by sentence
834
- for i in range(0, len(sentences), concurrent_batch_size):
835
- tasks = []
836
- batch = sentences[i:i + concurrent_batch_size]
837
- batch_messages = []
838
- for j, sent in enumerate(batch):
839
- # construct chat messages
840
- messages = []
841
- if self.system_prompt:
842
- messages.append({'role': 'system', 'content': self.system_prompt})
633
+ # context chunker init
634
+ self.context_chunker.fit(doc_text, units)
843
635
 
844
- context = self._get_context_sentences(text_content, i + j, sentences, document_key)
845
-
846
- if self.context_sentences == 0:
847
- # no context, just place sentence of interest
848
- if isinstance(text_content, str):
849
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
850
- else:
851
- sentence_content = text_content.copy()
852
- sentence_content[document_key] = sent['sentence_text']
853
- messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
636
+ # Prepare inputs for all units first
637
+ tasks_input = []
638
+ for i, unit in enumerate(units):
639
+ # construct chat messages
640
+ messages = []
641
+ if self.system_prompt:
642
+ messages.append({'role': 'system', 'content': self.system_prompt})
643
+
644
+ context = self.context_chunker.chunk(unit)
645
+
646
+ if context == "":
647
+ # no context, just place unit in user prompt
648
+ if isinstance(text_content, str):
649
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
854
650
  else:
855
- # insert context
856
- if isinstance(text_content, str):
857
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
858
- else:
859
- context_content = text_content.copy()
860
- context_content[document_key] = context
861
- messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
862
- # simulate conversation
863
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
864
- # place sentence of interest
865
- messages.append({'role': 'user', 'content': sent['sentence_text']})
866
-
867
- # add to tasks
868
- task = asyncio.create_task(
869
- self.inference_engine.chat_async(
870
- messages=messages,
871
- max_new_tokens=max_new_tokens,
872
- temperature=temperature,
873
- **kwrs
874
- )
875
- )
876
- tasks.append(task)
877
- batch_messages.append(messages)
651
+ unit_content = text_content.copy()
652
+ unit_content[document_key] = unit.text
653
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
654
+ else:
655
+ # insert context to user prompt
656
+ if isinstance(text_content, str):
657
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
658
+ else:
659
+ context_content = text_content.copy()
660
+ context_content[document_key] = context
661
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
662
+ # simulate conversation where assistant confirms
663
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
664
+ # place unit of interest
665
+ messages.append({'role': 'user', 'content': unit.text})
878
666
 
879
- # Wait until the batch is done, collect results and move on to next batch
880
- responses = await asyncio.gather(*tasks)
667
+ # Store unit and messages together for the task
668
+ tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
881
669
 
882
- # Collect outputs
883
- for gen_text, sent, messages in zip(responses, batch, batch_messages):
884
- if return_messages_log:
885
- messages.append({"role": "assistant", "content": gen_text})
886
- messages_log.append(messages)
670
+ # Process units concurrently with asyncio.Semaphore
671
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
672
+
673
+ async def semaphore_helper(task_data: Dict, **kwrs):
674
+ unit = task_data["unit"]
675
+ messages = task_data["messages"]
676
+ original_index = task_data["original_index"]
677
+
678
+ async with semaphore:
679
+ gen_text = await self.inference_engine.chat_async(
680
+ messages=messages
681
+ )
682
+ return {"original_index": original_index, "unit": unit, "gen_text": gen_text, "messages": messages}
683
+
684
+ # Create and gather tasks
685
+ tasks = []
686
+ for task_inp in tasks_input:
687
+ task = asyncio.create_task(semaphore_helper(
688
+ task_inp
689
+ ))
690
+ tasks.append(task)
691
+
692
+ results_raw = await asyncio.gather(*tasks)
693
+
694
+ # Sort results back into original order using the index stored
695
+ results_raw.sort(key=lambda x: x["original_index"])
696
+
697
+ # Restructure the results
698
+ output: List[FrameExtractionUnitResult] = []
699
+ messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
700
+
701
+ for result_data in results_raw:
702
+ unit = result_data["unit"]
703
+ gen_text = result_data["gen_text"]
704
+
705
+ # Create result object
706
+ result = FrameExtractionUnitResult(
707
+ start=unit.start,
708
+ end=unit.end,
709
+ text=unit.text,
710
+ gen_text=gen_text
711
+ )
712
+ output.append(result)
713
+
714
+ # Append to messages log if requested
715
+ if return_messages_log:
716
+ final_messages = result_data["messages"] + [{"role": "assistant", "content": gen_text}]
717
+ messages_log.append(final_messages)
887
718
 
888
- output.append({'sentence_start': sent['start'],
889
- 'sentence_end': sent['end'],
890
- 'sentence_text': sent['sentence_text'],
891
- 'gen_text': gen_text})
892
-
893
719
  if return_messages_log:
894
720
  return output, messages_log
895
-
896
- return output
897
-
721
+ else:
722
+ return output
898
723
 
899
- def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=512,
900
- document_key:str=None, temperature:float=0.0, stream:bool=False,
901
- concurrent:bool=False, concurrent_batch_size:int=32,
724
+
725
+ def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
726
+ verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32,
902
727
  case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
903
- allow_overlap_entities:bool=False, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
728
+ allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
904
729
  """
905
730
  This method inputs a text and outputs a list of LLMInformationExtractionFrame
906
731
  It use the extract() method and post-process outputs into frames.
@@ -911,16 +736,10 @@ class SentenceFrameExtractor(FrameExtractor):
911
736
  the input text content to put in prompt template.
912
737
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
913
738
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
914
- entity_key : str
915
- the key (in ouptut JSON) for entity text.
916
- max_new_tokens : str, Optional
917
- the max number of new tokens LLM should generate.
918
739
  document_key : str, Optional
919
740
  specify the key in text_content where document text is.
920
741
  If text_content is str, this parameter will be ignored.
921
- temperature : float, Optional
922
- the temperature for token sampling.
923
- stream : bool, Optional
742
+ verbose : bool, Optional
924
743
  if True, LLM generated text will be printed in terminal in real-time.
925
744
  concurrent : bool, Optional
926
745
  if True, the sentences will be extracted in concurrent.
@@ -944,41 +763,36 @@ class SentenceFrameExtractor(FrameExtractor):
944
763
  Return : str
945
764
  a list of frames.
946
765
  """
766
+ ENTITY_KEY = "entity_text"
947
767
  if concurrent:
948
- if stream:
949
- warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
768
+ if verbose:
769
+ warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
950
770
 
951
771
  nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
952
772
  extraction_results = asyncio.run(self.extract_async(text_content=text_content,
953
- max_new_tokens=max_new_tokens,
954
773
  document_key=document_key,
955
- temperature=temperature,
956
774
  concurrent_batch_size=concurrent_batch_size,
957
- return_messages_log=return_messages_log,
958
- **kwrs)
775
+ return_messages_log=return_messages_log)
959
776
  )
960
777
  else:
961
778
  extraction_results = self.extract(text_content=text_content,
962
- max_new_tokens=max_new_tokens,
963
- document_key=document_key,
964
- temperature=temperature,
965
- stream=stream,
966
- return_messages_log=return_messages_log,
967
- **kwrs)
779
+ document_key=document_key,
780
+ verbose=verbose,
781
+ return_messages_log=return_messages_log)
968
782
 
969
- llm_output_sentences, messages_log = extraction_results if return_messages_log else (extraction_results, None)
783
+ llm_output_results, messages_log = extraction_results if return_messages_log else (extraction_results, None)
970
784
 
971
785
  frame_list = []
972
- for sent in llm_output_sentences:
786
+ for res in llm_output_results:
973
787
  entity_json = []
974
- for entity in self._extract_json(gen_text=sent['gen_text']):
975
- if entity_key in entity:
788
+ for entity in self._extract_json(gen_text=res.gen_text):
789
+ if ENTITY_KEY in entity:
976
790
  entity_json.append(entity)
977
791
  else:
978
- warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{entity_key}"). This frame will be dropped.', RuntimeWarning)
792
+ warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
979
793
 
980
- spans = self._find_entity_spans(text=sent['sentence_text'],
981
- entities=[e[entity_key] for e in entity_json],
794
+ spans = self._find_entity_spans(text=res.text,
795
+ entities=[e[ENTITY_KEY] for e in entity_json],
982
796
  case_sensitive=case_sensitive,
983
797
  fuzzy_match=fuzzy_match,
984
798
  fuzzy_buffer_size=fuzzy_buffer_size,
@@ -987,34 +801,41 @@ class SentenceFrameExtractor(FrameExtractor):
987
801
  for ent, span in zip(entity_json, spans):
988
802
  if span is not None:
989
803
  start, end = span
990
- entity_text = sent['sentence_text'][start:end]
991
- start += sent['sentence_start']
992
- end += sent['sentence_start']
804
+ entity_text = res.text[start:end]
805
+ start += res.start
806
+ end += res.start
807
+ attr = {}
808
+ if "attr" in ent and ent["attr"] is not None:
809
+ attr = ent["attr"]
810
+
993
811
  frame = LLMInformationExtractionFrame(frame_id=f"{len(frame_list)}",
994
812
  start=start,
995
813
  end=end,
996
814
  entity_text=entity_text,
997
- attr={k: v for k, v in ent.items() if k != entity_key and v != ""})
815
+ attr=attr)
998
816
  frame_list.append(frame)
999
817
 
1000
818
  if return_messages_log:
1001
819
  return frame_list, messages_log
1002
820
  return frame_list
821
+
1003
822
 
1004
-
1005
- class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1006
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
1007
- review_mode:str, review_prompt:str=None, system_prompt:str=None,
1008
- context_sentences:Union[str, int]="all", **kwrs):
823
+ class ReviewFrameExtractor(DirectFrameExtractor):
824
+ def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
825
+ prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
1009
826
  """
1010
- This class adds a review step after the SentenceFrameExtractor.
1011
- For each sentence, the review process asks LLM to review its output and:
1012
- 1. add more frames while keeping current. This is efficient for boosting recall.
827
+ This class add a review step after the DirectFrameExtractor.
828
+ The Review process asks LLM to review its output and:
829
+ 1. add more frames while keep current. This is efficient for boosting recall.
1013
830
  2. or, regenerate frames (add new and delete existing).
1014
831
  Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
1015
832
 
1016
833
  Parameters:
1017
834
  ----------
835
+ unit_chunker : UnitChunker
836
+ the unit chunker object that determines how to chunk the document text into units.
837
+ context_chunker : ContextChunker
838
+ the context chunker object that determines how to get context for each unit.
1018
839
  inference_engine : InferenceEngine
1019
840
  the LLM inferencing engine object. Must implements the chat() method.
1020
841
  prompt_template : str
@@ -1027,36 +848,52 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1027
848
  addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
1028
849
  system_prompt : str, Optional
1029
850
  system prompt.
1030
- context_sentences : Union[str, int], Optional
1031
- number of sentences before and after the given sentence to provide additional context.
1032
- if "all", the full text will be provided in the prompt as context.
1033
- if 0, no additional context will be provided.
1034
- This is good for tasks that does not require context beyond the given sentence.
1035
- if > 0, the number of sentences before and after the given sentence to provide as context.
1036
- This is good for tasks that require context beyond the given sentence.
1037
851
  """
1038
- super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
1039
- system_prompt=system_prompt, context_sentences=context_sentences, **kwrs)
1040
-
852
+ super().__init__(inference_engine=inference_engine,
853
+ unit_chunker=unit_chunker,
854
+ prompt_template=prompt_template,
855
+ system_prompt=system_prompt,
856
+ context_chunker=context_chunker)
857
+ # check review mode
1041
858
  if review_mode not in {"addition", "revision"}:
1042
859
  raise ValueError('review_mode must be one of {"addition", "revision"}.')
1043
860
  self.review_mode = review_mode
1044
-
861
+ # assign review prompt
1045
862
  if review_prompt:
1046
863
  self.review_prompt = review_prompt
1047
864
  else:
1048
- file_path = importlib.resources.files('llm_ie.asset.default_prompts').\
1049
- joinpath(f"{self.__class__.__name__}_{self.review_mode}_review_prompt.txt")
1050
- with open(file_path, 'r', encoding="utf-8") as f:
1051
- self.review_prompt = f.read()
1052
-
1053
- warnings.warn(f'Custom review prompt not provided. The default review prompt is used:\n"{self.review_prompt}"', UserWarning)
1054
-
865
+ self.review_prompt = None
866
+ original_class_name = self.__class__.__name__
867
+
868
+ current_class_name = original_class_name
869
+ for current_class_in_mro in self.__class__.__mro__:
870
+ if current_class_in_mro is object:
871
+ continue
872
+
873
+ current_class_name = current_class_in_mro.__name__
874
+ try:
875
+ file_path = importlib.resources.files('llm_ie.asset.default_prompts').\
876
+ joinpath(f"{self.__class__.__name__}_{self.review_mode}_review_prompt.txt")
877
+ with open(file_path, 'r', encoding="utf-8") as f:
878
+ self.review_prompt = f.read()
879
+ except FileNotFoundError:
880
+ pass
881
+
882
+ except Exception as e:
883
+ warnings.warn(
884
+ f"Error attempting to read default review prompt for '{current_class_name}' "
885
+ f"from '{str(file_path)}': {e}. Trying next in MRO.",
886
+ UserWarning
887
+ )
888
+ continue
889
+
890
+ if self.review_prompt is None:
891
+ raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
1055
892
 
1056
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
1057
- document_key:str=None, temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
893
+ def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
894
+ verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
1058
895
  """
1059
- This method inputs a text and outputs a list of outputs per sentence.
896
+ This method inputs a text and outputs a list of outputs per unit.
1060
897
 
1061
898
  Parameters:
1062
899
  ----------
@@ -1064,281 +901,468 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1064
901
  the input text content to put in prompt template.
1065
902
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1066
903
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1067
- max_new_tokens : str, Optional
1068
- the max number of new tokens LLM should generate.
1069
904
  document_key : str, Optional
1070
905
  specify the key in text_content where document text is.
1071
906
  If text_content is str, this parameter will be ignored.
1072
- temperature : float, Optional
1073
- the temperature for token sampling.
1074
- stream : bool, Optional
907
+ verbose : bool, Optional
1075
908
  if True, LLM generated text will be printed in terminal in real-time.
1076
909
  return_messages_log : bool, Optional
1077
910
  if True, a list of messages will be returned.
1078
911
 
1079
- Return : str
1080
- the output from LLM. Need post-processing.
912
+ Return : List[FrameExtractionUnitResult]
913
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
1081
914
  """
1082
915
  # define output
1083
916
  output = []
1084
- # sentence tokenization
917
+ # unit chunking
918
+ if isinstance(text_content, str):
919
+ doc_text = text_content
920
+
921
+ elif isinstance(text_content, dict):
922
+ if document_key is None:
923
+ raise ValueError("document_key must be provided when text_content is dict.")
924
+ doc_text = text_content[document_key]
925
+
926
+ units = self.unit_chunker.chunk(doc_text)
927
+ # context chunker init
928
+ self.context_chunker.fit(doc_text, units)
929
+ # messages log
930
+ if return_messages_log:
931
+ messages_log = []
932
+
933
+ # generate unit by unit
934
+ for i, unit in enumerate(units):
935
+ # <--- Initial generation step --->
936
+ # construct chat messages
937
+ messages = []
938
+ if self.system_prompt:
939
+ messages.append({'role': 'system', 'content': self.system_prompt})
940
+
941
+ context = self.context_chunker.chunk(unit)
942
+
943
+ if context == "":
944
+ # no context, just place unit in user prompt
945
+ if isinstance(text_content, str):
946
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
947
+ else:
948
+ unit_content = text_content.copy()
949
+ unit_content[document_key] = unit.text
950
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
951
+ else:
952
+ # insert context to user prompt
953
+ if isinstance(text_content, str):
954
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
955
+ else:
956
+ context_content = text_content.copy()
957
+ context_content[document_key] = context
958
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
959
+ # simulate conversation where assistant confirms
960
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
961
+ # place unit of interest
962
+ messages.append({'role': 'user', 'content': unit.text})
963
+
964
+ if verbose:
965
+ print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
966
+ if context != "":
967
+ print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
968
+
969
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
970
+
971
+
972
+ initial = self.inference_engine.chat(
973
+ messages=messages,
974
+ verbose=verbose,
975
+ stream=False
976
+ )
977
+
978
+ if return_messages_log:
979
+ messages.append({"role": "assistant", "content": initial})
980
+ messages_log.append(messages)
981
+
982
+ # <--- Review step --->
983
+ if verbose:
984
+ print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
985
+
986
+ messages.append({'role': 'assistant', 'content': initial})
987
+ messages.append({'role': 'user', 'content': self.review_prompt})
988
+
989
+ review = self.inference_engine.chat(
990
+ messages=messages,
991
+ verbose=verbose,
992
+ stream=False
993
+ )
994
+
995
+ # Output
996
+ if self.review_mode == "revision":
997
+ gen_text = review
998
+ elif self.review_mode == "addition":
999
+ gen_text = initial + '\n' + review
1000
+
1001
+ if return_messages_log:
1002
+ messages.append({"role": "assistant", "content": review})
1003
+ messages_log.append(messages)
1004
+
1005
+ # add to output
1006
+ result = FrameExtractionUnitResult(
1007
+ start=unit.start,
1008
+ end=unit.end,
1009
+ text=unit.text,
1010
+ gen_text=gen_text)
1011
+ output.append(result)
1012
+
1013
+ if return_messages_log:
1014
+ return output, messages_log
1015
+
1016
+ return output
1017
+
1018
+
1019
+ def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
1020
+ """
1021
+ This method inputs a text and outputs a list of outputs per unit.
1022
+
1023
+ Parameters:
1024
+ ----------
1025
+ text_content : Union[str, Dict[str,str]]
1026
+ the input text content to put in prompt template.
1027
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1028
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1029
+ document_key : str, Optional
1030
+ specify the key in text_content where document text is.
1031
+ If text_content is str, this parameter will be ignored.
1032
+
1033
+ Return : List[FrameExtractionUnitResult]
1034
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
1035
+ """
1036
+ # unit chunking
1085
1037
  if isinstance(text_content, str):
1086
- sentences = self._get_sentences(text_content)
1038
+ doc_text = text_content
1039
+
1087
1040
  elif isinstance(text_content, dict):
1088
1041
  if document_key is None:
1089
1042
  raise ValueError("document_key must be provided when text_content is dict.")
1090
- sentences = self._get_sentences(text_content[document_key])
1043
+ doc_text = text_content[document_key]
1091
1044
 
1092
- if return_messages_log:
1093
- messages_log = []
1045
+ units = self.unit_chunker.chunk(doc_text)
1046
+ # context chunker init
1047
+ self.context_chunker.fit(doc_text, units)
1094
1048
 
1095
- # generate sentence by sentence
1096
- for i, sent in enumerate(sentences):
1049
+ # generate unit by unit
1050
+ for i, unit in enumerate(units):
1051
+ # <--- Initial generation step --->
1097
1052
  # construct chat messages
1098
1053
  messages = []
1099
1054
  if self.system_prompt:
1100
1055
  messages.append({'role': 'system', 'content': self.system_prompt})
1101
1056
 
1102
- context = self._get_context_sentences(text_content, i, sentences, document_key)
1057
+ context = self.context_chunker.chunk(unit)
1103
1058
 
1104
- if self.context_sentences == 0:
1105
- # no context, just place sentence of interest
1059
+ if context == "":
1060
+ # no context, just place unit in user prompt
1106
1061
  if isinstance(text_content, str):
1107
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1062
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
1108
1063
  else:
1109
- sentence_content = text_content.copy()
1110
- sentence_content[document_key] = sent['sentence_text']
1111
- messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
1064
+ unit_content = text_content.copy()
1065
+ unit_content[document_key] = unit.text
1066
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
1112
1067
  else:
1113
- # insert context
1068
+ # insert context to user prompt
1114
1069
  if isinstance(text_content, str):
1115
1070
  messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1116
1071
  else:
1117
1072
  context_content = text_content.copy()
1118
1073
  context_content[document_key] = context
1119
1074
  messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1120
- # simulate conversation
1121
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1122
- # place sentence of interest
1123
- messages.append({'role': 'user', 'content': sent['sentence_text']})
1124
-
1125
- if stream:
1126
- print(f"\n\n{Fore.GREEN}Sentence {i}: {Style.RESET_ALL}\n{sent['sentence_text']}\n")
1127
- if isinstance(self.context_sentences, int) and self.context_sentences > 0:
1128
- print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
1129
- print(f"{Fore.BLUE}Initial Output:{Style.RESET_ALL}")
1075
+ # simulate conversation where assistant confirms
1076
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
1077
+ # place unit of interest
1078
+ messages.append({'role': 'user', 'content': unit.text})
1130
1079
 
1131
- initial = self.inference_engine.chat(
1080
+
1081
+ yield f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n"
1082
+ if context != "":
1083
+ yield f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n"
1084
+
1085
+ yield f"{Fore.BLUE}Extraction:{Style.RESET_ALL}\n"
1086
+
1087
+ response_stream = self.inference_engine.chat(
1132
1088
  messages=messages,
1133
- max_new_tokens=max_new_tokens,
1134
- temperature=temperature,
1135
- stream=stream,
1136
- **kwrs
1089
+ stream=True
1137
1090
  )
1138
1091
 
1139
- # Review
1140
- if stream:
1141
- print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
1092
+ initial = ""
1093
+ for chunk in response_stream:
1094
+ initial += chunk
1095
+ yield chunk
1096
+
1097
+ # <--- Review step --->
1098
+ yield f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}"
1142
1099
 
1143
1100
  messages.append({'role': 'assistant', 'content': initial})
1144
1101
  messages.append({'role': 'user', 'content': self.review_prompt})
1145
1102
 
1146
- review = self.inference_engine.chat(
1103
+ response_stream = self.inference_engine.chat(
1147
1104
  messages=messages,
1148
- max_new_tokens=max_new_tokens,
1149
- temperature=temperature,
1150
- stream=stream,
1151
- **kwrs
1105
+ stream=True
1152
1106
  )
1153
1107
 
1154
- # Output
1155
- if self.review_mode == "revision":
1156
- gen_text = review
1157
- elif self.review_mode == "addition":
1158
- gen_text = initial + '\n' + review
1108
+ for chunk in response_stream:
1109
+ yield chunk
1159
1110
 
1160
- if return_messages_log:
1161
- messages.append({"role": "assistant", "content": review})
1162
- messages_log.append(messages)
1163
-
1164
- # add to output
1165
- output.append({'sentence_start': sent['start'],
1166
- 'sentence_end': sent['end'],
1167
- 'sentence_text': sent['sentence_text'],
1168
- 'gen_text': gen_text})
1169
-
1170
- if return_messages_log:
1171
- return output, messages_log
1172
-
1173
- return output
1174
-
1175
- async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
1176
- document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
1111
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
1112
+ concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
1177
1113
  """
1178
- The asynchronous version of the extract() method.
1114
+ This is the asynchronous version of the extract() method with the review step.
1179
1115
 
1180
1116
  Parameters:
1181
1117
  ----------
1182
1118
  text_content : Union[str, Dict[str,str]]
1183
- the input text content to put in prompt template.
1119
+ the input text content to put in prompt template.
1184
1120
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1185
1121
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1186
- max_new_tokens : str, Optional
1187
- the max number of new tokens LLM should generate.
1188
1122
  document_key : str, Optional
1189
- specify the key in text_content where document text is.
1123
+ specify the key in text_content where document text is.
1190
1124
  If text_content is str, this parameter will be ignored.
1191
- temperature : float, Optional
1192
- the temperature for token sampling.
1193
1125
  concurrent_batch_size : int, Optional
1194
- the number of sentences to process in concurrent.
1126
+ the batch size for concurrent processing.
1195
1127
  return_messages_log : bool, Optional
1196
- if True, a list of messages will be returned.
1128
+ if True, a list of messages will be returned, including review steps.
1197
1129
 
1198
- Return : str
1199
- the output from LLM. Need post-processing.
1130
+ Return : List[FrameExtractionUnitResult]
1131
+ the output from LLM for each unit after review. Contains the start, end, text, and generated text.
1200
1132
  """
1201
- # Check if self.inference_engine.chat_async() is implemented
1202
- if not hasattr(self.inference_engine, 'chat_async'):
1203
- raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1204
-
1205
- # define output
1206
- output = []
1207
- # sentence tokenization
1208
1133
  if isinstance(text_content, str):
1209
- sentences = self._get_sentences(text_content)
1134
+ doc_text = text_content
1210
1135
  elif isinstance(text_content, dict):
1211
1136
  if document_key is None:
1212
1137
  raise ValueError("document_key must be provided when text_content is dict.")
1213
- sentences = self._get_sentences(text_content[document_key])
1138
+ if document_key not in text_content:
1139
+ raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
1140
+ doc_text = text_content[document_key]
1141
+ else:
1142
+ raise TypeError("text_content must be a string or a dictionary.")
1214
1143
 
1215
- if return_messages_log:
1216
- messages_log = []
1144
+ units = self.unit_chunker.chunk(doc_text)
1217
1145
 
1218
- # generate initial outputs sentence by sentence
1219
- for i in range(0, len(sentences), concurrent_batch_size):
1220
- messages_list = []
1221
- init_tasks = []
1222
- review_tasks = []
1223
- batch = sentences[i:i + concurrent_batch_size]
1224
- for j, sent in enumerate(batch):
1225
- # construct chat messages
1226
- messages = []
1227
- if self.system_prompt:
1228
- messages.append({'role': 'system', 'content': self.system_prompt})
1146
+ # context chunker init
1147
+ self.context_chunker.fit(doc_text, units)
1229
1148
 
1230
- context = self._get_context_sentences(text_content, i + j, sentences, document_key)
1231
-
1232
- if self.context_sentences == 0:
1233
- # no context, just place sentence of interest
1234
- if isinstance(text_content, str):
1235
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1236
- else:
1237
- sentence_content = text_content.copy()
1238
- sentence_content[document_key] = sent['sentence_text']
1239
- messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
1149
+ # <--- Initial generation step --->
1150
+ initial_tasks_input = []
1151
+ for i, unit in enumerate(units):
1152
+ # construct chat messages for initial generation
1153
+ messages = []
1154
+ if self.system_prompt:
1155
+ messages.append({'role': 'system', 'content': self.system_prompt})
1156
+
1157
+ context = self.context_chunker.chunk(unit)
1158
+
1159
+ if context == "":
1160
+ # no context, just place unit in user prompt
1161
+ if isinstance(text_content, str):
1162
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
1240
1163
  else:
1241
- # insert context
1242
- if isinstance(text_content, str):
1243
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1244
- else:
1245
- context_content = text_content.copy()
1246
- context_content[document_key] = context
1247
- messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1248
- # simulate conversation
1249
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1250
- # place sentence of interest
1251
- messages.append({'role': 'user', 'content': sent['sentence_text']})
1252
-
1253
- messages_list.append(messages)
1254
-
1255
- task = asyncio.create_task(
1256
- self.inference_engine.chat_async(
1257
- messages=messages,
1258
- max_new_tokens=max_new_tokens,
1259
- temperature=temperature,
1260
- **kwrs
1261
- )
1164
+ unit_content = text_content.copy()
1165
+ unit_content[document_key] = unit.text
1166
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
1167
+ else:
1168
+ # insert context to user prompt
1169
+ if isinstance(text_content, str):
1170
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1171
+ else:
1172
+ context_content = text_content.copy()
1173
+ context_content[document_key] = context
1174
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1175
+ # simulate conversation where assistant confirms
1176
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
1177
+ # place unit of interest
1178
+ messages.append({'role': 'user', 'content': unit.text})
1179
+
1180
+ # Store unit and messages together for the initial task
1181
+ initial_tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
1182
+
1183
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
1184
+
1185
+ async def initial_semaphore_helper(task_data: Dict):
1186
+ unit = task_data["unit"]
1187
+ messages = task_data["messages"]
1188
+ original_index = task_data["original_index"]
1189
+
1190
+ async with semaphore:
1191
+ gen_text = await self.inference_engine.chat_async(
1192
+ messages=messages
1262
1193
  )
1263
- init_tasks.append(task)
1264
-
1265
- # Wait until the batch is done, collect results and move on to next batch
1266
- init_responses = await asyncio.gather(*init_tasks)
1267
- # Collect initials
1268
- initials = []
1269
- for gen_text, sent, messages in zip(init_responses, batch, messages_list):
1270
- initials.append({'sentence_start': sent['start'],
1271
- 'sentence_end': sent['end'],
1272
- 'sentence_text': sent['sentence_text'],
1273
- 'gen_text': gen_text,
1274
- 'messages': messages})
1275
-
1276
- # Review
1277
- for init in initials:
1278
- messages = init["messages"]
1279
- initial = init["gen_text"]
1280
- messages.append({'role': 'assistant', 'content': initial})
1281
- messages.append({'role': 'user', 'content': self.review_prompt})
1282
- task = asyncio.create_task(
1283
- self.inference_engine.chat_async(
1284
- messages=messages,
1285
- max_new_tokens=max_new_tokens,
1286
- temperature=temperature,
1287
- **kwrs
1288
- )
1289
- )
1290
- review_tasks.append(task)
1291
-
1292
- review_responses = await asyncio.gather(*review_tasks)
1293
-
1294
- # Collect reviews
1295
- reviews = []
1296
- for gen_text, sent in zip(review_responses, batch):
1297
- reviews.append({'sentence_start': sent['start'],
1298
- 'sentence_end': sent['end'],
1299
- 'sentence_text': sent['sentence_text'],
1300
- 'gen_text': gen_text})
1301
-
1302
- for init, rev in zip(initials, reviews):
1303
- if self.review_mode == "revision":
1304
- gen_text = rev['gen_text']
1305
- elif self.review_mode == "addition":
1306
- gen_text = init['gen_text'] + '\n' + rev['gen_text']
1194
+ # Return initial generation result along with the messages used and the unit
1195
+ return {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text, "initial_messages": messages}
1196
+
1197
+ # Create and gather initial generation tasks
1198
+ initial_tasks = [
1199
+ asyncio.create_task(initial_semaphore_helper(
1200
+ task_inp
1201
+ ))
1202
+ for task_inp in initial_tasks_input
1203
+ ]
1204
+
1205
+ initial_results_raw = await asyncio.gather(*initial_tasks)
1206
+
1207
+ # Sort initial results back into original order
1208
+ initial_results_raw.sort(key=lambda x: x["original_index"])
1209
+
1210
+ # <--- Review step --->
1211
+ review_tasks_input = []
1212
+ for result_data in initial_results_raw:
1213
+ # Prepare messages for the review step
1214
+ initial_messages = result_data["initial_messages"]
1215
+ initial_gen_text = result_data["initial_gen_text"]
1216
+ review_messages = initial_messages + [
1217
+ {'role': 'assistant', 'content': initial_gen_text},
1218
+ {'role': 'user', 'content': self.review_prompt}
1219
+ ]
1220
+ # Store data needed for review task
1221
+ review_tasks_input.append({
1222
+ "unit": result_data["unit"],
1223
+ "initial_gen_text": initial_gen_text,
1224
+ "messages": review_messages,
1225
+ "original_index": result_data["original_index"],
1226
+ "full_initial_log": initial_messages + [{'role': 'assistant', 'content': initial_gen_text}] if return_messages_log else None # Log up to initial generation
1227
+ })
1228
+
1229
+
1230
+ async def review_semaphore_helper(task_data: Dict, **kwrs):
1231
+ messages = task_data["messages"]
1232
+ original_index = task_data["original_index"]
1233
+
1234
+ async with semaphore:
1235
+ review_gen_text = await self.inference_engine.chat_async(
1236
+ messages=messages
1237
+ )
1238
+ # Combine initial and review results
1239
+ task_data["review_gen_text"] = review_gen_text
1240
+ if return_messages_log:
1241
+ # Log for the review call itself
1242
+ task_data["full_review_log"] = messages + [{'role': 'assistant', 'content': review_gen_text}]
1243
+ return task_data # Return the augmented dictionary
1307
1244
 
1308
- if return_messages_log:
1309
- messages = init["messages"]
1310
- messages.append({"role": "assistant", "content": rev['gen_text']})
1311
- messages_log.append(messages)
1245
+ # Create and gather review tasks
1246
+ review_tasks = [
1247
+ asyncio.create_task(review_semaphore_helper(
1248
+ task_inp
1249
+ ))
1250
+ for task_inp in review_tasks_input
1251
+ ]
1312
1252
 
1313
- # add to output
1314
- output.append({'sentence_start': init['sentence_start'],
1315
- 'sentence_end': init['sentence_end'],
1316
- 'sentence_text': init['sentence_text'],
1317
- 'gen_text': gen_text})
1318
-
1319
- if return_messages_log:
1253
+ final_results_raw = await asyncio.gather(*review_tasks)
1254
+
1255
+ # Sort final results back into original order (although gather might preserve order for tasks added sequentially)
1256
+ final_results_raw.sort(key=lambda x: x["original_index"])
1257
+
1258
+ # <--- Process final results --->
1259
+ output: List[FrameExtractionUnitResult] = []
1260
+ messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
1261
+
1262
+ for result_data in final_results_raw:
1263
+ unit = result_data["unit"]
1264
+ initial_gen = result_data["initial_gen_text"]
1265
+ review_gen = result_data["review_gen_text"]
1266
+
1267
+ # Combine based on review mode
1268
+ if self.review_mode == "revision":
1269
+ final_gen_text = review_gen
1270
+ elif self.review_mode == "addition":
1271
+ final_gen_text = initial_gen + '\n' + review_gen
1272
+ else: # Should not happen due to init check
1273
+ final_gen_text = review_gen # Default to revision if mode is somehow invalid
1274
+
1275
+ # Create final result object
1276
+ result = FrameExtractionUnitResult(
1277
+ start=unit.start,
1278
+ end=unit.end,
1279
+ text=unit.text,
1280
+ gen_text=final_gen_text # Use the combined/reviewed text
1281
+ )
1282
+ output.append(result)
1283
+
1284
+ # Append full conversation log if requested
1285
+ if return_messages_log:
1286
+ full_log_for_unit = result_data.get("full_initial_log", []) + [{'role': 'user', 'content': self.review_prompt}] + [{'role': 'assistant', 'content': review_gen}]
1287
+ messages_log.append(full_log_for_unit)
1288
+
1289
+ if return_messages_log:
1320
1290
  return output, messages_log
1321
- return output
1291
+ else:
1292
+ return output
1293
+
1294
+
1295
+ class BasicFrameExtractor(DirectFrameExtractor):
1296
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
1297
+ """
1298
+ This class diretly prompt LLM for frame extraction.
1299
+ Input system prompt (optional), prompt template (with instruction, few-shot examples),
1300
+ and specify a LLM.
1301
+
1302
+ Parameters:
1303
+ ----------
1304
+ inference_engine : InferenceEngine
1305
+ the LLM inferencing engine object. Must implements the chat() method.
1306
+ prompt_template : str
1307
+ prompt template with "{{<placeholder name>}}" placeholder.
1308
+ system_prompt : str, Optional
1309
+ system prompt.
1310
+ """
1311
+ super().__init__(inference_engine=inference_engine,
1312
+ unit_chunker=WholeDocumentUnitChunker(),
1313
+ prompt_template=prompt_template,
1314
+ system_prompt=system_prompt,
1315
+ context_chunker=NoContextChunker())
1316
+
1317
+ class BasicReviewFrameExtractor(ReviewFrameExtractor):
1318
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
1319
+ """
1320
+ This class add a review step after the BasicFrameExtractor.
1321
+ The Review process asks LLM to review its output and:
1322
+ 1. add more frames while keep current. This is efficient for boosting recall.
1323
+ 2. or, regenerate frames (add new and delete existing).
1324
+ Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
1325
+
1326
+ Parameters:
1327
+ ----------
1328
+ inference_engine : InferenceEngine
1329
+ the LLM inferencing engine object. Must implements the chat() method.
1330
+ prompt_template : str
1331
+ prompt template with "{{<placeholder name>}}" placeholder.
1332
+ review_prompt : str: Optional
1333
+ the prompt text that ask LLM to review. Specify addition or revision in the instruction.
1334
+ if not provided, a default review prompt will be used.
1335
+ review_mode : str
1336
+ review mode. Must be one of {"addition", "revision"}
1337
+ addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
1338
+ system_prompt : str, Optional
1339
+ system prompt.
1340
+ """
1341
+ super().__init__(inference_engine=inference_engine,
1342
+ unit_chunker=WholeDocumentUnitChunker(),
1343
+ prompt_template=prompt_template,
1344
+ review_mode=review_mode,
1345
+ review_prompt=review_prompt,
1346
+ system_prompt=system_prompt,
1347
+ context_chunker=NoContextChunker())
1322
1348
 
1323
1349
 
1324
- class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1325
- from nltk.tokenize.punkt import PunktSentenceTokenizer
1326
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
1327
- context_sentences:Union[str, int]="all", **kwrs):
1350
+ class SentenceFrameExtractor(DirectFrameExtractor):
1351
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
1352
+ context_sentences:Union[str, int]="all"):
1328
1353
  """
1329
- This class performs sentence-based Chain-of-thoughts (CoT) information extraction.
1330
- A simulated chat follows this process:
1354
+ This class performs sentence-by-sentence information extraction.
1355
+ The process is as follows:
1331
1356
  1. system prompt (optional)
1332
- 2. user instructions (schema, background, full text, few-shot example...)
1333
- 3. user input first sentence
1334
- 4. assistant analyze the sentence
1335
- 5. assistant extract outputs
1336
- 6. repeat #3, #4, #5
1357
+ 2. user prompt with instructions (schema, background, full text, few-shot example...)
1358
+ 3. feed a sentence (start with first sentence)
1359
+ 4. LLM extract entities and attributes from the sentence
1360
+ 5. iterate to the next sentence and repeat steps 3-4 until all sentences are processed.
1337
1361
 
1338
1362
  Input system prompt (optional), prompt template (with user instructions),
1339
1363
  and specify a LLM.
1340
1364
 
1341
- Parameters
1365
+ Parameters:
1342
1366
  ----------
1343
1367
  inference_engine : InferenceEngine
1344
1368
  the LLM inferencing engine object. Must implements the chat() method.
@@ -1354,108 +1378,79 @@ class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1354
1378
  if > 0, the number of sentences before and after the given sentence to provide as context.
1355
1379
  This is good for tasks that require context beyond the given sentence.
1356
1380
  """
1357
- super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
1358
- system_prompt=system_prompt, context_sentences=context_sentences, **kwrs)
1381
+ if not isinstance(context_sentences, int) and context_sentences != "all":
1382
+ raise ValueError('context_sentences must be an integer (>= 0) or "all".')
1383
+
1384
+ if isinstance(context_sentences, int) and context_sentences < 0:
1385
+ raise ValueError("context_sentences must be a positive integer.")
1386
+
1387
+ if isinstance(context_sentences, int):
1388
+ context_chunker = SlideWindowContextChunker(window_size=context_sentences)
1389
+ elif context_sentences == "all":
1390
+ context_chunker = WholeDocumentContextChunker()
1391
+
1392
+ super().__init__(inference_engine=inference_engine,
1393
+ unit_chunker=SentenceUnitChunker(),
1394
+ prompt_template=prompt_template,
1395
+ system_prompt=system_prompt,
1396
+ context_chunker=context_chunker)
1359
1397
 
1360
1398
 
1361
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
1362
- document_key:str=None, temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
1399
+ class SentenceReviewFrameExtractor(ReviewFrameExtractor):
1400
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
1401
+ review_mode:str, review_prompt:str=None, system_prompt:str=None,
1402
+ context_sentences:Union[str, int]="all"):
1363
1403
  """
1364
- This method inputs a text and outputs a list of outputs per sentence.
1404
+ This class adds a review step after the SentenceFrameExtractor.
1405
+ For each sentence, the review process asks LLM to review its output and:
1406
+ 1. add more frames while keeping current. This is efficient for boosting recall.
1407
+ 2. or, regenerate frames (add new and delete existing).
1408
+ Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
1365
1409
 
1366
1410
  Parameters:
1367
1411
  ----------
1368
- text_content : Union[str, Dict[str,str]]
1369
- the input text content to put in prompt template.
1370
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1371
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1372
- max_new_tokens : str, Optional
1373
- the max number of new tokens LLM should generate.
1374
- document_key : str, Optional
1375
- specify the key in text_content where document text is.
1376
- If text_content is str, this parameter will be ignored.
1377
- temperature : float, Optional
1378
- the temperature for token sampling.
1379
- stream : bool, Optional
1380
- if True, LLM generated text will be printed in terminal in real-time.
1381
- return_messages_log : bool, Optional
1382
- if True, a list of messages will be returned.
1383
-
1384
- Return : str
1385
- the output from LLM. Need post-processing.
1412
+ inference_engine : InferenceEngine
1413
+ the LLM inferencing engine object. Must implements the chat() method.
1414
+ prompt_template : str
1415
+ prompt template with "{{<placeholder name>}}" placeholder.
1416
+ review_prompt : str: Optional
1417
+ the prompt text that ask LLM to review. Specify addition or revision in the instruction.
1418
+ if not provided, a default review prompt will be used.
1419
+ review_mode : str
1420
+ review mode. Must be one of {"addition", "revision"}
1421
+ addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
1422
+ system_prompt : str, Optional
1423
+ system prompt.
1424
+ context_sentences : Union[str, int], Optional
1425
+ number of sentences before and after the given sentence to provide additional context.
1426
+ if "all", the full text will be provided in the prompt as context.
1427
+ if 0, no additional context will be provided.
1428
+ This is good for tasks that does not require context beyond the given sentence.
1429
+ if > 0, the number of sentences before and after the given sentence to provide as context.
1430
+ This is good for tasks that require context beyond the given sentence.
1386
1431
  """
1387
- # define output
1388
- output = []
1389
- # sentence tokenization
1390
- if isinstance(text_content, str):
1391
- sentences = self._get_sentences(text_content)
1392
- elif isinstance(text_content, dict):
1393
- sentences = self._get_sentences(text_content[document_key])
1394
-
1395
- if return_messages_log:
1396
- messages_log = []
1397
-
1398
- # generate sentence by sentence
1399
- for i, sent in enumerate(sentences):
1400
- # construct chat messages
1401
- messages = []
1402
- if self.system_prompt:
1403
- messages.append({'role': 'system', 'content': self.system_prompt})
1404
-
1405
- context = self._get_context_sentences(text_content, i, sentences, document_key)
1406
-
1407
- if self.context_sentences == 0:
1408
- # no context, just place sentence of interest
1409
- if isinstance(text_content, str):
1410
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1411
- else:
1412
- sentence_content = text_content.copy()
1413
- sentence_content[document_key] = sent['sentence_text']
1414
- messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
1415
- else:
1416
- # insert context
1417
- if isinstance(text_content, str):
1418
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1419
- else:
1420
- context_content = text_content.copy()
1421
- context_content[document_key] = context
1422
- messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1423
- # simulate conversation
1424
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1425
- # place sentence of interest
1426
- messages.append({'role': 'user', 'content': sent['sentence_text']})
1427
-
1428
- if stream:
1429
- print(f"\n\n{Fore.GREEN}Sentence: {Style.RESET_ALL}\n{sent['sentence_text']}\n")
1430
- if isinstance(self.context_sentences, int) and self.context_sentences > 0:
1431
- print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
1432
- print(f"{Fore.BLUE}CoT:{Style.RESET_ALL}")
1433
-
1434
- gen_text = self.inference_engine.chat(
1435
- messages=messages,
1436
- max_new_tokens=max_new_tokens,
1437
- temperature=temperature,
1438
- stream=stream,
1439
- **kwrs
1440
- )
1441
-
1442
- if return_messages_log:
1443
- messages.append({"role": "assistant", "content": gen_text})
1444
- messages_log.append(messages)
1445
-
1446
- # add to output
1447
- output.append({'sentence_start': sent['start'],
1448
- 'sentence_end': sent['end'],
1449
- 'sentence_text': sent['sentence_text'],
1450
- 'gen_text': gen_text})
1432
+ if not isinstance(context_sentences, int) and context_sentences != "all":
1433
+ raise ValueError('context_sentences must be an integer (>= 0) or "all".')
1451
1434
 
1452
- if return_messages_log:
1453
- return output, messages_log
1454
- return output
1435
+ if isinstance(context_sentences, int) and context_sentences < 0:
1436
+ raise ValueError("context_sentences must be a positive integer.")
1437
+
1438
+ if isinstance(context_sentences, int):
1439
+ context_chunker = SlideWindowContextChunker(window_size=context_sentences)
1440
+ elif context_sentences == "all":
1441
+ context_chunker = WholeDocumentContextChunker()
1442
+
1443
+ super().__init__(inference_engine=inference_engine,
1444
+ unit_chunker=SentenceUnitChunker(),
1445
+ prompt_template=prompt_template,
1446
+ review_mode=review_mode,
1447
+ review_prompt=review_prompt,
1448
+ system_prompt=system_prompt,
1449
+ context_chunker=context_chunker)
1455
1450
 
1456
1451
 
1457
1452
  class RelationExtractor(Extractor):
1458
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
1453
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
1459
1454
  """
1460
1455
  This is the abstract class for relation extraction.
1461
1456
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -1471,8 +1466,7 @@ class RelationExtractor(Extractor):
1471
1466
  """
1472
1467
  super().__init__(inference_engine=inference_engine,
1473
1468
  prompt_template=prompt_template,
1474
- system_prompt=system_prompt,
1475
- **kwrs)
1469
+ system_prompt=system_prompt)
1476
1470
 
1477
1471
  def _get_ROI(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
1478
1472
  text:str, buffer_size:int=100) -> str:
@@ -1548,7 +1542,7 @@ class RelationExtractor(Extractor):
1548
1542
 
1549
1543
  class BinaryRelationExtractor(RelationExtractor):
1550
1544
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_func: Callable,
1551
- system_prompt:str=None, **kwrs):
1545
+ system_prompt:str=None):
1552
1546
  """
1553
1547
  This class extracts binary (yes/no) relations between two entities.
1554
1548
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -1566,8 +1560,7 @@ class BinaryRelationExtractor(RelationExtractor):
1566
1560
  """
1567
1561
  super().__init__(inference_engine=inference_engine,
1568
1562
  prompt_template=prompt_template,
1569
- system_prompt=system_prompt,
1570
- **kwrs)
1563
+ system_prompt=system_prompt)
1571
1564
 
1572
1565
  if possible_relation_func:
1573
1566
  # Check if possible_relation_func is a function
@@ -1607,8 +1600,8 @@ class BinaryRelationExtractor(RelationExtractor):
1607
1600
  return False
1608
1601
 
1609
1602
 
1610
- def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1611
- temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1603
+ def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, verbose:bool=False,
1604
+ return_messages_log:bool=False) -> List[Dict]:
1612
1605
  """
1613
1606
  This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1614
1607
  Outputs pairs that are related.
@@ -1619,11 +1612,7 @@ class BinaryRelationExtractor(RelationExtractor):
1619
1612
  a document with frames.
1620
1613
  buffer_size : int, Optional
1621
1614
  the number of characters before and after the two frames in the ROI text.
1622
- max_new_tokens : str, Optional
1623
- the max number of new tokens LLM should generate.
1624
- temperature : float, Optional
1625
- the temperature for token sampling.
1626
- stream : bool, Optional
1615
+ verbose : bool, Optional
1627
1616
  if True, LLM generated text will be printed in terminal in real-time.
1628
1617
  return_messages_log : bool, Optional
1629
1618
  if True, a list of messages will be returned.
@@ -1642,7 +1631,7 @@ class BinaryRelationExtractor(RelationExtractor):
1642
1631
 
1643
1632
  if pos_rel:
1644
1633
  roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1645
- if stream:
1634
+ if verbose:
1646
1635
  print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1647
1636
  print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1648
1637
  messages = []
@@ -1656,10 +1645,7 @@ class BinaryRelationExtractor(RelationExtractor):
1656
1645
 
1657
1646
  gen_text = self.inference_engine.chat(
1658
1647
  messages=messages,
1659
- max_new_tokens=max_new_tokens,
1660
- temperature=temperature,
1661
- stream=stream,
1662
- **kwrs
1648
+ verbose=verbose
1663
1649
  )
1664
1650
  rel_json = self._extract_json(gen_text)
1665
1651
  if self._post_process(rel_json):
@@ -1674,8 +1660,8 @@ class BinaryRelationExtractor(RelationExtractor):
1674
1660
  return output
1675
1661
 
1676
1662
 
1677
- async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1678
- temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1663
+ async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100,
1664
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[Dict]:
1679
1665
  """
1680
1666
  This is the asynchronous version of the extract() method.
1681
1667
 
@@ -1730,10 +1716,7 @@ class BinaryRelationExtractor(RelationExtractor):
1730
1716
 
1731
1717
  task = asyncio.create_task(
1732
1718
  self.inference_engine.chat_async(
1733
- messages=messages,
1734
- max_new_tokens=max_new_tokens,
1735
- temperature=temperature,
1736
- **kwrs
1719
+ messages=messages
1737
1720
  )
1738
1721
  )
1739
1722
  tasks.append(task)
@@ -1755,9 +1738,9 @@ class BinaryRelationExtractor(RelationExtractor):
1755
1738
  return output
1756
1739
 
1757
1740
 
1758
- def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1759
- temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
1760
- stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1741
+ def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100,
1742
+ concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
1743
+ return_messages_log:bool=False) -> List[Dict]:
1761
1744
  """
1762
1745
  This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1763
1746
 
@@ -1767,15 +1750,11 @@ class BinaryRelationExtractor(RelationExtractor):
1767
1750
  a document with frames.
1768
1751
  buffer_size : int, Optional
1769
1752
  the number of characters before and after the two frames in the ROI text.
1770
- max_new_tokens : str, Optional
1771
- the max number of new tokens LLM should generate.
1772
- temperature : float, Optional
1773
- the temperature for token sampling.
1774
1753
  concurrent: bool, Optional
1775
1754
  if True, the extraction will be done in concurrent.
1776
1755
  concurrent_batch_size : int, Optional
1777
1756
  the number of frame pairs to process in concurrent.
1778
- stream : bool, Optional
1757
+ verbose : bool, Optional
1779
1758
  if True, LLM generated text will be printed in terminal in real-time.
1780
1759
  return_messages_log : bool, Optional
1781
1760
  if True, a list of messages will be returned.
@@ -1790,31 +1769,25 @@ class BinaryRelationExtractor(RelationExtractor):
1790
1769
  raise ValueError("All frame_ids in the input document must be unique.")
1791
1770
 
1792
1771
  if concurrent:
1793
- if stream:
1772
+ if verbose:
1794
1773
  warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
1795
1774
 
1796
1775
  nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
1797
1776
  return asyncio.run(self.extract_async(doc=doc,
1798
1777
  buffer_size=buffer_size,
1799
- max_new_tokens=max_new_tokens,
1800
- temperature=temperature,
1801
1778
  concurrent_batch_size=concurrent_batch_size,
1802
- return_messages_log=return_messages_log,
1803
- **kwrs)
1779
+ return_messages_log=return_messages_log)
1804
1780
  )
1805
1781
  else:
1806
1782
  return self.extract(doc=doc,
1807
1783
  buffer_size=buffer_size,
1808
- max_new_tokens=max_new_tokens,
1809
- temperature=temperature,
1810
- stream=stream,
1811
- return_messages_log=return_messages_log,
1812
- **kwrs)
1784
+ verbose=verbose,
1785
+ return_messages_log=return_messages_log)
1813
1786
 
1814
1787
 
1815
1788
  class MultiClassRelationExtractor(RelationExtractor):
1816
1789
  def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_types_func: Callable,
1817
- system_prompt:str=None, **kwrs):
1790
+ system_prompt:str=None):
1818
1791
  """
1819
1792
  This class extracts relations with relation types.
1820
1793
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -1833,8 +1806,7 @@ class MultiClassRelationExtractor(RelationExtractor):
1833
1806
  """
1834
1807
  super().__init__(inference_engine=inference_engine,
1835
1808
  prompt_template=prompt_template,
1836
- system_prompt=system_prompt,
1837
- **kwrs)
1809
+ system_prompt=system_prompt)
1838
1810
 
1839
1811
  if possible_relation_types_func:
1840
1812
  # Check if possible_relation_types_func is a function
@@ -1881,8 +1853,7 @@ class MultiClassRelationExtractor(RelationExtractor):
1881
1853
  return None
1882
1854
 
1883
1855
 
1884
- def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1885
- temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1856
+ def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, verbose:bool=False, return_messages_log:bool=False) -> List[Dict]:
1886
1857
  """
1887
1858
  This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1888
1859
 
@@ -1915,7 +1886,7 @@ class MultiClassRelationExtractor(RelationExtractor):
1915
1886
 
1916
1887
  if pos_rel_types:
1917
1888
  roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
1918
- if stream:
1889
+ if verbose:
1919
1890
  print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
1920
1891
  print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1921
1892
  messages = []
@@ -1930,10 +1901,8 @@ class MultiClassRelationExtractor(RelationExtractor):
1930
1901
 
1931
1902
  gen_text = self.inference_engine.chat(
1932
1903
  messages=messages,
1933
- max_new_tokens=max_new_tokens,
1934
- temperature=temperature,
1935
- stream=stream,
1936
- **kwrs
1904
+ stream=False,
1905
+ verbose=verbose
1937
1906
  )
1938
1907
 
1939
1908
  if return_messages_log:
@@ -1950,8 +1919,8 @@ class MultiClassRelationExtractor(RelationExtractor):
1950
1919
  return output
1951
1920
 
1952
1921
 
1953
- async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1954
- temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1922
+ async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100,
1923
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[Dict]:
1955
1924
  """
1956
1925
  This is the asynchronous version of the extract() method.
1957
1926
 
@@ -2006,10 +1975,7 @@ class MultiClassRelationExtractor(RelationExtractor):
2006
1975
  )})
2007
1976
  task = asyncio.create_task(
2008
1977
  self.inference_engine.chat_async(
2009
- messages=messages,
2010
- max_new_tokens=max_new_tokens,
2011
- temperature=temperature,
2012
- **kwrs
1978
+ messages=messages
2013
1979
  )
2014
1980
  )
2015
1981
  tasks.append(task)
@@ -2032,9 +1998,9 @@ class MultiClassRelationExtractor(RelationExtractor):
2032
1998
  return output
2033
1999
 
2034
2000
 
2035
- def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
2036
- temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
2037
- stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
2001
+ def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100,
2002
+ concurrent:bool=False, concurrent_batch_size:int=32,
2003
+ verbose:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
2038
2004
  """
2039
2005
  This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
2040
2006
 
@@ -2067,24 +2033,18 @@ class MultiClassRelationExtractor(RelationExtractor):
2067
2033
  raise ValueError("All frame_ids in the input document must be unique.")
2068
2034
 
2069
2035
  if concurrent:
2070
- if stream:
2036
+ if verbose:
2071
2037
  warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
2072
2038
 
2073
2039
  nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
2074
2040
  return asyncio.run(self.extract_async(doc=doc,
2075
2041
  buffer_size=buffer_size,
2076
- max_new_tokens=max_new_tokens,
2077
- temperature=temperature,
2078
2042
  concurrent_batch_size=concurrent_batch_size,
2079
- return_messages_log=return_messages_log,
2080
- **kwrs)
2043
+ return_messages_log=return_messages_log)
2081
2044
  )
2082
2045
  else:
2083
2046
  return self.extract(doc=doc,
2084
2047
  buffer_size=buffer_size,
2085
- max_new_tokens=max_new_tokens,
2086
- temperature=temperature,
2087
- stream=stream,
2088
- return_messages_log=return_messages_log,
2089
- **kwrs)
2048
+ verbose=verbose,
2049
+ return_messages_log=return_messages_log)
2090
2050