llm-ie 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -1,18 +1,18 @@
1
1
  import abc
2
2
  import re
3
- import json
4
- import json_repair
5
3
  import inspect
6
4
  import importlib.resources
7
5
  import warnings
8
6
  import itertools
9
7
  import asyncio
10
8
  import nest_asyncio
11
- from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional
12
- from llm_ie.data_types import FrameExtractionUnit, FrameExtractionUnitResult, LLMInformationExtractionFrame, LLMInformationExtractionDocument
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
11
+ from llm_ie.utils import extract_json, apply_prompt_template
12
+ from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
13
13
  from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
14
14
  from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
15
- from llm_ie.engines import InferenceEngine
15
+ from llm_ie.engines import InferenceEngine, MessagesLogger
16
16
  from colorama import Fore, Style
17
17
 
18
18
 
@@ -95,79 +95,8 @@ class Extractor:
95
95
  Returns : str
96
96
  a user prompt.
97
97
  """
98
- pattern = re.compile(r'{{(.*?)}}')
99
- if isinstance(text_content, str):
100
- matches = pattern.findall(self.prompt_template)
101
- if len(matches) != 1:
102
- raise ValueError("When text_content is str, the prompt template must has exactly 1 placeholder {{<placeholder name>}}.")
103
- text = re.sub(r'\\', r'\\\\', text_content)
104
- prompt = pattern.sub(text, self.prompt_template)
105
-
106
- elif isinstance(text_content, dict):
107
- # Check if all values are str
108
- if not all([isinstance(v, str) for v in text_content.values()]):
109
- raise ValueError("All values in text_content must be str.")
110
- # Check if all keys are in the prompt template
111
- placeholders = pattern.findall(self.prompt_template)
112
- if len(placeholders) != len(text_content):
113
- raise ValueError(f"Expect text_content ({len(text_content)}) and prompt template placeholder ({len(placeholders)}) to have equal size.")
114
- if not all([k in placeholders for k, _ in text_content.items()]):
115
- raise ValueError(f"All keys in text_content ({text_content.keys()}) must match placeholders in prompt template ({placeholders}).")
116
-
117
- prompt = pattern.sub(lambda match: re.sub(r'\\', r'\\\\', text_content[match.group(1)]), self.prompt_template)
118
-
119
- return prompt
120
-
121
- def _find_dict_strings(self, text: str) -> List[str]:
122
- """
123
- Extracts balanced JSON-like dictionaries from a string, even if nested.
98
+ return apply_prompt_template(self.prompt_template, text_content)
124
99
 
125
- Parameters:
126
- -----------
127
- text : str
128
- the input text containing JSON-like structures.
129
-
130
- Returns : List[str]
131
- A list of valid JSON-like strings representing dictionaries.
132
- """
133
- open_brace = 0
134
- start = -1
135
- json_objects = []
136
-
137
- for i, char in enumerate(text):
138
- if char == '{':
139
- if open_brace == 0:
140
- # start of a new JSON object
141
- start = i
142
- open_brace += 1
143
- elif char == '}':
144
- open_brace -= 1
145
- if open_brace == 0 and start != -1:
146
- json_objects.append(text[start:i + 1])
147
- start = -1
148
-
149
- return json_objects
150
-
151
-
152
- def _extract_json(self, gen_text:str) -> List[Dict[str, str]]:
153
- """
154
- This method inputs a generated text and output a JSON of information tuples
155
- """
156
- out = []
157
- dict_str_list = self._find_dict_strings(gen_text)
158
- for dict_str in dict_str_list:
159
- try:
160
- dict_obj = json.loads(dict_str)
161
- out.append(dict_obj)
162
- except json.JSONDecodeError:
163
- dict_obj = json_repair.repair_json(dict_str, skip_json_loads=True, return_objects=True)
164
- if dict_obj:
165
- warnings.warn(f'JSONDecodeError detected, fixed with repair_json:\n{dict_str}', RuntimeWarning)
166
- out.append(dict_obj)
167
- else:
168
- warnings.warn(f'JSONDecodeError could not be fixed:\n{dict_str}', RuntimeWarning)
169
- return out
170
-
171
100
 
172
101
  class FrameExtractor(Extractor):
173
102
  from nltk.tokenize import RegexpTokenizer
@@ -405,7 +334,7 @@ class DirectFrameExtractor(FrameExtractor):
405
334
 
406
335
 
407
336
  def extract(self, text_content:Union[str, Dict[str,str]],
408
- document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
337
+ document_key:str=None, verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
409
338
  """
410
339
  This method inputs a text and outputs a list of outputs per unit.
411
340
 
@@ -423,11 +352,9 @@ class DirectFrameExtractor(FrameExtractor):
423
352
  return_messages_log : bool, Optional
424
353
  if True, a list of messages will be returned.
425
354
 
426
- Return : List[FrameExtractionUnitResult]
355
+ Return : List[FrameExtractionUnit]
427
356
  the output from LLM for each unit. Contains the start, end, text, and generated text.
