llm-ie 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -8,11 +8,12 @@ import warnings
8
8
  import itertools
9
9
  import asyncio
10
10
  import nest_asyncio
11
- from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional
12
- from llm_ie.data_types import FrameExtractionUnit, FrameExtractionUnitResult, LLMInformationExtractionFrame, LLMInformationExtractionDocument
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
13
+ from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
13
14
  from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
14
15
  from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
15
- from llm_ie.engines import InferenceEngine
16
+ from llm_ie.engines import InferenceEngine, MessagesLogger
16
17
  from colorama import Fore, Style
17
18
 
18
19
 
@@ -405,7 +406,7 @@ class DirectFrameExtractor(FrameExtractor):
405
406
 
406
407
 
407
408
  def extract(self, text_content:Union[str, Dict[str,str]],
408
- document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
409
+ document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
409
410
  """
410
411
  This method inputs a text and outputs a list of outputs per unit.
411
412
 
@@ -423,11 +424,9 @@ class DirectFrameExtractor(FrameExtractor):
423
424
  return_messages_log : bool, Optional
424
425
  if True, a list of messages will be returned.
425
426
 
426
- Return : List[FrameExtractionUnitResult]
427
+ Return : List[FrameExtractionUnit]
427
428
  the output from LLM for each unit. Contains the start, end, text, and generated text.
428
429
  """
429
- # define output
430
- output = []
431
430
  # unit chunking
432
431
  if isinstance(text_content, str):
433
432
  doc_text = text_content
@@ -440,76 +439,70 @@ class DirectFrameExtractor(FrameExtractor):
440
439
  units = self.unit_chunker.chunk(doc_text)
441
440
  # context chunker init
442
441
  self.context_chunker.fit(doc_text, units)
442
+
443
443
  # messages log
444
- if return_messages_log:
445
- messages_log = []
444
+ messages_logger = MessagesLogger() if return_messages_log else None
446
445
 
447
446
  # generate unit by unit
448
447
  for i, unit in enumerate(units):
449
- # construct chat messages
450
- messages = []
451
- if self.system_prompt:
452
- messages.append({'role': 'system', 'content': self.system_prompt})
453
-
454
- context = self.context_chunker.chunk(unit)
455
-
456
- if context == "":
457
- # no context, just place unit in user prompt
458
- if isinstance(text_content, str):
459
- messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
460
- else:
461
- unit_content = text_content.copy()
462
- unit_content[document_key] = unit.text
463
- messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
464
- else:
465
- # insert context to user prompt
466
- if isinstance(text_content, str):
467
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
468
- else:
469
- context_content = text_content.copy()
470
- context_content[document_key] = context
471
- messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
472
- # simulate conversation where assistant confirms
473
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
474
- # place unit of interest
475
- messages.append({'role': 'user', 'content': unit.text})
448
+ try:
449
+ # construct chat messages
450
+ messages = []
451
+ if self.system_prompt:
452
+ messages.append({'role': 'system', 'content': self.system_prompt})
476
453
 
477
- if verbose:
478
- print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
479
- if context != "":
480
- print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
454
+ context = self.context_chunker.chunk(unit)
481
455
 
482
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
456
+ if context == "":
457
+ # no context, just place unit in user prompt
458
+ if isinstance(text_content, str):
459
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
460
+ else:
461
+ unit_content = text_content.copy()
462
+ unit_content[document_key] = unit.text
463
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
464
+ else:
465
+ # insert context to user prompt
466
+ if isinstance(text_content, str):
467
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
468
+ else:
469
+ context_content = text_content.copy()
470
+ context_content[document_key] = context
471
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
472
+ # simulate conversation where assistant confirms
473
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
474
+ # place unit of interest
475
+ messages.append({'role': 'user', 'content': unit.text})
483
476
 
484
-
485
- gen_text = self.inference_engine.chat(
486
- messages=messages,
487
- verbose=verbose,
488
- stream=False
489
- )
477
+ if verbose:
478
+ print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
479
+ if context != "":
480
+ print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
481
+
482
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
490
483
 
491
- if return_messages_log:
492
- message = {"role": "assistant", "content": gen_text["response"]}
493
- if "reasoning" in gen_text:
494
- message["reasoning"] = gen_text["reasoning"]
495
- messages.append(message)
496
- messages_log.append(messages)
497
-
498
- # add to output
499
- result = FrameExtractionUnitResult(
500
- start=unit.start,
501
- end=unit.end,
502
- text=unit.text,
503
- gen_text=gen_text["response"])
504
- output.append(result)
484
+
485
+ gen_text = self.inference_engine.chat(
486
+ messages=messages,
487
+ verbose=verbose,
488
+ stream=False,
489
+ messages_logger=messages_logger
490
+ )
491
+
492
+ # add generated text to unit
493
+ unit.set_generated_text(gen_text["response"])
494
+ unit.set_status("success")
495
+ except Exception as e:
496
+ unit.set_status("fail")
497
+ warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
505
498
 
506
499
  if return_messages_log:
507
- return output, messages_log
500
+ return units, messages_logger.get_messages_log()
508
501
 
509
- return output
502
+ return units
510
503
 
511
504
  def stream(self, text_content: Union[str, Dict[str, str]],
512
- document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
505
+ document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
513
506
  """
514
507
  Streams LLM responses per unit with structured event types,
515
508
  and returns collected data for post-processing.
@@ -525,12 +518,10 @@ class DirectFrameExtractor(FrameExtractor):
525
518
 
526
519
  Returns:
527
520
  --------
528
- List[FrameExtractionUnitResult]:
529
- A list of FrameExtractionUnitResult objects, each containing the
521
+ List[FrameExtractionUnit]:
522
+ A list of FrameExtractionUnit objects, each containing the
530
523
  original unit details and the fully accumulated 'gen_text' from the LLM.
531
524
  """
532
- collected_results: List[FrameExtractionUnitResult] = []
533
-
534
525
  if isinstance(text_content, str):
535
526
  doc_text = text_content
536
527
  elif isinstance(text_content, dict):
@@ -588,19 +579,14 @@ class DirectFrameExtractor(FrameExtractor):
588
579
  current_gen_text += chunk["data"]
589
580
 
590
581
  # Store the result for this unit
591
- result_for_unit = FrameExtractionUnitResult(
592
- start=unit.start,
593
- end=unit.end,
594
- text=unit.text,
595
- gen_text=current_gen_text
596
- )
597
- collected_results.append(result_for_unit)
582
+ unit.set_generated_text(current_gen_text)
583
+ unit.set_status("success")
598
584
 
599
585
  yield {"type": "info", "data": "All units processed by LLM."}
600
- return collected_results
586
+ return units
601
587
 
602
588
  async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
603
- concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
589
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
604
590
  """
605
591
  This is the asynchronous version of the extract() method.
606
592
 
@@ -618,7 +604,7 @@ class DirectFrameExtractor(FrameExtractor):
618
604
  return_messages_log : bool, Optional
619
605
  if True, a list of messages will be returned.
620
606
 
621
- Return : List[FrameExtractionUnitResult]
607
+ Return : List[FrameExtractionUnit]
622
608
  the output from LLM for each unit. Contains the start, end, text, and generated text.
623
609
  """
624
610
  if isinstance(text_content, str):
@@ -637,6 +623,9 @@ class DirectFrameExtractor(FrameExtractor):
637
623
  # context chunker init
638
624
  self.context_chunker.fit(doc_text, units)
639
625
 
626
+ # messages logger init
627
+ messages_logger = MessagesLogger() if return_messages_log else None
628
+
640
629
  # Prepare inputs for all units first
641
630
  tasks_input = []
642
631
  for i, unit in enumerate(units):
@@ -677,17 +666,15 @@ class DirectFrameExtractor(FrameExtractor):
677
666
  async def semaphore_helper(task_data: Dict, **kwrs):
678
667
  unit = task_data["unit"]
679
668
  messages = task_data["messages"]
680
- original_index = task_data["original_index"]
681
669
 
682
670
  async with semaphore:
683
671
  gen_text = await self.inference_engine.chat_async(
684
- messages=messages
672
+ messages=messages,
673
+ messages_logger=messages_logger
685
674
  )
686
675
 
687
- out = {"original_index": original_index, "unit": unit, "gen_text": gen_text["response"], "messages": messages}
688
- if "reasoning" in gen_text:
689
- out["reasoning"] = gen_text["reasoning"]
690
- return out
676
+ unit.set_generated_text(gen_text["response"])
677
+ unit.set_status("success")
691
678
 
692
679
  # Create and gather tasks
693
680
  tasks = []
@@ -697,40 +684,13 @@ class DirectFrameExtractor(FrameExtractor):
697
684
  ))
