llm-ie 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -8,11 +8,12 @@ import warnings
8
8
  import itertools
9
9
  import asyncio
10
10
  import nest_asyncio
11
- from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional
12
- from llm_ie.data_types import FrameExtractionUnit, FrameExtractionUnitResult, LLMInformationExtractionFrame, LLMInformationExtractionDocument
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
13
+ from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
13
14
  from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
14
15
  from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
15
- from llm_ie.engines import InferenceEngine
16
+ from llm_ie.engines import InferenceEngine, MessagesLogger
16
17
  from colorama import Fore, Style
17
18
 
18
19
 
@@ -405,7 +406,7 @@ class DirectFrameExtractor(FrameExtractor):
405
406
 
406
407
 
407
408
  def extract(self, text_content:Union[str, Dict[str,str]],
408
- document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
409
+ document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
409
410
  """
410
411
  This method inputs a text and outputs a list of outputs per unit.
411
412
 
@@ -423,11 +424,9 @@ class DirectFrameExtractor(FrameExtractor):
423
424
  return_messages_log : bool, Optional
424
425
  if True, a list of messages will be returned.
425
426
 
426
- Return : List[FrameExtractionUnitResult]
427
+ Return : List[FrameExtractionUnit]
427
428
  the output from LLM for each unit. Contains the start, end, text, and generated text.
428
429
  """
429
- # define output
430
- output = []
431
430
  # unit chunking
432
431
  if isinstance(text_content, str):
433
432
  doc_text = text_content
@@ -440,73 +439,70 @@ class DirectFrameExtractor(FrameExtractor):
440
439
  units = self.unit_chunker.chunk(doc_text)
441
440
  # context chunker init
442
441
  self.context_chunker.fit(doc_text, units)
442
+
443
443
  # messages log
444
- if return_messages_log:
445
- messages_log = []
444
+ messages_logger = MessagesLogger() if return_messages_log else None
446
445
 
447
446
  # generate unit by unit
448
447
  for i, unit in enumerate(units):
449
- # construct chat messages
450
- messages = []
451
- if self.system_prompt:
452
- messages.append({'role': 'system', 'content': self.system_prompt})
453
-
454
- context = self.context_chunker.chunk(unit)
455
-
456
- if context == "":
457
- # no context, just place unit in user prompt
458
- if isinstance(text_content, str):
459
- messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
460
- else:
461
- unit_content = text_content.copy()
462
- unit_content[document_key] = unit.text
463
- messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
464
- else:
465
- # insert context to user prompt
466
- if isinstance(text_content, str):
467
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
468
- else:
469
- context_content = text_content.copy()
470
- context_content[document_key] = context
471
- messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
472
- # simulate conversation where assistant confirms
473
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
474
- # place unit of interest
475
- messages.append({'role': 'user', 'content': unit.text})
448
+ try:
449
+ # construct chat messages
450
+ messages = []
451
+ if self.system_prompt:
452
+ messages.append({'role': 'system', 'content': self.system_prompt})
476
453
 
477
- if verbose:
478
- print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
479
- if context != "":
480
- print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
454
+ context = self.context_chunker.chunk(unit)
481
455
 
482
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
456
+ if context == "":
457
+ # no context, just place unit in user prompt
458
+ if isinstance(text_content, str):
459
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
460
+ else:
461
+ unit_content = text_content.copy()
462
+ unit_content[document_key] = unit.text
463
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
464
+ else:
465
+ # insert context to user prompt
466
+ if isinstance(text_content, str):
467
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
468
+ else:
469
+ context_content = text_content.copy()
470
+ context_content[document_key] = context
471
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
472
+ # simulate conversation where assistant confirms
473
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
474
+ # place unit of interest
475
+ messages.append({'role': 'user', 'content': unit.text})
483
476
 
484
-
485
- gen_text = self.inference_engine.chat(
486
- messages=messages,
487
- verbose=verbose,
488
- stream=False
489
- )
477
+ if verbose:
478
+ print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
479
+ if context != "":
480
+ print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
481
+
482
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
490
483
 
491
- if return_messages_log:
492
- messages.append({"role": "assistant", "content": gen_text})
493
- messages_log.append(messages)
494
-
495
- # add to output
496
- result = FrameExtractionUnitResult(
497
- start=unit.start,
498
- end=unit.end,
499
- text=unit.text,
500
- gen_text=gen_text)
501
- output.append(result)
484
+
485
+ gen_text = self.inference_engine.chat(
486
+ messages=messages,
487
+ verbose=verbose,
488
+ stream=False,
489
+ messages_logger=messages_logger
490
+ )
491
+
492
+ # add generated text to unit
493
+ unit.set_generated_text(gen_text["response"])
494
+ unit.set_status("success")
495
+ except Exception as e:
496
+ unit.set_status("fail")
497
+ warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
502
498
 
503
499
  if return_messages_log:
504
- return output, messages_log
500
+ return units, messages_logger.get_messages_log()
505
501
 
506
- return output
502
+ return units
507
503
 
508
504
  def stream(self, text_content: Union[str, Dict[str, str]],
509
- document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
505
+ document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
510
506
  """
511
507
  Streams LLM responses per unit with structured event types,
512
508
  and returns collected data for post-processing.
@@ -522,12 +518,10 @@ class DirectFrameExtractor(FrameExtractor):
522
518
 
523
519
  Returns:
524
520
  --------
525
- List[FrameExtractionUnitResult]:
526
- A list of FrameExtractionUnitResult objects, each containing the
521
+ List[FrameExtractionUnit]:
522
+ A list of FrameExtractionUnit objects, each containing the
527
523
  original unit details and the fully accumulated 'gen_text' from the LLM.
528
524
  """
529
- collected_results: List[FrameExtractionUnitResult] = []
530
-
531
525
  if isinstance(text_content, str):
532
526
  doc_text = text_content
533
527
  elif isinstance(text_content, dict):
@@ -581,22 +575,18 @@ class DirectFrameExtractor(FrameExtractor):
581
575
  )
582
576
  for chunk in response_stream:
583
577
  yield chunk
584
- current_gen_text += chunk
578
+ if chunk["type"] == "response":
579
+ current_gen_text += chunk["data"]
585
580
 
586
581
  # Store the result for this unit
587
- result_for_unit = FrameExtractionUnitResult(
588
- start=unit.start,
589
- end=unit.end,
590
- text=unit.text,
591
- gen_text=current_gen_text
592
- )
593
- collected_results.append(result_for_unit)
582
+ unit.set_generated_text(current_gen_text)
583
+ unit.set_status("success")
594
584
 
595
585
  yield {"type": "info", "data": "All units processed by LLM."}
596
- return collected_results
586
+ return units
597
587
 
598
588
  async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
599
- concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
589
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
600
590
  """
601
591
  This is the asynchronous version of the extract() method.
602
592
 
@@ -614,7 +604,7 @@ class DirectFrameExtractor(FrameExtractor):
614
604
  return_messages_log : bool, Optional
615
605
  if True, a list of messages will be returned.
616
606
 
617
- Return : List[FrameExtractionUnitResult]
607
+ Return : List[FrameExtractionUnit]
618
608
  the output from LLM for each unit. Contains the start, end, text, and generated text.
619
609
  """
620
610
  if isinstance(text_content, str):
@@ -633,6 +623,9 @@ class DirectFrameExtractor(FrameExtractor):
633
623
  # context chunker init
634
624
  self.context_chunker.fit(doc_text, units)
635
625
 
626
+ # messages logger init
627
+ messages_logger = MessagesLogger() if return_messages_log else None
628
+
636
629
  # Prepare inputs for all units first
637
630
  tasks_input = []
638
631
  for i, unit in enumerate(units):
@@ -673,13 +666,15 @@ class DirectFrameExtractor(FrameExtractor):
673
666
  async def semaphore_helper(task_data: Dict, **kwrs):
674
667
  unit = task_data["unit"]
675
668
  messages = task_data["messages"]
676
- original_index = task_data["original_index"]
677
669
 
678
670
  async with semaphore:
679
671
  gen_text = await self.inference_engine.chat_async(
680
- messages=messages
672
+ messages=messages,
673
+ messages_logger=messages_logger
681
674
  )
682
- return {"original_index": original_index, "unit": unit, "gen_text": gen_text, "messages": messages}
675
+
676
+ unit.set_generated_text(gen_text["response"])
677
+ unit.set_status("success")
683
678
 
684
679
  # Create and gather tasks
685
680
  tasks = []
@@ -689,37 +684,13 @@ class DirectFrameExtractor(FrameExtractor):
689
684
  ))