428
357
  """
429
- # define output
430
- output = []
431
358
  # unit chunking
432
359
  if isinstance(text_content, str):
433
360
  doc_text = text_content
@@ -440,76 +367,70 @@ class DirectFrameExtractor(FrameExtractor):
440
367
  units = self.unit_chunker.chunk(doc_text)
441
368
  # context chunker init
442
369
  self.context_chunker.fit(doc_text, units)
370
+
443
371
  # messages log
444
- if return_messages_log:
445
- messages_log = []
372
+ messages_logger = MessagesLogger() if return_messages_log else None
446
373
 
447
374
  # generate unit by unit
448
375
  for i, unit in enumerate(units):
449
- # construct chat messages
450
- messages = []
451
- if self.system_prompt:
452
- messages.append({'role': 'system', 'content': self.system_prompt})
453
-
454
- context = self.context_chunker.chunk(unit)
455
-
456
- if context == "":
457
- # no context, just place unit in user prompt
458
- if isinstance(text_content, str):
459
- messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
460
- else:
461
- unit_content = text_content.copy()
462
- unit_content[document_key] = unit.text
463
- messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
464
- else:
465
- # insert context to user prompt
466
- if isinstance(text_content, str):
467
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
468
- else:
469
- context_content = text_content.copy()
470
- context_content[document_key] = context
471
- messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
472
- # simulate conversation where assistant confirms
473
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
474
- # place unit of interest
475
- messages.append({'role': 'user', 'content': unit.text})
376
+ try:
377
+ # construct chat messages
378
+ messages = []
379
+ if self.system_prompt:
380
+ messages.append({'role': 'system', 'content': self.system_prompt})
476
381
 
477
- if verbose:
478
- print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
479
- if context != "":
480
- print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
382
+ context = self.context_chunker.chunk(unit)
481
383
 
482
- print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
384
+ if context == "":
385
+ # no context, just place unit in user prompt
386
+ if isinstance(text_content, str):
387
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
388
+ else:
389
+ unit_content = text_content.copy()
390
+ unit_content[document_key] = unit.text
391
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
392
+ else:
393
+ # insert context to user prompt
394
+ if isinstance(text_content, str):
395
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
396
+ else:
397
+ context_content = text_content.copy()
398
+ context_content[document_key] = context
399
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
400
+ # simulate conversation where assistant confirms
401
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
402
+ # place unit of interest
403
+ messages.append({'role': 'user', 'content': unit.text})
483
404
 
484
-
485
- gen_text = self.inference_engine.chat(
486
- messages=messages,
487
- verbose=verbose,
488
- stream=False
489
- )
405
+ if verbose:
406
+ print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
407
+ if context != "":
408
+ print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
409
+
410
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
490
411
 
491
- if return_messages_log:
492
- message = {"role": "assistant", "content": gen_text["response"]}
493
- if "reasoning" in gen_text:
494
- message["reasoning"] = gen_text["reasoning"]
495
- messages.append(message)
496
- messages_log.append(messages)
497
-
498
- # add to output
499
- result = FrameExtractionUnitResult(
500
- start=unit.start,
501
- end=unit.end,
502
- text=unit.text,
503
- gen_text=gen_text["response"])
504
- output.append(result)
412
+
413
+ gen_text = self.inference_engine.chat(
414
+ messages=messages,
415
+ verbose=verbose,
416
+ stream=False,
417
+ messages_logger=messages_logger
418
+ )
419
+
420
+ # add generated text to unit
421
+ unit.set_generated_text(gen_text["response"])
422
+ unit.set_status("success")
423
+ except Exception as e:
424
+ unit.set_status("fail")
425
+ warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
505
426
 
506
427
  if return_messages_log:
507
- return output, messages_log
428
+ return units, messages_logger.get_messages_log()
508
429
 
509
- return output
430
+ return units
510
431
 
511
432
  def stream(self, text_content: Union[str, Dict[str, str]],
512
- document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnitResult]]:
433
+ document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
513
434
  """
514
435
  Streams LLM responses per unit with structured event types,
515
436
  and returns collected data for post-processing.
@@ -525,12 +446,10 @@ class DirectFrameExtractor(FrameExtractor):
525
446
 
526
447
  Returns:
527
448
  --------
528
- List[FrameExtractionUnitResult]:
529
- A list of FrameExtractionUnitResult objects, each containing the
449
+ List[FrameExtractionUnit]:
450
+ A list of FrameExtractionUnit objects, each containing the
530
451
  original unit details and the fully accumulated 'gen_text' from the LLM.
531
452
  """
532
- collected_results: List[FrameExtractionUnitResult] = []
533
-
534
453
  if isinstance(text_content, str):
535
454
  doc_text = text_content
536
455
  elif isinstance(text_content, dict):
@@ -588,19 +507,14 @@ class DirectFrameExtractor(FrameExtractor):
588
507
  current_gen_text += chunk["data"]
589
508
 
590
509
  # Store the result for this unit
591
- result_for_unit = FrameExtractionUnitResult(
592
- start=unit.start,
593
- end=unit.end,
594
- text=unit.text,
595
- gen_text=current_gen_text
596
- )
597
- collected_results.append(result_for_unit)
510
+ unit.set_generated_text(current_gen_text)
511
+ unit.set_status("success")
598
512
 
599
513
  yield {"type": "info", "data": "All units processed by LLM."}
600
- return collected_results
514
+ return units
601
515
 
602
516
  async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
603
- concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
517
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
604
518
  """
