llm-ie 0.4.5__tar.gz → 0.4.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {llm_ie-0.4.5 → llm_ie-0.4.7}/PKG-INFO +10 -6
  2. {llm_ie-0.4.5 → llm_ie-0.4.7}/README.md +9 -5
  3. {llm_ie-0.4.5 → llm_ie-0.4.7}/pyproject.toml +3 -2
  4. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/extractors.py +351 -116
  5. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/__init__.py +0 -0
  6. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/PromptEditor_prompts/chat.txt +0 -0
  7. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/PromptEditor_prompts/comment.txt +0 -0
  8. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/PromptEditor_prompts/rewrite.txt +0 -0
  9. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/PromptEditor_prompts/system.txt +0 -0
  10. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/default_prompts/ReviewFrameExtractor_addition_review_prompt.txt +0 -0
  11. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/default_prompts/ReviewFrameExtractor_revision_review_prompt.txt +0 -0
  12. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/default_prompts/SentenceReviewFrameExtractor_addition_review_prompt.txt +0 -0
  13. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/default_prompts/SentenceReviewFrameExtractor_revision_review_prompt.txt +0 -0
  14. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/prompt_guide/BasicFrameExtractor_prompt_guide.txt +0 -0
  15. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/prompt_guide/BinaryRelationExtractor_prompt_guide.txt +0 -0
  16. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/prompt_guide/MultiClassRelationExtractor_prompt_guide.txt +0 -0
  17. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/prompt_guide/ReviewFrameExtractor_prompt_guide.txt +0 -0
  18. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/prompt_guide/SentenceCoTFrameExtractor_prompt_guide.txt +0 -0
  19. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/prompt_guide/SentenceFrameExtractor_prompt_guide.txt +0 -0
  20. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/asset/prompt_guide/SentenceReviewFrameExtractor_prompt_guide.txt +0 -0
  21. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/data_types.py +0 -0
  22. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/engines.py +0 -0
  23. {llm_ie-0.4.5 → llm_ie-0.4.7}/src/llm_ie/prompt_editor.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: llm-ie
3
- Version: 0.4.5
3
+ Version: 0.4.7
4
4
  Summary: An LLM-powered tool that transforms everyday language into robust information extraction pipelines.
5
5
  License: MIT
6
6
  Author: Enshuo (David) Hsu