690
685
  tasks.append(task)
691
686
 
692
- results_raw = await asyncio.gather(*tasks)
693
-
694
- # Sort results back into original order using the index stored
695
- results_raw.sort(key=lambda x: x["original_index"])
696
-
697
- # Restructure the results
698
- output: List[FrameExtractionUnitResult] = []
699
- messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
700
-
701
- for result_data in results_raw:
702
- unit = result_data["unit"]
703
- gen_text = result_data["gen_text"]
704
-
705
- # Create result object
706
- result = FrameExtractionUnitResult(
707
- start=unit.start,
708
- end=unit.end,
709
- text=unit.text,
710
- gen_text=gen_text
711
- )
712
- output.append(result)
713
-
714
- # Append to messages log if requested
715
- if return_messages_log:
716
- final_messages = result_data["messages"] + [{"role": "assistant", "content": gen_text}]
717
- messages_log.append(final_messages)
687
+ await asyncio.gather(*tasks)
718
688
 
689
+ # Return units
719
690
  if return_messages_log:
720
- return output, messages_log
691
+ return units, messages_logger.get_messages_log()
721
692
  else:
722
- return output
693
+ return units
723
694
 
724
695
 
725
696
  def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
@@ -727,7 +698,7 @@ class DirectFrameExtractor(FrameExtractor):
727
698
  case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