698
685
  tasks.append(task)
699
686
 
700
- results_raw = await asyncio.gather(*tasks)
701
-
702
- # Sort results back into original order using the index stored
703
- results_raw.sort(key=lambda x: x["original_index"])
704
-
705
- # Restructure the results
706
- output: List[FrameExtractionUnitResult] = []
707
- messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
708
-
709
- for result_data in results_raw:
710
- unit = result_data["unit"]
711
- gen_text = result_data["gen_text"]
712
-
713
- # Create result object
714
- result = FrameExtractionUnitResult(
715
- start=unit.start,
716
- end=unit.end,
717
- text=unit.text,
718
- gen_text=gen_text
719
- )
720
- output.append(result)
721
-
722
- # Append to messages log if requested
723
- if return_messages_log:
724
- message = {"role": "assistant", "content": gen_text}
725
- if "reasoning" in result_data:
726
- message["reasoning"] = result_data["reasoning"]
727
- final_messages = result_data["messages"] + [message]
728
- messages_log.append(final_messages)
687
+ await asyncio.gather(*tasks)
729
688
 
689
+ # Return units
730
690
  if return_messages_log:
731
- return output, messages_log
691
+ return units, messages_logger.get_messages_log()
732
692
  else:
733
- return output
693
+ return units
734
694
 