@@ -44,7 +44,7 @@ An LLM-powered tool that transforms everyday language into robust information ex
44
44
  - [v0.4.5](https://github.com/daviden1013/llm-ie/releases/tag/v0.4.5) (Feb 16, 2025):
45
45
  - Added option to adjust number of context sentences in sentence-based extractors.
46
46
  - Added support for OpenAI reasoning models ("o" series).
47
-
47
+ - [v0.4.6](https://github.com/daviden1013/llm-ie/releases/tag/v0.4.6) (Mar 1, 2025): Allow LLM to output overlapping frames.
48
48
 
49
49
  ## Table of Contents
50
50
  - [Overview](#overview)
@@ -1206,10 +1206,14 @@ We benchmarked the frame and relation extractors on biomedical information extra
1206
1206
  ## Citation
1207
1207
  For more information and benchmarks, please check our paper:
1208
1208
  ```bibtex
1209
- @article{hsu2024llm,
1210
- title={LLM-IE: A Python Package for Generative Information Extraction with Large Language Models},
1209
+ @article{hsu2025llm,
1210
+ title={LLM-IE: a python package for biomedical generative information extraction with large language models},
1211
1211
  author={Hsu, Enshuo and Roberts, Kirk},
1212
- journal={arXiv preprint arXiv:2411.11779},
1213
- year={2024}
1212
+ journal={JAMIA open},
1213
+ volume={8},
1214
+ number={2},
1215
+ pages={ooaf012},
1216
+ year={2025},
1217
+ publisher={Oxford University Press}
1214
1218
  }
1215
1219
  ```
@@ -27,7 +27,7 @@ An LLM-powered tool that transforms everyday language into robust information ex
27
27
  - [v0.4.5](https://github.com/daviden1013/llm-ie/releases/tag/v0.4.5) (Feb 16, 2025):
28
28
  - Added option to adjust number of context sentences in sentence-based extractors.
29
29
  - Added support for OpenAI reasoning models ("o" series).
30
-
30
+ - [v0.4.6](https://github.com/daviden1013/llm-ie/releases/tag/v0.4.6) (Mar 1, 2025): Allow LLM to output overlapping frames.
31
31
 
32
32
  ## Table of Contents
33
33
  - [Overview](#overview)
@@ -1189,10 +1189,14 @@ We benchmarked the frame and relation extractors on biomedical information extra
1189
1189
  ## Citation
1190
1190
  For more information and benchmarks, please check our paper:
1191
1191
  ```bibtex
1192
- @article{hsu2024llm,
1193
- title={LLM-IE: A Python Package for Generative Information Extraction with Large Language Models},
1192
+ @article{hsu2025llm,
1193
+ title={LLM-IE: a python package for biomedical generative information extraction with large language models},
1194
1194
  author={Hsu, Enshuo and Roberts, Kirk},
1195
- journal={arXiv preprint arXiv:2411.11779},
1196
- year={2024}
1195
+ journal={JAMIA open},
1196
+ volume={8},
1197
+ number={2},
1198
+ pages={ooaf012},
1199
+ year={2025},
1200
+ publisher={Oxford University Press}
1197
1201
  }
1198
1202
  ```
@@ -1,13 +1,14 @@
1
1
  [tool.poetry]
2
2
  name = "llm-ie"
3
- version = "0.4.5"
3
+ version = "0.4.7"
4
4
  description = "An LLM-powered tool that transforms everyday language into robust information extraction pipelines."
5
5
  authors = ["Enshuo (David) Hsu"]
6
6
  license = "MIT"
7
7
  readme = "README.md"
8
8
 
9
9
  exclude = [
10
- "test/**"
10
+ "test/**",
11
+ "develop/**"
11
12
  ]
12
13
 
13
14
 
@@ -224,7 +224,8 @@ class FrameExtractor(Extractor):
224
224
 
225
225
 
226
226
  def _find_entity_spans(self, text: str, entities: List[str], case_sensitive:bool=False,
227
- fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8) -> List[Tuple[int]]:
227
+ fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
228
+ allow_overlap_entities:bool=False) -> List[Tuple[int]]:
228
229
  """
229
230
  This function inputs a text and a list of entity text,
230
231
  outputs a list of spans (2-tuple) for each entity.
@@ -245,6 +246,8 @@ class FrameExtractor(Extractor):
245
246
  fuzzy_score_cutoff : float, Optional
246
247
  the Jaccard score cutoff for fuzzy matching.
247
248
  Matched entity text must have a score higher than this value or a None will be returned.
249
+ allow_overlap_entities : bool, Optional
250
+ if True, entities can overlap in the text.
248
251
  """
249
252
  # Handle case sensitivity
250
253
  if not case_sensitive:
@@ -264,15 +267,17 @@ class FrameExtractor(Extractor):
264
267
  if match and entity:
265
268
  start, end = match.span()
266
269
  entity_spans.append((start, end))
267
- # Replace the found entity with spaces to avoid finding the same instance again
268
- text = text[:start] + ' ' * (end - start) + text[end:]
270
+ if not allow_overlap_entities:
271
+ # Replace the found entity with spaces to avoid finding the same instance again
272
+ text = text[:start] + ' ' * (end - start) + text[end:]
269
273
  # Fuzzy match
270
274
  elif fuzzy_match:
271
275
  closest_substring_span, best_score = self._get_closest_substring(text, entity, buffer_size=fuzzy_buffer_size)
272
276
  if closest_substring_span and best_score >= fuzzy_score_cutoff:
273
277
  entity_spans.append(closest_substring_span)
274
- # Replace the found entity with spaces to avoid finding the same instance again
275
- text = text[:closest_substring_span[0]] + ' ' * (closest_substring_span[1] - closest_substring_span[0]) + text[closest_substring_span[1]:]
278
+ if not allow_overlap_entities:
279
+ # Replace the found entity with spaces to avoid finding the same instance again
280
+ text = text[:closest_substring_span[0]] + ' ' * (closest_substring_span[1] - closest_substring_span[0]) + text[closest_substring_span[1]:]
276
281
  else:
277
282
  entity_spans.append(None)
278
283
 
@@ -283,7 +288,7 @@ class FrameExtractor(Extractor):
283
288
  return entity_spans
284
289
 
285
290
  @abc.abstractmethod
286
- def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, **kwrs) -> str:
291
+ def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048, return_messages_log:bool=False, **kwrs) -> str:
287
292
  """
288
293
  This method inputs text content and outputs a string generated by LLM
289
294
 
@@ -295,6 +300,8 @@ class FrameExtractor(Extractor):
295
300
  If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
296
301
  max_new_tokens : str, Optional
297
302
  the max number of new tokens LLM can generate.
303
+ return_messages_log : bool, Optional
304
+ if True, a list of messages will be returned.
298
305
 
299
306
  Return : str
300
307
  the output from LLM. Need post-processing.
@@ -304,7 +311,7 @@ class FrameExtractor(Extractor):
304
311
 
305
312
  @abc.abstractmethod
306
313
  def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
307
- document_key:str=None, **kwrs) -> List[LLMInformationExtractionFrame]:
314
+ document_key:str=None, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
308
315
  """
309
316
  This method inputs text content and outputs a list of LLMInformationExtractionFrame
310
317
  It use the extract() method and post-process outputs into frames.
@@ -322,6 +329,8 @@ class FrameExtractor(Extractor):
322
329
  document_key : str, Optional
323
330
  specify the key in text_content where document text is.
324
331
  If text_content is str, this parameter will be ignored.
332
+ return_messages_log : bool, Optional
333
+ if True, a list of messages will be returned.
325
334
 
326
335
  Return : str
327
336
  a list of frames.
@@ -352,7 +361,7 @@ class BasicFrameExtractor(FrameExtractor):
352
361
 
353
362
 
354
363
  def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=2048,
355
- temperature:float=0.0, stream:bool=False, **kwrs) -> str:
364
+ temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> str:
356
365
  """
357
366
  This method inputs a text and outputs a string generated by LLM.
358
367
 
@@ -368,6 +377,8 @@ class BasicFrameExtractor(FrameExtractor):
368
377
  the temperature for token sampling.
369
378
  stream : bool, Optional
370
379
  if True, LLM generated text will be printed in terminal in real-time.
380
+ return_messages_log : bool, Optional
381
+ if True, a list of messages will be returned.
371
382
 
372
383
  Return : str
373
384
  the output from LLM. Need post-processing.
@@ -385,13 +396,19 @@ class BasicFrameExtractor(FrameExtractor):
385
396
  **kwrs
386
397
  )
387
398
 
399
+ if return_messages_log:
400
+ messages.append({"role": "assistant", "content": response})
401
+ messages_log = [messages]
402
+ return response, messages_log
403
+
388
404
  return response
389
405
 
390
406
 
391
407
  def extract_frames(self, text_content:Union[str, Dict[str,str]], entity_key:str, max_new_tokens:int=2048,
392
408
  temperature:float=0.0, document_key:str=None, stream:bool=False,
393
409
  case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2,
394
- fuzzy_score_cutoff:float=0.8, **kwrs) -> List[LLMInformationExtractionFrame]:
410
+ fuzzy_score_cutoff:float=0.8, allow_overlap_entities:bool=False,
411
+ return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
395
412
  """
396
413
  This method inputs a text and outputs a list of LLMInformationExtractionFrame
397
414
  It use the extract() method and post-process outputs into frames.
@@ -422,6 +439,11 @@ class BasicFrameExtractor(FrameExtractor):
422
439
  fuzzy_score_cutoff : float, Optional
423
440
  the Jaccard score cutoff for fuzzy matching.
424
441
  Matched entity text must have a score higher than this value or a None will be returned.
442
+ allow_overlap_entities : bool, Optional
443
+ if True, entities can overlap in the text.
444
+ Note that this can cause multiple frames to be generated on the same entity span if they have same entity text.
445
+ return_messages_log : bool, Optional
446
+ if True, a list of messages will be returned.
425
447
 
426
448
  Return : str
427
449
  a list of frames.
@@ -434,11 +456,13 @@ class BasicFrameExtractor(FrameExtractor):
434
456
  text = text_content[document_key]
435
457
 
436
458
  frame_list = []
437
- gen_text = self.extract(text_content=text_content,
438
- max_new_tokens=max_new_tokens,
439
- temperature=temperature,
440
- stream=stream,
441
- **kwrs)
459
+ extraction_results = self.extract(text_content=text_content,
460
+ max_new_tokens=max_new_tokens,
461
+ temperature=temperature,
462
+ stream=stream,
463
+ return_messages_log=return_messages_log,
464
+ **kwrs)
465
+ gen_text, messages_log = extraction_results if return_messages_log else (extraction_results, None)
442
466
 
443
467
  entity_json = []
444
468
  for entity in self._extract_json(gen_text=gen_text):
@@ -452,7 +476,8 @@ class BasicFrameExtractor(FrameExtractor):
452
476
  case_sensitive=case_sensitive,
453
477
  fuzzy_match=fuzzy_match,
454
478
  fuzzy_buffer_size=fuzzy_buffer_size,
455
- fuzzy_score_cutoff=fuzzy_score_cutoff)
479
+ fuzzy_score_cutoff=fuzzy_score_cutoff,
480
+ allow_overlap_entities=allow_overlap_entities)
456
481
 