728
699
  allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
729
700
  """
730
- This method inputs a text and outputs a list of LLMInformationExtractionFrame
701
+ This method inputs a document text and outputs a list of LLMInformationExtractionFrame
731
702
  It use the extract() method and post-process outputs into frames.
732
703
 
733
704
  Parameters:
@@ -780,18 +751,21 @@ class DirectFrameExtractor(FrameExtractor):
780
751
  verbose=verbose,
781
752
  return_messages_log=return_messages_log)
782
753
 
783
- llm_output_results, messages_log = extraction_results if return_messages_log else (extraction_results, None)
754
+ units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
784
755
 
785
756
  frame_list = []
786
- for res in llm_output_results:
757
+ for unit in units:
787
758
  entity_json = []
788
- for entity in self._extract_json(gen_text=res.gen_text):
759
+ if unit.status != "success":
760
+ warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
761
+ continue
762
+ for entity in self._extract_json(gen_text=unit.gen_text):
789
763
  if ENTITY_KEY in entity:
790
764
  entity_json.append(entity)
791
765
  else:
792
766
  warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
793
767
 
794
- spans = self._find_entity_spans(text=res.text,
768
+ spans = self._find_entity_spans(text=unit.text,
795
769
  entities=[e[ENTITY_KEY] for e in entity_json],
796
770
  case_sensitive=case_sensitive,
797
771
  fuzzy_match=fuzzy_match,
@@ -801,9 +775,9 @@ class DirectFrameExtractor(FrameExtractor):
801
775
  for ent, span in zip(entity_json, spans):
802
776
  if span is not None:
803
777
  start, end = span
804
- entity_text = res.text[start:end]
805
- start += res.start
806
- end += res.start
778
+ entity_text = unit.text[start:end]
779
+ start += unit.start
780
+ end += unit.start
807
781
  attr = {}
808
782
  if "attr" in ent and ent["attr"] is not None:
809
783
  attr = ent["attr"]
@@ -820,6 +794,208 @@ class DirectFrameExtractor(FrameExtractor):
820
794
  return frame_list
821
795
 
822
796
 
797
+ async def extract_frames_from_documents(self, text_contents:List[Union[str,Dict[str, any]]], document_key:str="text",
798
+ cpu_concurrency:int=4, llm_concurrency:int=32, case_sensitive:bool=False,
799
+ fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
800
+ allow_overlap_entities:bool=False, return_messages_log:bool=False) -> AsyncGenerator[Dict[str, any], None]:
801
+ """
802
+ This method inputs a list of documents and yields the results for each document as soon as it is complete.
803
+
804
+ Parameters:
805
+ -----------
806
+ text_contents : List[Union[str,Dict[str, any]]]
807
+ a list of input text contents to put in prompt template.
808
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
809
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
810
+ document_key: str, optional
811
+ The key in the `text_contents` dictionaries that holds the document text.
812
+ cpu_concurrency: int, optional
813
+ The number of parallel threads to use for CPU-bound tasks like chunking.
814
+ llm_concurrency: int, optional
815
+ The number of concurrent requests to make to the LLM.
816
+ case_sensitive : bool, Optional
817
+ if True, entity text matching will be case-sensitive.
818
+ fuzzy_match : bool, Optional
819
+ if True, fuzzy matching will be applied to find entity text.
820
+ fuzzy_buffer_size : float, Optional
821
+ the buffer size for fuzzy matching. Default is 20% of entity text length.
822
+ fuzzy_score_cutoff : float, Optional
823
+ the Jaccard score cutoff for fuzzy matching.
824
+ Matched entity text must have a score higher than this value or a None will be returned.
825
+ allow_overlap_entities : bool, Optional
826
+ if True, entities can overlap in the text.
827
+ return_messages_log : bool, Optional
828
+ if True, a list of messages will be returned.
829
+
830
+ Yields:
831
+ -------
832
+ AsyncGenerator[Dict[str, any], None]
833
+ A dictionary for each completed document, containing its 'idx' and extracted 'frames'.
834
+ """
835
+ # Validate text_contents must be a list of str or dict, and not both
836
+ if not isinstance(text_contents, list):
837
+ raise ValueError("text_contents must be a list of strings or dictionaries.")
838
+ if all(isinstance(doc, str) for doc in text_contents):
839
+ pass
840
+ elif all(isinstance(doc, dict) for doc in text_contents):
841
+ pass
842
+ # Set CPU executor and queues
843
+ cpu_executor = ThreadPoolExecutor(max_workers=cpu_concurrency)
844
+ tasks_queue = asyncio.Queue(maxsize=llm_concurrency * 2)
845
+ # Store to track units and pending counts
846
+ results_store = {
847
+ idx: {'pending': 0, 'units': [], 'text': doc if isinstance(doc, str) else doc.get(document_key, "")}
848
+ for idx, doc in enumerate(text_contents)
849
+ }
850
+
851
+ output_queue = asyncio.Queue()
852
+ messages_logger = MessagesLogger() if return_messages_log else None
853
+
854
+ async def producer():
855
+ try:
856
+ for idx, text_content in enumerate(text_contents):
857
+ text = text_content if isinstance(text_content, str) else text_content.get(document_key, "")
858
+ if not text:
859
+ warnings.warn(f"Document at index {idx} is empty or missing the document key '{document_key}'.")
860
+ # signal that this document is done
861
+ await output_queue.put({'idx': idx, 'frames': []})
862
+ continue
863
+
864
+ units = await self.unit_chunker.chunk_async(text, cpu_executor)
865
+ await self.context_chunker.fit_async(text, units, cpu_executor)
866
+ results_store[idx]['pending'] = len(units)
867
+
868
+ # Handle cases where a document yields no units
869
+ if not units:
870
+ # signal that this document is done
871
+ await output_queue.put({'idx': idx, 'frames': []})
872
+ continue
873
+
874
+ # Iterate through units
875
+ for unit in units:
876
+ context = await self.context_chunker.chunk_async(unit, cpu_executor)
877
+ messages = []
878
+ if self.system_prompt:
879
+ messages.append({'role': 'system', 'content': self.system_prompt})
880
+
881
+ if not context:
882
+ if isinstance(text_content, str):
883
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
884
+ else:
885
+ unit_content = text_content.copy()
886
+ unit_content[document_key] = unit.text
887
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
888
+ else:
889
+ # insert context to user prompt
890
+ if isinstance(text_content, str):
891
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
892
+ else:
893
+ context_content = text_content.copy()
894
+ context_content[document_key] = context
895
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
896
+ # simulate conversation where assistant confirms
897
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
898
+ # place unit of interest
899
+ messages.append({'role': 'user', 'content': unit.text})
900
+
901
+ await tasks_queue.put({'idx': idx, 'unit': unit, 'messages': messages})
902
+ finally:
903
+ for _ in range(llm_concurrency):
904
+ await tasks_queue.put(None)
905
+
906
+ async def worker():
907
+ while True:
908
+ task_item = await tasks_queue.get()
909
+ if task_item is None:
910
+ tasks_queue.task_done()
911
+ break
912
+
913
+ idx = task_item['idx']
914
+ unit = task_item['unit']
915
+ doc_results = results_store[idx]
916
+
917
+ try:
918
+ gen_text = await self.inference_engine.chat_async(
919
+ messages=task_item['messages'], messages_logger=messages_logger
920
+ )
921
+ unit.set_generated_text(gen_text["response"])
922
+ unit.set_status("success")
923
+ doc_results['units'].append(unit)
924
+ except Exception as e:
925
+ warnings.warn(f"Error processing unit for doc idx {idx}: {e}")
926
+ finally:
927
+ doc_results['pending'] -= 1
928
+ if doc_results['pending'] <= 0:
929
+ final_frames = self._post_process_and_create_frames(doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
930
+ output_payload = {'idx': idx, 'frames': final_frames}
931
+ if return_messages_log:
932
+ output_payload['messages_log'] = messages_logger.get_messages_log()
933
+ await output_queue.put(output_payload)
934
+
935
+ tasks_queue.task_done()
936
+
937
+ # Start producer and workers
938
+ producer_task = asyncio.create_task(producer())
939
+ worker_tasks = [asyncio.create_task(worker()) for _ in range(llm_concurrency)]
940
+
941
+ # Main loop to gather results
942
+ docs_completed = 0
943
+ while docs_completed < len(text_contents):
944
+ result = await output_queue.get()
945
+ yield result
946
+ docs_completed += 1
947
+
948
+ # Final cleanup
949
+ await producer_task
950
+ await tasks_queue.join()
951
+
952
+ # Cancel any lingering worker tasks
953
+ for task in worker_tasks:
954
+ task.cancel()
955
+ await asyncio.gather(*worker_tasks, return_exceptions=True)
956
+
957
+ cpu_executor.shutdown(wait=False)
958
+
959
+
960
+ def _post_process_and_create_frames(self, doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
961
+ """Helper function to run post-processing logic for a completed document."""
962
+ ENTITY_KEY = "entity_text"
963
+ frame_list = []
964
+ for res in sorted(doc_results['units'], key=lambda r: r.start):
965
+ entity_json = []
966
+ for entity in self._extract_json(gen_text=res.gen_text):
967
+ if ENTITY_KEY in entity:
968
+ entity_json.append(entity)
969
+ else:
970
+ warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
971
+
972
+ spans = self._find_entity_spans(
973
+ text=res.text,
974
+ entities=[e[ENTITY_KEY] for e in entity_json],
975
+ case_sensitive=case_sensitive,
976
+ fuzzy_match=fuzzy_match,
977
+ fuzzy_buffer_size=fuzzy_buffer_size,
978
+ fuzzy_score_cutoff=fuzzy_score_cutoff,
979
+ allow_overlap_entities=allow_overlap_entities
980
+ )
981
+ for ent, span in zip(entity_json, spans):
982
+ if span is not None:
983
+ start, end = span
984
+ entity_text = res.text[start:end]
985
+ start += res.start
986
+ end += res.start
987
+ attr = ent.get("attr", {}) or {}
988
+ frame = LLMInformationExtractionFrame(
989
+ frame_id=f"{len(frame_list)}",
990
+ start=start,
991
+ end=end,
992
+ entity_text=entity_text,
993
+ attr=attr
994
+ )
995
+ frame_list.append(frame)
996
+ return frame_list
997
+
998
+
823
999
  class ReviewFrameExtractor(DirectFrameExtractor):
824
1000
  def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
825
1001
  prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
@@ -891,7 +1067,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
891
1067
  raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
892
1068
 
893
1069
  def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
894
- verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
1070
+ verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
895
1071
  """