735
695
 
736
696
  def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
@@ -738,7 +698,7 @@ class DirectFrameExtractor(FrameExtractor):
738
698
  case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
739
699
  allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
740
700
  """
741
- This method inputs a text and outputs a list of LLMInformationExtractionFrame
701
+ This method inputs a document text and outputs a list of LLMInformationExtractionFrame
742
702
  It use the extract() method and post-process outputs into frames.
743
703
 
744
704
  Parameters:
@@ -791,18 +751,21 @@ class DirectFrameExtractor(FrameExtractor):
791
751
  verbose=verbose,
792
752
  return_messages_log=return_messages_log)
793
753
 
794
- llm_output_results, messages_log = extraction_results if return_messages_log else (extraction_results, None)
754
+ units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
795
755
 
796
756
  frame_list = []
797
- for res in llm_output_results:
757
+ for unit in units:
798
758
  entity_json = []
799
- for entity in self._extract_json(gen_text=res.gen_text):
759
+ if unit.status != "success":
760
+ warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
761
+ continue
762
+ for entity in self._extract_json(gen_text=unit.gen_text):
800
763
  if ENTITY_KEY in entity:
801
764
  entity_json.append(entity)
802
765
  else:
803
766
  warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
804
767
 
805
- spans = self._find_entity_spans(text=res.text,
768
+ spans = self._find_entity_spans(text=unit.text,
806
769
  entities=[e[ENTITY_KEY] for e in entity_json],
807
770
  case_sensitive=case_sensitive,
808
771
  fuzzy_match=fuzzy_match,
@@ -812,9 +775,9 @@ class DirectFrameExtractor(FrameExtractor):
812
775
  for ent, span in zip(entity_json, spans):
813
776
  if span is not None:
814
777
  start, end = span
815
- entity_text = res.text[start:end]
816
- start += res.start
817
- end += res.start
778
+ entity_text = unit.text[start:end]
779
+ start += unit.start
780
+ end += unit.start
818
781
  attr = {}
819
782
  if "attr" in ent and ent["attr"] is not None:
820
783
  attr = ent["attr"]
@@ -831,6 +794,208 @@ class DirectFrameExtractor(FrameExtractor):
831
794
  return frame_list
832
795
 
833
796
 
797
+ async def extract_frames_from_documents(self, text_contents:List[Union[str,Dict[str, any]]], document_key:str="text",
798
+ cpu_concurrency:int=4, llm_concurrency:int=32, case_sensitive:bool=False,
799
+ fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
800
+ allow_overlap_entities:bool=False, return_messages_log:bool=False) -> AsyncGenerator[Dict[str, any], None]:
801
+ """
802
+ This method inputs a list of documents and yields the results for each document as soon as it is complete.
803
+
804
+ Parameters:
805
+ -----------
806
+ text_contents : List[Union[str,Dict[str, any]]]
807
+ a list of input text contents to put in prompt template.
808
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
809
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
810
+ document_key: str, optional
811
+ The key in the `text_contents` dictionaries that holds the document text.
812
+ cpu_concurrency: int, optional
813
+ The number of parallel threads to use for CPU-bound tasks like chunking.
814
+ llm_concurrency: int, optional
815
+ The number of concurrent requests to make to the LLM.
816
+ case_sensitive : bool, Optional
817
+ if True, entity text matching will be case-sensitive.
818
+ fuzzy_match : bool, Optional
819
+ if True, fuzzy matching will be applied to find entity text.
820
+ fuzzy_buffer_size : float, Optional
821
+ the buffer size for fuzzy matching. Default is 20% of entity text length.
822
+ fuzzy_score_cutoff : float, Optional
823
+ the Jaccard score cutoff for fuzzy matching.
824
+ Matched entity text must have a score higher than this value or a None will be returned.
825
+ allow_overlap_entities : bool, Optional
826
+ if True, entities can overlap in the text.
827
+ return_messages_log : bool, Optional
828
+ if True, a list of messages will be returned.
829
+
830
+ Yields:
831
+ -------
832
+ AsyncGenerator[Dict[str, any], None]
833
+ A dictionary for each completed document, containing its 'idx' and extracted 'frames'.
834
+ """
835
+ # Validate text_contents must be a list of str or dict, and not both
836
+ if not isinstance(text_contents, list):
837
+ raise ValueError("text_contents must be a list of strings or dictionaries.")
838
+ if all(isinstance(doc, str) for doc in text_contents):
839
+ pass
840
+ elif all(isinstance(doc, dict) for doc in text_contents):
841
+ pass
842
+ # Set CPU executor and queues
843
+ cpu_executor = ThreadPoolExecutor(max_workers=cpu_concurrency)
844
+ tasks_queue = asyncio.Queue(maxsize=llm_concurrency * 2)
845
+ # Store to track units and pending counts
846
+ results_store = {
847
+ idx: {'pending': 0, 'units': [], 'text': doc if isinstance(doc, str) else doc.get(document_key, "")}
848
+ for idx, doc in enumerate(text_contents)
849
+ }
850
+
851
+ output_queue = asyncio.Queue()
852
+ messages_logger = MessagesLogger() if return_messages_log else None
853
+
854
+ async def producer():
855
+ try:
856
+ for idx, text_content in enumerate(text_contents):
857
+ text = text_content if isinstance(text_content, str) else text_content.get(document_key, "")
858
+ if not text:
859
+ warnings.warn(f"Document at index {idx} is empty or missing the document key '{document_key}'.")
860
+ # signal that this document is done
861
+ await output_queue.put({'idx': idx, 'frames': []})
862
+ continue
863
+
864
+ units = await self.unit_chunker.chunk_async(text, cpu_executor)
865
+ await self.context_chunker.fit_async(text, units, cpu_executor)
866
+ results_store[idx]['pending'] = len(units)
867
+
868
+ # Handle cases where a document yields no units
869
+ if not units:
870
+ # signal that this document is done
871
+ await output_queue.put({'idx': idx, 'frames': []})
872
+ continue
873
+
874
+ # Iterate through units
875
+ for unit in units:
876
+ context = await self.context_chunker.chunk_async(unit, cpu_executor)
877
+ messages = []
878
+ if self.system_prompt:
879
+ messages.append({'role': 'system', 'content': self.system_prompt})
880
+
881
+ if not context:
882
+ if isinstance(text_content, str):
883
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
884
+ else:
885
+ unit_content = text_content.copy()
886
+ unit_content[document_key] = unit.text
887
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
888
+ else:
889
+ # insert context to user prompt
890
+ if isinstance(text_content, str):
891
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
892
+ else:
893
+ context_content = text_content.copy()
894
+ context_content[document_key] = context
895
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
896
+ # simulate conversation where assistant confirms
897
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
898
+ # place unit of interest
899
+ messages.append({'role': 'user', 'content': unit.text})
900
+
901
+ await tasks_queue.put({'idx': idx, 'unit': unit, 'messages': messages})
902
+ finally:
903
+ for _ in range(llm_concurrency):
904
+ await tasks_queue.put(None)
905
+
906
+ async def worker():
907
+ while True:
908
+ task_item = await tasks_queue.get()
909
+ if task_item is None:
910
+ tasks_queue.task_done()
911
+ break
912
+
913
+ idx = task_item['idx']
914
+ unit = task_item['unit']
915
+ doc_results = results_store[idx]
916
+
917
+ try:
918
+ gen_text = await self.inference_engine.chat_async(
919
+ messages=task_item['messages'], messages_logger=messages_logger
920
+ )
921
+ unit.set_generated_text(gen_text["response"])
922
+ unit.set_status("success")
923
+ doc_results['units'].append(unit)
924
+ except Exception as e:
925
+ warnings.warn(f"Error processing unit for doc idx {idx}: {e}")
926
+ finally:
927
+ doc_results['pending'] -= 1
928
+ if doc_results['pending'] <= 0:
929
+ final_frames = self._post_process_and_create_frames(doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
930
+ output_payload = {'idx': idx, 'frames': final_frames}
931
+ if return_messages_log:
932
+ output_payload['messages_log'] = messages_logger.get_messages_log()
933
+ await output_queue.put(output_payload)
934
+
935
+ tasks_queue.task_done()
936
+
937
+ # Start producer and workers
938
+ producer_task = asyncio.create_task(producer())
939
+ worker_tasks = [asyncio.create_task(worker()) for _ in range(llm_concurrency)]
940
+
941
+ # Main loop to gather results
942
+ docs_completed = 0
943
+ while docs_completed < len(text_contents):
944
+ result = await output_queue.get()
945
+ yield result
946
+ docs_completed += 1
947
+
948
+ # Final cleanup
949
+ await producer_task
950
+ await tasks_queue.join()
951
+
952
+ # Cancel any lingering worker tasks
953
+ for task in worker_tasks:
954
+ task.cancel()
955
+ await asyncio.gather(*worker_tasks, return_exceptions=True)
956
+
957
+ cpu_executor.shutdown(wait=False)
958
+
959
+
960
+ def _post_process_and_create_frames(self, doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
961
+ """Helper function to run post-processing logic for a completed document."""
962
+ ENTITY_KEY = "entity_text"
963
+ frame_list = []
964
+ for res in sorted(doc_results['units'], key=lambda r: r.start):
965
+ entity_json = []
966
+ for entity in self._extract_json(gen_text=res.gen_text):
967
+ if ENTITY_KEY in entity:
968
+ entity_json.append(entity)
969
+ else:
970
+ warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
971
+
972
+ spans = self._find_entity_spans(
973
+ text=res.text,
974
+ entities=[e[ENTITY_KEY] for e in entity_json],
975
+ case_sensitive=case_sensitive,
976
+ fuzzy_match=fuzzy_match,
977
+ fuzzy_buffer_size=fuzzy_buffer_size,
978
+ fuzzy_score_cutoff=fuzzy_score_cutoff,
979
+ allow_overlap_entities=allow_overlap_entities
980
+ )
981
+ for ent, span in zip(entity_json, spans):
982
+ if span is not None:
983
+ start, end = span
984
+ entity_text = res.text[start:end]
985
+ start += res.start
986
+ end += res.start
987
+ attr = ent.get("attr", {}) or {}
988
+ frame = LLMInformationExtractionFrame(
989
+ frame_id=f"{len(frame_list)}",
990
+ start=start,
991
+ end=end,
992
+ entity_text=entity_text,
993
+ attr=attr
994
+ )
995
+ frame_list.append(frame)
996
+ return frame_list
997
+
998
+
834
999
  class ReviewFrameExtractor(DirectFrameExtractor):
835
1000
  def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
836
1001
  prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
@@ -902,7 +1067,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
902
1067
  raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
903
1068
 
904
1069
  def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
905
- verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
1070
+ verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
906
1071
  """
