llm-ie 1.2.4__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -6,8 +6,7 @@ import warnings
6
6
  import itertools
7
7
  import asyncio
8
8
  import nest_asyncio
9
- from concurrent.futures import ThreadPoolExecutor
10
- from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
9
+ from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional
11
10
  from llm_ie.utils import extract_json, apply_prompt_template
12
11
  from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
13
12
  from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
@@ -98,6 +97,451 @@ class Extractor:
98
97
  return apply_prompt_template(self.prompt_template, text_content)
99
98
 
100
99
 
100
+ class StructExtractor(Extractor):
101
+ def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker, prompt_template:str,
102
+ system_prompt:str=None, context_chunker:ContextChunker=None, aggregation_func:Callable=None):
103
+ """
104
+ This class is for unanchored structured information extraction.
105
+ Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
106
+
107
+ Parameters:
108
+ ----------
109
+ inference_engine : InferenceEngine
110
+ the LLM inferencing engine object. Must implements the chat() method.
111
+ unit_chunker : UnitChunker
112
+ the unit chunker object that determines how to chunk the document text into units.
113
+ prompt_template : str
114
+ prompt template with "{{<placeholder name>}}" placeholder.
115
+ system_prompt : str, Optional
116
+ system prompt.
117
+ context_chunker : ContextChunker
118
+ the context chunker object that determines how to get context for each unit.
119
+ aggregation_func : Callable
120
+ a function that inputs a list of structured information (dict)
121
+ and outputs an aggregated structured information (dict).
122
+ if not specified, the default is to merge all dicts by updating keys and overwriting values sequentially.
123
+ """
124
+ super().__init__(inference_engine=inference_engine,
125
+ prompt_template=prompt_template,
126
+ system_prompt=system_prompt)
127
+
128
+ self.unit_chunker = unit_chunker
129
+ self.context_chunker = context_chunker
130
+ self.aggregation_func = aggregation_func
131
+
132
+
133
+ def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
134
+ verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
135
+ """
136
+ This method inputs text content and outputs a string generated by LLM
137
+
138
+ Parameters:
139
+ ----------
140
+ text_content : Union[str, Dict[str,str]]
141
+ the input text content to put in prompt template.
142
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
143
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
144
+ return_messages_log : bool, Optional
145
+ if True, a list of messages will be returned.
146
+
147
+ Return : List[FrameExtractionUnit]
148
+ the output from LLM. Need post-processing.
149
+ """
150
+ # unit chunking
151
+ if isinstance(text_content, str):
152
+ doc_text = text_content
153
+
154
+ elif isinstance(text_content, dict):
155
+ if document_key is None:
156
+ raise ValueError("document_key must be provided when text_content is dict.")
157
+ doc_text = text_content[document_key]
158
+
159
+ units = self.unit_chunker.chunk(doc_text)
160
+ # context chunker init
161
+ self.context_chunker.fit(doc_text, units)
162
+
163
+ # messages log
164
+ messages_logger = MessagesLogger() if return_messages_log else None
165
+
166
+ # generate unit by unit
167
+ for i, unit in enumerate(units):
168
+ try:
169
+ # construct chat messages
170
+ messages = []
171
+ if self.system_prompt:
172
+ messages.append({'role': 'system', 'content': self.system_prompt})
173
+
174
+ context = self.context_chunker.chunk(unit)
175
+
176
+ if context == "":
177
+ # no context, just place unit in user prompt
178
+ if isinstance(text_content, str):
179
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
180
+ else:
181
+ unit_content = text_content.copy()
182
+ unit_content[document_key] = unit.text
183
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
184
+ else:
185
+ # insert context to user prompt
186
+ if isinstance(text_content, str):
187
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
188
+ else:
189
+ context_content = text_content.copy()
190
+ context_content[document_key] = context
191
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
192
+ # simulate conversation where assistant confirms
193
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
194
+ # place unit of interest
195
+ messages.append({'role': 'user', 'content': unit.text})
196
+
197
+ if verbose:
198
+ print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
199
+ if context != "":
200
+ print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
201
+
202
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
203
+
204
+
205
+ gen_text = self.inference_engine.chat(
206
+ messages=messages,
207
+ verbose=verbose,
208
+ stream=False,
209
+ messages_logger=messages_logger
210
+ )
211
+
212
+ # add generated text to unit
213
+ unit.set_generated_text(gen_text["response"])
214
+ unit.set_status("success")
215
+ except Exception as e:
216
+ unit.set_status("fail")
217
+ warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
218
+
219
+ if return_messages_log:
220
+ return units, messages_logger.get_messages_log()
221
+
222
+ return units
223
+
224
+ def stream(self, text_content: Union[str, Dict[str, str]],
225
+ document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
226
+ """
227
+ Streams LLM responses per unit with structured event types,
228
+ and returns collected data for post-processing.
229
+
230
+ Yields:
231
+ -------
232
+ Dict[str, Any]: (type, data)
233
+ - {"type": "info", "data": str_message}: General informational messages.
234
+ - {"type": "unit", "data": dict_unit_info}: Signals start of a new unit. dict_unit_info contains {'id', 'text', 'start', 'end'}
235
+ - {"type": "context", "data": str_context}: Context string for the current unit.
236
+ - {"type": "reasoning", "data": str_chunk}: A reasoning model thinking chunk from the LLM.
237
+ - {"type": "response", "data": str_chunk}: A response/answer chunk from the LLM.
238
+
239
+ Returns:
240
+ --------
241
+ List[FrameExtractionUnit]:
242
+ A list of FrameExtractionUnit objects, each containing the
243
+ original unit details and the fully accumulated 'gen_text' from the LLM.
244
+ """
245
+ if isinstance(text_content, str):
246
+ doc_text = text_content
247
+ elif isinstance(text_content, dict):
248
+ if document_key is None:
249
+ raise ValueError("document_key must be provided when text_content is dict.")
250
+ if document_key not in text_content:
251
+ raise ValueError(f"document_key '{document_key}' not found in text_content.")
252
+ doc_text = text_content[document_key]
253
+ else:
254
+ raise TypeError("text_content must be a string or a dictionary.")
255
+
256
+ units: List[FrameExtractionUnit] = self.unit_chunker.chunk(doc_text)
257
+ self.context_chunker.fit(doc_text, units)
258
+
259
+ yield {"type": "info", "data": f"Starting LLM processing for {len(units)} units."}
260
+
261
+ for i, unit in enumerate(units):
262
+ unit_info_payload = {"id": i, "text": unit.text, "start": unit.start, "end": unit.end}
263
+ yield {"type": "unit", "data": unit_info_payload}
264
+
265
+ messages = []
266
+ if self.system_prompt:
267
+ messages.append({'role': 'system', 'content': self.system_prompt})
268
+
269
+ context_str = self.context_chunker.chunk(unit)
270
+
271
+ # Construct prompt input based on whether text_content was str or dict
272
+ if context_str:
273
+ yield {"type": "context", "data": context_str}
274
+ prompt_input_for_context = context_str
275
+ if isinstance(text_content, dict):
276
+ context_content_dict = text_content.copy()
277
+ context_content_dict[document_key] = context_str
278
+ prompt_input_for_context = context_content_dict
279
+ messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_context)})
280
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
281
+ messages.append({'role': 'user', 'content': unit.text})
282
+ else: # No context
283
+ prompt_input_for_unit = unit.text
284
+ if isinstance(text_content, dict):
285
+ unit_content_dict = text_content.copy()
286
+ unit_content_dict[document_key] = unit.text
287
+ prompt_input_for_unit = unit_content_dict
288
+ messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_unit)})
289
+
290
+ current_gen_text = ""
291
+
292
+ response_stream = self.inference_engine.chat(
293
+ messages=messages,
294
+ stream=True
295
+ )
296
+ for chunk in response_stream:
297
+ yield chunk
298
+ if chunk["type"] == "response":
299
+ current_gen_text += chunk["data"]
300
+
301
+ # Store the result for this unit
302
+ unit.set_generated_text(current_gen_text)
303
+ unit.set_status("success")
304
+
305
+ yield {"type": "info", "data": "All units processed by LLM."}
306
+ return units
307
+
308
+ async def _extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
309
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
310
+ """
311
+ This is the asynchronous version of the extract() method.
312
+
313
+ Parameters:
314
+ ----------
315
+ text_content : Union[str, Dict[str,str]]
316
+ the input text content to put in prompt template.
317
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
318
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
319
+ document_key : str, Optional
320
+ specify the key in text_content where document text is.
321
+ If text_content is str, this parameter will be ignored.
322
+ concurrent_batch_size : int, Optional
323
+ the batch size for concurrent processing.
324
+ return_messages_log : bool, Optional
325
+ if True, a list of messages will be returned.
326
+
327
+ Return : List[FrameExtractionUnit]
328
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
329
+ """
330
+ if isinstance(text_content, str):
331
+ doc_text = text_content
332
+ elif isinstance(text_content, dict):
333
+ if document_key is None:
334
+ raise ValueError("document_key must be provided when text_content is dict.")
335
+ if document_key not in text_content:
336
+ raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
337
+ doc_text = text_content[document_key]
338
+ else:
339
+ raise TypeError("text_content must be a string or a dictionary.")
340
+
341
+ units = self.unit_chunker.chunk(doc_text)
342
+
343
+ # context chunker init
344
+ self.context_chunker.fit(doc_text, units)
345
+
346
+ # messages logger init
347
+ messages_logger = MessagesLogger() if return_messages_log else None
348
+
349
+ # Prepare inputs for all units first
350
+ tasks_input = []
351
+ for i, unit in enumerate(units):
352
+ # construct chat messages
353
+ messages = []
354
+ if self.system_prompt:
355
+ messages.append({'role': 'system', 'content': self.system_prompt})
356
+
357
+ context = self.context_chunker.chunk(unit)
358
+
359
+ if context == "":
360
+ # no context, just place unit in user prompt
361
+ if isinstance(text_content, str):
362
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
363
+ else:
364
+ unit_content = text_content.copy()
365
+ unit_content[document_key] = unit.text
366
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
367
+ else:
368
+ # insert context to user prompt
369
+ if isinstance(text_content, str):
370
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
371
+ else:
372
+ context_content = text_content.copy()
373
+ context_content[document_key] = context
374
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
375
+ # simulate conversation where assistant confirms
376
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
377
+ # place unit of interest
378
+ messages.append({'role': 'user', 'content': unit.text})
379
+
380
+ # Store unit and messages together for the task
381
+ tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
382
+
383
+ # Process units concurrently with asyncio.Semaphore
384
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
385
+
386
+ async def semaphore_helper(task_data: Dict, **kwrs):
387
+ unit = task_data["unit"]
388
+ messages = task_data["messages"]
389
+
390
+ async with semaphore:
391
+ gen_text = await self.inference_engine.chat_async(
392
+ messages=messages,
393
+ messages_logger=messages_logger
394
+ )
395
+
396
+ unit.set_generated_text(gen_text["response"])
397
+ unit.set_status("success")
398
+
399
+ # Create and gather tasks
400
+ tasks = []
401
+ for task_inp in tasks_input:
402
+ task = asyncio.create_task(semaphore_helper(
403
+ task_inp
404
+ ))
405
+ tasks.append(task)
406
+
407
+ await asyncio.gather(*tasks)
408
+
409
+ # Return units
410
+ if return_messages_log:
411
+ return units, messages_logger.get_messages_log()
412
+ else:
413
+ return units
414
+
415
+ def _default_struct_aggregate(self, structs: List[Dict[str, Any]]) -> Dict[str, Any]:
416
+ """
417
+ Given a list of structured information (dict), aggregate them into a single dict by seqentially updating keys
418
+ and overwriting values.
419
+ """
420
+ aggregated_struct = {}
421
+ for struct in structs:
422
+ aggregated_struct.update(struct)
423
+ return aggregated_struct
424
+
425
+ def _post_process_struct(self, units: List[FrameExtractionUnit]) -> Dict[str, Any]:
426
+ """
427
+ Helper method to post-process units into a structured dictionary.
428
+ Shared by extract_struct and extract_struct_async.
429
+ """
430
+ struct_json = []
431
+ for unit in units:
432
+ if unit.status != "success":
433
+ continue
434
+ try:
435
+ unit_struct_json = extract_json(unit.get_generated_text())
436
+ struct_json.extend(unit_struct_json)
437
+ except Exception as e:
438
+ unit.set_status("fail")
439
+ warnings.warn(f"Struct extraction failed for unit ({unit.start}, {unit.end}): {e}", RuntimeWarning)
440
+
441
+ if self.aggregation_func is None:
442
+ struct = self._default_struct_aggregate(struct_json)
443
+ else:
444
+ struct = self.aggregation_func(struct_json)
445
+ return struct
446
+
447
+
448
+ def extract_struct(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
449
+ verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32,
450
+ return_messages_log:bool=False) -> List[Dict[str, Any]]:
451
+ """
452
+ This method inputs a document text and outputs a list of LLMInformationExtractionFrame
453
+ It use the extract() method and post-process outputs into frames.
454
+
455
+ Parameters:
456
+ ----------
457
+ text_content : Union[str, Dict[str,str]]
458
+ the input text content to put in prompt template.
459
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
460
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
461
+ document_key : str, Optional
462
+ specify the key in text_content where document text is.
463
+ If text_content is str, this parameter will be ignored.
464
+ verbose : bool, Optional
465
+ if True, LLM generated text will be printed in terminal in real-time.
466
+ concurrent : bool, Optional
467
+ if True, the sentences will be extracted in concurrent.
468
+ concurrent_batch_size : int, Optional
469
+ the number of sentences to process in concurrent. Only used when `concurrent` is True.
470
+ return_messages_log : bool, Optional
471
+ if True, a list of messages will be returned.
472
+
473
+ Return : List[Dict[str, Any]]
474
+ a list of unanchored structured information.
475
+ """
476
+ if concurrent:
477
+ if verbose:
478
+ warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
479
+
480
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
481
+ extraction_results = asyncio.run(self._extract_async(text_content=text_content,
482
+ document_key=document_key,
483
+ concurrent_batch_size=concurrent_batch_size,
484
+ return_messages_log=return_messages_log)
485
+ )
486
+ else:
487
+ extraction_results = self.extract(text_content=text_content,
488
+ document_key=document_key,
489
+ verbose=verbose,
490
+ return_messages_log=return_messages_log)
491
+
492
+ units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
493
+
494
+ struct = self._post_process_struct(units)
495
+
496
+ if return_messages_log:
497
+ return struct, messages_log
498
+ return struct
499
+
500
+ async def extract_struct_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
501
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> Dict[str, Any]:
502
+ """
503
+ This is the async version of extract_struct.
504
+ """
505
+ extraction_results = await self._extract_async(text_content=text_content,
506
+ document_key=document_key,
507
+ concurrent_batch_size=concurrent_batch_size,
508
+ return_messages_log=return_messages_log)
509
+
510
+ units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
511
+ struct = self._post_process_struct(units)
512
+
513
+ if return_messages_log:
514
+ return struct, messages_log
515
+ return struct
516
+
517
+
518
+ class BasicStructExtractor(StructExtractor):
519
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
520
+ system_prompt:str=None, aggregation_func:Callable=None):
521
+ """
522
+ This class prompts the LLM with the whole document at once for structured information extraction.
523
+ Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
524
+
525
+ Parameters:
526
+ ----------
527
+ inference_engine : InferenceEngine
528
+ the LLM inferencing engine object. Must implements the chat() method.
529
+ prompt_template : str
530
+ prompt template with "{{<placeholder name>}}" placeholder.
531
+ system_prompt : str, Optional
532
+ system prompt.
533
+ aggregation_func : Callable
534
+ a function that inputs a list of structured information (dict)
535
+ and outputs an aggregated structured information (dict).
536
+ if not specified, the default is to merge all dicts by updating keys and overwriting values sequentially.
537
+ """
538
+ super().__init__(inference_engine=inference_engine,
539
+ unit_chunker=WholeDocumentUnitChunker(),
540
+ prompt_template=prompt_template,
541
+ system_prompt=system_prompt,
542
+ context_chunker=WholeDocumentContextChunker())
543
+
544
+
101
545
  class FrameExtractor(Extractor):