896
1072
  This method inputs a text and outputs a list of outputs per unit.
897
1073
 
@@ -912,8 +1088,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
912
1088
  Return : List[FrameExtractionUnitResult]
913
1089
  the output from LLM for each unit. Contains the start, end, text, and generated text.
914
1090
  """
915
- # define output
916
- output = []
917
1091
  # unit chunking
918
1092
  if isinstance(text_content, str):
919
1093
  doc_text = text_content
@@ -926,9 +1100,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
926
1100
  units = self.unit_chunker.chunk(doc_text)
927
1101
  # context chunker init
928
1102
  self.context_chunker.fit(doc_text, units)
929
- # messages log
930
- if return_messages_log:
931
- messages_log = []
1103
+
1104
+ # messages logger init
1105
+ messages_logger = MessagesLogger() if return_messages_log else None
932
1106
 
933
1107
  # generate unit by unit
934
1108
  for i, unit in enumerate(units):
@@ -962,7 +1136,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
962
1136
  messages.append({'role': 'user', 'content': unit.text})
963
1137
 
964
1138
  if verbose:
965
- print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
1139
+ print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
966
1140
  if context != "":
967
1141
  print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
968
1142
 
@@ -972,48 +1146,38 @@ class ReviewFrameExtractor(DirectFrameExtractor):
972
1146
  initial = self.inference_engine.chat(
973
1147
  messages=messages,
974
1148
  verbose=verbose,
975
- stream=False
1149
+ stream=False,
1150
+ messages_logger=messages_logger
976
1151
  )
977
1152
 
978
- if return_messages_log:
979
- messages.append({"role": "assistant", "content": initial})
980
- messages_log.append(messages)
981
-
982
1153
  # <--- Review step --->
983
1154
  if verbose:
984
1155
  print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
985
1156
 
986
- messages.append({'role': 'assistant', 'content': initial})
1157
+ messages.append({'role': 'assistant', 'content': initial["response"]})
987
1158
  messages.append({'role': 'user', 'content': self.review_prompt})
988
1159
 
989
1160
  review = self.inference_engine.chat(
990
1161
  messages=messages,
991
1162
  verbose=verbose,
992
- stream=False
1163
+ stream=False,
1164
+ messages_logger=messages_logger
993
1165
  )
994
1166
 
995
1167
  # Output
996
1168
  if self.review_mode == "revision":
997
- gen_text = review
1169
+ gen_text = review["response"]
998
1170
  elif self.review_mode == "addition":
999
- gen_text = initial + '\n' + review
1171
+ gen_text = initial["response"] + '\n' + review["response"]
1000
1172
 
1001
- if return_messages_log:
1002
- messages.append({"role": "assistant", "content": review})
1003
- messages_log.append(messages)
1004
-
1005
- # add to output
1006
- result = FrameExtractionUnitResult(
1007
- start=unit.start,
1008
- end=unit.end,
1009
- text=unit.text,
1010
- gen_text=gen_text)
1011
- output.append(result)
1173
+ # add generated text to unit
1174
+ unit.set_generated_text(gen_text)
1175
+ unit.set_status("success")
1012
1176
 
1013
1177
  if return_messages_log:
1014
- return output, messages_log
1178
+ return units, messages_logger.get_messages_log()
1015
1179
 
1016
- return output
1180
+ return units
1017
1181
 
1018
1182
 
1019
1183
  def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
@@ -1109,7 +1273,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1109
1273
  yield chunk
1110
1274
 
1111
1275
  async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
1112
- concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
1276
+ concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnit]:
1113
1277
  """