605
519
  This is the asynchronous version of the extract() method.
606
520
 
@@ -618,7 +532,7 @@ class DirectFrameExtractor(FrameExtractor):
618
532
  return_messages_log : bool, Optional
619
533
  if True, a list of messages will be returned.
620
534
 
621
- Return : List[FrameExtractionUnitResult]
535
+ Return : List[FrameExtractionUnit]
622
536
  the output from LLM for each unit. Contains the start, end, text, and generated text.
623
537
  """
624
538
  if isinstance(text_content, str):
@@ -637,6 +551,9 @@ class DirectFrameExtractor(FrameExtractor):
637
551
  # context chunker init
638
552
  self.context_chunker.fit(doc_text, units)
639
553
 
554
+ # messages logger init
555
+ messages_logger = MessagesLogger() if return_messages_log else None
556
+
640
557
  # Prepare inputs for all units first
641
558
  tasks_input = []
642
559
  for i, unit in enumerate(units):
@@ -677,17 +594,15 @@ class DirectFrameExtractor(FrameExtractor):
677
594
  async def semaphore_helper(task_data: Dict, **kwrs):
678
595
  unit = task_data["unit"]
679
596
  messages = task_data["messages"]
680
- original_index = task_data["original_index"]
681
597
 
682
598
  async with semaphore:
683
599
  gen_text = await self.inference_engine.chat_async(
684
- messages=messages
600
+ messages=messages,
601
+ messages_logger=messages_logger
685
602
  )
686
603
 
687
- out = {"original_index": original_index, "unit": unit, "gen_text": gen_text["response"], "messages": messages}
688
- if "reasoning" in gen_text:
689
- out["reasoning"] = gen_text["reasoning"]
690
- return out
604
+ unit.set_generated_text(gen_text["response"])
605
+ unit.set_status("success")
691
606
 
692
607
  # Create and gather tasks
693
608
  tasks = []
@@ -697,40 +612,13 @@ class DirectFrameExtractor(FrameExtractor):
697
612
  ))
698
613
  tasks.append(task)
699
614
 
700
- results_raw = await asyncio.gather(*tasks)
701
-
702
- # Sort results back into original order using the index stored
703
- results_raw.sort(key=lambda x: x["original_index"])
704
-
705
- # Restructure the results
706
- output: List[FrameExtractionUnitResult] = []
707
- messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
708
-
709
- for result_data in results_raw:
710
- unit = result_data["unit"]
711
- gen_text = result_data["gen_text"]
712
-
713
- # Create result object
714
- result = FrameExtractionUnitResult(
715
- start=unit.start,
716
- end=unit.end,
717
- text=unit.text,
718
- gen_text=gen_text
719
- )
720
- output.append(result)
721
-
722
- # Append to messages log if requested
723
- if return_messages_log:
724
- message = {"role": "assistant", "content": gen_text}
725
- if "reasoning" in result_data:
726
- message["reasoning"] = result_data["reasoning"]
727
- final_messages = result_data["messages"] + [message]
728
- messages_log.append(final_messages)
615
+ await asyncio.gather(*tasks)
729
616
 
617
+ # Return units
730
618
  if return_messages_log:
731
- return output, messages_log
619
+ return units, messages_logger.get_messages_log()
732
620
  else:
733
- return output
621
+ return units
734
622
 
735
623
 
736
624
  def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
@@ -738,7 +626,7 @@ class DirectFrameExtractor(FrameExtractor):
738
626
  case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
739
627
  allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
740
628
  """
741
- This method inputs a text and outputs a list of LLMInformationExtractionFrame
629
+ This method inputs a document text and outputs a list of LLMInformationExtractionFrame
742
630
  It use the extract() method and post-process outputs into frames.
743
631
 
744
632
  Parameters:
@@ -791,18 +679,21 @@ class DirectFrameExtractor(FrameExtractor):
791
679
  verbose=verbose,
792
680
  return_messages_log=return_messages_log)
793
681
 
794
- llm_output_results, messages_log = extraction_results if return_messages_log else (extraction_results, None)
682
+ units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
795
683
 
796
684
  frame_list = []
797
- for res in llm_output_results:
685
+ for unit in units:
798
686
  entity_json = []
799
- for entity in self._extract_json(gen_text=res.gen_text):
687
+ if unit.status != "success":
688
+ warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
689
+ continue
690
+ for entity in extract_json(gen_text=unit.gen_text):
800
691
  if ENTITY_KEY in entity:
801
692
  entity_json.append(entity)
802
693
  else:
803
694
  warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
804
695
 
