llm-ie 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -8,10 +8,12 @@ import warnings
8
8
  import itertools
9
9
  import asyncio
10
10
  import nest_asyncio
11
- from typing import Set, List, Dict, Tuple, Union, Callable
12
- from llm_ie.data_types import LLMInformationExtractionFrame, LLMInformationExtractionDocument
11
+ from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional
12
+ from llm_ie.data_types import FrameExtractionUnit, FrameExtractionUnitResult, LLMInformationExtractionFrame, LLMInformationExtractionDocument
13
+ from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
14
+ from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
13
15
  from llm_ie.engines import InferenceEngine
14
- from colorama import Fore, Style
16
+ from colorama import Fore, Style
15
17
 
16
18
 
17
19
  class Extractor:
@@ -38,15 +40,46 @@ class Extractor:
38
40
  def get_prompt_guide(cls) -> str:
39
41
  """
40
42
  This method returns the pre-defined prompt guideline for the extractor from the package asset.
43
+ It searches for a guide specific to the current class first, if not found, it will search
44
+ for the guide in its ancestors by traversing the class's method resolution order (MRO).
41
45
  """
42
- # Check if the prompt guide is available
43
- file_path = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{cls.__name__}_prompt_guide.txt")
44
- try:
45
- with open(file_path, 'r', encoding="utf-8") as f:
46
- return f.read()
47
- except FileNotFoundError:
48
- warnings.warn(f"Prompt guide for {cls.__name__} is not available. Is it a customed extractor?", UserWarning)
49
- return None
46
+ original_class_name = cls.__name__
47
+
48
+ for current_class_in_mro in cls.__mro__:
49
+ if current_class_in_mro is object:
50
+ continue
51
+
52
+ current_class_name = current_class_in_mro.__name__
53
+
54
+ try:
55
+ file_path_obj = importlib.resources.files('llm_ie.asset.prompt_guide').joinpath(f"{current_class_name}_prompt_guide.txt")
56
+
57
+ with open(file_path_obj, 'r', encoding="utf-8") as f:
58
+ prompt_content = f.read()
59
+ # If the guide was found for an ancestor, not the original class, issue a warning.
60
+ if cls is not current_class_in_mro:
61
+ warnings.warn(
62
+ f"Prompt guide for '{original_class_name}' not found. "
63
+ f"Using guide from ancestor: '{current_class_name}_prompt_guide.txt'.",
64
+ UserWarning
65
+ )
66
+ return prompt_content
67
+ except FileNotFoundError:
68
+ pass
69
+
70
+ except Exception as e:
71
+ warnings.warn(
72
+ f"Error attempting to read prompt guide for '{current_class_name}' "
73
+ f"from '{str(file_path_obj)}': {e}. Trying next in MRO.",
74
+ UserWarning
75
+ )
76
+ continue
77
+
78
+ # If the loop completes, no prompt guide was found for the original class or any of its ancestors.
79
+ raise FileNotFoundError(
80
+ f"Prompt guide for '{original_class_name}' not found in the package asset. "
81
+ f"Is it a custom extractor?"
82
+ )
50
83
 
51
84
  def _get_user_prompt(self, text_content:Union[str, Dict[str,str]]) -> str:
52
85
  """
@@ -138,7 +171,8 @@ class Extractor:
138
171
 
139
172
  class FrameExtractor(Extractor):
140
173
  from nltk.tokenize import RegexpTokenizer
141
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
174
+ def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
175
+ prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None, **kwrs):
142
176
  """
143
177
  This is the abstract class for frame extraction.
144
178
  Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
@@ -147,15 +181,25 @@ class FrameExtractor(Extractor):
147
181
  ----------
148
182
  inference_engine : InferenceEngine
149
183
  the LLM inferencing engine object. Must implements the chat() method.
184
+ unit_chunker : UnitChunker
185
+ the unit chunker object that determines how to chunk the document text into units.
150
186
  prompt_template : str
151
187
  prompt template with "{{<placeholder name>}}" placeholder.
152
188
  system_prompt : str, Optional
153
189
  system prompt.
190
+ context_chunker : ContextChunker
191
+ the context chunker object that determines how to get context for each unit.
154
192
  """
155
193
  super().__init__(inference_engine=inference_engine,
156
194
  prompt_template=prompt_template,
157
195
  system_prompt=system_prompt,
158
196
  **kwrs)
197
+
198
+ self.unit_chunker = unit_chunker
199
+ if context_chunker is None:
200
+ self.context_chunker = NoContextChunker()
201
+ else:
202
+ self.context_chunker = context_chunker
159
203
 
160
204
  self.tokenizer = self.RegexpTokenizer(r'\w+|[^\w\s]')
161
205
 
@@ -288,7 +332,7 @@ class FrameExtractor(Extractor):
288
332
  return entity_spans
289
333
 
290
334
  @abc.abstractmethod
291
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, **kwrs) -> str:
335
+ def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, return_messages_log:bool=False, **kwrs) -> str:
292
336
  """
293
337
  This method inputs text content and outputs a string generated by LLM
294
338
 
@@ -300,6 +344,8 @@ class FrameExtractor(Extractor):
300
344
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
301
345
  max_new_tokens : str, Optional
302
346
  the max number of new tokens LLM can generate.
347
+ return_messages_log : bool, Optional
348
+ if True, a list of messages will be returned.
303
349
 
304
350
  Return : str
305
351
  the output from LLM. Need post-processing.
@@ -309,7 +355,7 @@ class FrameExtractor(Extractor):
309
355
 
310
356
  @abc.abstractmethod
311
357
  def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
312
- document_key:str=None, **kwrs) -> List[LLMInformationExtractionFrame]:
358
+ document_key:str=None, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
313
359
  """
314
360
  This method inputs text content and outputs a list of LLMInformationExtractionFrame
315
361
  It use the extract() method and post-process outputs into frames.
@@ -327,6 +373,8 @@ class FrameExtractor(Extractor):
327
373
  document_key : str, Optional
328
374
  specify the key in text_content where document text is.
329
375
  If text_content is str, this parameter will be ignored.
376
+ return_messages_log : bool, Optional
377
+ if True, a list of messages will be returned.
330
378
 
331
379
  Return : str
332
380
  a list of frames.
@@ -334,332 +382,38 @@ class FrameExtractor(Extractor):
334
382
  return NotImplemented
335
383
 
336
384
 
337
- class BasicFrameExtractor(FrameExtractor):
338
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
385
+ class DirectFrameExtractor(FrameExtractor):
386
+ def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
387
+ prompt_template:str, system_prompt:str=None, context_chunker:ContextChunker=None, **kwrs):
339
388
  """
340
- This class diretly prompt LLM for frame extraction.
341
- Input system prompt (optional), prompt template (with instruction, few-shot examples),
342
- and specify a LLM.
389
+ This class is for general unit-context frame extraction.
390
+ Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
343
391
 
344
392
  Parameters:
345
393
  ----------
346
394
  inference_engine : InferenceEngine
347
395
  the LLM inferencing engine object. Must implements the chat() method.
396
+ unit_chunker : UnitChunker
397
+ the unit chunker object that determines how to chunk the document text into units.
348
398
  prompt_template : str
349
399
  prompt template with "{{<placeholder name>}}" placeholder.
350
400
  system_prompt : str, Optional
351
401
  system prompt.