457
482
  for i, (ent, span) in enumerate(zip(entity_json, spans)):
458
483
  if span is not None:
@@ -463,6 +488,10 @@ class BasicFrameExtractor(FrameExtractor):
463
488
  entity_text=text[start:end],
464
489
  attr={k: v for k, v in ent.items() if k != entity_key and v != ""})
465
490
  frame_list.append(frame)
491
+
492
+ if return_messages_log:
493
+ return frame_list, messages_log
494
+
466
495
  return frame_list
467
496
 
468
497
 
@@ -509,7 +538,7 @@ class ReviewFrameExtractor(BasicFrameExtractor):
509
538
 
510
539
 
511
540
  def extract(self, text_content:Union[str, Dict[str,str]],
512
- max_new_tokens:int=4096, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
541
+ max_new_tokens:int=4096, temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> str:
513
542
  """
514
543
  This method inputs a text and outputs a string generated by LLM.
515
544
 
@@ -525,6 +554,8 @@ class ReviewFrameExtractor(BasicFrameExtractor):
525
554
  the temperature for token sampling.
526
555
  stream : bool, Optional
527
556
  if True, LLM generated text will be printed in terminal in real-time.
557
+ return_messages_log : bool, Optional
558
+ if True, a list of messages will be returned.
528
559
 
529
560
  Return : str
530
561
  the output from LLM. Need post-processing.
@@ -561,10 +592,18 @@ class ReviewFrameExtractor(BasicFrameExtractor):
561
592
  )
562
593
 
563
594
  # Output
595
+ output_text = ""
564
596
  if self.review_mode == "revision":
565
- return review
597
+ output_text = review
566
598
  elif self.review_mode == "addition":
567
- return initial + '\n' + review
599
+ output_text = initial + '\n' + review
600
+
601
+ if return_messages_log:
602
+ messages.append({"role": "assistant", "content": review})
603
+ messages_log = [messages]
604
+ return output_text, messages_log
605
+
606
+ return output_text
568
607
 
569
608
 
570
609
  class SentenceFrameExtractor(FrameExtractor):
@@ -648,7 +687,7 @@ class SentenceFrameExtractor(FrameExtractor):
648
687
 
649
688
 
650
689
  def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
651
- document_key:str=None, temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict[str,str]]:
690
+ document_key:str=None, temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
652
691
  """
653
692
  This method inputs a text and outputs a list of outputs per sentence.
654
693
 
@@ -667,6 +706,8 @@ class SentenceFrameExtractor(FrameExtractor):
667
706
  the temperature for token sampling.
668
707
  stream : bool, Optional
669
708
  if True, LLM generated text will be printed in terminal in real-time.
709
+ return_messages_log : bool, Optional
710
+ if True, a list of messages will be returned.
670
711
 
671
712
  Return : str
672
713
  the output from LLM. Need post-processing.
@@ -681,6 +722,9 @@ class SentenceFrameExtractor(FrameExtractor):
681
722
  raise ValueError("document_key must be provided when text_content is dict.")
682
723
  sentences = self._get_sentences(text_content[document_key])
683
724
 
725
+ if return_messages_log:
726
+ messages_log = []
727
+
684
728
  # generate sentence by sentence
685
729
  for i, sent in enumerate(sentences):
686
730
  # construct chat messages
@@ -692,10 +736,20 @@ class SentenceFrameExtractor(FrameExtractor):
692
736
 
693
737
  if self.context_sentences == 0:
694
738
  # no context, just place sentence of interest
695
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
739
+ if isinstance(text_content, str):
740
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
741
+ else:
742
+ sentence_content = text_content.copy()
743
+ sentence_content[document_key] = sent['sentence_text']
744
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
696
745
  else:
697
746
  # insert context
698
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
747
+ if isinstance(text_content, str):
748
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
749
+ else:
750
+ context_content = text_content.copy()
751
+ context_content[document_key] = context
752
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
699
753
  # simulate conversation
700
754
  messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
701
755
  # place sentence of interest
@@ -715,6 +769,10 @@ class SentenceFrameExtractor(FrameExtractor):
715
769
  stream=stream,
716
770
  **kwrs
717
771
  )
772
+
773
+ if return_messages_log:
774
+ messages.append({"role": "assistant", "content": gen_text})
775
+ messages_log.append(messages)
718
776
 
719
777
  # add to output
720
778
  output.append({'sentence_start': sent['start'],
@@ -722,11 +780,15 @@ class SentenceFrameExtractor(FrameExtractor):
722
780
  'sentence_text': sent['sentence_text'],
723
781
  'gen_text': gen_text})
724
782
 
783
+ if return_messages_log:
784
+ return output, messages_log
785
+
725
786
  return output
726
787
 
727
788
 
728
789
  async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
729
- document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
790
+ document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32,
791
+ return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
730
792
  """
731
793
  The asynchronous version of the extract() method.
732
794
 
@@ -745,6 +807,11 @@ class SentenceFrameExtractor(FrameExtractor):
745
807
  the temperature for token sampling.
746
808
  concurrent_batch_size : int, Optional
747
809
  the number of sentences to process in concurrent.
810
+ return_messages_log : bool, Optional
811
+ if True, a list of messages will be returned.
812
+
813
+ Return : str
814
+ the output from LLM. Need post-processing.
748
815
  """
749
816
  # Check if self.inference_engine.chat_async() is implemented
750
817
  if not hasattr(self.inference_engine, 'chat_async'):