805
- spans = self._find_entity_spans(text=res.text,
696
+ spans = self._find_entity_spans(text=unit.text,
806
697
  entities=[e[ENTITY_KEY] for e in entity_json],
807
698
  case_sensitive=case_sensitive,
808
699
  fuzzy_match=fuzzy_match,
@@ -812,9 +703,9 @@ class DirectFrameExtractor(FrameExtractor):
812
703
  for ent, span in zip(entity_json, spans):
813
704
  if span is not None:
814
705
  start, end = span
815
- entity_text = res.text[start:end]
816
- start += res.start
817
- end += res.start
706
+ entity_text = unit.text[start:end]
707
+ start += unit.start
708
+ end += unit.start
818
709
  attr = {}
819
710
  if "attr" in ent and ent["attr"] is not None:
820
711
  attr = ent["attr"]
@@ -831,6 +722,208 @@ class DirectFrameExtractor(FrameExtractor):
831
722
  return frame_list
832
723
 
833
724
 
725
+ async def extract_frames_from_documents(self, text_contents:List[Union[str,Dict[str, any]]], document_key:str="text",
726
+ cpu_concurrency:int=4, llm_concurrency:int=32, case_sensitive:bool=False,
727
+ fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
728
+ allow_overlap_entities:bool=False, return_messages_log:bool=False) -> AsyncGenerator[Dict[str, any], None]:
729
+ """
730
+ This method inputs a list of documents and yields the results for each document as soon as it is complete.
731
+
732
+ Parameters:
733
+ -----------
734
+ text_contents : List[Union[str,Dict[str, any]]]
735
+ a list of input text contents to put in prompt template.
736
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
737
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
738
+ document_key: str, optional
739
+ The key in the `text_contents` dictionaries that holds the document text.
740
+ cpu_concurrency: int, optional
741
+ The number of parallel threads to use for CPU-bound tasks like chunking.
742
+ llm_concurrency: int, optional
743
+ The number of concurrent requests to make to the LLM.
744
+ case_sensitive : bool, Optional
745
+ if True, entity text matching will be case-sensitive.
746
+ fuzzy_match : bool, Optional
747
+ if True, fuzzy matching will be applied to find entity text.
748
+ fuzzy_buffer_size : float, Optional
749
+ the buffer size for fuzzy matching. Default is 20% of entity text length.
750
+ fuzzy_score_cutoff : float, Optional
751
+ the Jaccard score cutoff for fuzzy matching.
752
+ Matched entity text must have a score higher than this value or a None will be returned.
753
+ allow_overlap_entities : bool, Optional
754
+ if True, entities can overlap in the text.
755
+ return_messages_log : bool, Optional
756
+ if True, a list of messages will be returned.
757
+
758
+ Yields:
759
+ -------
760
+ AsyncGenerator[Dict[str, any], None]
761
+ A dictionary for each completed document, containing its 'idx' and extracted 'frames'.
762
+ """
763
+ # Validate text_contents must be a list of str or dict, and not both
764
+ if not isinstance(text_contents, list):
765
+ raise ValueError("text_contents must be a list of strings or dictionaries.")
766
+ if all(isinstance(doc, str) for doc in text_contents):
767
+ pass
768
+ elif all(isinstance(doc, dict) for doc in text_contents):
769
+ pass
770
+ # Set CPU executor and queues
771
+ cpu_executor = ThreadPoolExecutor(max_workers=cpu_concurrency)
772
+ tasks_queue = asyncio.Queue(maxsize=llm_concurrency * 2)
773
+ # Store to track units and pending counts
774
+ results_store = {
775
+ idx: {'pending': 0, 'units': [], 'text': doc if isinstance(doc, str) else doc.get(document_key, "")}
776
+ for idx, doc in enumerate(text_contents)
777
+ }
778
+
779
+ output_queue = asyncio.Queue()
780
+ messages_logger = MessagesLogger() if return_messages_log else None
781
+
782
+ async def producer():
783
+ try:
784
+ for idx, text_content in enumerate(text_contents):
785
+ text = text_content if isinstance(text_content, str) else text_content.get(document_key, "")
786
+ if not text:
787
+ warnings.warn(f"Document at index {idx} is empty or missing the document key '{document_key}'.")
788
+ # signal that this document is done
789
+ await output_queue.put({'idx': idx, 'frames': []})
790
+ continue
791
+
792
+ units = await self.unit_chunker.chunk_async(text, cpu_executor)
793
+ await self.context_chunker.fit_async(text, units, cpu_executor)
794
+ results_store[idx]['pending'] = len(units)
795
+
796
+ # Handle cases where a document yields no units
797
+ if not units:
798
+ # signal that this document is done
799
+ await output_queue.put({'idx': idx, 'frames': []})
800
+ continue
801
+
802
+ # Iterate through units
803
+ for unit in units:
804
+ context = await self.context_chunker.chunk_async(unit, cpu_executor)
805
+ messages = []
806
+ if self.system_prompt:
807
+ messages.append({'role': 'system', 'content': self.system_prompt})
808
+
809
+ if not context:
810
+ if isinstance(text_content, str):
811
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
812
+ else:
813
+ unit_content = text_content.copy()
814
+ unit_content[document_key] = unit.text
815
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
816
+ else:
817
+ # insert context to user prompt
818
+ if isinstance(text_content, str):
819
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
820
+ else:
821
+ context_content = text_content.copy()
822
+ context_content[document_key] = context
823
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
824
+ # simulate conversation where assistant confirms
825
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
826
+ # place unit of interest
827
+ messages.append({'role': 'user', 'content': unit.text})
828
+
829
+ await tasks_queue.put({'idx': idx, 'unit': unit, 'messages': messages})
830
+ finally:
831
+ for _ in range(llm_concurrency):
832
+ await tasks_queue.put(None)
833
+
834
+ async def worker():
835
+ while True:
836
+ task_item = await tasks_queue.get()
837
+ if task_item is None:
838
+ tasks_queue.task_done()
839
+ break
840
+
841
+ idx = task_item['idx']
842
+ unit = task_item['unit']
843
+ doc_results = results_store[idx]
844
+
845
+ try:
846
+ gen_text = await self.inference_engine.chat_async(
847
+ messages=task_item['messages'], messages_logger=messages_logger
848
+ )
849
+ unit.set_generated_text(gen_text["response"])
850
+ unit.set_status("success")
851
+ doc_results['units'].append(unit)
852
+ except Exception as e:
853
+ warnings.warn(f"Error processing unit for doc idx {idx}: {e}")
854
+ finally:
855
+ doc_results['pending'] -= 1
856
+ if doc_results['pending'] <= 0:
857
+ final_frames = self._post_process_and_create_frames(doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
858
+ output_payload = {'idx': idx, 'frames': final_frames}
859
+ if return_messages_log:
860
+ output_payload['messages_log'] = messages_logger.get_messages_log()
861
+ await output_queue.put(output_payload)
862
+
863
+ tasks_queue.task_done()
864
+
865
+ # Start producer and workers
866
+ producer_task = asyncio.create_task(producer())
867
+ worker_tasks = [asyncio.create_task(worker()) for _ in range(llm_concurrency)]
868
+
869
+ # Main loop to gather results
870
+ docs_completed = 0
871
+ while docs_completed < len(text_contents):
872
+ result = await output_queue.get()
873
+ yield result
874
+ docs_completed += 1
875
+
876
+ # Final cleanup
877
+ await producer_task
878
+ await tasks_queue.join()
879
+
880
+ # Cancel any lingering worker tasks
881
+ for task in worker_tasks:
882
+ task.cancel()
883
+ await asyncio.gather(*worker_tasks, return_exceptions=True)
884
+
885
+ cpu_executor.shutdown(wait=False)
886
+
887
+
888
+ def _post_process_and_create_frames(self, doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
889
+ """Helper function to run post-processing logic for a completed document."""
890
+ ENTITY_KEY = "entity_text"
891
+ frame_list = []
892
+ for res in sorted(doc_results['units'], key=lambda r: r.start):
893
+ entity_json = []
894
+ for entity in extract_json(gen_text=res.gen_text):
895
+ if ENTITY_KEY in entity:
896
+ entity_json.append(entity)
897
+ else:
898
+ warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
899
+
900
+ spans = self._find_entity_spans(
901
+ text=res.text,
902
+ entities=[e[ENTITY_KEY] for e in entity_json],
903
+ case_sensitive=case_sensitive,
904
+ fuzzy_match=fuzzy_match,
905
+ fuzzy_buffer_size=fuzzy_buffer_size,
906
+ fuzzy_score_cutoff=fuzzy_score_cutoff,
907
+ allow_overlap_entities=allow_overlap_entities
908
+ )
909
+ for ent, span in zip(entity_json, spans):
910
+ if span is not None:
911
+ start, end = span
912
+ entity_text = res.text[start:end]
913
+ start += res.start
914
+ end += res.start
915
+ attr = ent.get("attr", {}) or {}
916
+ frame = LLMInformationExtractionFrame(
917
+ frame_id=f"{len(frame_list)}",
918
+ start=start,
919
+ end=end,
920
+ entity_text=entity_text,
921
+ attr=attr
922
+ )
923
+ frame_list.append(frame)
924
+ return frame_list
925
+
926
+
834
927
  class ReviewFrameExtractor(DirectFrameExtractor):
835
928
  def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
836
929
  prompt_template:str, review_mode:str, review_prompt:str=None, system_prompt:str=None):
@@ -902,7 +995,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
902
995
  raise ValueError(f"Cannot find review prompt for {self.__class__.__name__} in the package. Please provide a review_prompt.")
903
996
 
904
997
  def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
905
- verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnitResult]:
998
+ verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
906
999
  """
907
1000
  This method inputs a text and outputs a list of outputs per unit.
908
1001
 
@@ -923,8 +1016,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
923
1016
  Return : List[FrameExtractionUnitResult]
924
1017
  the output from LLM for each unit. Contains the start, end, text, and generated text.
925
1018
  """
926
- # define output
927
- output = []
928
1019
  # unit chunking
929
1020
  if isinstance(text_content, str):
930
1021
  doc_text = text_content
@@ -937,9 +1028,9 @@ class ReviewFrameExtractor(DirectFrameExtractor):
937
1028
  units = self.unit_chunker.chunk(doc_text)
938
1029
  # context chunker init
939
1030
  self.context_chunker.fit(doc_text, units)
940
- # messages log
941
- if return_messages_log:
942
- messages_log = []
1031
+
1032
+ # messages logger init
1033
+ messages_logger = MessagesLogger() if return_messages_log else None
943
1034
 
944
1035
  # generate unit by unit
945
1036
  for i, unit in enumerate(units):
@@ -973,7 +1064,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
973
1064
  messages.append({'role': 'user', 'content': unit.text})
974
1065
 
975
1066
  if verbose:
976
- print(f"\n\n{Fore.GREEN}Unit {i}:{Style.RESET_ALL}\n{unit.text}\n")
1067
+ print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
977
1068
  if context != "":
978
1069
  print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
979
1070
 
@@ -983,7 +1074,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
983
1074
  initial = self.inference_engine.chat(
984
1075
  messages=messages,
985
1076
  verbose=verbose,
986
- stream=False
1077
+ stream=False,
1078
+ messages_logger=messages_logger
987
1079
  )
988
1080
 
989
1081
  # <--- Review step --->
@@ -996,7 +1088,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
996
1088
  review = self.inference_engine.chat(
997
1089
  messages=messages,
998
1090
  verbose=verbose,
999
- stream=False
1091
+ stream=False,
1092
+ messages_logger=messages_logger
1000
1093
  )
1001
1094
 
1002
1095
  # Output
@@ -1005,28 +1098,14 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1005
1098
  elif self.review_mode == "addition":
1006
1099
  gen_text = initial["response"] + '\n' + review["response"]
1007
1100
 
1008
- if return_messages_log:
1009
- if "reasoning" in initial:
1010
- messages[-2]["reasoning"] = initial["reasoning"]
1011
-
1012
- message = {"role": "assistant", "content": review["response"]}
1013
- if "reasoning" in review:
1014
- message["reasoning"] = review["reasoning"]
1015
- messages.append(message)
1016
- messages_log.append(messages)
1017
-
1018
- # add to output
1019
- result = FrameExtractionUnitResult(
1020
- start=unit.start,
1021
- end=unit.end,
1022
- text=unit.text,
1023
- gen_text=gen_text)
1024
- output.append(result)
1101
+ # add generated text to unit
1102
+ unit.set_generated_text(gen_text)
1103
+ unit.set_status("success")
1025
1104
 
1026
1105
  if return_messages_log:
1027
- return output, messages_log
1106
+ return units, messages_logger.get_messages_log()
1028
1107
 
1029
- return output
1108
+ return units
1030
1109
 
1031
1110
 
1032
1111
  def stream(self, text_content:Union[str, Dict[str,str]], document_key:str=None) -> Generator[str, None, None]:
@@ -1122,7 +1201,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1122
1201
  yield chunk
1123
1202
 
1124
1203
  async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
1125
- concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnitResult]:
1204
+ concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnit]:
1126
1205
  """
1127
1206
  This is the asynchronous version of the extract() method with the review step.
1128
1207
 
@@ -1154,11 +1233,15 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1154
1233
  else:
1155
1234
  raise TypeError("text_content must be a string or a dictionary.")
1156
1235
 
1236
+ # unit chunking
1157
1237
  units = self.unit_chunker.chunk(doc_text)
1158
1238
 
1159
1239
  # context chunker init
1160
1240
  self.context_chunker.fit(doc_text, units)
1161
1241
 
1242
+ # messages logger init
1243
+ messages_logger = MessagesLogger() if return_messages_log else None
1244
+
1162
1245
  # <--- Initial generation step --->
1163
1246
  initial_tasks_input = []
1164
1247
  for i, unit in enumerate(units):
@@ -1202,7 +1285,8 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1202
1285
 
1203
1286
  async with semaphore:
1204
1287
  gen_text = await self.inference_engine.chat_async(
1205
- messages=messages
1288
+ messages=messages,
1289
+ messages_logger=messages_logger
1206
1290
  )
1207
1291
  # Return initial generation result along with the messages used and the unit
1208
1292
  out = {"original_index": original_index, "unit": unit, "initial_gen_text": gen_text["response"], "initial_messages": messages}
@@ -1253,16 +1337,11 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1253
1337
 
1254
1338
  async with semaphore:
1255
1339
  review_gen_text = await self.inference_engine.chat_async(
1256
- messages=messages
1340
+ messages=messages,
1341
+ messages_logger=messages_logger
1257
1342
  )
1258
1343
  # Combine initial and review results
1259
1344
  task_data["review_gen_text"] = review_gen_text["response"]
1260
- if return_messages_log:
1261
- # Log for the review call itself
1262
- message = {'role': 'assistant', 'content': review_gen_text["response"]}
1263
- if "reasoning" in review_gen_text:
1264
- message["reasoning"] = review_gen_text["reasoning"]
1265
- task_data["full_review_log"] = task_data["full_initial_log"] + [message]
1266
1345
  return task_data # Return the augmented dictionary
1267
1346
 
1268
1347
  # Create and gather review tasks
@@ -1279,9 +1358,6 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1279
1358
  final_results_raw.sort(key=lambda x: x["original_index"])
1280
1359
 
1281
1360
  # <--- Process final results --->
1282
- output: List[FrameExtractionUnitResult] = []
1283
- messages_log: Optional[List[List[Dict[str, str]]]] = [] if return_messages_log else None
1284
-
1285
1361
  for result_data in final_results_raw:
1286
1362
  unit = result_data["unit"]
1287
1363
  initial_gen = result_data["initial_gen_text"]
@@ -1296,23 +1372,13 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1296
1372
  final_gen_text = review_gen # Default to revision if mode is somehow invalid
1297
1373
 
1298
1374
  # Create final result object
1299
- result = FrameExtractionUnitResult(
1300
- start=unit.start,
1301
- end=unit.end,
1302
- text=unit.text,
1303
- gen_text=final_gen_text # Use the combined/reviewed text
1304
- )
1305
- output.append(result)
1306
-
1307
- # Append full conversation log if requested
1308
- if return_messages_log:
1309
- full_log_for_unit = result_data["full_review_log"]
1310
- messages_log.append(full_log_for_unit)
1375
+ unit.set_generated_text(final_gen_text)
1376
+ unit.set_status("success")
1311
1377
 
1312
1378
  if return_messages_log:
1313
- return output, messages_log
1379
+ return units, messages_logger.get_messages_log()
1314
1380
  else:
1315
- return output
1381
+ return units
1316
1382
 
1317
1383
 
1318
1384
  class BasicFrameExtractor(DirectFrameExtractor):
@@ -1549,6 +1615,9 @@ class AttributeExtractor(Extractor):
1549
1615
  a dictionary of attributes extracted from the frame.
1550
1616
  If return_messages_log is True, a list of messages will be returned as well.
1551
1617
  """