402
+ context_chunker : ContextChunker
403
+ the context chunker object that determines how to get context for each unit.
352
404
  """
353
- super().__init__(inference_engine=inference_engine,
354
- prompt_template=prompt_template,
355
- system_prompt=system_prompt,
405
+ super().__init__(inference_engine=inference_engine,
406
+ unit_chunker=unit_chunker,
407
+ prompt_template=prompt_template,
408
+ system_prompt=system_prompt,
409
+ context_chunker=context_chunker,
356
410
  **kwrs)
357
-
358
-
359
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
360
- temperature:float=0.0, stream:bool=False, **kwrs) -> str:
361
- """
362
- This method inputs a text and outputs a string generated by LLM.
363
-
364
- Parameters:
365
- ----------
366
- text_content : Union[str, Dict[str,str]]
367
- the input text content to put in prompt template.
368
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
369
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
370
- max_new_tokens : str, Optional
371
- the max number of new tokens LLM can generate.
372
- temperature : float, Optional
373
- the temperature for token sampling.
374
- stream : bool, Optional
375
- if True, LLM generated text will be printed in terminal in real-time.
376
-
377
- Return : str
378
- the output from LLM. Need post-processing.
379
- """
380
- messages = []
381
- if self.system_prompt:
382
- messages.append({'role': 'system', 'content': self.system_prompt})
383
-
384
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
385
- response = self.inference_engine.chat(
386
- messages=messages,
387
- max_new_tokens=max_new_tokens,
388
- temperature=temperature,
389
- stream=stream,
390
- **kwrs
391
- )
392
-
393
- return response
394
-
395
-
396
- def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
397
- temperature:float=0.0, document_key:str=None, stream:bool=False,
398
- case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2,
399
- fuzzy_score_cutoff:float=0.8, allow_overlap_entities:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
400
- """
401
- This method inputs a text and outputs a list of LLMInformationExtractionFrame
402
- It use the extract() method and post-process outputs into frames.
403
-
404
- Parameters:
405
- ----------
406
- text_content : Union[str, Dict[str,str]]
407
- the input text content to put in prompt template.
408
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
409
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
410
- entity_key : str
411
- the key (in ouptut JSON) for entity text. Any extraction that does not include entity key will be dropped.
412
- max_new_tokens : str, Optional
413
- the max number of new tokens LLM should generate.
414
- temperature : float, Optional
415
- the temperature for token sampling.
416
- document_key : str, Optional
417
- specify the key in text_content where document text is.
418
- If text_content is str, this parameter will be ignored.
419
- stream : bool, Optional
420
- if True, LLM generated text will be printed in terminal in real-time.
421
- case_sensitive : bool, Optional
422
- if True, entity text matching will be case-sensitive.
423
- fuzzy_match : bool, Optional
424
- if True, fuzzy matching will be applied to find entity text.
425
- fuzzy_buffer_size : float, Optional
426
- the buffer size for fuzzy matching. Default is 20% of entity text length.
427
- fuzzy_score_cutoff : float, Optional
428
- the Jaccard score cutoff for fuzzy matching.
429
- Matched entity text must have a score higher than this value or a None will be returned.
430
- allow_overlap_entities : bool, Optional
431
- if True, entities can overlap in the text.
432
- Note that this can cause multiple frames to be generated on the same entity span if they have same entity text.
433
-
434
- Return : str
435
- a list of frames.
436
- """
437
- if isinstance(text_content, str):
438
- text = text_content
439
- elif isinstance(text_content, dict):
440
- if document_key is None:
441
- raise ValueError("document_key must be provided when text_content is dict.")
442
- text = text_content[document_key]
443
-
444
- frame_list = []
445
- gen_text = self.extract(text_content=text_content,
446
- max_new_tokens=max_new_tokens,
447
- temperature=temperature,
448
- stream=stream,
449
- **kwrs)
450
-
451
- entity_json = []
452
- for entity in self._extract_json(gen_text=gen_text):
453
- if entity_key in entity:
454
- entity_json.append(entity)
455
- else:
456
- warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{entity_key}"). This frame will be dropped.', RuntimeWarning)
457
-
458
- spans = self._find_entity_spans(text=text,
459
- entities=[e[entity_key] for e in entity_json],
460
- case_sensitive=case_sensitive,
461
- fuzzy_match=fuzzy_match,
462
- fuzzy_buffer_size=fuzzy_buffer_size,
463
- fuzzy_score_cutoff=fuzzy_score_cutoff,
464
- allow_overlap_entities=allow_overlap_entities)
465
-
466
- for i, (ent, span) in enumerate(zip(entity_json, spans)):
467
- if span is not None:
468
- start, end = span
469
- frame = LLMInformationExtractionFrame(frame_id=f"{i}",
470
- start=start,
471
- end=end,
472
- entity_text=text[start:end],
473
- attr={k: v for k, v in ent.items() if k != entity_key and v != ""})
474
- frame_list.append(frame)
475
- return frame_list
476
-
477
411
 
478
- class ReviewFrameExtractor(BasicFrameExtractor):
479
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
480
- review_mode:str, review_prompt:str=None,system_prompt:str=None, **kwrs):
481
- """
482
- This class add a review step after the BasicFrameExtractor.
483
- The Review process asks LLM to review its output and:
484
- 1. add more frames while keep current. This is efficient for boosting recall.
485
- 2. or, regenerate frames (add new and delete existing).
486
- Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
487
-
488
- Parameters:
489
- ----------
490
- inference_engine : InferenceEngine
491
- the LLM inferencing engine object. Must implements the chat() method.
492
- prompt_template : str
493
- prompt template with "{{<placeholder name>}}" placeholder.
494
- review_prompt : str: Optional
495
- the prompt text that ask LLM to review. Specify addition or revision in the instruction.
496
- if not provided, a default review prompt will be used.
497
- review_mode : str
498
- review mode. Must be one of {"addition", "revision"}
499
- addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
500
- system_prompt : str, Optional
501
- system prompt.
502
- """
503
- super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
504
- system_prompt=system_prompt, **kwrs)
505
- if review_mode not in {"addition", "revision"}:
506
- raise ValueError('review_mode must be one of {"addition", "revision"}.')
507
- self.review_mode = review_mode
508
-
509
- if review_prompt:
510
- self.review_prompt = review_prompt
511
- else:
512
- file_path = importlib.resources.files('llm_ie.asset.default_prompts').\
513
- joinpath(f"{self.__class__.__name__}_{self.review_mode}_review_prompt.txt")
514
- with open(file_path, 'r', encoding="utf-8") as f:
515
- self.review_prompt = f.read()
516
-
517
- warnings.warn(f'Custom review prompt not provided. The default review prompt is used:\n"{self.review_prompt}"', UserWarning)
518
-
519
-
520
- def extract(self, text_content:Union[str, Dict[str,str]],
521
- max_new_tokens:int=4096, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
522
- """
523
- This method inputs a text and outputs a string generated by LLM.
524
-
525
- Parameters:
526
- ----------
527
- text_content : Union[str, Dict[str,str]]
528
- the input text content to put in prompt template.
529
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
530
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
531
- max_new_tokens : str, Optional
532
- the max number of new tokens LLM can generate.
533
- temperature : float, Optional
534
- the temperature for token sampling.
535
- stream : bool, Optional
536
- if True, LLM generated text will be printed in terminal in real-time.
537
-
538
- Return : str
539
- the output from LLM. Need post-processing.
540
- """
541
- messages = []
542
- if self.system_prompt:
543
- messages.append({'role': 'system', 'content': self.system_prompt})
544
-
545
- messages.append({'role': 'user', 'content': self._get_user_prompt(text_content)})
546
- # Initial output
547
- if stream:
548
- print(f"{Fore.BLUE}Initial Output:{Style.RESET_ALL}")
549
-
550
- initial = self.inference_engine.chat(
551
- messages=messages,
552
- max_new_tokens=max_new_tokens,
553
- temperature=temperature,
554
- stream=stream,
555
- **kwrs
556
- )
557
-
558
- # Review
559
- messages.append({'role': 'assistant', 'content': initial})
560
- messages.append({'role': 'user', 'content': self.review_prompt})
561
-
562
- if stream:
563
- print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
564
- review = self.inference_engine.chat(
565
- messages=messages,
566
- max_new_tokens=max_new_tokens,
567
- temperature=temperature,
568
- stream=stream,
569
- **kwrs
570
- )
571
-
572
- # Output
573
- if self.review_mode == "revision":
574
- return review
575
- elif self.review_mode == "addition":
576
- return initial + '\n' + review
577
-
578
-
579
- class SentenceFrameExtractor(FrameExtractor):
580
- from nltk.tokenize.punkt import PunktSentenceTokenizer
581
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
582
- context_sentences:Union[str, int]="all", **kwrs):
583
- """
584
- This class performs sentence-by-sentence information extraction.
585
- The process is as follows:
586
- 1. system prompt (optional)
587
- 2. user prompt with instructions (schema, background, full text, few-shot example...)
588
- 3. feed a sentence (start with first sentence)
589
- 4. LLM extract entities and attributes from the sentence
590
- 5. repeat #3 and #4
591
-
592
- Input system prompt (optional), prompt template (with user instructions),
593
- and specify a LLM.
594
-
595
- Parameters:
596
- ----------
597
- inference_engine : InferenceEngine
598
- the LLM inferencing engine object. Must implements the chat() method.
599
- prompt_template : str
600
- prompt template with "{{<placeholder name>}}" placeholder.
601
- system_prompt : str, Optional
602
- system prompt.
603
- context_sentences : Union[str, int], Optional
604
- number of sentences before and after the given sentence to provide additional context.
605
- if "all", the full text will be provided in the prompt as context.
606
- if 0, no additional context will be provided.
607
- This is good for tasks that does not require context beyond the given sentence.
608
- if > 0, the number of sentences before and after the given sentence to provide as context.
609
- This is good for tasks that require context beyond the given sentence.
610
- """
611
- super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
612
- system_prompt=system_prompt, **kwrs)
613
-
614
- if not isinstance(context_sentences, int) and context_sentences != "all":
615
- raise ValueError('context_sentences must be an integer (>= 0) or "all".')
616
-
617
- if isinstance(context_sentences, int) and context_sentences < 0:
618
- raise ValueError("context_sentences must be a positive integer.")
619
-
620
- self.context_sentences =context_sentences
621
-
622
-
623
- def _get_sentences(self, text:str) -> List[Dict[str,str]]:
624
- """
625
- This method sentence tokenize the input text into a list of sentences
626
- as dict of {start, end, sentence_text}
627
-
628
- Parameters:
629
- ----------
630
- text : str
631
- text to sentence tokenize.
632
-
633
- Returns : List[Dict[str,str]]
634
- a list of sentences as dict with keys: {"sentence_text", "start", "end"}.
635
- """
636
- sentences = []
637
- for start, end in self.PunktSentenceTokenizer().span_tokenize(text):
638
- sentences.append({"sentence_text": text[start:end],
639
- "start": start,
640
- "end": end})
641
- return sentences
642
-
643
412
 
644
- def _get_context_sentences(self, text_content, i:int, sentences:List[Dict[str, str]], document_key:str=None) -> str:
645
- """
646
- This function returns the context sentences for the current sentence of interest (i).
647
- """
648
- if self.context_sentences == "all":
649
- context = text_content if isinstance(text_content, str) else text_content[document_key]
650
- elif self.context_sentences == 0:
651
- context = ""
652
- else:
653
- start = max(0, i - self.context_sentences)
654
- end = min(i + 1 + self.context_sentences, len(sentences))
655
- context = " ".join([s['sentence_text'] for s in sentences[start:end]])
656
- return context
657
-
658
-
659
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
660
- document_key:str=None, temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict[str,str]]:
413
+ def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
414
+ document_key:str=None, temperature:float=0.0, verbose:bool=False, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
661
415
  """
662
- This method inputs a text and outputs a list of outputs per sentence.
416
+ This method inputs a text and outputs a list of outputs per unit.
663
417
 
664
418
  Parameters:
665
419
  ----------
@@ -667,77 +421,211 @@ class SentenceFrameExtractor(FrameExtractor):
667
421
  the input text content to put in prompt template.
668
422
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
669
423
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
670
- max_new_tokens : str, Optional
424
+ max_new_tokens : int, Optional
671
425
  the max number of new tokens LLM should generate.
672
426
  document_key : str, Optional
673
427
  specify the key in text_content where document text is.
674
428
  If text_content is str, this parameter will be ignored.
675
429
  temperature : float, Optional
676
430
  the temperature for token sampling.
677
- stream : bool, Optional
431
+ verbose : bool, Optional
678
432
  if True, LLM generated text will be printed in terminal in real-time.
433
+ return_messages_log : bool, Optional
434
+ if True, a list of messages will be returned.
679
435
 
680
- Return : str
681
- the output from LLM. Need post-processing.
436
+ Return : List[FrameExtractionUnitResult]
437
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
682
438
  """
683
439
  # define output
684
440
  output = []
685
- # sentence tokenization
441
+ # unit chunking
686
442
  if isinstance(text_content, str):
687
- sentences = self._get_sentences(text_content)
443
+ doc_text = text_content
444
+
688
445
  elif isinstance(text_content, dict):
689
446
  if document_key is None:
690
447
  raise ValueError("document_key must be provided when text_content is dict.")
691
- sentences = self._get_sentences(text_content[document_key])
692
-
693
- # generate sentence by sentence
694
- for i, sent in enumerate(sentences):
448
+ doc_text = text_content[document_key]
449
+
450
+ units = self.unit_chunker.chunk(doc_text)
451
+ # context chunker init
452
+ self.context_chunker.fit(doc_text, units)
453
+ # messages log
454
+ if return_messages_log:
455
+ messages_log = []
456
+
457
+ # generate unit by unit
458
+ for i, unit in enumerate(units):
695
459
  # construct chat messages
696
460
  messages = []
697
461
  if self.system_prompt:
698
462
  messages.append({'role': 'system', 'content': self.system_prompt})
699
463
 
700
- context = self._get_context_sentences(text_content, i, sentences, document_key)
464
+ context = self.context_chunker.chunk(unit)
701
465
 
702
- if self.context_sentences == 0:
703
- # no context, just place sentence of interest
704
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
466
+ if context == "":
467
+ # no context, just place unit in user prompt
468
+ if isinstance(text_content, str):
469
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
470
+ else:
471
+ unit_content = text_content.copy()
472
+ unit_content[document_key] = unit.text
473
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
705
474
  else:
706
- # insert context
707
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
708
- # simulate conversation
709
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
710
- # place sentence of interest
711
- messages.append({'role': 'user', 'content': sent['sentence_text']})
712
-
713
- if stream:
714
- print(f"\n\n{Fore.GREEN}Sentence {i}:{Style.RESET_ALL}\n{sent['sentence_text']}\n")
715
- if isinstance(self.context_sentences, int) and self.context_sentences > 0:
475
+ # insert context to user prompt
476
+ if isinstance(text_content, str):
477
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
478
+ else:
479
+ context_content = text_content.copy()
480
+ context_content[document_key] = context
481
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
482
+ # simulate conversation where assistant confirms
483
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
484
+ # place unit of interest
485
+ messages.append({'role': 'user', 'content': unit.text})
486
+
487
+ if verbose:
488
+ print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
489
+ if context != "":
716
490
  print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
717
491
 
718
492
  print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
719
493
 
720
- gen_text = self.inference_engine.chat(
721
- messages=messages,
722
- max_new_tokens=max_new_tokens,
723
- temperature=temperature,
724
- stream=stream,
725
- **kwrs
726
- )
494
+ response_stream = self.inference_engine.chat(
495
+ messages=messages,
496
+ max_new_tokens=max_new_tokens,
497
+ temperature=temperature,
498
+ stream=True,
499
+ **kwrs
500
+ )
501
+
502
+ gen_text = ""
503
+ for chunk in response_stream:
504
+ gen_text += chunk
505
+ print(chunk, end='', flush=True)
506
+
507
+ else:
508
+ gen_text = self.inference_engine.chat(
509
+ messages=messages,
510
+ max_new_tokens=max_new_tokens,
511
+ temperature=temperature,
512
+ stream=False,
513
+ **kwrs
514
+ )
515
+
516
+ if return_messages_log:
517
+ messages.append({"role": "assistant", "content": gen_text})
518
+ messages_log.append(messages)
727
519
 
728
520
  # add to output
729
- output.append({'sentence_start': sent['start'],
730
- 'sentence_end': sent['end'],
731
- 'sentence_text': sent['sentence_text'],
732
- 'gen_text': gen_text})
521
+ result = FrameExtractionUnitResult(
522
+ start=unit.start,
523
+ end=unit.end,
524
+ text=unit.text,
525
+ gen_text=gen_text)
526
+ output.append(result)
733
527
 
528
+ if return_messages_log:
529
+ return output, messages_log
530
+
734
531
  return output
735
532
 
533
+ def stream(self, text_content: Union[str, Dict[str, str]], max_new_tokens: int = 2048, document_key: str = None,
534
+ temperature: float = 0.0, **kwrs) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
535
+ """
536
+ Streams LLM responses per unit with structured event types,
537
+ and returns collected data for post-processing.
538
+
539
+ Yields:
540
+ -------
541
+ Dict[str, Any]: (type, data)
542
+ - {"type": "info", "data": str_message}: General informational messages.
543
+ - {"type": "unit", "data": dict_unit_info}: Signals start of a new unit. dict_unit_info contains {'id', 'text', 'start', 'end'}
544
+ - {"type": "context", "data": str_context}: Context string for the current unit.
545
+ - {"type": "llm_chunk", "data": str_chunk}: A raw chunk from the LLM.
736
546
 