@@ -760,10 +827,14 @@ class SentenceFrameExtractor(FrameExtractor):
760
827
  raise ValueError("document_key must be provided when text_content is dict.")
761
828
  sentences = self._get_sentences(text_content[document_key])
762
829
 
830
+ if return_messages_log:
831
+ messages_log = []
832
+
763
833
  # generate sentence by sentence
764
- tasks = []
765
834
  for i in range(0, len(sentences), concurrent_batch_size):
835
+ tasks = []
766
836
  batch = sentences[i:i + concurrent_batch_size]
837
+ batch_messages = []
767
838
  for j, sent in enumerate(batch):
768
839
  # construct chat messages
769
840
  messages = []
@@ -774,10 +845,20 @@ class SentenceFrameExtractor(FrameExtractor):
774
845
 
775
846
  if self.context_sentences == 0:
776
847
  # no context, just place sentence of interest
777
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
848
+ if isinstance(text_content, str):
849
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
850
+ else:
851
+ sentence_content = text_content.copy()
852
+ sentence_content[document_key] = sent['sentence_text']
853
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
778
854
  else:
779
855
  # insert context
780
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
856
+ if isinstance(text_content, str):
857
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
858
+ else:
859
+ context_content = text_content.copy()
860
+ context_content[document_key] = context
861
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
781
862
  # simulate conversation
782
863
  messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
783
864
  # place sentence of interest
@@ -793,16 +874,25 @@ class SentenceFrameExtractor(FrameExtractor):
793
874
  )
794
875
  )
795
876
  tasks.append(task)
877
+ batch_messages.append(messages)
796
878
 
797
879
  # Wait until the batch is done, collect results and move on to next batch
798
880
  responses = await asyncio.gather(*tasks)
799
881
 
800
- # Collect outputs
801
- for gen_text, sent in zip(responses, sentences):
802
- output.append({'sentence_start': sent['start'],
803
- 'sentence_end': sent['end'],
804
- 'sentence_text': sent['sentence_text'],
805
- 'gen_text': gen_text})
882
+ # Collect outputs
883
+ for gen_text, sent, messages in zip(responses, batch, batch_messages):
884
+ if return_messages_log:
885
+ messages.append({"role": "assistant", "content": gen_text})
886
+ messages_log.append(messages)
887
+
888
+ output.append({'sentence_start': sent['start'],
889
+ 'sentence_end': sent['end'],
890
+ 'sentence_text': sent['sentence_text'],
891
+ 'gen_text': gen_text})
892
+
893
+ if return_messages_log:
894
+ return output, messages_log
895
+
806
896
  return output
807
897
 
808
898
 
@@ -810,7 +900,7 @@ class SentenceFrameExtractor(FrameExtractor):
810
900
  document_key:str=None, temperature:float=0.0, stream:bool=False,
811
901
  concurrent:bool=False, concurrent_batch_size:int=32,
812
902
  case_sensitive:bool=False, fuzzy_match:bool=True, fuzzy_buffer_size:float=0.2, fuzzy_score_cutoff:float=0.8,
813
- **kwrs) -> List[LLMInformationExtractionFrame]:
903
+ allow_overlap_entities:bool=False, return_messages_log:bool=False, **kwrs) -> List[LLMInformationExtractionFrame]:
814
904
  """
815
905
  This method inputs a text and outputs a list of LLMInformationExtractionFrame
816
906
  It use the extract() method and post-process outputs into frames.
@@ -845,6 +935,11 @@ class SentenceFrameExtractor(FrameExtractor):
845
935
  fuzzy_score_cutoff : float, Optional
846
936
  the Jaccard score cutoff for fuzzy matching.
847
937
  Matched entity text must have a score higher than this value or a None will be returned.
938
+ allow_overlap_entities : bool, Optional
939
+ if True, entities can overlap in the text.
940
+ Note that this can cause multiple frames to be generated on the same entity span if they have same entity text.
941
+ return_messages_log : bool, Optional
942
+ if True, a list of messages will be returned.
848
943
 
849
944
  Return : str
850
945
  a list of frames.
@@ -854,20 +949,25 @@ class SentenceFrameExtractor(FrameExtractor):
854
949
  warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
855
950
 
856
951
  nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
857
- llm_output_sentences = asyncio.run(self.extract_async(text_content=text_content,
858
- max_new_tokens=max_new_tokens,
859
- document_key=document_key,
860
- temperature=temperature,
861
- concurrent_batch_size=concurrent_batch_size,
862
- **kwrs)
863
- )
952
+ extraction_results = asyncio.run(self.extract_async(text_content=text_content,
953
+ max_new_tokens=max_new_tokens,
954
+ document_key=document_key,
955
+ temperature=temperature,
956
+ concurrent_batch_size=concurrent_batch_size,
957
+ return_messages_log=return_messages_log,
958
+ **kwrs)
959
+ )
864
960
  else:
865
- llm_output_sentences = self.extract(text_content=text_content,
961
+ extraction_results = self.extract(text_content=text_content,
866
962
  max_new_tokens=max_new_tokens,
867
963
  document_key=document_key,
868
964
  temperature=temperature,
869
965
  stream=stream,
966
+ return_messages_log=return_messages_log,
870
967
  **kwrs)
968
+
969
+ llm_output_sentences, messages_log = extraction_results if return_messages_log else (extraction_results, None)
970
+
871
971
  frame_list = []
872
972
  for sent in llm_output_sentences:
873
973
  entity_json = []
@@ -882,7 +982,8 @@ class SentenceFrameExtractor(FrameExtractor):
882
982
  case_sensitive=case_sensitive,
883
983
  fuzzy_match=fuzzy_match,
884
984
  fuzzy_buffer_size=fuzzy_buffer_size,
885
- fuzzy_score_cutoff=fuzzy_score_cutoff)
985
+ fuzzy_score_cutoff=fuzzy_score_cutoff,
986
+ allow_overlap_entities=allow_overlap_entities)
886
987
  for ent, span in zip(entity_json, spans):
887
988
  if span is not None:
888
989
  start, end = span
@@ -895,6 +996,9 @@ class SentenceFrameExtractor(FrameExtractor):
895
996
  entity_text=entity_text,
896
997
  attr={k: v for k, v in ent.items() if k != entity_key and v != ""})
897
998
  frame_list.append(frame)
999
+
1000
+ if return_messages_log:
1001
+ return frame_list, messages_log
898
1002
  return frame_list
899
1003
 
900
1004
 
@@ -950,7 +1054,7 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
950
1054
 
951
1055
 
952
1056
  def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
953
- document_key:str=None, temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict[str,str]]:
1057
+ document_key:str=None, temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
954
1058
  """