1618
+ # messages logger init
1619
+ messages_logger = MessagesLogger() if return_messages_log else None
1620
+
1552
1621
  # construct chat messages
1553
1622
  messages = []
1554
1623
  if self.system_prompt:
@@ -1567,19 +1636,15 @@ class AttributeExtractor(Extractor):
1567
1636
  gen_text = self.inference_engine.chat(
1568
1637
  messages=messages,
1569
1638
  verbose=verbose,
1570
- stream=False
1639
+ stream=False,
1640
+ messages_logger=messages_logger
1571
1641
  )
1572
- if return_messages_log:
1573
- message = {"role": "assistant", "content": gen_text["response"]}
1574
- if "reasoning" in gen_text:
1575
- message["reasoning"] = gen_text["reasoning"]
1576
- messages.append(message)
1577
1642
 
1578
- attribute_list = self._extract_json(gen_text=gen_text["response"])
1643
+ attribute_list = extract_json(gen_text=gen_text["response"])
1579
1644
  if isinstance(attribute_list, list) and len(attribute_list) > 0:
1580
1645
  attributes = attribute_list[0]
1581
1646
  if return_messages_log:
1582
- return attributes, messages
1647
+ return attributes, messages_logger.get_messages_log()
1583
1648
  return attributes
1584
1649
 