737
- async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
738
- document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
547
+ Returns:
548
+ --------
549
+ List[FrameExtractionUnitResult]:
550
+ A list of FrameExtractionUnitResult objects, each containing the
551
+ original unit details and the fully accumulated 'gen_text' from the LLM.
739
552
  """
740
- The asynchronous version of the extract() method.
553
+ collected_results: List[FrameExtractionUnitResult] = []
554
+
555
+ if isinstance(text_content, str):
556
+ doc_text = text_content
557
+ elif isinstance(text_content, dict):
558
+ if document_key is None:
559
+ raise ValueError("document_key must be provided when text_content is dict.")
560
+ if document_key not in text_content:
561
+ raise ValueError(f"document_key '{document_key}' not found in text_content.")
562
+ doc_text = text_content[document_key]
563
+ else:
564
+ raise TypeError("text_content must be a string or a dictionary.")
565
+
566
+ units: List[FrameExtractionUnit] = self.unit_chunker.chunk(doc_text)
567
+ self.context_chunker.fit(doc_text, units)
568
+
569
+ yield {"type": "info", "data": f"Starting LLM processing for {len(units)} units."}
570
+
571
+ for i, unit in enumerate(units):
572
+ unit_info_payload = {"id": i, "text": unit.text, "start": unit.start, "end": unit.end}
573
+ yield {"type": "unit", "data": unit_info_payload}
574
+
575
+ messages = []
576
+ if self.system_prompt:
577
+ messages.append({'role': 'system', 'content': self.system_prompt})
578
+
579
+ context_str = self.context_chunker.chunk(unit)
580
+
581
+ # Construct prompt input based on whether text_content was str or dict
582
+ if context_str:
583
+ yield {"type": "context", "data": context_str}
584
+ prompt_input_for_context = context_str
585
+ if isinstance(text_content, dict):
586
+ context_content_dict = text_content.copy()
587
+ context_content_dict[document_key] = context_str
588
+ prompt_input_for_context = context_content_dict
589
+ messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_context)})
590
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
591
+ messages.append({'role': 'user', 'content': unit.text})
592
+ else: # No context
593
+ prompt_input_for_unit = unit.text
594
+ if isinstance(text_content, dict):
595
+ unit_content_dict = text_content.copy()
596
+ unit_content_dict[document_key] = unit.text
597
+ prompt_input_for_unit = unit_content_dict
598
+ messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_unit)})
599
+
600
+ current_gen_text = ""
601
+
602
+ response_stream = self.inference_engine.chat(
603
+ messages=messages,
604
+ max_new_tokens=max_new_tokens,
605
+ temperature=temperature,
606
+ stream=True,
607
+ **kwrs
608
+ )
609
+ for chunk in response_stream:
610
+ yield {"type": "llm_chunk", "data": chunk}
611
+ current_gen_text += chunk
612
+
613
+ # Store the result for this unit
614
+ result_for_unit = FrameExtractionUnitResult(
615
+ start=unit.start,
616
+ end=unit.end,
617
+ text=unit.text,
618
+ gen_text=current_gen_text
619
+ )
620
+ collected_results.append(result_for_unit)
621
+
622
+ yield {"type": "info", "data": "All units processed by LLM."}
623
+ return collected_results
624
+
625
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None, temperature:float=0.0,
626
+ concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
627
+ """
628
+ This is the asynchronous version of the extract() method.
741
629
 
742
630
  Parameters:
743
631
  ----------
@@ -745,7 +633,7 @@ class SentenceFrameExtractor(FrameExtractor):
745
633
  the input text content to put in prompt template.
746
634
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
747
635
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
748
- max_new_tokens : str, Optional
636
+ max_new_tokens : int, Optional
749
637
  the max number of new tokens LLM should generate.
750
638
  document_key : str, Optional
751
639
  specify the key in text_content where document text is.
@@ -753,73 +641,129 @@ class SentenceFrameExtractor(FrameExtractor):
753
641
  temperature : float, Optional
754
642
  the temperature for token sampling.
755
643
  concurrent_batch_size : int, Optional
756
- the number of sentences to process in concurrent.
757
- """
758
- # Check if self.inference_engine.chat_async() is implemented
759
- if not hasattr(self.inference_engine, 'chat_async'):
760
- raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
644
+ the batch size for concurrent processing.
645
+ return_messages_log : bool, Optional
646
+ if True, a list of messages will be returned.
761
647
 
762
- # define output
763
- output = []
764
- # sentence tokenization
648
+ Return : List[FrameExtractionUnitResult]
649
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
650
+ """
765
651
  if isinstance(text_content, str):
766
- sentences = self._get_sentences(text_content)
652
+ doc_text = text_content
767
653
  elif isinstance(text_content, dict):
768
654
  if document_key is None:
769
655
  raise ValueError("document_key must be provided when text_content is dict.")
770
- sentences = self._get_sentences(text_content[document_key])
656
+ if document_key not in text_content:
657
+ raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
658
+ doc_text = text_content[document_key]
659
+ else:
660
+ raise TypeError("text_content must be a string or a dictionary.")
771
661
 
772
- # generate sentence by sentence
773
- for i in range(0, len(sentences), concurrent_batch_size):
774
- tasks = []
775
- batch = sentences[i:i + concurrent_batch_size]
776
- for j, sent in enumerate(batch):
777
- # construct chat messages
778
- messages = []
779
- if self.system_prompt:
780
- messages.append({'role': 'system', 'content': self.system_prompt})
662
+ units = self.unit_chunker.chunk(doc_text)
781
663
 
782
- context = self._get_context_sentences(text_content, i + j, sentences, document_key)
783
-
784
- if self.context_sentences == 0:
785
- # no context, just place sentence of interest
786
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
664
+ # context chunker init
665
+ self.context_chunker.fit(doc_text, units)
666
+
667
+ # Prepare inputs for all units first
668
+ tasks_input = []
669
+ for i, unit in enumerate(units):
670
+ # construct chat messages
671
+ messages = []
672
+ if self.system_prompt:
673
+ messages.append({'role': 'system', 'content': self.system_prompt})
674
+
675
+ context = self.context_chunker.chunk(unit)
676
+
677
+ if context == "":
678
+ # no context, just place unit in user prompt
679
+ if isinstance(text_content, str):
680
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
787
681
  else:
788
- # insert context
682
+ unit_content = text_content.copy()
683
+ unit_content[document_key] = unit.text
684
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
685
+ else:
686
+ # insert context to user prompt
687
+ if isinstance(text_content, str):
789
688
  messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