102
546
  from nltk.tokenize import RegexpTokenizer
103
547
  def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
@@ -300,11 +744,19 @@ class FrameExtractor(Extractor):
300
744
  return_messages_log : bool, Optional
301
745
  if True, a list of messages will be returned.
302
746
 
303
- Return : str
747
+ Return : List[LLMInformationExtractionFrame]
304
748
  a list of frames.
305
749
  """
306
750
  return NotImplemented
307
751
 
752
+ @abc.abstractmethod
753
+ async def extract_frames_async(self, text_content:Union[str, Dict[str,str]], entity_key:str,
754
+ document_key:str=None, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
755
+ """
756
+ This is the async version of extract_frames.
757
+ """
758
+ return NotImplemented
759
+
308
760
 
309
761
  class DirectFrameExtractor(FrameExtractor):
310
762
  def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker,
@@ -513,7 +965,7 @@ class DirectFrameExtractor(FrameExtractor):
513
965
  yield {"type": "info", "data": "All units processed by LLM."}
514
966
  return units
515
967
 
516
- async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
968
+ async def _extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
517
969
  concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
518
970
  """
519
971
  This is the asynchronous version of the extract() method.
@@ -620,6 +1072,45 @@ class DirectFrameExtractor(FrameExtractor):
620
1072
  else:
621
1073
  return units
622
1074
 
1075
+ def _post_process_units_to_frames(self, units, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
1076
+ ENTITY_KEY = "entity_text"
1077
+ frame_list = []
1078
+ for unit in units:
1079
+ entity_json = []
1080
+ if unit.status != "success":
1081
+ warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
1082
+ continue
1083
+ for entity in extract_json(gen_text=unit.gen_text):
1084
+ if ENTITY_KEY in entity:
1085
+ entity_json.append(entity)
1086
+ else:
1087
+ warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
1088
+
1089
+ spans = self._find_entity_spans(text=unit.text,
1090
+ entities=[e[ENTITY_KEY] for e in entity_json],
1091
+ case_sensitive=case_sensitive,
1092
+ fuzzy_match=fuzzy_match,
1093
+ fuzzy_buffer_size=fuzzy_buffer_size,
1094
+ fuzzy_score_cutoff=fuzzy_score_cutoff,
1095
+ allow_overlap_entities=allow_overlap_entities)
1096
+ for ent, span in zip(entity_json, spans):
1097
+ if span is not None:
1098
+ start, end = span
1099
+ entity_text = unit.text[start:end]
1100
+ start += unit.start
1101
+ end += unit.start
1102
+ attr = {}
1103
+ if "attr" in ent and ent["attr"] is not None:
1104
+ attr = ent["attr"]
1105
+
1106
+ frame = LLMInformationExtractionFrame(frame_id=f"{len(frame_list)}",
1107
+ start=start,
1108
+ end=end,
1109
+ entity_text=entity_text,
1110
+ attr=attr)
1111
+ frame_list.append(frame)
1112
+ return frame_list
1113
+
623
1114
 
624
1115
  def extract_frames(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
625
1116
  verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32,
@@ -659,7 +1150,7 @@ class DirectFrameExtractor(FrameExtractor):
659
1150
  return_messages_log : bool, Optional
660
1151
  if True, a list of messages will be returned.
661
1152
 
662
- Return : str
1153
+ Return : List[LLMInformationExtractionFrame]
663
1154
  a list of frames.
664
1155
  """
665
1156
  ENTITY_KEY = "entity_text"
@@ -668,7 +1159,7 @@ class DirectFrameExtractor(FrameExtractor):
668
1159
  warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
669
1160
 
670
1161
  nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
671
- extraction_results = asyncio.run(self.extract_async(text_content=text_content,
1162
+ extraction_results = asyncio.run(self._extract_async(text_content=text_content,
672
1163
  document_key=document_key,
673
1164
  concurrent_batch_size=concurrent_batch_size,
674
1165
  return_messages_log=return_messages_log)
@@ -681,248 +1172,31 @@ class DirectFrameExtractor(FrameExtractor):
681
1172
 
682
1173
  units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
683
1174
 
684
- frame_list = []
685
- for unit in units:
686
- entity_json = []
687
- if unit.status != "success":
688
- warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
689
- continue
690
- for entity in extract_json(gen_text=unit.gen_text):
691
- if ENTITY_KEY in entity:
692
- entity_json.append(entity)
693
- else:
694
- warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
695
-
696
- spans = self._find_entity_spans(text=unit.text,
697
- entities=[e[ENTITY_KEY] for e in entity_json],
698
- case_sensitive=case_sensitive,
699
- fuzzy_match=fuzzy_match,
700
- fuzzy_buffer_size=fuzzy_buffer_size,
701
- fuzzy_score_cutoff=fuzzy_score_cutoff,
702
- allow_overlap_entities=allow_overlap_entities)
703
- for ent, span in zip(entity_json, spans):
704
- if span is not None:
705
- start, end = span
706
- entity_text = unit.text[start:end]
707
- start += unit.start
708
- end += unit.start
709
- attr = {}
710
- if "attr" in ent and ent["attr"] is not None:
711
- attr = ent["attr"]
712
-
713
- frame = LLMInformationExtractionFrame(frame_id=f"{len(frame_list)}",
714
- start=start,
715
- end=end,
716
- entity_text=entity_text,
717
- attr=attr)
718
- frame_list.append(frame)
1175
+ frame_list = self._post_process_units_to_frames(units, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
719
1176
 
720
1177
  if return_messages_log:
721
1178
  return frame_list, messages_log
722
1179
  return frame_list
723
-
724
1180
 
725
- async def extract_frames_from_documents(self, text_contents:List[Union[str,Dict[str, any]]], document_key:str="text",
726
- cpu_concurrency:int=4, llm_concurrency:int=32, case_sensitive:bool=False,
727
- fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
728
- allow_overlap_entities:bool=False, return_messages_log:bool=False) -> AsyncGenerator[Dict[str, any], None]:
1181
+ async def extract_frames_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
1182
+ concurrent_batch_size:int=32, case_sensitive:bool=False,
1183
+ fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
1184
+ allow_overlap_entities:bool=False, return_messages_log:bool=False) -> List[LLMInformationExtractionFrame]:
729
1185
  """
730
- This method inputs a list of documents and yields the results for each document as soon as it is complete.
731
-
732
- Parameters:
733
- -----------
734
- text_contents : List[Union[str,Dict[str, any]]]
735
- a list of input text contents to put in prompt template.
736
- If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
737
- If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
738
- document_key: str, optional
739
- The key in the `text_contents` dictionaries that holds the document text.
740
- cpu_concurrency: int, optional
741
- The number of parallel threads to use for CPU-bound tasks like chunking.
742
- llm_concurrency: int, optional
743
- The number of concurrent requests to make to the LLM.
744
- case_sensitive : bool, Optional
745
- if True, entity text matching will be case-sensitive.
746
- fuzzy_match : bool, Optional
747
- if True, fuzzy matching will be applied to find entity text.
748
- fuzzy_buffer_size : float, Optional
749
- the buffer size for fuzzy matching. Default is 20% of entity text length.
750
- fuzzy_score_cutoff : float, Optional
751
- the Jaccard score cutoff for fuzzy matching.
752
- Matched entity text must have a score higher than this value or a None will be returned.
753
- allow_overlap_entities : bool, Optional
754
- if True, entities can overlap in the text.
755
- return_messages_log : bool, Optional
756
- if True, a list of messages will be returned.
757
-
758
- Yields:
759
- -------
760
- AsyncGenerator[Dict[str, any], None]
761
- A dictionary for each completed document, containing its 'idx' and extracted 'frames'.
762
- """
763
- # Validate text_contents must be a list of str or dict, and not both
764
- if not isinstance(text_contents, list):
765
- raise ValueError("text_contents must be a list of strings or dictionaries.")
766
- if all(isinstance(doc, str) for doc in text_contents):
767
- pass
768
- elif all(isinstance(doc, dict) for doc in text_contents):
769
- pass
770
- # Set CPU executor and queues
771
- cpu_executor = ThreadPoolExecutor(max_workers=cpu_concurrency)
772
- tasks_queue = asyncio.Queue(maxsize=llm_concurrency * 2)
773
- # Store to track units and pending counts
774
- results_store = {
775
- idx: {'pending': 0, 'units': [], 'text': doc if isinstance(doc, str) else doc.get(document_key, "")}
776
- for idx, doc in enumerate(text_contents)
777
- }
778
-
779
- output_queue = asyncio.Queue()
780
- messages_logger = MessagesLogger() if return_messages_log else None
781
-
782
- async def producer():
783
- try:
784
- for idx, text_content in enumerate(text_contents):
785
- text = text_content if isinstance(text_content, str) else text_content.get(document_key, "")
786
- if not text:
787
- warnings.warn(f"Document at index {idx} is empty or missing the document key '{document_key}'.")
788
- # signal that this document is done
789
- await output_queue.put({'idx': idx, 'frames': []})
790
- continue
791
-
792
- units = await self.unit_chunker.chunk_async(text, cpu_executor)
793
- await self.context_chunker.fit_async(text, units, cpu_executor)
794
- results_store[idx]['pending'] = len(units)
795
-
796
- # Handle cases where a document yields no units
797
- if not units:
798
- # signal that this document is done
799
- await output_queue.put({'idx': idx, 'frames': []})
800
- continue
801
-
802
- # Iterate through units
803
- for unit in units:
804
- context = await self.context_chunker.chunk_async(unit, cpu_executor)
805
- messages = []
806
- if self.system_prompt:
807
- messages.append({'role': 'system', 'content': self.system_prompt})
808
-
809
- if not context:
810
- if isinstance(text_content, str):
811
- messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
812
- else:
813
- unit_content = text_content.copy()
814
- unit_content[document_key] = unit.text
815
- messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
816
- else:
817
- # insert context to user prompt
818
- if isinstance(text_content, str):
819
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
820
- else:
821
- context_content = text_content.copy()
822
- context_content[document_key] = context
823
- messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
824
- # simulate conversation where assistant confirms
825
- messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
826
- # place unit of interest
827
- messages.append({'role': 'user', 'content': unit.text})
828
-
829
- await tasks_queue.put({'idx': idx, 'unit': unit, 'messages': messages})
830
- finally:
831
- for _ in range(llm_concurrency):
832
- await tasks_queue.put(None)
833
-
834
- async def worker():
835
- while True:
836
- task_item = await tasks_queue.get()
837
- if task_item is None:
838
- tasks_queue.task_done()
839
- break
840
-
841
- idx = task_item['idx']
842
- unit = task_item['unit']
843
- doc_results = results_store[idx]
844
-
845
- try:
846
- gen_text = await self.inference_engine.chat_async(
847
- messages=task_item['messages'], messages_logger=messages_logger
848
- )
849
- unit.set_generated_text(gen_text["response"])
850
- unit.set_status("success")
851
- doc_results['units'].append(unit)
852
- except Exception as e:
853
- warnings.warn(f"Error processing unit for doc idx {idx}: {e}")
854
- finally:
855
- doc_results['pending'] -= 1
856
- if doc_results['pending'] <= 0:
857
- final_frames = self._post_process_and_create_frames(doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
858
- output_payload = {'idx': idx, 'frames': final_frames}
859
- if return_messages_log:
860
- output_payload['messages_log'] = messages_logger.get_messages_log()
861
- await output_queue.put(output_payload)
862
-
863
- tasks_queue.task_done()
864
-
865
- # Start producer and workers
866
- producer_task = asyncio.create_task(producer())
867
- worker_tasks = [asyncio.create_task(worker()) for _ in range(llm_concurrency)]
868
-
869
- # Main loop to gather results
870
- docs_completed = 0
871
- while docs_completed < len(text_contents):
872
- result = await output_queue.get()
873
- yield result
874
- docs_completed += 1
875
-
876
- # Final cleanup
877
- await producer_task
878
- await tasks_queue.join()
879
-
880
- # Cancel any lingering worker tasks
881
- for task in worker_tasks:
882
- task.cancel()
883
- await asyncio.gather(*worker_tasks, return_exceptions=True)
884
-
885
- cpu_executor.shutdown(wait=False)
886
-
887
-
888
- def _post_process_and_create_frames(self, doc_results, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities):
889
- """Helper function to run post-processing logic for a completed document."""
890
- ENTITY_KEY = "entity_text"
891
- frame_list = []
892
- for res in sorted(doc_results['units'], key=lambda r: r.start):
893
- entity_json = []
894
- for entity in extract_json(gen_text=res.gen_text):
895
- if ENTITY_KEY in entity:
896
- entity_json.append(entity)
897
- else:
898
- warnings.warn(f'Extractor output "{entity}" does not have entity_key ("{ENTITY_KEY}"). This frame will be dropped.', RuntimeWarning)
1186
+ This is the async version of extract_frames.
1187
+ """
1188
+ extraction_results = await self._extract_async(text_content=text_content,
1189
+ document_key=document_key,
1190
+ concurrent_batch_size=concurrent_batch_size,
1191
+ return_messages_log=return_messages_log)
1192
+
1193
+ units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
1194
+ frame_list = self._post_process_units_to_frames(units, case_sensitive, fuzzy_match, fuzzy_buffer_size, fuzzy_score_cutoff, allow_overlap_entities)
899
1195
 
900
- spans = self._find_entity_spans(
901
- text=res.text,
902
- entities=[e[ENTITY_KEY] for e in entity_json],
903
- case_sensitive=case_sensitive,
904
- fuzzy_match=fuzzy_match,
905
- fuzzy_buffer_size=fuzzy_buffer_size,
906
- fuzzy_score_cutoff=fuzzy_score_cutoff,
907
- allow_overlap_entities=allow_overlap_entities
908
- )
909
- for ent, span in zip(entity_json, spans):
910
- if span is not None:
911
- start, end = span
912
- entity_text = res.text[start:end]
913
- start += res.start
914
- end += res.start
915
- attr = ent.get("attr", {}) or {}
916
- frame = LLMInformationExtractionFrame(
917
- frame_id=f"{len(frame_list)}",
918
- start=start,
919
- end=end,
920
- entity_text=entity_text,
921
- attr=attr
922
- )
923
- frame_list.append(frame)
1196
+ if return_messages_log:
1197
+ return frame_list, messages_log
924
1198
  return frame_list
925
-
1199
+
926
1200
 
927
1201
  class ReviewFrameExtractor(DirectFrameExtractor):
928
1202
  def __init__(self, unit_chunker:UnitChunker, context_chunker:ContextChunker, inference_engine:InferenceEngine,
@@ -1200,7 +1474,7 @@ class ReviewFrameExtractor(DirectFrameExtractor):
1200
1474
  for chunk in response_stream:
1201
1475
  yield chunk
1202
1476
 
1203
- async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
1477
+ async def _extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
1204
1478
  concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[FrameExtractionUnit]:
1205
1479
  """
1206
1480
  This is the asynchronous version of the extract() method with the review step.
@@ -1703,7 +1977,7 @@ class AttributeExtractor(Extractor):
1703
1977
  return (new_frames, messages_log) if return_messages_log else new_frames
1704
1978
 
1705
1979
 
1706
- async def extract_async(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
1980
+ async def _extract_async(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
1707
1981
  concurrent_batch_size:int=32, inplace:bool=True, return_messages_log:bool=False) -> Union[None, List[LLMInformationExtractionFrame]]:
1708
1982
  """
1709
1983
  This method extracts attributes from the document asynchronously.
@@ -1775,6 +2049,16 @@ class AttributeExtractor(Extractor):
1775
2049
  else:
1776
2050
  return (new_frames, messages_logger.get_messages_log()) if return_messages_log else new_frames
1777
2051
 
2052
+ async def extract_attributes_async(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
2053
+ concurrent_batch_size:int=32, inplace:bool=True,
2054
+ return_messages_log:bool=False) -> Union[None, List[LLMInformationExtractionFrame]]:
2055
+ """
2056
+ This is the async version of extract_attributes.
2057
+ """
2058
+ return await self._extract_async(frames=frames, text=text, context_size=context_size,
2059
+ concurrent_batch_size=concurrent_batch_size, inplace=inplace, return_messages_log=return_messages_log)
2060
+
2061
+
1778
2062
  def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
1779
2063
  concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
1780
2064
  return_messages_log:bool=False, inplace:bool=True) -> Union[None, List[LLMInformationExtractionFrame]]:
@@ -1810,7 +2094,7 @@ class AttributeExtractor(Extractor):
1810
2094
 
1811
2095
  nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
1812
2096
 
1813
- return asyncio.run(self.extract_async(frames=frames, text=text, context_size=context_size,
2097
+ return asyncio.run(self._extract_async(frames=frames, text=text, context_size=context_size,
1814
2098
  concurrent_batch_size=concurrent_batch_size,
1815
2099
  inplace=inplace, return_messages_log=return_messages_log))
1816
2100
  else:
@@ -1955,6 +2239,17 @@ class RelationExtractor(Extractor):
1955
2239
  return asyncio.run(self._extract_async(doc, buffer_size, concurrent_batch_size, return_messages_log))
1956
2240
  else:
1957
2241
  return self._extract(doc, buffer_size, verbose, return_messages_log)
2242
+
2243
+ async def extract_relations_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
2244
+ """
2245
+ This is the async version of extract_relations.
2246
+ """
2247
+ if not doc.has_frame():
2248
+ raise ValueError("Input document must have frames.")
2249
+ if doc.has_duplicate_frame_ids():
2250
+ raise ValueError("All frame_ids in the input document must be unique.")
2251
+
2252
+ return await self._extract_async(doc, buffer_size, concurrent_batch_size, return_messages_log)
1958
2253
 
1959
2254
 
1960
2255
  class BinaryRelationExtractor(RelationExtractor):