955
1059
  This method inputs a text and outputs a list of outputs per sentence.
956
1060
 
@@ -969,6 +1073,8 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
969
1073
  the temperature for token sampling.
970
1074
  stream : bool, Optional
971
1075
  if True, LLM generated text will be printed in terminal in real-time.
1076
+ return_messages_log : bool, Optional
1077
+ if True, a list of messages will be returned.
972
1078
 
973
1079
  Return : str
974
1080
  the output from LLM. Need post-processing.
@@ -983,6 +1089,9 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
983
1089
  raise ValueError("document_key must be provided when text_content is dict.")
984
1090
  sentences = self._get_sentences(text_content[document_key])
985
1091
 
1092
+ if return_messages_log:
1093
+ messages_log = []
1094
+
986
1095
  # generate sentence by sentence
987
1096
  for i, sent in enumerate(sentences):
988
1097
  # construct chat messages
@@ -994,10 +1103,20 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
994
1103
 
995
1104
  if self.context_sentences == 0:
996
1105
  # no context, just place sentence of interest
997
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1106
+ if isinstance(text_content, str):
1107
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1108
+ else:
1109
+ sentence_content = text_content.copy()
1110
+ sentence_content[document_key] = sent['sentence_text']
1111
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
998
1112
  else:
999
1113
  # insert context
1000
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1114
+ if isinstance(text_content, str):
1115
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1116
+ else:
1117
+ context_content = text_content.copy()
1118
+ context_content[document_key] = context
1119
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1001
1120
  # simulate conversation
1002
1121
  messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1003
1122
  # place sentence of interest
@@ -1020,6 +1139,7 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1020
1139
  # Review
1021
1140
  if stream:
1022
1141
  print(f"\n{Fore.YELLOW}Review:{Style.RESET_ALL}")
1142
+
1023
1143
  messages.append({'role': 'assistant', 'content': initial})
1024
1144
  messages.append({'role': 'user', 'content': self.review_prompt})
1025
1145
 
@@ -1037,15 +1157,23 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1037
1157
  elif self.review_mode == "addition":
1038
1158
  gen_text = initial + '\n' + review
1039
1159
 
1160
+ if return_messages_log:
1161
+ messages.append({"role": "assistant", "content": review})
1162
+ messages_log.append(messages)
1163
+
1040
1164
  # add to output
1041
1165
  output.append({'sentence_start': sent['start'],
1042
1166
  'sentence_end': sent['end'],
1043
1167
  'sentence_text': sent['sentence_text'],
1044
1168
  'gen_text': gen_text})
1169
+
1170
+ if return_messages_log:
1171
+ return output, messages_log
1172
+
1045
1173
  return output
1046
1174
 
1047
1175
  async def extract_async(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
1048
- document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict[str,str]]:
1176
+ document_key:str=None, temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
1049
1177
  """
1050
1178
  The asynchronous version of the extract() method.
1051
1179
 
@@ -1064,6 +1192,8 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1064
1192
  the temperature for token sampling.
1065
1193
  concurrent_batch_size : int, Optional
1066
1194
  the number of sentences to process in concurrent.
1195
+ return_messages_log : bool, Optional
1196
+ if True, a list of messages will be returned.
1067
1197
 
1068
1198
  Return : str
1069
1199
  the output from LLM. Need post-processing.
@@ -1082,10 +1212,14 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1082
1212
  raise ValueError("document_key must be provided when text_content is dict.")
1083
1213
  sentences = self._get_sentences(text_content[document_key])
1084
1214
 
1215
+ if return_messages_log:
1216
+ messages_log = []
1217
+
1085
1218
  # generate initial outputs sentence by sentence
1086
- tasks = []
1087
- messages_list = []
1088
1219
  for i in range(0, len(sentences), concurrent_batch_size):
1220
+ messages_list = []
1221
+ init_tasks = []
1222
+ review_tasks = []
1089
1223
  batch = sentences[i:i + concurrent_batch_size]
1090
1224
  for j, sent in enumerate(batch):
1091
1225
  # construct chat messages
@@ -1097,10 +1231,20 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1097
1231
 
1098
1232
  if self.context_sentences == 0:
1099
1233
  # no context, just place sentence of interest
1100
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1234
+ if isinstance(text_content, str):
1235
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1236
+ else:
1237
+ sentence_content = text_content.copy()
1238
+ sentence_content[document_key] = sent['sentence_text']
1239
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
1101
1240
  else:
1102
1241
  # insert context
1103
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1242
+ if isinstance(text_content, str):
1243
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1244
+ else:
1245
+ context_content = text_content.copy()
1246
+ context_content[document_key] = context
1247
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1104
1248
  # simulate conversation
1105
1249
  messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1106
1250
  # place sentence of interest
@@ -1116,24 +1260,21 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1116
1260
  **kwrs
1117
1261
  )
1118
1262
  )
1119
- tasks.append(task)
1263
+ init_tasks.append(task)
1120
1264
 
1121
- # Wait until the batch is done, collect results and move on to next batch
1122
- responses = await asyncio.gather(*tasks)
1123
- # Collect initials
1124
- initials = []
1125
- for gen_text, sent, messages in zip(responses, sentences, messages_list):
1126
- initials.append({'sentence_start': sent['start'],
1127
- 'sentence_end': sent['end'],
1128
- 'sentence_text': sent['sentence_text'],
1129
- 'gen_text': gen_text,
1130
- 'messages': messages})
1265
+ # Wait until the batch is done, collect results and move on to next batch
1266
+ init_responses = await asyncio.gather(*init_tasks)
1267
+ # Collect initials
1268
+ initials = []
1269
+ for gen_text, sent, messages in zip(init_responses, batch, messages_list):
1270
+ initials.append({'sentence_start': sent['start'],
1271
+ 'sentence_end': sent['end'],
1272
+ 'sentence_text': sent['sentence_text'],
1273
+ 'gen_text': gen_text,
1274
+ 'messages': messages})
1131
1275
 
1132
- # Review
1133
- tasks = []
1134
- for i in range(0, len(initials), concurrent_batch_size):
1135
- batch = initials[i:i + concurrent_batch_size]
1136
- for init in batch:
1276
+ # Review
1277
+ for init in initials:
1137
1278
  messages = init["messages"]
1138
1279
  initial = init["gen_text"]
1139
1280
  messages.append({'role': 'assistant', 'content': initial})
@@ -1146,29 +1287,37 @@ class SentenceReviewFrameExtractor(SentenceFrameExtractor):
1146
1287
  **kwrs
1147
1288
  )
1148
1289
  )
1149
- tasks.append(task)
1290
+ review_tasks.append(task)
1150
1291
 
1151
- responses = await asyncio.gather(*tasks)
1152
-
1153
- # Collect reviews
1154
- reviews = []
1155
- for gen_text, sent in zip(responses, sentences):
1156
- reviews.append({'sentence_start': sent['start'],
1157
- 'sentence_end': sent['end'],
1158
- 'sentence_text': sent['sentence_text'],
1159
- 'gen_text': gen_text})
1160
-
1161
- for init, rev in zip(initials, reviews):
1162
- if self.review_mode == "revision":
1163
- gen_text = rev['gen_text']
1164
- elif self.review_mode == "addition":
1165
- gen_text = init['gen_text'] + '\n' + rev['gen_text']
1166
-
1167
- # add to output
1168
- output.append({'sentence_start': init['sentence_start'],
1169
- 'sentence_end': init['sentence_end'],
1170
- 'sentence_text': init['sentence_text'],
1171
- 'gen_text': gen_text})
1292
+ review_responses = await asyncio.gather(*review_tasks)
1293
+
1294
+ # Collect reviews
1295
+ reviews = []
1296
+ for gen_text, sent in zip(review_responses, batch):
1297
+ reviews.append({'sentence_start': sent['start'],
1298
+ 'sentence_end': sent['end'],
1299
+ 'sentence_text': sent['sentence_text'],
1300
+ 'gen_text': gen_text})
1301
+
1302
+ for init, rev in zip(initials, reviews):
1303
+ if self.review_mode == "revision":
1304
+ gen_text = rev['gen_text']
1305
+ elif self.review_mode == "addition":
1306
+ gen_text = init['gen_text'] + '\n' + rev['gen_text']
1307
+
1308
+ if return_messages_log:
1309
+ messages = init["messages"]
1310
+ messages.append({"role": "assistant", "content": rev['gen_text']})
1311
+ messages_log.append(messages)
1312
+
1313
+ # add to output
1314
+ output.append({'sentence_start': init['sentence_start'],
1315
+ 'sentence_end': init['sentence_end'],
1316
+ 'sentence_text': init['sentence_text'],
1317
+ 'gen_text': gen_text})
1318
+
1319
+ if return_messages_log:
1320
+ return output, messages_log
1172
1321
  return output
1173
1322
 
1174
1323
 
@@ -1210,7 +1359,7 @@ class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1210
1359
 
1211
1360
 
1212
1361
  def extract(self, text_content:Union[str, Dict[str,str]], max_new_tokens:int=512,
1213
- document_key:str=None, temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict[str,str]]:
1362
+ document_key:str=None, temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict[str,str]]:
1214
1363
  """