1585
1650
 
@@ -1620,7 +1685,7 @@ class AttributeExtractor(Extractor):
1620
1685
  if return_messages_log:
1621
1686
  attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
1622
1687
  verbose=verbose, return_messages_log=return_messages_log)
1623
- messages_log.append(messages)
1688
+ messages_log.extend(messages)
1624
1689
  else:
1625
1690
  attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
1626
1691
  verbose=verbose, return_messages_log=return_messages_log)
@@ -1669,6 +1734,9 @@ class AttributeExtractor(Extractor):
1669
1734
  if not isinstance(text, str):
1670
1735
  raise TypeError(f"Expect text as str, received {type(text)} instead.")
1671
1736
 
1737
+ # messages logger init
1738
+ messages_logger = MessagesLogger() if return_messages_log else None
1739
+
1672
1740
  # async helper
1673
1741
  semaphore = asyncio.Semaphore(concurrent_batch_size)
1674
1742
 
@@ -1681,15 +1749,8 @@ class AttributeExtractor(Extractor):
1681
1749
  context = self._get_context(frame, text, context_size)
1682
1750
  messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
1683
1751
 
1684
- gen_text = await self.inference_engine.chat_async(messages=messages)
1685
-
1686
- if return_messages_log:
1687
- message = {"role": "assistant", "content": gen_text["response"]}
1688
- if "reasoning" in gen_text:
1689
- message["reasoning"] = gen_text["reasoning"]
1690
- messages.append(message)
1691
-
1692
- attribute_list = self._extract_json(gen_text=gen_text["response"])
1752
+ gen_text = await self.inference_engine.chat_async(messages=messages, messages_logger=messages_logger)
1753
+ attribute_list = extract_json(gen_text=gen_text["response"])
1693
1754
  attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