790
- # simulate conversation
791
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
792
- # place sentence of interest
793
- messages.append({'role': 'user', 'content': sent['sentence_text']})
794
-
795
- # add to tasks
796
- task = asyncio.create_task(
797
- self.inference_engine.chat_async(
798
- messages=messages,
799
- max_new_tokens=max_new_tokens,
800
- temperature=temperature,
801
- **kwrs
802
- )
689
+ else:
690
+ context_content = text_content.copy()
691
+ context_content[document_key] = context
692
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
693
+ # simulate conversation where assistant confirms
694
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
695
+ # place unit of interest
696
+ messages.append({'role': 'user', 'content': unit.text})
697
+
698
+ # Store unit and messages together for the task
699
+ tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
700
+
701
+ # Process units concurrently with asyncio.Semaphore
702
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
703
+
704
+ async def semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
705
+ unit = task_data["unit"]
706
+ messages = task_data["messages"]
707
+ original_index = task_data["original_index"]
708
+
709
+ async with semaphore:
710
+ gen_text = await self.inference_engine.chat_async(
711
+ messages=messages,
712
+ max_new_tokens=max_new_tokens,
713
+ temperature=temperature,
714
+ **kwrs
803
715
  )
804
- tasks.append(task)
805
-
806
- # Wait until the batch is done, collect results and move on to next batch
807
- responses = await asyncio.gather(*tasks)
716
+ return {"original_index": original_index, "unit": unit, "gen_text": gen_text, "messages": messages}
717
+
718
+ # Create and gather tasks
719
+ tasks = []
720
+ for task_inp in tasks_input:
721
+ task = asyncio.create_task(semaphore_helper(
722
+ task_inp,
723
+ max_new_tokens=max_new_tokens,
724
+ temperature=temperature,
725
+ **kwrs
726
+ ))
727
+ tasks.append(task)
728
+
729
+ results_raw = await asyncio.gather(*tasks)
730
+
731
+ # Sort results back into original order using the index stored
732
+ results_raw.sort(key=lambda x: x["original_index"])
733
+
734
+ # Restructure the results
735
+ output: List[FrameExtractionUnitResult] = []
736
+ messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
737
+
738
+ for result_data in results_raw:
739
+ unit = result_data["unit"]
740
+ gen_text = result_data["gen_text"]
741
+
742
+ # Create result object
743
+ result = FrameExtractionUnitResult(
744
+ start=unit.start,
745
+ end=unit.end,
746
+ text=unit.text,
747
+ gen_text=gen_text
748
+ )
749
+ output.append(result)
750
+
751
+ # Append to messages log if requested
752
+ if return_messages_log:
753
+ final_messages = result_data["messages"] + [{"role": "assistant", "content": gen_text}]
754
+ messages_log.append(final_messages)
755
+
756
+ if return_messages_log:
757
+ return output, messages_log
758
+ else:
759
+ return output
808
760
 
809
- # Collect outputs
810
- for gen_text, sent in zip(responses, batch):
811
- output.append({'sentence_start': sent['start'],
812
- 'sentence_end': sent['end'],
813
- 'sentence_text': sent['sentence_text'],
814
- 'gen_text': gen_text})
815
- return output
816
-
817
761
 
818
- def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=512,
819
- document_key:str=None, temperature:float=0.0, stream:bool=False,
762
+ def extract_frames(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
763
+ document_key:str=None, temperature:float=0.0, verbose:bool=False,
820
764
  concurrent:bool=False, concurrent_batch_size:int=32,
821
765
  case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
822
- allow_overlap_entities:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
766
+ allow_overlap_entities:bool=False, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
823
767
  """
824
768
  This method inputs a text and outputs a list of LLMInformationExtractionFrame
825
769
  It use the extract() method and post-process outputs into frames.
@@ -830,8 +774,6 @@ class SentenceFrameExtractor(FrameExtractor):
830
774
  the input text content to put in prompt template.
831
775
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
832
776
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
833
- entity_key : str
834
- the key (in ouptut JSON) for entity text.
835
777
  max_new_tokens : str, Optional
836
778
  the max number of new tokens LLM should generate.
837
779
  document_key : str, Optional
@@ -839,7 +781,7 @@ class SentenceFrameExtractor(FrameExtractor):
839
781
  If text_content is str, this parameter will be ignored.
840
782
  temperature : float, Optional
841
783
  the temperature for token sampling.
842
- stream : bool, Optional
784
+ verbose : bool, Optional
843
785
  if True, LLM generated text will be printed in terminal in real-time.
844
786
  concurrent : bool, Optional
845
787
  if True, the sentences will be extracted in concurrent.
@@ -857,40 +799,48 @@ class SentenceFrameExtractor(FrameExtractor):
857
799
  allow_overlap_entities : bool, Optional
858
800
  if True, entities can overlap in the text.
859
801
  Note that this can cause multiple frames to be generated on the same entity span if they have same entity text.
802
+ return_messages_log : bool, Optional
803
+ if True, a list of messages will be returned.
860
804
 
861
805
  Return : str
862
806
  a list of frames.
863
807
  """
808
+ ENTITY_KEY = "entity_text"
864
809
  if concurrent:
865
- if stream:
866
- warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
810
+ if verbose:
811
+ warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
867
812
 
868
813
  nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
869
- llm_output_sentences = asyncio.run(self.extract_async(text_content=text_content,
870
- max_new_tokens=max_new_tokens,
871
- document_key=document_key,
872
- temperature=temperature,
873
- concurrent_batch_size=concurrent_batch_size,
874
- **kwrs)
875
- )
814
+ extraction_results = asyncio.run(self.extract_async(text_content=text_content,
815
+ max_new_tokens=max_new_tokens,
816
+ document_key=document_key,
817
+ temperature=temperature,
818
+ concurrent_batch_size=concurrent_batch_size,
819
+ return_messages_log=return_messages_log,
820
+ **kwrs)
821
+ )
876
822
  else:
877
- llm_output_sentences = self.extract(text_content=text_content,
878
- max_new_tokens=max_new_tokens,
879
- document_key=document_key,
880
- temperature=temperature,
881
- stream=stream,
882
- **kwrs)
823
+ extraction_results = self.extract(text_content=text_content,
824
+ max_new_tokens=max_new_tokens,
825
+ document_key=document_key,
826
+ temperature=temperature,
827
+ verbose=verbose,
828
+ return_messages_log=return_messages_log,
829
+ **kwrs)
830
+
831
+ llm_output_results, messages_log = extraction_results if return_messages_log else (extraction_results, None)
832
+
883
833
  frame_list = []
884
- for sent in llm_output_sentences:
834
+ for res in llm_output_results:
885
835
  entity_json = []
886
- for entity in self._extract_json(gen_text=sent['gen_text']):
887
- if entity_key in entity:
836
+ for entity in self._extract_json(gen_text=res.gen_text):
837
+ if ENTITY_KEY in entity:
888
838
  entity_json.append(entity)
889
839
  else:
890
- warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{entity_key}"). This frame will be dropped.', RuntimeWarning)
840
+ warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
891
841
 
892
- spans = self._find_entity_spans(text=sent['sentence_text'],
893
- entities=[e[entity_key] for e in entity_json],
842
+ spans = self._find_entity_spans(text=res.text,
843
+ entities=[e[ENTITY_KEY] for e in entity_json],
894
844
  case_sensitive=case_sensitive,
895
845
  fuzzy_match=fuzzy_match,
896
846
  fuzzy_buffer_size=fuzzy_buffer_size,
@@ -899,31 +849,41 @@ class SentenceFrameExtractor(FrameExtractor):
899
849
  for ent, span in zip(entity_json, spans):
900
850
  if span is not None:
901
851
  start, end = span
902
- entity_text = sent['sentence_text'][start:end]
903
- start += sent['sentence_start']
904
- end += sent['sentence_start']
852
+ entity_text = res.text[start:end]
853
+ start += res.start
854
+ end += res.start
855
+ attr = {}
856
+ if "attr" in ent and ent["attr"] is not None:
857
+ attr = ent["attr"]
858
+
905
859
  frame = LLMInformationExtractionFrame(frame_id=f"{len(frame_list)}",
906
860
  start=start,
907
861
  end=end,
908
862
  entity_text=entity_text,
909
- attr={k: v for k, v in ent.items() if k != entity_key and v != ""})
863
+ attr=attr)
910
864
  frame_list.append(frame)
911
- return frame_list
912
865
 
866
+ if return_messages_log:
867
+ return frame_list, messages_log
868
+ return frame_list
869
+
913
870
 
914
- class SentenceReviewFrameExtractor(SentenceFrameExtractor):
915
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
916
- review_mode:str, review_prompt:str=None, system_prompt:str=None,
917
- context_sentences:Union[str, int]="all", **kwrs):
871
+ class ReviewFrameExtractor(DirectFrameExtractor):
872
+ def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker,
873
+ inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None, **kwrs):
918
874
  """
919
- This class adds a review step after the SentenceFrameExtractor.
920
- For each sentence, the review process asks LLM to review its output and:
921
- 1. add more frames while keeping current. This is efficient for boosting recall.
875
+ This class add a review step after the DirectFrameExtractor.
876
+ The Review process asks LLM to review its output and:
877
+ 1. add more frames while keep current. This is efficient for boosting recall.
922
878
  2. or, regenerate frames (add new and delete existing).
923
879
  Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
924
880
 
925
881
  Parameters:
926
882
  ----------
883
+ unit_chunker : UnitChunker
884
+ the unit chunker object that determines how to chunk the document text into units.
885
+ context_chunker : ContextChunker
886
+ the context chunker object that determines how to get context for each unit.
927
887
  inference_engine : InferenceEngine
928
888
  the LLM inferencing engine object. Must implements the chat() method.
929
889
  prompt_template : str
@@ -936,36 +896,215 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
936
896
  addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
937
897
  system_prompt : str, Optional
938
898
  system prompt.
939
- context_sentences : Union[str, int], Optional
940
- number of sentences before and after the given sentence to provide additional context.
941
- if "all", the full text will be provided in the prompt as context.
942
- if 0, no additional context will be provided.
943
- This is good for tasks that does not require context beyond the given sentence.
944
- if > 0, the number of sentences before and after the given sentence to provide as context.
945
- This is good for tasks that require context beyond the given sentence.
946
899
  """
947
- super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
948
- system_prompt=system_prompt, context_sentences=context_sentences, **kwrs)
949
-
900
+ super().__init__(inference_engine=inference_engine,
901
+ unit_chunker=unit_chunker,
902
+ prompt_template=prompt_template,
903
+ system_prompt=system_prompt,
904
+ context_chunker=context_chunker,
905
+ **kwrs)
906
+ # check review mode
950
907
  if review_mode not in {"addition", "revision"}:
951
908
  raise ValueError('review_mode must be one of {"addition", "revision"}.')
952
909
  self.review_mode = review_mode