1215
1364
  This method inputs a text and outputs a list of outputs per sentence.
1216
1365
 
@@ -1229,6 +1378,8 @@ class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1229
1378
  the temperature for token sampling.
1230
1379
  stream : bool, Optional
1231
1380
  if True, LLM generated text will be printed in terminal in real-time.
1381
+ return_messages_log : bool, Optional
1382
+ if True, a list of messages will be returned.
1232
1383
 
1233
1384
  Return : str
1234
1385
  the output from LLM. Need post-processing.
@@ -1241,6 +1392,9 @@ class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1241
1392
  elif isinstance(text_content, dict):
1242
1393
  sentences = self._get_sentences(text_content[document_key])
1243
1394
 
1395
+ if return_messages_log:
1396
+ messages_log = []
1397
+
1244
1398
  # generate sentence by sentence
1245
1399
  for i, sent in enumerate(sentences):
1246
1400
  # construct chat messages
@@ -1252,10 +1406,20 @@ class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1252
1406
 
1253
1407
  if self.context_sentences == 0:
1254
1408
  # no context, just place sentence of interest
1255
- messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1409
+ if isinstance(text_content, str):
1410
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sent['sentence_text'])})
1411
+ else:
1412
+ sentence_content = text_content.copy()
1413
+ sentence_content[document_key] = sent['sentence_text']
1414
+ messages.append({'role': 'user', 'content': self._get_user_prompt(sentence_content)})
1256
1415
  else:
1257
1416
  # insert context
1258
- messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1417
+ if isinstance(text_content, str):
1418
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
1419
+ else:
1420
+ context_content = text_content.copy()
1421
+ context_content[document_key] = context
1422
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
1259
1423
  # simulate conversation
1260
1424
  messages.append({'role': 'assistant', 'content': 'Sure, please provide the sentence of interest.'})
1261
1425
  # place sentence of interest
@@ -1275,11 +1439,18 @@ class SentenceCoTFrameExtractor(SentenceFrameExtractor):
1275
1439
  **kwrs
1276
1440
  )
1277
1441
 
1442
+ if return_messages_log:
1443
+ messages.append({"role": "assistant", "content": gen_text})
1444
+ messages_log.append(messages)
1445
+
1278
1446
  # add to output
1279
1447
  output.append({'sentence_start': sent['start'],
1280
1448
  'sentence_end': sent['end'],
1281
1449
  'sentence_text': sent['sentence_text'],
1282
1450
  'gen_text': gen_text})
1451
+
1452
+ if return_messages_log:
1453
+ return output, messages_log
1283
1454
  return output
1284
1455
 
1285
1456
 
@@ -1350,7 +1521,7 @@ class RelationExtractor(Extractor):
1350
1521
 
1351
1522
  @abc.abstractmethod
1352
1523
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1353
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1524
+ temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1354
1525
  """
1355
1526
  This method considers all combinations of two frames.
1356
1527
 
@@ -1366,6 +1537,8 @@ class RelationExtractor(Extractor):
1366
1537
  the temperature for token sampling.
1367
1538
  stream : bool, Optional
1368
1539
  if True, LLM generated text will be printed in terminal in real-time.
1540
+ return_messages_log : bool, Optional
1541
+ if True, a list of messages will be returned.
1369
1542
 
1370
1543
  Return : List[Dict]
1371
1544
  a list of dict with {"frame_1", "frame_2"} for all relations.
@@ -1435,7 +1608,7 @@ class BinaryRelationExtractor(RelationExtractor):
1435
1608
 
1436
1609
 
1437
1610
  def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1438
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1611
+ temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1439
1612
  """