1694
1755
  return {"frame": frame, "attributes": attributes, "messages": messages}
1695
1756
 
@@ -1699,12 +1760,8 @@ class AttributeExtractor(Extractor):
1699
1760
 
1700
1761
  # process results
1701
1762
  new_frames = []
1702
- messages_log = [] if return_messages_log else None
1703
1763
 
1704
1764
  for result in results:
1705
- if return_messages_log:
1706
- messages_log.append(result["messages"])
1707
-
1708
1765
  if inplace:
1709
1766
  result["frame"].attr.update(result["attributes"])
1710
1767
  else:
@@ -1714,9 +1771,9 @@ class AttributeExtractor(Extractor):
1714
1771
 
1715
1772
  # output
1716
1773
  if inplace:
1717
- return messages_log if return_messages_log else None
1774
+ return messages_logger.get_messages_log() if return_messages_log else None
1718
1775
  else:
1719
- return (new_frames, messages_log) if return_messages_log else new_frames
1776
+ return (new_frames, messages_logger.get_messages_log()) if return_messages_log else new_frames
1720
1777
 
1721
1778
  def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
1722
1779
  concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
@@ -1839,7 +1896,7 @@ class RelationExtractor(Extractor):
1839
1896
  return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
1840
1897
  pairs = itertools.combinations(doc.frames, 2)