910
+ # assign review prompt
911
+ if review_prompt:
912
+ self.review_prompt = review_prompt
913
+ else:
914
+ self.review_prompt = None
915
+ original_class_name = self.__class__.__name__
916
+
917
+ current_class_name = original_class_name
918
+ for current_class_in_mro in self.__class__.__mro__:
919
+ if current_class_in_mro is object:
920
+ continue
921
+
922
+ current_class_name = current_class_in_mro.__name__
923
+ try:
924
+ file_path = importlib.resources.files('llm_ie.asset.default_prompts').\
925
+ joinpath(f"{self.__class__.__name__}_{self.review_mode}_review_prompt.txt")
926
+ with open(file_path, 'r', encoding="utf-8") as f:
927
+ self.review_prompt = f.read()
928
+ except FileNotFoundError:
929
+ pass
930
+
931
+ except Exception as e:
932
+ warnings.warn(
933
+ f"Error attempting to read default review prompt for '{current_class_name}' "
934
+ f"from '{str(file_path)}': {e}. Trying next in MRO.",
935
+ UserWarning
936
+ )
937
+ continue
938
+
939
+ if self.review_prompt is None:
940
+ raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
941
+
942
+ def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None,
943
+ temperature:float=0.0, verbose:bool=False, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
944
+ """
945
+ This method inputs a text and outputs a list of outputs per unit.
946
+
947
+ Parameters:
948
+ ----------
949
+ text_content : Union[str, Dict[str,str]]
950
+ the input text content to put in prompt template.
951
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
952
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
953
+ max_new_tokens : int, Optional
954
+ the max number of new tokens LLM should generate.
955
+ document_key : str, Optional
956
+ specify the key in text_content where document text is.
957
+ If text_content is str, this parameter will be ignored.
958
+ temperature : float, Optional
959
+ the temperature for token sampling.
960
+ verbose : bool, Optional
961
+ if True, LLM generated text will be printed in terminal in real-time.
962
+ return_messages_log : bool, Optional
963
+ if True, a list of messages will be returned.
964
+
965
+ Return : List[FrameExtractionUnitResult]
966
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
967
+ """
968
+ # define output
969
+ output = []
970
+ # unit chunking
971
+ if isinstance(text_content, str):
972
+ doc_text = text_content
973
+
974
+ elif isinstance(text_content, dict):
975
+ if document_key is None:
976
+ raise ValueError("document_key must be provided when text_content is dict.")
977
+ doc_text = text_content[document_key]
978
+
979
+ units = self.unit_chunker.chunk(doc_text)
980
+ # context chunker init
981
+ self.context_chunker.fit(doc_text, units)
982
+ # messages log
983
+ if return_messages_log:
984
+ messages_log = []
985
+
986
+ # generate unit by unit
987
+ for i, unit in enumerate(units):
988
+ # <--- Initial generation step --->
989
+ # construct chat messages
990
+ messages = []
991
+ if self.system_prompt:
992
+ messages.append({'role': 'system', 'content': self.system_prompt})
993
+
994
+ context = self.context_chunker.chunk(unit)
995
+
996
+ if context == "":
997
+ # no context, just place unit in user prompt
998
+ if isinstance(text_content, str):
999
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
1000
+ else:
1001
+ unit_content = text_content.copy()
1002
+ unit_content[document_key] = unit.text
1003
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
1004
+ else:
1005
+ # insert context to user prompt
1006
+ if isinstance(text_content, str):
1007
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1008
+ else:
1009
+ context_content = text_content.copy()
1010
+ context_content[document_key] = context
1011
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1012
+ # simulate conversation where assistant confirms
1013
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
1014
+ # place unit of interest
1015
+ messages.append({'role': 'user', 'content': unit.text})
1016
+
1017
+ if verbose:
1018
+ print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
1019
+ if context != "":
1020
+ print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
1021
+
1022
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1023
+
1024
+ response_stream = self.inference_engine.chat(
1025
+ messages=messages,
1026
+ max_new_tokens=max_new_tokens,
1027
+ temperature=temperature,
1028
+ stream=True,
1029
+ **kwrs
1030
+ )
1031
+
1032
+ initial = ""
1033
+ for chunk in response_stream:
1034
+ initial += chunk
1035
+ print(chunk, end='', flush=True)
1036
+
1037
+ else:
1038
+ initial = self.inference_engine.chat(
1039
+ messages=messages,
1040
+ max_new_tokens=max_new_tokens,
1041
+ temperature=temperature,
1042
+ stream=False,
1043
+ **kwrs
1044
+ )
1045
+
1046
+ if return_messages_log:
1047
+ messages.append({"role": "assistant", "content": initial})
1048
+ messages_log.append(messages)
1049
+
1050
+ # <--- Review step --->
1051
+ if verbose:
1052
+ print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
1053
+
1054
+ messages.append({'role': 'assistant', 'content': initial})
1055
+ messages.append({'role': 'user', 'content': self.review_prompt})
1056
+
1057
+ if verbose:
1058
+ response_stream = self.inference_engine.chat(
1059
+ messages=messages,
1060
+ max_new_tokens=max_new_tokens,
1061
+ temperature=temperature,
1062
+ stream=True,
1063
+ **kwrs
1064
+ )
1065
+
1066
+ review = ""
1067
+ for chunk in response_stream:
1068
+ review += chunk
1069
+ print(chunk, end='', flush=True)
1070
+
1071
+ else:
1072
+ review = self.inference_engine.chat(
1073
+ messages=messages,
1074
+ max_new_tokens=max_new_tokens,
1075
+ temperature=temperature,
1076
+ stream=False,
1077
+ **kwrs
1078
+ )
953
1079
 
954
- if review_prompt:
955
- self.review_prompt = review_prompt
956
- else:
957
- file_path = importlib.resources.files('llm_ie.asset.default_prompts').\
958
- joinpath(f"{self.__class__.__name__}_{self.review_mode}_review_prompt.txt")
959
- with open(file_path, 'r', encoding="utf-8") as f:
960
- self.review_prompt = f.read()
1080
+ # Output
1081
+ if self.review_mode == "revision":
1082
+ gen_text = review
1083
+ elif self.review_mode == "addition":
1084
+ gen_text = initial + '\n' + review
1085
+
1086
+ if return_messages_log:
1087
+ messages.append({"role": "assistant", "content": review})
1088
+ messages_log.append(messages)
961
1089
 
962
- warnings.warn(f'Custom review prompt not provided. The default review prompt is used:\n"{self.review_prompt}"', UserWarning)
1090
+ # add to output
1091
+ result = FrameExtractionUnitResult(
1092
+ start=unit.start,
1093
+ end=unit.end,
1094
+ text=unit.text,
1095
+ gen_text=gen_text)
1096
+ output.append(result)
1097
+
1098
+ if return_messages_log:
1099
+ return output, messages_log
1100
+
1101
+ return output
963
1102
 
964
1103
 
965
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
966
- document_key:str=None, temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict[str,str]]:
1104
+ def stream(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
1105
+ document_key:str=None, temperature:float=0.0, **kwrs) -> Generator[str, None, None]:
967
1106
  """
968
- This method inputs a text and outputs a list of outputs per sentence.
1107
+ This method inputs a text and outputs a list of outputs per unit.
969
1108
 
970
1109
  Parameters:
971
1110
  ----------
@@ -973,234 +1112,371 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
973
1112
  the input text content to put in prompt template.
974
1113
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
975
1114
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
976
- max_new_tokens : str, Optional
1115
+ max_new_tokens : int, Optional
977
1116
  the max number of new tokens LLM should generate.
978
1117
  document_key : str, Optional
979
1118
  specify the key in text_content where document text is.
980
1119
  If text_content is str, this parameter will be ignored.
981
1120
  temperature : float, Optional
982
1121
  the temperature for token sampling.
983
- stream : bool, Optional
984
- if True, LLM generated text will be printed in terminal in real-time.
985
1122
 
986
- Return : str
987
- the output from LLM. Need post-processing.
1123
+ Return : List[FrameExtractionUnitResult]
1124
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
988
1125
  """
989
- # define output
990
- output = []
991
- # sentence tokenization
1126
+ # unit chunking
992
1127
  if isinstance(text_content, str):
993
- sentences = self._get_sentences(text_content)
1128
+ doc_text = text_content
1129
+
994
1130
  elif isinstance(text_content, dict):
995
1131
  if document_key is None:
996
1132
  raise ValueError("document_key must be provided when text_content is dict.")
997
- sentences = self._get_sentences(text_content[document_key])
1133
+ doc_text = text_content[document_key]
998
1134
 
999
- # generate sentence by sentence
1000
- for i, sent in enumerate(sentences):
1135
+ units = self.unit_chunker.chunk(doc_text)
1136
+ # context chunker init
1137
+ self.context_chunker.fit(doc_text, units)
1138
+
1139
+ # generate unit by unit
1140
+ for i, unit in enumerate(units):
1141
+ # <--- Initial generation step --->
1001
1142
  # construct chat messages
1002
1143
  messages = []
1003
1144
  if self.system_prompt:
1004
1145
  messages.append({'role': 'system', 'content': self.system_prompt})
1005
1146
 
1006
- context = self._get_context_sentences(text_content, i, sentences, document_key)
1147
+ context = self.context_chunker.chunk(unit)
1007
1148
 
1008
- if self.context_sentences == 0:
1009
- # no context, just place sentence of interest
1010
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1149
+ if context == "":
1150
+ # no context, just place unit in user prompt
1151
+ if isinstance(text_content, str):
1152
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
1153
+ else:
1154
+ unit_content = text_content.copy()
1155
+ unit_content[document_key] = unit.text
1156
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
1011
1157
  else:
1012
- # insert context
1013
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1014
- # simulate conversation
1015
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1016
- # place sentence of interest
1017
- messages.append({'role': 'user', 'content': sent['sentence_text']})
1018
-
1019
- if stream:
1020
- print(f"\n\n{Fore.GREEN}Sentence {i}: {Style.RESET_ALL}\n{sent['sentence_text']}\n")
1021
- if isinstance(self.context_sentences, int) and self.context_sentences > 0:
1022
- print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
1023
- print(f"{Fore.BLUE}Initial Output:{Style.RESET_ALL}")
1158
+ # insert context to user prompt
1159
+ if isinstance(text_content, str):
1160
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1161
+ else:
1162
+ context_content = text_content.copy()
1163
+ context_content[document_key] = context
1164
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1165
+ # simulate conversation where assistant confirms
1166
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
1167
+ # place unit of interest
1168
+ messages.append({'role': 'user', 'content': unit.text})
1169
+
1170
+
1171
+ yield f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n"
1172
+ if context != "":
1173
+ yield f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n"
1174
+
1175
+ yield f"{Fore.BLUE}Extraction:{Style.RESET_ALL}\n"
1024
1176
 
1025
- initial = self.inference_engine.chat(
1177
+ response_stream = self.inference_engine.chat(
1026
1178
  messages=messages,
1027
1179
  max_new_tokens=max_new_tokens,
1028
1180
  temperature=temperature,
1029
- stream=stream,
1181
+ stream=True,
1030
1182
  **kwrs
1031
1183
  )
1032
1184
 
1033
- # Review
1034
- if stream:
1035
- print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
1185
+ initial = ""
1186
+ for chunk in response_stream:
1187
+ initial += chunk
1188
+ yield chunk
1189
+
1190
+ # <--- Review step --->
1191
+ yield f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}"
1192
+
1036
1193
  messages.append({'role': 'assistant', 'content': initial})
1037
1194
  messages.append({'role': 'user', 'content': self.review_prompt})
1038
1195
 