1440
1613
  This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1441
1614
  Outputs pairs that are related.
@@ -1452,11 +1625,17 @@ class BinaryRelationExtractor(RelationExtractor):
1452
1625
  the temperature for token sampling.
1453
1626
  stream : bool, Optional
1454
1627
  if True, LLM generated text will be printed in terminal in real-time.
1628
+ return_messages_log : bool, Optional
1629
+ if True, a list of messages will be returned.
1455
1630
 
1456
1631
  Return : List[Dict]
1457
1632
  a list of dict with {"frame_1_id", "frame_2_id"}.
1458
1633
  """
1459
1634
  pairs = itertools.combinations(doc.frames, 2)
1635
+
1636
+ if return_messages_log:
1637
+ messages_log = []
1638
+
1460
1639
  output = []
1461
1640
  for frame_1, frame_2 in pairs:
1462
1641
  pos_rel = self.possible_relation_func(frame_1, frame_2)
@@ -1484,13 +1663,19 @@ class BinaryRelationExtractor(RelationExtractor):
1484
1663
  )
1485
1664
  rel_json = self._extract_json(gen_text)
1486
1665
  if self._post_process(rel_json):
1487
- output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id})
1666
+ output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
1488
1667
 
1668
+ if return_messages_log:
1669
+ messages.append({"role": "assistant", "content": gen_text})
1670
+ messages_log.append(messages)
1671
+
1672
+ if return_messages_log:
1673
+ return output, messages_log
1489
1674
  return output
1490
1675
 
1491
1676
 
1492
1677
  async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1493
- temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
1678
+ temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1494
1679
  """
1495
1680
  This is the asynchronous version of the extract() method.
1496
1681
 
@@ -1506,6 +1691,8 @@ class BinaryRelationExtractor(RelationExtractor):
1506
1691
  the temperature for token sampling.
1507
1692
  concurrent_batch_size : int, Optional
1508
1693
  the number of frame pairs to process in concurrent.
1694
+ return_messages_log : bool, Optional
1695
+ if True, a list of messages will be returned.
1509
1696
 
1510
1697
  Return : List[Dict]
1511
1698
  a list of dict with {"frame_1", "frame_2"}.
@@ -1515,12 +1702,17 @@ class BinaryRelationExtractor(RelationExtractor):
1515
1702
  raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1516
1703
 
1517
1704
  pairs = itertools.combinations(doc.frames, 2)
1705
+ if return_messages_log:
1706
+ messages_log = []
1707
+
1518
1708
  n_frames = len(doc.frames)
1519
1709
  num_pairs = (n_frames * (n_frames-1)) // 2
1520
- rel_pair_list = []
1521
- tasks = []
1710
+ output = []
1522
1711
  for i in range(0, num_pairs, concurrent_batch_size):
1712
+ rel_pair_list = []
1713
+ tasks = []
1523
1714
  batch = list(itertools.islice(pairs, concurrent_batch_size))
1715
+ batch_messages = []
1524
1716
  for frame_1, frame_2 in batch:
1525
1717
  pos_rel = self.possible_relation_func(frame_1, frame_2)
1526
1718
 
@@ -1535,6 +1727,7 @@ class BinaryRelationExtractor(RelationExtractor):
1535
1727
  "frame_1": str(frame_1.to_dict()),
1536
1728
  "frame_2": str(frame_2.to_dict())}
1537
1729
  )})
1730
+
1538
1731
  task = asyncio.create_task(
1539
1732
  self.inference_engine.chat_async(
1540
1733
  messages=messages,
@@ -1544,20 +1737,27 @@ class BinaryRelationExtractor(RelationExtractor):
1544
1737
  )
1545
1738
  )
1546
1739
  tasks.append(task)
1740
+ batch_messages.append(messages)
1547
1741
 
1548
1742
  responses = await asyncio.gather(*tasks)
1549
1743
 
1550
- output = []
1551
- for d, response in zip(rel_pair_list, responses):
1552
- rel_json = self._extract_json(response)
1553
- if self._post_process(rel_json):
1554
- output.append(d)
1744
+ for d, response, messages in zip(rel_pair_list, responses, batch_messages):
1745
+ if return_messages_log:
1746
+ messages.append({"role": "assistant", "content": response})
1747
+ messages_log.append(messages)
1555
1748
 
1749
+ rel_json = self._extract_json(response)
1750
+ if self._post_process(rel_json):
1751
+ output.append(d)
1752
+
1753
+ if return_messages_log:
1754
+ return output, messages_log
1556
1755
  return output
1557
1756
 
1558
1757
 
1559
1758
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1560
- temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
1759
+ temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
1760
+ stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1561
1761
  """
1562
1762
  This method considers all combinations of two frames. Use the possible_relation_func to filter impossible pairs.
1563
1763
 
@@ -1577,6 +1777,8 @@ class BinaryRelationExtractor(RelationExtractor):
1577
1777
  the number of frame pairs to process in concurrent.
1578
1778
  stream : bool, Optional
1579
1779
  if True, LLM generated text will be printed in terminal in real-time.
1780
+ return_messages_log : bool, Optional
1781
+ if True, a list of messages will be returned.
1580
1782
 
1581
1783
  Return : List[Dict]
1582
1784
  a list of dict with {"frame_1", "frame_2"} for all relations.
@@ -1597,6 +1799,7 @@ class BinaryRelationExtractor(RelationExtractor):
1597
1799
  max_new_tokens=max_new_tokens,
1598
1800
  temperature=temperature,
1599
1801
  concurrent_batch_size=concurrent_batch_size,
1802
+ return_messages_log=return_messages_log,
1600
1803
  **kwrs)
1601
1804
  )
1602
1805
  else:
@@ -1605,6 +1808,7 @@ class BinaryRelationExtractor(RelationExtractor):
1605
1808
  max_new_tokens=max_new_tokens,
1606
1809
  temperature=temperature,
1607
1810
  stream=stream,
1811
+ return_messages_log=return_messages_log,
1608
1812
  **kwrs)
1609
1813
 
1610
1814
 
@@ -1678,7 +1882,7 @@ class MultiClassRelationExtractor(RelationExtractor):
1678
1882
 
1679
1883
 
1680
1884
  def extract(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1681
- temperature:float=0.0, stream:bool=False, **kwrs) -> List[Dict]:
1885
+ temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1682
1886
  """