1114
1278
  This is the asynchronous version of the extract() method with the review step.
1115
1279
 
@@ -1141,11 +1305,15 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1141
1305
  else:
1142
1306
  raise TypeError("text_content must be a string or a dictionary.")
1143
1307
 
1308
+ # unit chunking
1144
1309
  units = self.unit_chunker.chunk(doc_text)
1145
1310
 
1146
1311
  # context chunker init
1147
1312
  self.context_chunker.fit(doc_text, units)
1148
1313
 
1314
+ # messages logger init
1315
+ messages_logger = MessagesLogger() if return_messages_log else None
1316
+
1149
1317
  # <--- Initial generation step --->
1150
1318
  initial_tasks_input = []
1151
1319
  for i, unit in enumerate(units):
@@ -1189,10 +1357,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1189
1357
 
1190
1358
  async with semaphore:
1191
1359
  gen_text = await self.inference_engine.chat_async(
1192
- messages=messages
1360
+ messages=messages,
1361
+ messages_logger=messages_logger
1193
1362
  )
1194
1363
  # Return initial generation result along with the messages used and the unit
1195
- return {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text, "initial_messages": messages}
1364
+ out = {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text["response"], "initial_messages": messages}
1365
+ if "reasoning" in gen_text:
1366
+ out["reasoning"] = gen_text["reasoning"]
1367
+ return out
1196
1368
 