1039
- review = self.inference_engine.chat(
1196
+ response_stream = self.inference_engine.chat(
1040
1197
  messages=messages,
1041
1198
  max_new_tokens=max_new_tokens,
1042
1199
  temperature=temperature,
1043
- stream=stream,
1200
+ stream=True,
1044
1201
  **kwrs
1045
1202
  )
1046
1203
 
1047
- # Output
1048
- if self.review_mode == "revision":
1049
- gen_text = review
1050
- elif self.review_mode == "addition":
1051
- gen_text = initial + '\n' + review
1204
+ for chunk in response_stream:
1205
+ yield chunk
1052
1206
 
1053
- # add to output
1054
- output.append({'sentence_start': sent['start'],
1055
- 'sentence_end': sent['end'],
1056
- 'sentence_text': sent['sentence_text'],
1057
- 'gen_text': gen_text})
1058
- return output
1059
-
1060
- async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
1061
- document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
1207
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, document_key:str=None, temperature:float=0.0,
1208
+ concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
1062
1209
  """
1063
- The asynchronous version of the extract() method.
1210
+ This is the asynchronous version of the extract() method with the review step.
1064
1211
 
1065
1212
  Parameters:
1066
1213
  ----------
1067
1214
  text_content : Union[str, Dict[str,str]]
1068
- the input text content to put in prompt template.
1215
+ the input text content to put in prompt template.
1069
1216
  If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1070
1217
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1071
- max_new_tokens : str, Optional
1072
- the max number of new tokens LLM should generate.
1218
+ max_new_tokens : int, Optional
1219
+ the max number of new tokens LLM should generate.
1073
1220
  document_key : str, Optional
1074
- specify the key in text_content where document text is.
1221
+ specify the key in text_content where document text is.
1075
1222
  If text_content is str, this parameter will be ignored.
1076
1223
  temperature : float, Optional
1077
1224
  the temperature for token sampling.
1078
1225
  concurrent_batch_size : int, Optional
1079
- the number of sentences to process in concurrent.
1226
+ the batch size for concurrent processing.
1227
+ return_messages_log : bool, Optional
1228
+ if True, a list of messages will be returned, including review steps.
1080
1229
 
1081
- Return : str
1082
- the output from LLM. Need post-processing.
1230
+ Return : List[FrameExtractionUnitResult]
1231
+ the output from LLM for each unit after review. Contains the start, end, text, and generated text.
1083
1232
  """
1084
- # Check if self.inference_engine.chat_async() is implemented
1085
- if not hasattr(self.inference_engine, 'chat_async'):
1086
- raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1087
-
1088
- # define output
1089
- output = []
1090
- # sentence tokenization
1091
1233
  if isinstance(text_content, str):
1092
- sentences = self._get_sentences(text_content)
1234
+ doc_text = text_content
1093
1235
  elif isinstance(text_content, dict):
1094
1236
  if document_key is None:
1095
1237
  raise ValueError("document_key must be provided when text_content is dict.")
1096
- sentences = self._get_sentences(text_content[document_key])
1097
-
1098
- # generate initial outputs sentence by sentence
1099
- for i in range(0, len(sentences), concurrent_batch_size):
1100
- messages_list = []
1101
- init_tasks = []
1102
- review_tasks = []
1103
- batch = sentences[i:i + concurrent_batch_size]
1104
- for j, sent in enumerate(batch):
1105
- # construct chat messages
1106
- messages = []
1107
- if self.system_prompt:
1108
- messages.append({'role': 'system', 'content': self.system_prompt})
1238
+ if document_key not in text_content:
1239
+ raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
1240
+ doc_text = text_content[document_key]
1241
+ else:
1242
+ raise TypeError("text_content must be a string or a dictionary.")
1109
1243
 
1110
- context = self._get_context_sentences(text_content, i + j, sentences, document_key)
1111
-
1112
- if self.context_sentences == 0:
1113
- # no context, just place sentence of interest
1114
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1244
+ units = self.unit_chunker.chunk(doc_text)
1245
+
1246
+ # context chunker init
1247
+ self.context_chunker.fit(doc_text, units)
1248
+
1249
+ # <--- Initial generation step --->
1250
+ initial_tasks_input = []
1251
+ for i, unit in enumerate(units):
1252
+ # construct chat messages for initial generation
1253
+ messages = []
1254
+ if self.system_prompt:
1255
+ messages.append({'role': 'system', 'content': self.system_prompt})
1256
+
1257
+ context = self.context_chunker.chunk(unit)
1258
+
1259
+ if context == "":
1260
+ # no context, just place unit in user prompt
1261
+ if isinstance(text_content, str):
1262
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
1115
1263
  else:
1116
- # insert context
1264
+ unit_content = text_content.copy()
1265
+ unit_content[document_key] = unit.text
1266
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
1267
+ else:
1268
+ # insert context to user prompt
1269
+ if isinstance(text_content, str):
1117
1270
  messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1118
- # simulate conversation
1119
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1120
- # place sentence of interest
1121
- messages.append({'role': 'user', 'content': sent['sentence_text']})
1271
+ else:
1272
+ context_content = text_content.copy()
1273
+ context_content[document_key] = context
1274
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1275
+ # simulate conversation where assistant confirms
1276
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
1277
+ # place unit of interest
1278
+ messages.append({'role': 'user', 'content': unit.text})
1279
+
1280
+ # Store unit and messages together for the initial task
1281
+ initial_tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
1282
+
1283
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
1284
+
1285
+ async def initial_semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
1286
+ unit = task_data["unit"]
1287
+ messages = task_data["messages"]
1288
+ original_index = task_data["original_index"]
1289
+
1290
+ async with semaphore:
1291
+ gen_text = await self.inference_engine.chat_async(
1292
+ messages=messages,
1293
+ max_new_tokens=max_new_tokens,
1294
+ temperature=temperature,
1295
+ **kwrs
1296
+ )
1297
+ # Return initial generation result along with the messages used and the unit
1298
+ return {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text, "initial_messages": messages}
1299
+
1300
+ # Create and gather initial generation tasks
1301
+ initial_tasks = [
1302
+ asyncio.create_task(initial_semaphore_helper(
1303
+ task_inp,
1304
+ max_new_tokens=max_new_tokens,
1305
+ temperature=temperature,
1306
+ **kwrs
1307
+ ))
1308
+ for task_inp in initial_tasks_input
1309
+ ]
1310
+
1311
+ initial_results_raw = await asyncio.gather(*initial_tasks)
1312
+
1313
+ # Sort initial results back into original order
1314
+ initial_results_raw.sort(key=lambda x: x["original_index"])
1315
+
1316
+ # <--- Review step --->
1317
+ review_tasks_input = []
1318
+ for result_data in initial_results_raw:
1319
+ # Prepare messages for the review step
1320
+ initial_messages = result_data["initial_messages"]
1321
+ initial_gen_text = result_data["initial_gen_text"]
1322
+ review_messages = initial_messages + [
1323
+ {'role': 'assistant', 'content': initial_gen_text},
1324
+ {'role': 'user', 'content': self.review_prompt}
1325
+ ]
1326
+ # Store data needed for review task
1327
+ review_tasks_input.append({
1328
+ "unit": result_data["unit"],
1329
+ "initial_gen_text": initial_gen_text,
1330
+ "messages": review_messages,
1331
+ "original_index": result_data["original_index"],
1332
+ "full_initial_log": initial_messages + [{'role': 'assistant', 'content': initial_gen_text}] if return_messages_log else None # Log up to initial generation
1333
+ })
1334
+
1335
+
1336
+ async def review_semaphore_helper(task_data: Dict, max_new_tokens: int, temperature: float, **kwrs):
1337
+ messages = task_data["messages"]
1338
+ original_index = task_data["original_index"]
1339
+
1340
+ async with semaphore:
1341
+ review_gen_text = await self.inference_engine.chat_async(
1342
+ messages=messages,
1343
+ max_new_tokens=max_new_tokens,
1344
+ temperature=temperature,
1345
+ **kwrs
1346
+ )
1347
+ # Combine initial and review results
1348
+ task_data["review_gen_text"] = review_gen_text
1349
+ if return_messages_log:
1350
+ # Log for the review call itself
1351
+ task_data["full_review_log"] = messages + [{'role': 'assistant', 'content': review_gen_text}]
1352
+ return task_data # Return the augmented dictionary
1353
+
1354
+ # Create and gather review tasks
1355
+ review_tasks = [
1356
+ asyncio.create_task(review_semaphore_helper(
1357
+ task_inp,
1358
+ max_new_tokens=max_new_tokens,
1359
+ temperature=temperature,
1360
+ **kwrs
1361
+ ))
1362
+ for task_inp in review_tasks_input
1363
+ ]
1364
+
1365
+ final_results_raw = await asyncio.gather(*review_tasks)
1366
+
1367
+ # Sort final results back into original order (although gather might preserve order for tasks added sequentially)
1368
+ final_results_raw.sort(key=lambda x: x["original_index"])
1369
+
1370
+ # <--- Process final results --->
1371
+ output: List[FrameExtractionUnitResult] = []
1372
+ messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
1373
+
1374
+ for result_data in final_results_raw:
1375
+ unit = result_data["unit"]
1376
+ initial_gen = result_data["initial_gen_text"]
1377
+ review_gen = result_data["review_gen_text"]
1378
+
1379
+ # Combine based on review mode
1380
+ if self.review_mode == "revision":
1381
+ final_gen_text = review_gen
1382
+ elif self.review_mode == "addition":
1383
+ final_gen_text = initial_gen + '\n' + review_gen
1384
+ else: # Should not happen due to init check
1385
+ final_gen_text = review_gen # Default to revision if mode is somehow invalid
1386
+
1387
+ # Create final result object
1388
+ result = FrameExtractionUnitResult(
1389
+ start=unit.start,
1390
+ end=unit.end,
1391
+ text=unit.text,
1392
+ gen_text=final_gen_text # Use the combined/reviewed text
1393
+ )
1394
+ output.append(result)
1395
+
1396
+ # Append full conversation log if requested
1397
+ if return_messages_log:
1398
+ full_log_for_unit = result_data.get("full_initial_log", []) + [{'role': 'user', 'content': self.review_prompt}] + [{'role': 'assistant', 'content': review_gen}]
1399
+ messages_log.append(full_log_for_unit)
1400
+
1401
+ if return_messages_log:
1402
+ return output, messages_log
1403
+ else:
1404
+ return output
1122
1405
 
1123
- messages_list.append(messages)
1124
1406
 
1125
- task = asyncio.create_task(
1126
- self.inference_engine.chat_async(
1127
- messages=messages,
1128
- max_new_tokens=max_new_tokens,
1129
- temperature=temperature,
1130
- **kwrs
1131
- )
1132
- )
1133
- init_tasks.append(task)
1134
-
1135
- # Wait until the batch is done, collect results and move on to next batch
1136
- init_responses = await asyncio.gather(*init_tasks)
1137
- # Collect initials
1138
- initials = []
1139
- for gen_text, sent, messages in zip(init_responses, batch, messages_list):
1140
- initials.append({'sentence_start': sent['start'],
1141
- 'sentence_end': sent['end'],
1142
- 'sentence_text': sent['sentence_text'],
1143
- 'gen_text': gen_text,
1144
- 'messages': messages})
1145
-
1146
- # Review
1147
- for init in initials:
1148
- messages = init["messages"]
1149
- initial = init["gen_text"]
1150
- messages.append({'role': 'assistant', 'content': initial})
1151
- messages.append({'role': 'user', 'content': self.review_prompt})
1152
- task = asyncio.create_task(
1153
- self.inference_engine.chat_async(
1154
- messages=messages,
1155
- max_new_tokens=max_new_tokens,
1156
- temperature=temperature,
1157
- **kwrs
1158
- )
1159
- )
1160
- review_tasks.append(task)
1161
-
1162
- review_responses = await asyncio.gather(*review_tasks)
1163
-
1164
- # Collect reviews
1165
- reviews = []
1166
- for gen_text, sent in zip(review_responses, batch):
1167
- reviews.append({'sentence_start': sent['start'],
1168
- 'sentence_end': sent['end'],
1169
- 'sentence_text': sent['sentence_text'],
1170
- 'gen_text': gen_text})
1171
-
1172
- for init, rev in zip(initials, reviews):
1173
- if self.review_mode == "revision":
1174
- gen_text = rev['gen_text']
1175
- elif self.review_mode == "addition":
1176
- gen_text = init['gen_text'] + '\n' + rev['gen_text']
1177
-
1178
- # add to output
1179
- output.append({'sentence_start': init['sentence_start'],
1180
- 'sentence_end': init['sentence_end'],
1181
- 'sentence_text': init['sentence_text'],
1182
- 'gen_text': gen_text})
1183
- return output
1407
+ class BasicFrameExtractor(DirectFrameExtractor):
1408
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None, **kwrs):
1409
+ """
1410
+ This class diretly prompt LLM for frame extraction.
1411
+ Input system prompt (optional), prompt template (with instruction, few-shot examples),
1412
+ and specify a LLM.
1413
+
1414
+ Parameters:
1415
+ ----------
1416
+ inference_engine : InferenceEngine
1417
+ the LLM inferencing engine object. Must implements the chat() method.
1418
+ prompt_template : str
1419
+ prompt template with "{{<placeholder name>}}" placeholder.
1420
+ system_prompt : str, Optional
1421
+ system prompt.
1422
+ """
1423
+ super().__init__(inference_engine=inference_engine,
1424
+ unit_chunker=WholeDocumentUnitChunker(),
1425
+ prompt_template=prompt_template,
1426
+ system_prompt=system_prompt,
1427
+ context_chunker=NoContextChunker(),
1428
+ **kwrs)
1429
+
1430
+ class BasicReviewFrameExtractor(ReviewFrameExtractor):
1431
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None, **kwrs):
1432
+ """
1433
+ This class add a review step after the BasicFrameExtractor.
1434
+ The Review process asks LLM to review its output and:
1435
+ 1. add more frames while keep current. This is efficient for boosting recall.
1436
+ 2. or, regenerate frames (add new and delete existing).
1437
+ Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
1438
+
1439
+ Parameters:
1440
+ ----------
1441
+ inference_engine : InferenceEngine
1442
+ the LLM inferencing engine object. Must implements the chat() method.
1443
+ prompt_template : str
1444
+ prompt template with "{{<placeholder name>}}" placeholder.
1445
+ review_prompt : str: Optional
1446
+ the prompt text that ask LLM to review. Specify addition or revision in the instruction.
1447
+ if not provided, a default review prompt will be used.
1448
+ review_mode : str
1449
+ review mode. Must be one of {"addition", "revision"}
1450
+ addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
1451
+ system_prompt : str, Optional
1452
+ system prompt.
1453
+ """
1454
+ super().__init__(inference_engine=inference_engine,
1455
+ unit_chunker=WholeDocumentUnitChunker(),
1456
+ prompt_template=prompt_template,
1457
+ review_mode=review_mode,
1458
+ review_prompt=review_prompt,
1459
+ system_prompt=system_prompt,
1460
+ context_chunker=NoContextChunker(),
1461
+ **kwrs)
1184
1462
 