1683
1887
  This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1684
1888
 
@@ -1694,11 +1898,17 @@ class MultiClassRelationExtractor(RelationExtractor):
1694
1898
  the temperature for token sampling.
1695
1899
  stream : bool, Optional
1696
1900
  if True, LLM generated text will be printed in terminal in real-time.
1901
+ return_messages_log : bool, Optional
1902
+ if True, a list of messages will be returned.
1697
1903
 
1698
1904
  Return : List[Dict]
1699
- a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
1905
+ a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
1700
1906
  """
1701
1907
  pairs = itertools.combinations(doc.frames, 2)
1908
+
1909
+ if return_messages_log:
1910
+ messages_log = []
1911
+
1702
1912
  output = []
1703
1913
  for frame_1, frame_2 in pairs:
1704
1914
  pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
@@ -1725,16 +1935,23 @@ class MultiClassRelationExtractor(RelationExtractor):
1725
1935
  stream=stream,
1726
1936
  **kwrs
1727
1937
  )
1938
+
1939
+ if return_messages_log:
1940
+ messages.append({"role": "assistant", "content": gen_text})
1941
+ messages_log.append(messages)
1942
+
1728
1943
  rel_json = self._extract_json(gen_text)
1729
1944
  rel = self._post_process(rel_json, pos_rel_types)
1730
1945
  if rel:
1731
- output.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'relation':rel})
1946
+ output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id, 'relation':rel})
1732
1947
 
1948
+ if return_messages_log:
1949
+ return output, messages_log
1733
1950
  return output
1734
1951
 
1735
1952
 
1736
1953
  async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1737
- temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[Dict]:
1954
+ temperature:float=0.0, concurrent_batch_size:int=32, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1738
1955
  """
1739
1956
  This is the asynchronous version of the extract() method.
1740
1957
 
@@ -1750,21 +1967,28 @@ class MultiClassRelationExtractor(RelationExtractor):
1750
1967
  the temperature for token sampling.
1751
1968
  concurrent_batch_size : int, Optional
1752
1969
  the number of frame pairs to process in concurrent.
1970
+ return_messages_log : bool, Optional
1971
+ if True, a list of messages will be returned.
1753
1972
 
1754
1973
  Return : List[Dict]
1755
- a list of dict with {"frame_1", "frame_2", "relation"} for all frame pairs.
1974
+ a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
1756
1975
  """
1757
1976
  # Check if self.inference_engine.chat_async() is implemented
1758
1977
  if not hasattr(self.inference_engine, 'chat_async'):
1759
1978
  raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
1760
1979
 
1761
1980
  pairs = itertools.combinations(doc.frames, 2)
1981
+ if return_messages_log:
1982
+ messages_log = []
1983
+
1762
1984
  n_frames = len(doc.frames)
1763
1985
  num_pairs = (n_frames * (n_frames-1)) // 2
1764
- rel_pair_list = []
1765
- tasks = []
1986
+ output = []
1766
1987
  for i in range(0, num_pairs, concurrent_batch_size):
1988
+ rel_pair_list = []
1989
+ tasks = []
1767
1990
  batch = list(itertools.islice(pairs, concurrent_batch_size))
1991
+ batch_messages = []
1768
1992
  for frame_1, frame_2 in batch:
1769
1993
  pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
1770
1994
 
@@ -1789,21 +2013,28 @@ class MultiClassRelationExtractor(RelationExtractor):
1789
2013
  )
1790
2014
  )
1791
2015
  tasks.append(task)
2016
+ batch_messages.append(messages)
1792
2017
 
1793
2018
  responses = await asyncio.gather(*tasks)
1794
2019
 
1795
- output = []
1796
- for d, response in zip(rel_pair_list, responses):
1797
- rel_json = self._extract_json(response)
1798
- rel = self._post_process(rel_json, d['pos_rel_types'])
1799
- if rel:
1800
- output.append({'frame_1':d['frame_1'], 'frame_2':d['frame_2'], 'relation':rel})
2020
+ for d, response, messages in zip(rel_pair_list, responses, batch_messages):
2021
+ if return_messages_log:
2022
+ messages.append({"role": "assistant", "content": response})
2023
+ messages_log.append(messages)
2024
+
2025
+ rel_json = self._extract_json(response)
2026
+ rel = self._post_process(rel_json, d['pos_rel_types'])
2027
+ if rel:
2028
+ output.append({'frame_1_id':d['frame_1'], 'frame_2_id':d['frame_2'], 'relation':rel})
1801
2029
 
2030
+ if return_messages_log:
2031
+ return output, messages_log
1802
2032
  return output
1803
2033
 
1804
2034
 
1805
2035
  def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
1806
- temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32, stream:bool=False, **kwrs) -> List[Dict]:
2036
+ temperature:float=0.0, concurrent:bool=False, concurrent_batch_size:int=32,
2037
+ stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
1807
2038
  """
1808
2039
  This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
1809
2040
 
@@ -1823,6 +2054,8 @@ class MultiClassRelationExtractor(RelationExtractor):
1823
2054
  the number of frame pairs to process in concurrent.
1824
2055
  stream : bool, Optional
1825
2056
  if True, LLM generated text will be printed in terminal in real-time.
2057
+ return_messages_log : bool, Optional
2058
+ if True, a list of messages will be returned.
1826
2059
 
1827
2060
  Return : List[Dict]
1828
2061
  a list of dict with {"frame_1", "frame_2", "relation"} for all relations.
@@ -1843,6 +2076,7 @@ class MultiClassRelationExtractor(RelationExtractor):
1843
2076
  max_new_tokens=max_new_tokens,
1844
2077
  temperature=temperature,
1845
2078
  concurrent_batch_size=concurrent_batch_size,
2079
+ return_messages_log=return_messages_log,
1846
2080
  **kwrs)
1847
2081
  )
1848
2082
  else:
@@ -1851,5 +2085,6 @@ class MultiClassRelationExtractor(RelationExtractor):
1851
2085
  max_new_tokens=max_new_tokens,
1852
2086
  temperature=temperature,
1853
2087
  stream=stream,
2088
+ return_messages_log=return_messages_log,
1854
2089
  **kwrs)
1855
2090
 
File without changes
File without changes
File without changes