907
1072
  This method inputs a text and outputs a list of outputs per unit.
908
1073
 
@@ -923,8 +1088,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
923
1088
  Return : List[FrameExtractionUnitResult]
924
1089
  the output from LLM for each unit. Contains the start, end, text, and generated text.
925
1090
  """
926
- # define output
927
- output = []
928
1091
  # unit chunking
929
1092
  if isinstance(text_content, str):
930
1093
  doc_text = text_content
@@ -937,9 +1100,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
937
1100
  units = self.unit_chunker.chunk(doc_text)
938
1101
  # context chunker init
939
1102
  self.context_chunker.fit(doc_text, units)
940
- # messages log
941
- if return_messages_log:
942
- messages_log = []
1103
+
1104
+ # messages logger init
1105
+ messages_logger = MessagesLogger() if return_messages_log else None
943
1106
 
944
1107
  # generate unit by unit
945
1108
  for i, unit in enumerate(units):
@@ -973,7 +1136,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
973
1136
  messages.append({'role': 'user', 'content': unit.text})
974
1137
 
975
1138
  if verbose:
976
- print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
1139
+ print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
977
1140
  if context != "":
978
1141
  print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
979
1142
 
@@ -983,7 +1146,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
983
1146
  initial = self.inference_engine.chat(
984
1147
  messages=messages,
985
1148
  verbose=verbose,
986
- stream=False
1149
+ stream=False,
1150
+ messages_logger=messages_logger
987
1151
  )
988
1152
 
989
1153
  # <--- Review step --->
@@ -996,7 +1160,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
996
1160
  review = self.inference_engine.chat(
997
1161
  messages=messages,
998
1162
  verbose=verbose,
999
- stream=False
1163
+ stream=False,
1164
+ messages_logger=messages_logger
1000
1165
  )
1001
1166
 
1002
1167
  # Output
@@ -1005,28 +1170,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1005
1170
  elif self.review_mode == "addition":
1006
1171
  gen_text = initial["response"] + '\n' + review["response"]
1007
1172
 
1008
- if return_messages_log:
1009
- if "reasoning" in initial:
1010
- messages[-2]["reasoning"] = initial["reasoning"]
1011
-
1012
- message = {"role": "assistant", "content": review["response"]}
1013
- if "reasoning" in review:
1014
- message["reasoning"] = review["reasoning"]
1015
- messages.append(message)
1016
- messages_log.append(messages)
1017
-
1018
- # add to output
1019
- result = FrameExtractionUnitResult(
1020
- start=unit.start,
1021
- end=unit.end,
1022
- text=unit.text,
1023
- gen_text=gen_text)
1024
- output.append(result)
1173
+ # add generated text to unit
1174
+ unit.set_generated_text(gen_text)
1175
+ unit.set_status("success")
1025
1176
 
1026
1177
  if return_messages_log:
1027
- return output, messages_log
1178
+ return units, messages_logger.get_messages_log()
1028
1179
 
1029
- return output
1180
+ return units
1030
1181
 
1031
1182
 
1032
1183
  def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
@@ -1122,7 +1273,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1122
1273
  yield chunk
1123
1274
 
1124
1275
  async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
1125
- concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
1276
+ concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnit]:
1126
1277
  """