1197
1369
  # Create and gather initial generation tasks
1198
1370
  initial_tasks = [
@@ -1218,28 +1390,30 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1218
1390
  {'role': 'user', 'content': self.review_prompt}
1219
1391
  ]
1220
1392
  # Store data needed for review task
1393
+ if "reasoning" in result_data:
1394
+ message = {'role': 'assistant', 'content': initial_gen_text, "reasoning": result_data["reasoning"]}
1395
+ else:
1396
+ message = {'role': 'assistant', 'content': initial_gen_text}
1397
+
1221
1398
  review_tasks_input.append({
1222
1399
  "unit": result_data["unit"],
1223
1400
  "initial_gen_text": initial_gen_text,
1224
1401
  "messages": review_messages,
1225
1402
  "original_index": result_data["original_index"],
1226
- "full_initial_log": initial_messages + [{'role': 'assistant', 'content': initial_gen_text}] if return_messages_log else None # Log up to initial generation
1403
+ "full_initial_log": initial_messages + [message] + [{'role': 'user', 'content': self.review_prompt}] if return_messages_log else None
1227
1404
  })
1228
1405
 
1229
1406
 
1230
1407
  async def review_semaphore_helper(task_data: Dict, **kwrs):
1231
1408
  messages = task_data["messages"]
1232
- original_index = task_data["original_index"]
1233
1409
 
1234
1410
  async with semaphore:
1235
1411
  review_gen_text = await self.inference_engine.chat_async(
1236
- messages=messages
1412
+ messages=messages,
1413
+ messages_logger=messages_logger
1237
1414
  )
1238
1415
  # Combine initial and review results