1841
1898
  relations = []
1842
- messages_log = [] if return_messages_log else None
1899
+ messages_logger = MessagesLogger() if return_messages_log else None
1843
1900
 
1844
1901
  for frame_1, frame_2 in pairs:
1845
1902
  task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
@@ -1851,20 +1908,14 @@ class RelationExtractor(Extractor):
1851
1908
 
1852
1909
  gen_text = self.inference_engine.chat(
1853
1910
  messages=task_payload['messages'],
1854
- verbose=verbose
1911
+ verbose=verbose,
1912
+ messages_logger=messages_logger
1855
1913
  )
1856
1914
  relation = self._post_process_result(gen_text["response"], task_payload)
1857
1915
  if relation:
1858
1916
  relations.append(relation)
1859
1917
 
1860
- if return_messages_log:
1861
- message = {"role": "assistant", "content": gen_text["response"]}
1862
- if "reasoning" in gen_text:
1863
- message["reasoning"] = gen_text["reasoning"]
1864
- task_payload['messages'].append(message)
1865
- messages_log.append(task_payload['messages'])
1866
-
1867
- return (relations, messages_log) if return_messages_log else relations
1918
+ return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
1868
1919
 
1869
1920
  async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
1870
1921
  pairs = list(itertools.combinations(doc.frames, 2))
@@ -1873,12 +1924,12 @@ class RelationExtractor(Extractor):
1873
1924
  tasks_input = [task for task in tasks_input if task is not None]
1874
1925
 
1875
1926
  relations = []
1876
- messages_log = [] if return_messages_log else None
1927
+ messages_logger = MessagesLogger() if return_messages_log else None
1877
1928
  semaphore = asyncio.Semaphore(concurrent_batch_size)
1878
1929
 
1879
1930
  async def semaphore_helper(task_payload: Dict):
1880
1931
  async with semaphore:
1881
- gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
1932
+ gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'], messages_logger=messages_logger)
1882
1933
  return gen_text, task_payload
1883
1934
 
1884
1935
  tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
@@ -1889,14 +1940,7 @@ class RelationExtractor(Extractor):
1889
1940
  if relation:
1890
1941
  relations.append(relation)
1891
1942
 
1892
- if return_messages_log:
1893
- message = {"role": "assistant", "content": gen_text["response"]}
1894
- if "reasoning" in gen_text:
1895
- message["reasoning"] = gen_text["reasoning"]
1896
- task_payload['messages'].append(message)
1897
- messages_log.append(task_payload['messages'])
1898
-
1899
- return (relations, messages_log) if return_messages_log else relations
1943
+ return (relations, messages_logger.get_messages_log()) if return_messages_log else relations
1900
1944
 
1901
1945
  def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
1902
1946
  if not doc.has_frame():
@@ -1959,7 +2003,7 @@ class BinaryRelationExtractor(RelationExtractor):
1959
2003
  return None
1960
2004
 
1961
2005
  def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
1962
- rel_json = self._extract_json(gen_text)
2006
+ rel_json = extract_json(gen_text)
1963
2007
  if len(rel_json) > 0 and "Relation" in rel_json[0]:
1964
2008
  rel = rel_json[0]["Relation"]
1965
2009
  if (isinstance(rel, bool) and rel) or (isinstance(rel, str) and rel.lower() == 'true'):
@@ -2025,7 +2069,7 @@ class MultiClassRelationExtractor(RelationExtractor):
2025
2069
  return None
2026
2070
 
2027
2071
  def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
2028
- rel_json = self._extract_json(gen_text)
2072
+ rel_json = extract_json(gen_text)
2029
2073
  pos_rel_types = pair_data['pos_rel_types']
2030
2074
  if len(rel_json) > 0 and "RelationType" in rel_json[0]:
2031
2075
  rel_type = rel_json[0]["RelationType"]