1127
1278
  This is the asynchronous version of the extract() method with the review step.
1128
1279
 
@@ -1154,11 +1305,15 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1154
1305
  else:
1155
1306
  raise TypeError("text_content must be a string or a dictionary.")
1156
1307
 
1308
+ # unit chunking
1157
1309
  units = self.unit_chunker.chunk(doc_text)
1158
1310
 
1159
1311
  # context chunker init
1160
1312
  self.context_chunker.fit(doc_text, units)
1161
1313
 
1314
+ # messages logger init
1315
+ messages_logger = MessagesLogger() if return_messages_log else None
1316
+
1162
1317
  # <--- Initial generation step --->
1163
1318
  initial_tasks_input = []
1164
1319
  for i, unit in enumerate(units):
@@ -1202,7 +1357,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1202
1357
 
1203
1358
  async with semaphore:
1204
1359
  gen_text = await self.inference_engine.chat_async(
1205
- messages=messages
1360
+ messages=messages,
1361
+ messages_logger=messages_logger
1206
1362
  )
1207
1363
  # Return initial generation result along with the messages used and the unit
1208
1364
  out = {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text["response"], "initial_messages": messages}
@@ -1253,16 +1409,11 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1253
1409
 
1254
1410
  async with semaphore:
1255
1411
  review_gen_text = await self.inference_engine.chat_async(
1256
- messages=messages
1412
+ messages=messages,
1413
+ messages_logger=messages_logger
1257
1414
  )