1185
1463
 
1186
- class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1187
- from nltk.tokenize.punkt import PunktSentenceTokenizer
1188
- def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
1464
+ class SentenceFrameExtractor(DirectFrameExtractor):
1465
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None,
1189
1466
  context_sentences:Union[str, int]="all", **kwrs):
1190
1467
  """
1191
- This class performs sentence-based Chain-of-thoughts (CoT) information extraction.
1192
- A simulated chat follows this process:
1468
+ This class performs sentence-by-sentence information extraction.
1469
+ The process is as follows:
1193
1470
  1. system prompt (optional)
1194
- 2. user instructions (schema, background, full text, few-shot example...)
1195
- 3. user input first sentence
1196
- 4. assistant analyze the sentence
1197
- 5. assistant extract outputs
1198
- 6. repeat #3, #4, #5
1471
+ 2. user prompt with instructions (schema, background, full text, few-shot example...)
1472
+ 3. feed a sentence (start with first sentence)
1473
+ 4. LLM extract entities and attributes from the sentence
1474
+ 5. iterate to the next sentence and repeat steps 3-4 until all sentences are processed.
1199
1475
 
1200
1476
  Input system prompt (optional), prompt template (with user instructions),
1201
1477
  and specify a LLM.
1202
1478
 
1203
- Parameters
1479
+ Parameters:
1204
1480
  ----------
1205
1481
  inference_engine : InferenceEngine
1206
1482
  the LLM inferencing engine object. Must implements the chat() method.
@@ -1216,82 +1492,77 @@ class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1216
1492
  if > 0, the number of sentences before and after the given sentence to provide as context.
1217
1493
  This is good for tasks that require context beyond the given sentence.
1218
1494
  """
1219
- super().__init__(inference_engine=inference_engine, prompt_template=prompt_template,
1220
- system_prompt=system_prompt, context_sentences=context_sentences, **kwrs)
1495
+ if not isinstance(context_sentences, int) and context_sentences != "all":
1496
+ raise ValueError('context_sentences must be an integer (>= 0) or "all".')
1497
+
1498
+ if isinstance(context_sentences, int) and context_sentences < 0:
1499
+ raise ValueError("context_sentences must be a positive integer.")
1500
+
1501
+ if isinstance(context_sentences, int):
1502
+ context_chunker = SlideWindowContextChunker(window_size=context_sentences)
1503
+ elif context_sentences == "all":
1504
+ context_chunker = WholeDocumentContextChunker()
1505
+
1506
+ super().__init__(inference_engine=inference_engine,
1507
+ unit_chunker=SentenceUnitChunker(),
1508
+ prompt_template=prompt_template,
1509
+ system_prompt=system_prompt,
1510
+ context_chunker=context_chunker,
1511
+ **kwrs)
1221
1512
 
1222
1513
 
1223
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
1224
- document_key:str=None, temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict[str,str]]:
1514
+ class SentenceReviewFrameExtractor(ReviewFrameExtractor):
1515
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
1516
+ review_mode:str, review_prompt:str=None, system_prompt:str=None,
1517
+ context_sentences:Union[str, int]="all", **kwrs):
1225
1518
  """
1226
- This method inputs a text and outputs a list of outputs per sentence.
1519
+ This class adds a review step after the SentenceFrameExtractor.
1520
+ For each sentence, the review process asks LLM to review its output and:
1521
+ 1. add more frames while keeping current. This is efficient for boosting recall.
1522
+ 2. or, regenerate frames (add new and delete existing).
1523
+ Use the review_mode parameter to specify. Note that the review_prompt should instruct LLM accordingly.
1227
1524
 
1228
1525
  Parameters:
1229
1526
  ----------
1230
- text_content : Union[str, Dict[str,str]]
1231
- the input text content to put in prompt template.
1232
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
1233
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
1234
- max_new_tokens : str, Optional
1235
- the max number of new tokens LLM should generate.
1236
- document_key : str, Optional
1237
- specify the key in text_content where document text is.
1238
- If text_content is str, this parameter will be ignored.
1239
- temperature : float, Optional
1240
- the temperature for token sampling.
1241
- stream : bool, Optional
1242
- if True, LLM generated text will be printed in terminal in real-time.
1243
-
1244
- Return : str
1245
- the output from LLM. Need post-processing.
1527
+ inference_engine : InferenceEngine
1528
+ the LLM inferencing engine object. Must implements the chat() method.
1529
+ prompt_template : str
1530
+ prompt template with "{{<placeholder name>}}" placeholder.
1531
+ review_prompt : str: Optional
1532
+ the prompt text that ask LLM to review. Specify addition or revision in the instruction.
1533
+ if not provided, a default review prompt will be used.
1534
+ review_mode : str
1535
+ review mode. Must be one of {"addition", "revision"}
1536
+ addition mode only ask LLM to add new frames, while revision mode ask LLM to regenerate.
1537
+ system_prompt : str, Optional
1538
+ system prompt.
1539
+ context_sentences : Union[str, int], Optional
1540
+ number of sentences before and after the given sentence to provide additional context.
1541
+ if "all", the full text will be provided in the prompt as context.
1542
+ if 0, no additional context will be provided.
1543
+ This is good for tasks that does not require context beyond the given sentence.
1544
+ if > 0, the number of sentences before and after the given sentence to provide as context.
1545
+ This is good for tasks that require context beyond the given sentence.
1246
1546
  """
1247
- # define output
1248
- output = []
1249
- # sentence tokenization
1250
- if isinstance(text_content, str):
1251
- sentences = self._get_sentences(text_content)
1252
- elif isinstance(text_content, dict):
1253
- sentences = self._get_sentences(text_content[document_key])
1254
-
1255
- # generate sentence by sentence
1256
- for i, sent in enumerate(sentences):
1257
- # construct chat messages
1258
- messages = []
1259
- if self.system_prompt:
1260
- messages.append({'role': 'system', 'content': self.system_prompt})
1261
-
1262
- context = self._get_context_sentences(text_content, i, sentences, document_key)
1263
-
1264
- if self.context_sentences == 0:
1265
- # no context, just place sentence of interest
1266
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1267
- else:
1268
- # insert context
1269
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1270
- # simulate conversation
1271
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1272
- # place sentence of interest
1273
- messages.append({'role': 'user', 'content': sent['sentence_text']})
1274
-
1275
- if stream:
1276
- print(f"\n\n{Fore.GREEN}Sentence: {Style.RESET_ALL}\n{sent['sentence_text']}\n")
1277
- if isinstance(self.context_sentences, int) and self.context_sentences > 0:
1278
- print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
1279
- print(f"{Fore.BLUE}CoT:{Style.RESET_ALL}")
1280
-
1281
- gen_text = self.inference_engine.chat(
1282
- messages=messages,
1283
- max_new_tokens=max_new_tokens,
1284
- temperature=temperature,
1285
- stream=stream,
1286
- **kwrs
1287
- )
1547
+ if not isinstance(context_sentences, int) and context_sentences != "all":
1548
+ raise ValueError('context_sentences must be an integer (>= 0) or "all".')
1549
+
1550
+ if isinstance(context_sentences, int) and context_sentences < 0:
1551
+ raise ValueError("context_sentences must be a positive integer.")
1552
+
1553
+ if isinstance(context_sentences, int):
1554
+ context_chunker = SlideWindowContextChunker(window_size=context_sentences)
1555
+ elif context_sentences == "all":
1556
+ context_chunker = WholeDocumentContextChunker()
1288
1557
 
1289
- # add to output
1290
- output.append({'sentence_start': sent['start'],
1291
- 'sentence_end': sent['end'],
1292
- 'sentence_text': sent['sentence_text'],
1293
- 'gen_text': gen_text})
1294
- return output
1558
+ super().__init__(inference_engine=inference_engine,
1559
+ unit_chunker=SentenceUnitChunker(),
1560
+ prompt_template=prompt_template,
1561
+ review_mode=review_mode,
1562
+ review_prompt=review_prompt,
1563
+ system_prompt=system_prompt,
1564
+ context_chunker=context_chunker,
1565
+ **kwrs)
1295
1566
 
1296
1567
 
1297
1568
  class RelationExtractor(Extractor):
@@ -1361,7 +1632,7 @@ class RelationExtractor(Extractor):
1361
1632
 
1362
1633
  @abc.abstractmethod
1363
1634
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1364
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1635
+ temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1365
1636
  """
1366
1637
  This method considers all combinations of two frames.
1367
1638
 
@@ -1377,6 +1648,8 @@ class RelationExtractor(Extractor):
1377
1648
  the temperature for token sampling.
1378
1649
  stream : bool, Optional
1379
1650
  if True, LLM generated text will be printed in terminal in real-time.
1651
+ return_messages_log : bool, Optional
1652
+ if True, a list of messages will be returned.
1380
1653
 
1381
1654
  Return : List[Dict]
1382
1655
  a list of dict with {"frame_1", "frame_2"} for all relations.
@@ -1446,7 +1719,7 @@ class BinaryRelationExtractor(RelationExtractor):
1446
1719
 
1447
1720
 
1448
1721
  def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1449
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1722
+ temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1450
1723
  """
1451
1724
  This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1452
1725
  Outputs pairs that are related.
@@ -1463,11 +1736,17 @@ class BinaryRelationExtractor(RelationExtractor):
1463
1736
  the temperature for token sampling.
1464
1737
  stream : bool, Optional
1465
1738
  if True, LLM generated text will be printed in terminal in real-time.
1739
+ return_messages_log : bool, Optional
1740
+ if True, a list of messages will be returned.
1466
1741
 
1467
1742
  Return : List[Dict]
1468
1743
  a list of dict with {"frame_1_id", "frame_2_id"}.