1239
- task_data["review_gen_text"] = review_gen_text
1240
- if return_messages_log:
1241
- # Log for the review call itself
1242
- task_data["full_review_log"] = messages + [{'role': 'assistant', 'content': review_gen_text}]
1416
+ task_data["review_gen_text"] = review_gen_text["response"]
1243
1417
  return task_data # Return the augmented dictionary
1244
1418
 
1245
1419
  # Create and gather review tasks
@@ -1256,9 +1430,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1256
1430
  final_results_raw.sort(key=lambda x: x["original_index"])
1257
1431
 
1258
1432
  # <--- Process final results --->
1259
- output: List[FrameExtractionUnitResult] = []
1260
- messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
1261
-
1262
1433
  for result_data in final_results_raw:
1263
1434
  unit = result_data["unit"]
1264
1435
  initial_gen = result_data["initial_gen_text"]
@@ -1273,23 +1444,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1273
1444
  final_gen_text = review_gen # Default to revision if mode is somehow invalid
1274
1445
 
1275
1446
  # Create final result object
1276
- result = FrameExtractionUnitResult(
1277
- start=unit.start,
1278
- end=unit.end,
1279
- text=unit.text,
1280
- gen_text=final_gen_text # Use the combined/reviewed text
1281
- )
1282
- output.append(result)
1283
-
1284
- # Append full conversation log if requested
1285
- if return_messages_log:
1286
- full_log_for_unit = result_data.get("full_initial_log", []) + [{'role': 'user', 'content': self.review_prompt}] + [{'role': 'assistant', 'content': review_gen}]
1287
- messages_log.append(full_log_for_unit)
1447
+ unit.set_generated_text(final_gen_text)
1448
+ unit.set_status("success")
1288
1449
 
1289
1450
  if return_messages_log:
1290
- return output, messages_log
1451
+ return units, messages_logger.get_messages_log()
1291
1452
  else:
1292
- return output
1453
+ return units
1293
1454
 
1294
1455
 
1295
1456
  class BasicFrameExtractor(DirectFrameExtractor):
@@ -1526,6 +1687,9 @@ class AttributeExtractor(Extractor):
1526
1687
  a dictionary of attributes extracted from the frame.
1527
1688
  If return_messages_log is True, a list of messages will be returned as well.