1258
1415
  # Combine initial and review results
1259
1416
  task_data["review_gen_text"] = review_gen_text["response"]
1260
- if return_messages_log:
1261
- # Log for the review call itself
1262
- message = {'role': 'assistant', 'content': review_gen_text["response"]}
1263
- if "reasoning" in review_gen_text:
1264
- message["reasoning"] = review_gen_text["reasoning"]
1265
- task_data["full_review_log"] = task_data["full_initial_log"] + [message]
1266
1417
  return task_data # Return the augmented dictionary
1267
1418
 
1268
1419
  # Create and gather review tasks
@@ -1279,9 +1430,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1279
1430
  final_results_raw.sort(key=lambda x: x["original_index"])
1280
1431
 
1281
1432
  # <--- Process final results --->
1282
- output: List[FrameExtractionUnitResult] = []
1283
- messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
1284
-
1285
1433
  for result_data in final_results_raw:
1286
1434
  unit = result_data["unit"]
1287
1435
  initial_gen = result_data["initial_gen_text"]
@@ -1296,23 +1444,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1296
1444
  final_gen_text = review_gen # Default to revision if mode is somehow invalid
1297
1445
 
1298
1446
  # Create final result object
1299
- result = FrameExtractionUnitResult(
1300
- start=unit.start,
1301
- end=unit.end,
1302
- text=unit.text,
1303
- gen_text=final_gen_text # Use the combined/reviewed text
1304
- )
1305
- output.append(result)
1306
-
1307
- # Append full conversation log if requested
1308
- if return_messages_log:
1309
- full_log_for_unit = result_data["full_review_log"]
1310
- messages_log.append(full_log_for_unit)
1447
+ unit.set_generated_text(final_gen_text)
1448
+ unit.set_status("success")
1311
1449
 