1469
1744
  """
1470
1745
  pairs = itertools.combinations(doc.frames, 2)
1746
+
1747
+ if return_messages_log:
1748
+ messages_log = []
1749
+
1471
1750
  output = []
1472
1751
  for frame_1, frame_2 in pairs:
1473
1752
  pos_rel = self.possible_relation_func(frame_1, frame_2)
@@ -1495,13 +1774,19 @@ class BinaryRelationExtractor(RelationExtractor):
1495
1774
  )
1496
1775
  rel_json = self._extract_json(gen_text)
1497
1776
  if self._post_process(rel_json):
1498
- output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id})
1777
+ output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
1499
1778
 
1779
+ if return_messages_log:
1780
+ messages.append({"role": "assistant", "content": gen_text})
1781
+ messages_log.append(messages)
1782
+
1783
+ if return_messages_log:
1784
+ return output, messages_log
1500
1785
  return output
1501
1786
 
1502
1787
 
1503
1788
  async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1504
- temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
1789
+ temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1505
1790
  """
1506
1791
  This is the asynchronous version of the extract() method.
1507
1792
 
@@ -1517,6 +1802,8 @@ class BinaryRelationExtractor(RelationExtractor):
1517
1802
  the temperature for token sampling.
1518
1803
  concurrent_batch_size : int, Optional
1519
1804
  the number of frame pairs to process in concurrent.
1805
+ return_messages_log : bool, Optional
1806
+ if True, a list of messages will be returned.
1520
1807
 
1521
1808
  Return : List[Dict]
1522
1809
  a list of dict with {"frame_1", "frame_2"}.
@@ -1526,12 +1813,17 @@ class BinaryRelationExtractor(RelationExtractor):
1526
1813
  raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1527
1814
 
1528
1815
  pairs = itertools.combinations(doc.frames, 2)
1816
+ if return_messages_log:
1817
+ messages_log = []
1818
+
1529
1819
  n_frames = len(doc.frames)
1530
1820
  num_pairs = (n_frames * (n_frames-1)) // 2
1531
- rel_pair_list = []
1532
- tasks = []
1821
+ output = []
1533
1822
  for i in range(0, num_pairs, concurrent_batch_size):
1823
+ rel_pair_list = []
1824
+ tasks = []
1534
1825
  batch = list(itertools.islice(pairs, concurrent_batch_size))
1826
+ batch_messages = []
1535
1827
  for frame_1, frame_2 in batch:
1536
1828
  pos_rel = self.possible_relation_func(frame_1, frame_2)
1537
1829
 
@@ -1546,6 +1838,7 @@ class BinaryRelationExtractor(RelationExtractor):
1546
1838
  "frame_1": str(frame_1.to_dict()),
1547
1839
  "frame_2": str(frame_2.to_dict())}
1548
1840
  )})
1841
+
1549
1842
  task = asyncio.create_task(
1550
1843
  self.inference_engine.chat_async(
1551
1844
  messages=messages,
@@ -1555,20 +1848,27 @@ class BinaryRelationExtractor(RelationExtractor):
1555
1848
  )
1556
1849
  )
1557
1850
  tasks.append(task)
1851
+ batch_messages.append(messages)
1558
1852
 
1559
1853
  responses = await asyncio.gather(*tasks)
1560
1854
 
1561
- output = []
1562
- for d, response in zip(rel_pair_list, responses):
1563
- rel_json = self._extract_json(response)
1564
- if self._post_process(rel_json):
1565
- output.append(d)
1855
+ for d, response, messages in zip(rel_pair_list, responses, batch_messages):
1856
+ if return_messages_log:
1857
+ messages.append({"role": "assistant", "content": response})
1858
+ messages_log.append(messages)
1859
+
1860
+ rel_json = self._extract_json(response)
1861
+ if self._post_process(rel_json):
1862
+ output.append(d)
1566
1863
 
1864
+ if return_messages_log:
1865
+ return output, messages_log
1567
1866
  return output
1568
1867
 
1569
1868
 
1570
1869
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1571
- temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
1870
+ temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
1871
+ stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1572
1872
  """
1573
1873
  This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1574
1874
 
@@ -1588,6 +1888,8 @@ class BinaryRelationExtractor(RelationExtractor):
1588
1888
  the number of frame pairs to process in concurrent.
1589
1889
  stream : bool, Optional
1590
1890
  if True, LLM generated text will be printed in terminal in real-time.
1891
+ return_messages_log : bool, Optional
1892
+ if True, a list of messages will be returned.
1591
1893
 
1592
1894
  Return : List[Dict]
1593
1895
  a list of dict with {"frame_1", "frame_2"} for all relations.
@@ -1608,6 +1910,7 @@ class BinaryRelationExtractor(RelationExtractor):
1608
1910
  max_new_tokens=max_new_tokens,
1609
1911
  temperature=temperature,
1610
1912
  concurrent_batch_size=concurrent_batch_size,
1913
+ return_messages_log=return_messages_log,
1611
1914
  **kwrs)
1612
1915
  )
1613
1916
  else:
@@ -1616,6 +1919,7 @@ class BinaryRelationExtractor(RelationExtractor):
1616
1919
  max_new_tokens=max_new_tokens,
1617
1920
  temperature=temperature,
1618
1921
  stream=stream,
1922
+ return_messages_log=return_messages_log,
1619
1923
  **kwrs)
1620
1924
 
1621
1925
 
@@ -1689,7 +1993,7 @@ class MultiClassRelationExtractor(RelationExtractor):
1689
1993
 
1690
1994
 
1691
1995
  def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1692
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1996
+ temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1693
1997
  """
1694
1998
  This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1695
1999
 
@@ -1705,11 +2009,17 @@ class MultiClassRelationExtractor(RelationExtractor):
1705
2009
  the temperature for token sampling.
1706
2010
  stream : bool, Optional
1707
2011
  if True, LLM generated text will be printed in terminal in real-time.
2012
+ return_messages_log : bool, Optional
2013
+ if True, a list of messages will be returned.
1708
2014
 
1709
2015
  Return : List[Dict]
1710
- a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
2016
+ a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
1711
2017
  """
1712
2018
  pairs = itertools.combinations(doc.frames, 2)
2019
+
2020
+ if return_messages_log:
2021
+ messages_log = []
2022
+
1713
2023
  output = []
1714
2024
  for frame_1, frame_2 in pairs:
1715
2025
  pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
@@ -1736,16 +2046,23 @@ class MultiClassRelationExtractor(RelationExtractor):
1736
2046
  stream=stream,
1737
2047
  **kwrs
1738
2048
  )
2049
+
2050
+ if return_messages_log:
2051
+ messages.append({"role": "assistant", "content": gen_text})
2052
+ messages_log.append(messages)
2053
+
1739
2054
  rel_json = self._extract_json(gen_text)
1740
2055
  rel = self._post_process(rel_json, pos_rel_types)
1741
2056
  if rel:
1742
- output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'relation':rel})
2057
+ output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id, 'relation':rel})
1743
2058
 
2059
+ if return_messages_log:
2060
+ return output, messages_log
1744
2061
  return output
1745
2062
 
1746
2063
 
1747
2064
  async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1748
- temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
2065
+ temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1749
2066
  """
1750
2067
  This is the asynchronous version of the extract() method.
1751
2068
 
@@ -1761,21 +2078,28 @@ class MultiClassRelationExtractor(RelationExtractor):
1761
2078
  the temperature for token sampling.
1762
2079
  concurrent_batch_size : int, Optional
1763
2080
  the number of frame pairs to process in concurrent.
2081
+ return_messages_log : bool, Optional
2082
+ if True, a list of messages will be returned.
1764
2083
 
1765
2084
  Return : List[Dict]
1766
- a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
2085
+ a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
1767
2086
  """
1768
2087
  # Check if self.inference_engine.chat_async() is implemented
1769
2088
  if not hasattr(self.inference_engine, 'chat_async'):
1770
2089
  raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1771
2090
 
1772
2091
  pairs = itertools.combinations(doc.frames, 2)
2092
+ if return_messages_log:
2093
+ messages_log = []
2094
+
1773
2095
  n_frames = len(doc.frames)
1774
2096
  num_pairs = (n_frames * (n_frames-1)) // 2
1775
- rel_pair_list = []
1776
- tasks = []
2097
+ output = []
1777
2098
  for i in range(0, num_pairs, concurrent_batch_size):
2099
+ rel_pair_list = []
2100
+ tasks = []
1778
2101
  batch = list(itertools.islice(pairs, concurrent_batch_size))
2102
+ batch_messages = []
1779
2103
  for frame_1, frame_2 in batch:
1780
2104
  pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1781
2105
 
@@ -1800,21 +2124,28 @@ class MultiClassRelationExtractor(RelationExtractor):
1800
2124
  )
1801
2125
  )
1802
2126
  tasks.append(task)
2127
+ batch_messages.append(messages)
1803
2128
 
1804
2129
  responses = await asyncio.gather(*tasks)
1805
2130
 
1806
- output = []
1807
- for d, response in zip(rel_pair_list, responses):
1808
- rel_json = self._extract_json(response)
1809
- rel = self._post_process(rel_json, d['pos_rel_types'])
1810
- if rel:
1811
- output.append({'frame_1':d['frame_1'], 'frame_2':d['frame_2'], 'relation':rel})
2131
+ for d, response, messages in zip(rel_pair_list, responses, batch_messages):
2132
+ if return_messages_log:
2133
+ messages.append({"role": "assistant", "content": response})
2134
+ messages_log.append(messages)
2135
+
2136
+ rel_json = self._extract_json(response)
2137
+ rel = self._post_process(rel_json, d['pos_rel_types'])
2138
+ if rel:
2139
+ output.append({'frame_1_id':d['frame_1'], 'frame_2_id':d['frame_2'], 'relation':rel})
1812
2140
 
2141
+ if return_messages_log:
2142
+ return output, messages_log
1813
2143
  return output
1814
2144
 
1815
2145
 
1816
2146
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1817
- temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
2147
+ temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
2148
+ stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1818
2149
  """
1819
2150
  This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1820
2151
 
@@ -1834,6 +2165,8 @@ class MultiClassRelationExtractor(RelationExtractor):
1834
2165
  the number of frame pairs to process in concurrent.
1835
2166
  stream : bool, Optional
1836
2167
  if True, LLM generated text will be printed in terminal in real-time.
2168
+ return_messages_log : bool, Optional
2169
+ if True, a list of messages will be returned.
1837
2170
 
1838
2171
  Return : List[Dict]
1839
2172
  a list of dict with {"frame_1", "frame_2", "relation"} for all relations.
@@ -1854,6 +2187,7 @@ class MultiClassRelationExtractor(RelationExtractor):
1854
2187
  max_new_tokens=max_new_tokens,
1855
2188
  temperature=temperature,
1856
2189
  concurrent_batch_size=concurrent_batch_size,
2190
+ return_messages_log=return_messages_log,
1857
2191
  **kwrs)
1858
2192
  )
1859
2193
  else:
@@ -1862,5 +2196,6 @@ class MultiClassRelationExtractor(RelationExtractor):
1862
2196
  max_new_tokens=max_new_tokens,
1863
2197
  temperature=temperature,
1864
2198
  stream=stream,
2199
+ return_messages_log=return_messages_log,
1865
2200
  **kwrs)
1866
2201