1528
1689
  """
1690
+ # messages logger init
1691
+ messages_logger = MessagesLogger() if return_messages_log else None
1692
+
1529
1693
  # construct chat messages
1530
1694
  messages = []
1531
1695
  if self.system_prompt:
@@ -1541,19 +1705,18 @@ class AttributeExtractor(Extractor):
1541
1705
 
1542
1706
  print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
1543
1707
 
1544
- get_text = self.inference_engine.chat(
1708
+ gen_text = self.inference_engine.chat(
1545
1709
  messages=messages,
1546
1710
  verbose=verbose,
1547
- stream=False
1711
+ stream=False,
1712
+ messages_logger=messages_logger
1548
1713
  )
1549
- if return_messages_log:
1550
- messages.append({"role": "assistant", "content": get_text})
1551
1714
 
1552
- attribute_list = self._extract_json(gen_text=get_text)
1715
+ attribute_list = self._extract_json(gen_text=gen_text["response"])
1553
1716
  if isinstance(attribute_list, list) and len(attribute_list) > 0:
1554
1717
  attributes = attribute_list[0]
1555
1718
  if return_messages_log:
1556
- return attributes, messages
1719
+ return attributes, messages_logger.get_messages_log()
1557
1720
  return attributes
1558
1721
 
1559
1722
 
@@ -1594,7 +1757,7 @@ class AttributeExtractor(Extractor):
1594
1757
  if return_messages_log:
1595
1758
  attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
1596
1759
  verbose=verbose, return_messages_log=return_messages_log)
1597
- messages_log.append(messages)
1760
+ messages_log.extend(messages)
1598
1761
  else:
1599
1762
  attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
1600
1763
  verbose=verbose, return_messages_log=return_messages_log)
@@ -1643,6 +1806,9 @@ class AttributeExtractor(Extractor):
1643
1806
  if not isinstance(text, str):
1644
1807
  raise TypeError(f"Expect text as str, received {type(text)} instead.")
1645
1808
 
1809
+ # messages logger init
1810
+ messages_logger = MessagesLogger() if return_messages_log else None
1811
+
1646
1812
  # async helper
1647
1813
  semaphore = asyncio.Semaphore(concurrent_batch_size)
1648
1814
 
@@ -1655,12 +1821,8 @@ class AttributeExtractor(Extractor):
1655
1821
  context = self._get_context(frame, text, context_size)
1656
1822
  messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
1657
1823
 
1658
- gen_text = await self.inference_engine.chat_async(messages=messages)
1659
-
1660
- if return_messages_log:
1661
- messages.append({"role": "assistant", "content": gen_text})
1662
-
1663
- attribute_list = self._extract_json(gen_text=gen_text)
1824
+ gen_text = await self.inference_engine.chat_async(messages=messages, messages_logger=messages_logger)
1825
+ attribute_list = self._extract_json(gen_text=gen_text["response"])
1664
1826
  attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
1665
1827
  return {"frame": frame, "attributes": attributes, "messages": messages}
1666
1828
 
@@ -1670,12 +1832,8 @@ class AttributeExtractor(Extractor):
1670
1832
 
1671
1833
  # process results
1672
1834
  new_frames = []
1673
- messages_log = [] if return_messages_log else None
1674
1835
 
1675
1836
  for result in results:
1676
- if return_messages_log:
1677
- messages_log.append(result["messages"])
1678
-
1679
1837
  if inplace:
1680
1838
  result["frame"].attr.update(result["attributes"])
1681
1839
  else:
@@ -1685,9 +1843,9 @@ class AttributeExtractor(Extractor):
1685
1843
 
1686
1844
  # output
1687
1845
  if inplace:
1688
- return messages_log if return_messages_log else None
1846
+ return messages_logger.get_messages_log() if return_messages_log else None
1689
1847
  else:
1690
- return (new_frames, messages_log) if return_messages_log else new_frames
1848
+ return (new_frames, messages_logger.get_messages_log()) if return_messages_log else new_frames
1691
1849
 
1692
1850
  def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
1693
1851
  concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
@@ -1810,7 +1968,7 @@ class RelationExtractor(Extractor):
1810
1968
  return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
1811
1969
  pairs = itertools.combinations(doc.frames, 2)
1812
1970
  relations = []
1813
- messages_log = [] if return_messages_log else None
1971
+ messages_logger = MessagesLogger() if return_messages_log else None
1814
1972
 
1815
1973
  for frame_1, frame_2 in pairs:
1816
1974
  task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
@@ -1822,17 +1980,14 @@ class RelationExtractor(Extractor):
1822
1980
 
1823
1981
  gen_text = self.inference_engine.chat(
1824
1982
  messages=task_payload['messages'],
1825
- verbose=verbose
1983
+ verbose=verbose,
1984
+ messages_logger=messages_logger
1826
1985
  )
1827
- relation = self._post_process_result(gen_text, task_payload)
1986
+ relation = self._post_process_result(gen_text["response"], task_payload)
1828
1987
  if relation:
1829
1988
  relations.append(relation)
1830
1989
 
1831
- if return_messages_log:
1832
- task_payload['messages'].append({"role": "assistant", "content": gen_text})
1833
- messages_log.append(task_payload['messages'])
1834
-
1835
- return (relations, messages_log) if return_messages_log else relations
1990
+ return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
1836
1991
 
1837
1992
  async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
1838
1993
  pairs = list(itertools.combinations(doc.frames, 2))
@@ -1841,27 +1996,23 @@ class RelationExtractor(Extractor):
1841
1996
  tasks_input = [task for task in tasks_input if task is not None]
1842
1997
 
1843
1998
  relations = []
1844
- messages_log = [] if return_messages_log else None
1999
+ messages_logger = MessagesLogger() if return_messages_log else None
1845
2000
  semaphore = asyncio.Semaphore(concurrent_batch_size)
1846
2001
 
1847
2002
  async def semaphore_helper(task_payload: Dict):
1848
2003
  async with semaphore:
1849
- gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
2004
+ gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'], messages_logger=messages_logger)
1850
2005
  return gen_text, task_payload
1851
2006
 
1852
2007
  tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
1853
2008
  results = await asyncio.gather(*tasks)
1854
2009
 
1855
2010
  for gen_text, task_payload in results:
1856
- relation = self._post_process_result(gen_text, task_payload)
2011
+ relation = self._post_process_result(gen_text["response"], task_payload)
1857
2012
  if relation:
1858
2013
  relations.append(relation)
1859
2014
 
1860
- if return_messages_log:
1861
- task_payload['messages'].append({"role": "assistant", "content": gen_text})
1862
- messages_log.append(task_payload['messages'])
1863
-
1864
- return (relations, messages_log) if return_messages_log else relations
2015
+ return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
1865
2016
 
1866
2017
  def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
1867
2018
  if not doc.has_frame():