1312
1450
  if return_messages_log:
1313
- return output, messages_log
1451
+ return units, messages_logger.get_messages_log()
1314
1452
  else:
1315
- return output
1453
+ return units
1316
1454
 
1317
1455
 
1318
1456
  class BasicFrameExtractor(DirectFrameExtractor):
@@ -1549,6 +1687,9 @@ class AttributeExtractor(Extractor):
1549
1687
  a dictionary of attributes extracted from the frame.
1550
1688
  If return_messages_log is True, a list of messages will be returned as well.
1551
1689
  """
1690
+ # messages logger init
1691
+ messages_logger = MessagesLogger() if return_messages_log else None
1692
+
1552
1693
  # construct chat messages
1553
1694
  messages = []
1554
1695
  if self.system_prompt:
@@ -1567,19 +1708,15 @@ class AttributeExtractor(Extractor):
1567
1708
  gen_text = self.inference_engine.chat(
1568
1709
  messages=messages,
1569
1710
  verbose=verbose,
1570
- stream=False
1711
+ stream=False,
1712
+ messages_logger=messages_logger
1571
1713
  )
1572
- if return_messages_log:
1573
- message = {"role": "assistant", "content": gen_text["response"]}
1574
- if "reasoning" in gen_text:
1575
- message["reasoning"] = gen_text["reasoning"]
1576
- messages.append(message)
1577
1714
 
1578
1715
  attribute_list = self._extract_json(gen_text=gen_text["response"])
1579
1716
  if isinstance(attribute_list, list) and len(attribute_list) > 0:
1580
1717
  attributes = attribute_list[0]
1581
1718
  if return_messages_log:
1582
- return attributes, messages
1719
+ return attributes, messages_logger.get_messages_log()
1583
1720
  return attributes
1584
1721
 
1585
1722
 
@@ -1620,7 +1757,7 @@ class AttributeExtractor(Extractor):
1620
1757
  if return_messages_log:
1621
1758
  attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
1622
1759
  verbose=verbose, return_messages_log=return_messages_log)
1623
- messages_log.append(messages)
1760
+ messages_log.extend(messages)
1624
1761
  else:
1625
1762
  attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
1626
1763
  verbose=verbose, return_messages_log=return_messages_log)
@@ -1669,6 +1806,9 @@ class AttributeExtractor(Extractor):
1669
1806
  if not isinstance(text, str):
1670
1807
  raise TypeError(f"Expect text as str, received {type(text)} instead.")
1671
1808
 
1809
+ # messages logger init
1810
+ messages_logger = MessagesLogger() if return_messages_log else None
1811
+
1672
1812
  # async helper
1673
1813
  semaphore = asyncio.Semaphore(concurrent_batch_size)
1674
1814
 
@@ -1681,14 +1821,7 @@ class AttributeExtractor(Extractor):
1681
1821
  context = self._get_context(frame, text, context_size)
1682
1822
  messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
1683
1823
 
1684
- gen_text = await self.inference_engine.chat_async(messages=messages)
1685
-
1686
- if return_messages_log:
1687
- message = {"role": "assistant", "content": gen_text["response"]}
1688
- if "reasoning" in gen_text:
1689
- message["reasoning"] = gen_text["reasoning"]
1690
- messages.append(message)
1691
-
1824
+ gen_text = await self.inference_engine.chat_async(messages=messages, messages_logger=messages_logger)
1692
1825
  attribute_list = self._extract_json(gen_text=gen_text["response"])
1693
1826
  attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
1694
1827
  return {"frame": frame, "attributes": attributes, "messages": messages}
@@ -1699,12 +1832,8 @@ class AttributeExtractor(Extractor):
1699
1832
 
1700
1833
  # process results
1701
1834
  new_frames = []
1702
- messages_log = [] if return_messages_log else None
1703
1835
 
1704
1836
  for result in results:
1705
- if return_messages_log:
1706
- messages_log.append(result["messages"])
1707
-
1708
1837
  if inplace:
1709
1838
  result["frame"].attr.update(result["attributes"])
1710
1839
  else:
@@ -1714,9 +1843,9 @@ class AttributeExtractor(Extractor):
1714
1843
 
1715
1844
  # output
1716
1845
  if inplace:
1717
- return messages_log if return_messages_log else None
1846
+ return messages_logger.get_messages_log() if return_messages_log else None
1718
1847
  else:
1719
- return (new_frames, messages_log) if return_messages_log else new_frames
1848
+ return (new_frames, messages_logger.get_messages_log()) if return_messages_log else new_frames
1720
1849
 
1721
1850
  def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
1722
1851
  concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
@@ -1839,7 +1968,7 @@ class RelationExtractor(Extractor):
1839
1968
  return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
1840
1969
  pairs = itertools.combinations(doc.frames, 2)
1841
1970
  relations = []
1842
- messages_log = [] if return_messages_log else None
1971
+ messages_logger = MessagesLogger() if return_messages_log else None
1843
1972
 
1844
1973
  for frame_1, frame_2 in pairs:
1845
1974
  task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
@@ -1851,20 +1980,14 @@ class RelationExtractor(Extractor):
1851
1980
 
1852
1981
  gen_text = self.inference_engine.chat(
1853
1982
  messages=task_payload['messages'],
1854
- verbose=verbose
1983
+ verbose=verbose,
1984
+ messages_logger=messages_logger
1855
1985
  )
1856
1986
  relation = self._post_process_result(gen_text["response"], task_payload)
1857
1987
  if relation:
1858
1988
  relations.append(relation)
1859
1989
 
1860
- if return_messages_log:
1861
- message = {"role": "assistant", "content": gen_text["response"]}
1862
- if "reasoning" in gen_text:
1863
- message["reasoning"] = gen_text["reasoning"]
1864
- task_payload['messages'].append(message)
1865
- messages_log.append(task_payload['messages'])
1866
-
1867
- return (relations, messages_log) if return_messages_log else relations
1990
+ return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
1868
1991
 
1869
1992
  async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
1870
1993
  pairs = list(itertools.combinations(doc.frames, 2))
@@ -1873,12 +1996,12 @@ class RelationExtractor(Extractor):
1873
1996
  tasks_input = [task for task in tasks_input if task is not None]
1874
1997
 
1875
1998
  relations = []
1876
- messages_log = [] if return_messages_log else None
1999
+ messages_logger = MessagesLogger() if return_messages_log else None
1877
2000
  semaphore = asyncio.Semaphore(concurrent_batch_size)
1878
2001
 
1879
2002
  async def semaphore_helper(task_payload: Dict):
1880
2003
  async with semaphore:
1881
- gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
2004
+ gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'], messages_logger=messages_logger)
1882
2005
  return gen_text, task_payload
1883
2006
 
1884
2007
  tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
@@ -1889,14 +2012,7 @@ class RelationExtractor(Extractor):
1889
2012
  if relation:
1890
2013
  relations.append(relation)
1891
2014
 
1892
- if return_messages_log:
1893
- message = {"role": "assistant", "content": gen_text["response"]}
1894
- if "reasoning" in gen_text:
1895
- message["reasoning"] = gen_text["reasoning"]
1896
- task_payload['messages'].append(message)
1897
- messages_log.append(task_payload['messages'])
1898
-
1899
- return (relations, messages_log) if return_messages_log else relations
2015
+ return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
1900
2016
 
1901
2017
  def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
1902
2018
  if not doc.has_frame():