llm-ie 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +2 -2
- llm_ie/asset/prompt_guide/AttributeExtractor_prompt_guide.txt +52 -0
- llm_ie/extractors.py +409 -460
- {llm_ie-1.1.0.dist-info → llm_ie-1.2.0.dist-info}/METADATA +1 -1
- {llm_ie-1.1.0.dist-info → llm_ie-1.2.0.dist-info}/RECORD +6 -5
- {llm_ie-1.1.0.dist-info → llm_ie-1.2.0.dist-info}/WHEEL +0 -0
llm_ie/__init__.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from .data_types import LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
2
2
|
from .engines import BasicLLMConfig, Qwen3LLMConfig, OpenAIReasoningLLMConfig, LlamaCppInferenceEngine, OllamaInferenceEngine, HuggingFaceHubInferenceEngine, OpenAIInferenceEngine, AzureOpenAIInferenceEngine, LiteLLMInferenceEngine
|
|
3
|
-
from .extractors import DirectFrameExtractor, ReviewFrameExtractor, BasicFrameExtractor, BasicReviewFrameExtractor, SentenceFrameExtractor, SentenceReviewFrameExtractor, BinaryRelationExtractor, MultiClassRelationExtractor
|
|
3
|
+
from .extractors import DirectFrameExtractor, ReviewFrameExtractor, BasicFrameExtractor, BasicReviewFrameExtractor, SentenceFrameExtractor, SentenceReviewFrameExtractor, AttributeExtractor, BinaryRelationExtractor, MultiClassRelationExtractor
|
|
4
4
|
from .chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker, TextLineUnitChunker, ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
|
|
5
5
|
from .prompt_editor import PromptEditor
|
|
6
6
|
|
|
7
7
|
__all__ = ["LLMInformationExtractionFrame", "LLMInformationExtractionDocument",
|
|
8
8
|
"BasicLLMConfig", "Qwen3LLMConfig", "OpenAIReasoningLLMConfig", "LlamaCppInferenceEngine", "OllamaInferenceEngine", "HuggingFaceHubInferenceEngine", "OpenAIInferenceEngine", "AzureOpenAIInferenceEngine", "LiteLLMInferenceEngine",
|
|
9
|
-
"DirectFrameExtractor", "ReviewFrameExtractor", "BasicFrameExtractor", "BasicReviewFrameExtractor", "SentenceFrameExtractor", "SentenceReviewFrameExtractor", "BinaryRelationExtractor", "MultiClassRelationExtractor",
|
|
9
|
+
"DirectFrameExtractor", "ReviewFrameExtractor", "BasicFrameExtractor", "BasicReviewFrameExtractor", "SentenceFrameExtractor", "SentenceReviewFrameExtractor", "AttributeExtractor", "BinaryRelationExtractor", "MultiClassRelationExtractor",
|
|
10
10
|
"UnitChunker", "WholeDocumentUnitChunker", "SentenceUnitChunker", "TextLineUnitChunker", "ContextChunker", "NoContextChunker", "WholeDocumentContextChunker", "SlideWindowContextChunker",
|
|
11
11
|
"PromptEditor"]
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
Prompt Template Design:
|
|
2
|
+
|
|
3
|
+
1. Task Description:
|
|
4
|
+
Provide a detailed description of the task, including the background and the type of task (e.g., attribute extraction task).
|
|
5
|
+
|
|
6
|
+
2. Schema Definition:
|
|
7
|
+
List the attributes to extract, and provide clear definitions for each one.
|
|
8
|
+
|
|
9
|
+
3. Output Format Definition:
|
|
10
|
+
The output should be a JSON list, where each attribute be a key. The values could be any structure (e.g., str, int, List[str]).
|
|
11
|
+
|
|
12
|
+
4. Optional: Hints:
|
|
13
|
+
Provide itemized hints for the information extractors to guide the extraction process. Remind the prompted agent to be truthful. Emphasize that the prompted agent is supposed to perform the task instead of writting code or instruct other agents to do it.
|
|
14
|
+
|
|
15
|
+
5. Optional: Examples:
|
|
16
|
+
Include examples in the format:
|
|
17
|
+
Input: ...
|
|
18
|
+
Output: ...
|
|
19
|
+
|
|
20
|
+
6. Entity:
|
|
21
|
+
The template must include a placeholder {{frame}} for the entity.
|
|
22
|
+
|
|
23
|
+
7. Context:
|
|
24
|
+
The template must include a placeholder {{context}} for the context. Explain to the prompted agent that <Entity> tags are used to mark the entity in the context.
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
|
|
29
|
+
### Task description
|
|
30
|
+
This is an attribute extraction task. Given a diagnosis entity and the context, you need to generate attributes for the entity.
|
|
31
|
+
|
|
32
|
+
### Schema definition
|
|
33
|
+
"Date" which is the date when the diagnosis was made in MM/DD/YYYY format,
|
|
34
|
+
"Status" which is the current status of the diagnosis (e.g. active, resolved, etc.)
|
|
35
|
+
|
|
36
|
+
### Output format definition
|
|
37
|
+
Your output should follow the JSON format:
|
|
38
|
+
{"Date": "<MM/DD/YYYY>", "Status": "<status>"}
|
|
39
|
+
|
|
40
|
+
I am only interested in the content between []. Do not explain your answer.
|
|
41
|
+
|
|
42
|
+
### Hints
|
|
43
|
+
- If the date is not complete, use the first available date in the context. For example, if the date is 01/2023, you should return 01/01/2023.
|
|
44
|
+
- If the status is not available, you should return "not specified".
|
|
45
|
+
|
|
46
|
+
### Entity
|
|
47
|
+
Information about the entity to extract attributes from:
|
|
48
|
+
{{frame}}
|
|
49
|
+
|
|
50
|
+
### Context
|
|
51
|
+
Context for the entity. The <Entity> tags are used to mark the entity in the context.
|
|
52
|
+
{{context}}
|
llm_ie/extractors.py
CHANGED
|
@@ -1449,11 +1449,11 @@ class SentenceReviewFrameExtractor(ReviewFrameExtractor):
|
|
|
1449
1449
|
context_chunker=context_chunker)
|
|
1450
1450
|
|
|
1451
1451
|
|
|
1452
|
-
class
|
|
1452
|
+
class AttributeExtractor(Extractor):
|
|
1453
1453
|
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
|
|
1454
1454
|
"""
|
|
1455
|
-
This is
|
|
1456
|
-
|
|
1455
|
+
This class is for attribute extraction for frames. Though FrameExtractors can also extract attributes, when
|
|
1456
|
+
the number of attribute increases, it is more efficient to use a dedicated AttributeExtractor.
|
|
1457
1457
|
|
|
1458
1458
|
Parameters
|
|
1459
1459
|
----------
|
|
@@ -1467,322 +1467,469 @@ class RelationExtractor(Extractor):
|
|
|
1467
1467
|
super().__init__(inference_engine=inference_engine,
|
|
1468
1468
|
prompt_template=prompt_template,
|
|
1469
1469
|
system_prompt=system_prompt)
|
|
1470
|
+
# validate prompt template
|
|
1471
|
+
if "{{context}}" not in self.prompt_template or "{{frame}}" not in self.prompt_template:
|
|
1472
|
+
raise ValueError("prompt_template must contain both {{context}} and {{frame}} placeholders.")
|
|
1470
1473
|
|
|
1471
|
-
def
|
|
1472
|
-
text:str, buffer_size:int=100) -> str:
|
|
1474
|
+
def _get_context(self, frame:LLMInformationExtractionFrame, text:str, context_size:int=256) -> str:
|
|
1473
1475
|
"""
|
|
1474
|
-
This method returns the
|
|
1475
|
-
The returned text has the
|
|
1476
|
+
This method returns the context that covers the frame. Leaves a context_size of characters before and after.
|
|
1477
|
+
The returned text has the frame inline annotated with <entity>.
|
|
1476
1478
|
|
|
1477
1479
|
Parameters:
|
|
1478
1480
|
-----------
|
|
1479
|
-
|
|
1481
|
+
frame : LLMInformationExtractionFrame
|
|
1480
1482
|
a frame
|
|
1481
|
-
frame_2 : LLMInformationExtractionFrame
|
|
1482
|
-
the other frame
|
|
1483
1483
|
text : str
|
|
1484
1484
|
the entire document text
|
|
1485
|
-
|
|
1486
|
-
the number of characters before and after the
|
|
1485
|
+
context_size : int, Optional
|
|
1486
|
+
the number of characters before and after the frame in the context text.
|
|
1487
1487
|
|
|
1488
1488
|
Return : str
|
|
1489
|
-
the
|
|
1489
|
+
the context text with the frame inline annotated with <entity>.
|
|
1490
1490
|
"""
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1491
|
+
start = max(frame.start - context_size, 0)
|
|
1492
|
+
end = min(frame.end + context_size, len(text))
|
|
1493
|
+
context = text[start:end]
|
|
1494
1494
|
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
f'<{left_frame_name}>' + \
|
|
1501
|
-
roi[left_frame.start - start:left_frame.end - start] + \
|
|
1502
|
-
f"</{left_frame_name}>" + \
|
|
1503
|
-
roi[left_frame.end - start:right_frame.start - start] + \
|
|
1504
|
-
f'<{right_frame_name}>' + \
|
|
1505
|
-
roi[right_frame.start - start:right_frame.end - start] + \
|
|
1506
|
-
f"</{right_frame_name}>" + \
|
|
1507
|
-
roi[right_frame.end - start:end - start]
|
|
1495
|
+
context_annotated = context[0:frame.start - start] + \
|
|
1496
|
+
f"<entity> " + \
|
|
1497
|
+
context[frame.start - start:frame.end - start] + \
|
|
1498
|
+
f" </entity>" + \
|
|
1499
|
+
context[frame.end - start:end - start]
|
|
1508
1500
|
|
|
1509
1501
|
if start > 0:
|
|
1510
|
-
|
|
1502
|
+
context_annotated = "..." + context_annotated
|
|
1511
1503
|
if end < len(text):
|
|
1512
|
-
|
|
1513
|
-
return
|
|
1504
|
+
context_annotated = context_annotated + "..."
|
|
1505
|
+
return context_annotated
|
|
1514
1506
|
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100, max_new_tokens:int=128,
|
|
1518
|
-
temperature:float=0.0, stream:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
1507
|
+
def _extract_from_frame(self, frame:LLMInformationExtractionFrame, text:str,
|
|
1508
|
+
context_size:int=256, verbose:bool=False, return_messages_log:bool=False) -> Dict[str, Any]:
|
|
1519
1509
|
"""
|
|
1520
|
-
This method
|
|
1510
|
+
This method extracts attributes from a single frame.
|
|
1521
1511
|
|
|
1522
1512
|
Parameters:
|
|
1523
1513
|
-----------
|
|
1524
|
-
|
|
1525
|
-
a
|
|
1526
|
-
|
|
1527
|
-
the
|
|
1528
|
-
|
|
1529
|
-
the
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
stream : bool, Optional
|
|
1533
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
1514
|
+
frame : LLMInformationExtractionFrame
|
|
1515
|
+
a frame to extract attributes from.
|
|
1516
|
+
text : str
|
|
1517
|
+
the entire document text.
|
|
1518
|
+
context_size : int, Optional
|
|
1519
|
+
the number of characters before and after the frame in the context text.
|
|
1520
|
+
verbose : bool, Optional
|
|
1521
|
+
if True, LLM generated text will be printed in terminal in real-time.
|
|
1534
1522
|
return_messages_log : bool, Optional
|
|
1535
1523
|
if True, a list of messages will be returned.
|
|
1536
1524
|
|
|
1537
|
-
Return :
|
|
1538
|
-
a
|
|
1525
|
+
Return : Dict[str, Any]
|
|
1526
|
+
a dictionary of attributes extracted from the frame.
|
|
1527
|
+
If return_messages_log is True, a list of messages will be returned as well.
|
|
1539
1528
|
"""
|
|
1540
|
-
|
|
1541
|
-
|
|
1529
|
+
# construct chat messages
|
|
1530
|
+
messages = []
|
|
1531
|
+
if self.system_prompt:
|
|
1532
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1542
1533
|
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
system_prompt:str=None):
|
|
1546
|
-
"""
|
|
1547
|
-
This class extracts binary (yes/no) relations between two entities.
|
|
1548
|
-
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
1534
|
+
context = self._get_context(frame, text, context_size)
|
|
1535
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
|
|
1549
1536
|
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
prompt_template : str
|
|
1555
|
-
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1556
|
-
possible_relation_func : Callable, Optional
|
|
1557
|
-
a function that inputs 2 frames and returns a bool indicating possible relations between them.
|
|
1558
|
-
system_prompt : str, Optional
|
|
1559
|
-
system prompt.
|
|
1560
|
-
"""
|
|
1561
|
-
super().__init__(inference_engine=inference_engine,
|
|
1562
|
-
prompt_template=prompt_template,
|
|
1563
|
-
system_prompt=system_prompt)
|
|
1564
|
-
|
|
1565
|
-
if possible_relation_func:
|
|
1566
|
-
# Check if possible_relation_func is a function
|
|
1567
|
-
if not callable(possible_relation_func):
|
|
1568
|
-
raise TypeError(f"Expect possible_relation_func as a function, received {type(possible_relation_func)} instead.")
|
|
1537
|
+
if verbose:
|
|
1538
|
+
print(f"\n\n{Fore.GREEN}Frame: {frame.frame_id}{Style.RESET_ALL}\n{frame.to_dict()}\n")
|
|
1539
|
+
if context != "":
|
|
1540
|
+
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
1569
1541
|
|
|
1570
|
-
|
|
1571
|
-
# Check if frame_1, frame_2 are in input parameters
|
|
1572
|
-
if len(sig.parameters) != 2:
|
|
1573
|
-
raise ValueError("The possible_relation_func must have exactly frame_1 and frame_2 as parameters.")
|
|
1574
|
-
if "frame_1" not in sig.parameters.keys():
|
|
1575
|
-
raise ValueError("The possible_relation_func is missing frame_1 as a parameter.")
|
|
1576
|
-
if "frame_2" not in sig.parameters.keys():
|
|
1577
|
-
raise ValueError("The possible_relation_func is missing frame_2 as a parameter.")
|
|
1578
|
-
# Check if output is a bool
|
|
1579
|
-
if sig.return_annotation != bool:
|
|
1580
|
-
raise ValueError(f"Expect possible_relation_func to output a bool, current type hint suggests {sig.return_annotation} instead.")
|
|
1581
|
-
|
|
1582
|
-
self.possible_relation_func = possible_relation_func
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
def _post_process(self, rel_json:str) -> bool:
|
|
1586
|
-
if len(rel_json) > 0:
|
|
1587
|
-
if "Relation" in rel_json[0]:
|
|
1588
|
-
rel = rel_json[0]["Relation"]
|
|
1589
|
-
if isinstance(rel, bool):
|
|
1590
|
-
return rel
|
|
1591
|
-
elif isinstance(rel, str) and rel in {"True", "False"}:
|
|
1592
|
-
return eval(rel)
|
|
1593
|
-
else:
|
|
1594
|
-
warnings.warn('Extractor output JSON "Relation" key does not have bool or {"True", "False"} as value.' + \
|
|
1595
|
-
'Following default, relation = False.', RuntimeWarning)
|
|
1596
|
-
else:
|
|
1597
|
-
warnings.warn('Extractor output JSON without "Relation" key. Following default, relation = False.', RuntimeWarning)
|
|
1598
|
-
else:
|
|
1599
|
-
warnings.warn('Extractor did not output a JSON list. Following default, relation = False.', RuntimeWarning)
|
|
1600
|
-
return False
|
|
1601
|
-
|
|
1542
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1602
1543
|
|
|
1603
|
-
|
|
1604
|
-
|
|
1544
|
+
get_text = self.inference_engine.chat(
|
|
1545
|
+
messages=messages,
|
|
1546
|
+
verbose=verbose,
|
|
1547
|
+
stream=False
|
|
1548
|
+
)
|
|
1549
|
+
if return_messages_log:
|
|
1550
|
+
messages.append({"role": "assistant", "content": get_text})
|
|
1551
|
+
|
|
1552
|
+
attribute_list = self._extract_json(gen_text=get_text)
|
|
1553
|
+
if isinstance(attribute_list, list) and len(attribute_list) > 0:
|
|
1554
|
+
attributes = attribute_list[0]
|
|
1555
|
+
if return_messages_log:
|
|
1556
|
+
return attributes, messages
|
|
1557
|
+
return attributes
|
|
1558
|
+
|
|
1559
|
+
|
|
1560
|
+
def extract(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256, verbose:bool=False,
|
|
1561
|
+
return_messages_log:bool=False, inplace:bool=True) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
1605
1562
|
"""
|
|
1606
|
-
This method
|
|
1607
|
-
Outputs pairs that are related.
|
|
1563
|
+
This method extracts attributes from the document.
|
|
1608
1564
|
|
|
1609
1565
|
Parameters:
|
|
1610
1566
|
-----------
|
|
1611
|
-
|
|
1612
|
-
a
|
|
1613
|
-
|
|
1614
|
-
the
|
|
1567
|
+
frames : List[LLMInformationExtractionFrame]
|
|
1568
|
+
a list of frames to extract attributes from.
|
|
1569
|
+
text : str
|
|
1570
|
+
the entire document text.
|
|
1571
|
+
context_size : int, Optional
|
|
1572
|
+
the number of characters before and after the frame in the context text.
|
|
1615
1573
|
verbose : bool, Optional
|
|
1616
1574
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
1617
1575
|
return_messages_log : bool, Optional
|
|
1618
1576
|
if True, a list of messages will be returned.
|
|
1577
|
+
inplace : bool, Optional
|
|
1578
|
+
if True, the method will modify the frames in-place.
|
|
1579
|
+
|
|
1580
|
+
Return : Union[None, List[LLMInformationExtractionFrame]]
|
|
1581
|
+
if inplace is True, the method will modify the frames in-place.
|
|
1582
|
+
if inplace is False, the method will return a list of frames with attributes extracted.
|
|
1583
|
+
"""
|
|
1584
|
+
for frame in frames:
|
|
1585
|
+
if not isinstance(frame, LLMInformationExtractionFrame):
|
|
1586
|
+
raise TypeError(f"Expect frame as LLMInformationExtractionFrame, received {type(frame)} instead.")
|
|
1587
|
+
if not isinstance(text, str):
|
|
1588
|
+
raise TypeError(f"Expect text as str, received {type(text)} instead.")
|
|
1589
|
+
|
|
1590
|
+
new_frames = []
|
|
1591
|
+
messages_log = [] if return_messages_log else None
|
|
1619
1592
|
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
|
|
1623
|
-
|
|
1593
|
+
for frame in frames:
|
|
1594
|
+
if return_messages_log:
|
|
1595
|
+
attr, messages = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1596
|
+
verbose=verbose, return_messages_log=return_messages_log)
|
|
1597
|
+
messages_log.append(messages)
|
|
1598
|
+
else:
|
|
1599
|
+
attr = self._extract_from_frame(frame=frame, text=text, context_size=context_size,
|
|
1600
|
+
verbose=verbose, return_messages_log=return_messages_log)
|
|
1601
|
+
|
|
1602
|
+
if inplace:
|
|
1603
|
+
frame.attr.update(attr)
|
|
1604
|
+
else:
|
|
1605
|
+
new_frame = frame.copy()
|
|
1606
|
+
new_frame.attr.update(attr)
|
|
1607
|
+
new_frames.append(new_frame)
|
|
1624
1608
|
|
|
1625
|
-
if
|
|
1626
|
-
messages_log
|
|
1609
|
+
if inplace:
|
|
1610
|
+
return messages_log if return_messages_log else None
|
|
1611
|
+
else:
|
|
1612
|
+
return (new_frames, messages_log) if return_messages_log else new_frames
|
|
1627
1613
|
|
|
1628
|
-
output = []
|
|
1629
|
-
for frame_1, frame_2 in pairs:
|
|
1630
|
-
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1631
1614
|
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1615
|
+
async def extract_async(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
1616
|
+
concurrent_batch_size:int=32, inplace:bool=True, return_messages_log:bool=False) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
1617
|
+
"""
|
|
1618
|
+
This method extracts attributes from the document asynchronously.
|
|
1619
|
+
|
|
1620
|
+
Parameters:
|
|
1621
|
+
-----------
|
|
1622
|
+
frames : List[LLMInformationExtractionFrame]
|
|
1623
|
+
a list of frames to extract attributes from.
|
|
1624
|
+
text : str
|
|
1625
|
+
the entire document text.
|
|
1626
|
+
context_size : int, Optional
|
|
1627
|
+
the number of characters before and after the frame in the context text.
|
|
1628
|
+
concurrent_batch_size : int, Optional
|
|
1629
|
+
the batch size for concurrent processing.
|
|
1630
|
+
inplace : bool, Optional
|
|
1631
|
+
if True, the method will modify the frames in-place.
|
|
1632
|
+
return_messages_log : bool, Optional
|
|
1633
|
+
if True, a list of messages will be returned.
|
|
1634
|
+
|
|
1635
|
+
Return : Union[None, List[LLMInformationExtractionFrame]]
|
|
1636
|
+
if inplace is True, the method will modify the frames in-place.
|
|
1637
|
+
if inplace is False, the method will return a list of frames with attributes extracted.
|
|
1638
|
+
"""
|
|
1639
|
+
# validation
|
|
1640
|
+
for frame in frames:
|
|
1641
|
+
if not isinstance(frame, LLMInformationExtractionFrame):
|
|
1642
|
+
raise TypeError(f"Expect frame as LLMInformationExtractionFrame, received {type(frame)} instead.")
|
|
1643
|
+
if not isinstance(text, str):
|
|
1644
|
+
raise TypeError(f"Expect text as str, received {type(text)} instead.")
|
|
1645
|
+
|
|
1646
|
+
# async helper
|
|
1647
|
+
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1648
|
+
|
|
1649
|
+
async def semaphore_helper(frame:LLMInformationExtractionFrame, text:str, context_size:int) -> dict:
|
|
1650
|
+
async with semaphore:
|
|
1637
1651
|
messages = []
|
|
1638
1652
|
if self.system_prompt:
|
|
1639
1653
|
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1640
1654
|
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
"frame_2": str(frame_2.to_dict())}
|
|
1644
|
-
)})
|
|
1645
|
-
|
|
1646
|
-
gen_text = self.inference_engine.chat(
|
|
1647
|
-
messages=messages,
|
|
1648
|
-
verbose=verbose
|
|
1649
|
-
)
|
|
1650
|
-
rel_json = self._extract_json(gen_text)
|
|
1651
|
-
if self._post_process(rel_json):
|
|
1652
|
-
output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
|
|
1655
|
+
context = self._get_context(frame, text, context_size)
|
|
1656
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
|
|
1653
1657
|
|
|
1658
|
+
gen_text = await self.inference_engine.chat_async(messages=messages)
|
|
1659
|
+
|
|
1654
1660
|
if return_messages_log:
|
|
1655
1661
|
messages.append({"role": "assistant", "content": gen_text})
|
|
1656
|
-
messages_log.append(messages)
|
|
1657
1662
|
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
|
|
1663
|
+
attribute_list = self._extract_json(gen_text=gen_text)
|
|
1664
|
+
attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
|
|
1665
|
+
return {"frame": frame, "attributes": attributes, "messages": messages}
|
|
1666
|
+
|
|
1667
|
+
# create tasks
|
|
1668
|
+
tasks = [asyncio.create_task(semaphore_helper(frame, text, context_size)) for frame in frames]
|
|
1669
|
+
results = await asyncio.gather(*tasks)
|
|
1670
|
+
|
|
1671
|
+
# process results
|
|
1672
|
+
new_frames = []
|
|
1673
|
+
messages_log = [] if return_messages_log else None
|
|
1674
|
+
|
|
1675
|
+
for result in results:
|
|
1676
|
+
if return_messages_log:
|
|
1677
|
+
messages_log.append(result["messages"])
|
|
1678
|
+
|
|
1679
|
+
if inplace:
|
|
1680
|
+
result["frame"].attr.update(result["attributes"])
|
|
1681
|
+
else:
|
|
1682
|
+
new_frame = result["frame"].copy()
|
|
1683
|
+
new_frame.attr.update(result["attributes"])
|
|
1684
|
+
new_frames.append(new_frame)
|
|
1685
|
+
|
|
1686
|
+
# output
|
|
1687
|
+
if inplace:
|
|
1688
|
+
return messages_log if return_messages_log else None
|
|
1689
|
+
else:
|
|
1690
|
+
return (new_frames, messages_log) if return_messages_log else new_frames
|
|
1691
|
+
|
|
1692
|
+
def extract_attributes(self, frames:List[LLMInformationExtractionFrame], text:str, context_size:int=256,
|
|
1693
|
+
concurrent:bool=False, concurrent_batch_size:int=32, verbose:bool=False,
|
|
1694
|
+
return_messages_log:bool=False, inplace:bool=True) -> Union[None, List[LLMInformationExtractionFrame]]:
|
|
1665
1695
|
"""
|
|
1666
|
-
This
|
|
1696
|
+
This method extracts attributes from the document.
|
|
1667
1697
|
|
|
1668
1698
|
Parameters:
|
|
1669
1699
|
-----------
|
|
1670
|
-
|
|
1671
|
-
a
|
|
1672
|
-
|
|
1673
|
-
the
|
|
1674
|
-
|
|
1675
|
-
the
|
|
1676
|
-
|
|
1677
|
-
the
|
|
1700
|
+
frames : List[LLMInformationExtractionFrame]
|
|
1701
|
+
a list of frames to extract attributes from.
|
|
1702
|
+
text : str
|
|
1703
|
+
the entire document text.
|
|
1704
|
+
context_size : int, Optional
|
|
1705
|
+
the number of characters before and after the frame in the context text.
|
|
1706
|
+
concurrent : bool, Optional
|
|
1707
|
+
if True, the method will run in concurrent mode with batch size concurrent_batch_size.
|
|
1678
1708
|
concurrent_batch_size : int, Optional
|
|
1679
|
-
the
|
|
1709
|
+
the batch size for concurrent processing.
|
|
1710
|
+
verbose : bool, Optional
|
|
1711
|
+
if True, LLM generated text will be printed in terminal in real-time.
|
|
1680
1712
|
return_messages_log : bool, Optional
|
|
1681
1713
|
if True, a list of messages will be returned.
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
a list of dict with {"frame_1", "frame_2"}.
|
|
1685
|
-
"""
|
|
1686
|
-
# Check if self.inference_engine.chat_async() is implemented
|
|
1687
|
-
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1688
|
-
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1714
|
+
inplace : bool, Optional
|
|
1715
|
+
if True, the method will modify the frames in-place.
|
|
1689
1716
|
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
for i in range(0, num_pairs, concurrent_batch_size):
|
|
1698
|
-
rel_pair_list = []
|
|
1699
|
-
tasks = []
|
|
1700
|
-
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
1701
|
-
batch_messages = []
|
|
1702
|
-
for frame_1, frame_2 in batch:
|
|
1703
|
-
pos_rel = self.possible_relation_func(frame_1, frame_2)
|
|
1704
|
-
|
|
1705
|
-
if pos_rel:
|
|
1706
|
-
rel_pair_list.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id})
|
|
1707
|
-
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1708
|
-
messages = []
|
|
1709
|
-
if self.system_prompt:
|
|
1710
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1711
|
-
|
|
1712
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1713
|
-
"frame_1": str(frame_1.to_dict()),
|
|
1714
|
-
"frame_2": str(frame_2.to_dict())}
|
|
1715
|
-
)})
|
|
1716
|
-
|
|
1717
|
-
task = asyncio.create_task(
|
|
1718
|
-
self.inference_engine.chat_async(
|
|
1719
|
-
messages=messages
|
|
1720
|
-
)
|
|
1721
|
-
)
|
|
1722
|
-
tasks.append(task)
|
|
1723
|
-
batch_messages.append(messages)
|
|
1724
|
-
|
|
1725
|
-
responses = await asyncio.gather(*tasks)
|
|
1717
|
+
Return : Union[None, List[LLMInformationExtractionFrame]]
|
|
1718
|
+
if inplace is True, the method will modify the frames in-place.
|
|
1719
|
+
if inplace is False, the method will return a list of frames with attributes extracted.
|
|
1720
|
+
"""
|
|
1721
|
+
if concurrent:
|
|
1722
|
+
if verbose:
|
|
1723
|
+
warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1726
1724
|
|
|
1727
|
-
|
|
1728
|
-
if return_messages_log:
|
|
1729
|
-
messages.append({"role": "assistant", "content": response})
|
|
1730
|
-
messages_log.append(messages)
|
|
1725
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
1731
1726
|
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
|
|
1727
|
+
return asyncio.run(self.extract_async(frames=frames, text=text, context_size=context_size,
|
|
1728
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
1729
|
+
inplace=inplace, return_messages_log=return_messages_log))
|
|
1730
|
+
else:
|
|
1731
|
+
return self.extract(frames=frames, text=text, context_size=context_size,
|
|
1732
|
+
verbose=verbose, return_messages_log=return_messages_log, inplace=inplace)
|
|
1735
1733
|
|
|
1736
|
-
if return_messages_log:
|
|
1737
|
-
return output, messages_log
|
|
1738
|
-
return output
|
|
1739
1734
|
|
|
1735
|
+
class RelationExtractor(Extractor):
|
|
1736
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, system_prompt:str=None):
|
|
1737
|
+
"""
|
|
1738
|
+
This is the abstract class for relation extraction.
|
|
1739
|
+
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
1740
1740
|
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1741
|
+
Parameters
|
|
1742
|
+
----------
|
|
1743
|
+
inference_engine : InferenceEngine
|
|
1744
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
1745
|
+
prompt_template : str
|
|
1746
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1747
|
+
system_prompt : str, Optional
|
|
1748
|
+
system prompt.
|
|
1744
1749
|
"""
|
|
1745
|
-
|
|
1750
|
+
super().__init__(inference_engine=inference_engine,
|
|
1751
|
+
prompt_template=prompt_template,
|
|
1752
|
+
system_prompt=system_prompt)
|
|
1753
|
+
|
|
1754
|
+
def _get_ROI(self, frame_1:LLMInformationExtractionFrame, frame_2:LLMInformationExtractionFrame,
|
|
1755
|
+
text:str, buffer_size:int=128) -> str:
|
|
1756
|
+
"""
|
|
1757
|
+
This method returns the Region of Interest (ROI) that covers the two frames. Leaves a buffer_size of characters before and after.
|
|
1758
|
+
The returned text has the two frames inline annotated with <entity_1>, <entity_2>.
|
|
1746
1759
|
|
|
1747
1760
|
Parameters:
|
|
1748
1761
|
-----------
|
|
1749
|
-
|
|
1750
|
-
a
|
|
1762
|
+
frame_1 : LLMInformationExtractionFrame
|
|
1763
|
+
a frame
|
|
1764
|
+
frame_2 : LLMInformationExtractionFrame
|
|
1765
|
+
the other frame
|
|
1766
|
+
text : str
|
|
1767
|
+
the entire document text
|
|
1751
1768
|
buffer_size : int, Optional
|
|
1752
1769
|
the number of characters before and after the two frames in the ROI text.
|
|
1753
|
-
concurrent: bool, Optional
|
|
1754
|
-
if True, the extraction will be done in concurrent.
|
|
1755
|
-
concurrent_batch_size : int, Optional
|
|
1756
|
-
the number of frame pairs to process in concurrent.
|
|
1757
|
-
verbose : bool, Optional
|
|
1758
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
1759
|
-
return_messages_log : bool, Optional
|
|
1760
|
-
if True, a list of messages will be returned.
|
|
1761
1770
|
|
|
1762
|
-
Return :
|
|
1763
|
-
|
|
1771
|
+
Return : str
|
|
1772
|
+
the ROI text with the two frames inline annotated with <entity_1>, <entity_2>.
|
|
1764
1773
|
"""
|
|
1774
|
+
left_frame, right_frame = sorted([frame_1, frame_2], key=lambda f: f.start)
|
|
1775
|
+
left_frame_name = "entity_1" if left_frame.frame_id == frame_1.frame_id else "entity_2"
|
|
1776
|
+
right_frame_name = "entity_1" if right_frame.frame_id == frame_1.frame_id else "entity_2"
|
|
1777
|
+
|
|
1778
|
+
start = max(left_frame.start - buffer_size, 0)
|
|
1779
|
+
end = min(right_frame.end + buffer_size, len(text))
|
|
1780
|
+
roi = text[start:end]
|
|
1781
|
+
|
|
1782
|
+
roi_annotated = roi[0:left_frame.start - start] + \
|
|
1783
|
+
f"<{left_frame_name}> " + \
|
|
1784
|
+
roi[left_frame.start - start:left_frame.end - start] + \
|
|
1785
|
+
f" </{left_frame_name}>" + \
|
|
1786
|
+
roi[left_frame.end - start:right_frame.start - start] + \
|
|
1787
|
+
f"<{right_frame_name}> " + \
|
|
1788
|
+
roi[right_frame.start - start:right_frame.end - start] + \
|
|
1789
|
+
f" </{right_frame_name}>" + \
|
|
1790
|
+
roi[right_frame.end - start:end - start]
|
|
1791
|
+
|
|
1792
|
+
if start > 0:
|
|
1793
|
+
roi_annotated = "..." + roi_annotated
|
|
1794
|
+
if end < len(text):
|
|
1795
|
+
roi_annotated = roi_annotated + "..."
|
|
1796
|
+
return roi_annotated
|
|
1797
|
+
|
|
1798
|
+
@abc.abstractmethod
|
|
1799
|
+
def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
|
|
1800
|
+
text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
|
|
1801
|
+
"""Checks if a relation is possible and constructs the task payload."""
|
|
1802
|
+
raise NotImplementedError
|
|
1803
|
+
|
|
1804
|
+
@abc.abstractmethod
|
|
1805
|
+
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1806
|
+
"""Processes the LLM output for a single pair and returns the final relation dictionary."""
|
|
1807
|
+
raise NotImplementedError
|
|
1808
|
+
|
|
1809
|
+
def _extract(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, verbose: bool = False,
|
|
1810
|
+
return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1811
|
+
pairs = itertools.combinations(doc.frames, 2)
|
|
1812
|
+
relations = []
|
|
1813
|
+
messages_log = [] if return_messages_log else None
|
|
1814
|
+
|
|
1815
|
+
for frame_1, frame_2 in pairs:
|
|
1816
|
+
task_payload = self._get_task_if_possible(frame_1, frame_2, doc.text, buffer_size)
|
|
1817
|
+
if task_payload:
|
|
1818
|
+
if verbose:
|
|
1819
|
+
print(f"\n\n{Fore.GREEN}Evaluating pair:{Style.RESET_ALL} ({frame_1.frame_id}, {frame_2.frame_id})")
|
|
1820
|
+
print(f"{Fore.YELLOW}ROI Text:{Style.RESET_ALL}\n{task_payload['roi_text']}\n")
|
|
1821
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1822
|
+
|
|
1823
|
+
gen_text = self.inference_engine.chat(
|
|
1824
|
+
messages=task_payload['messages'],
|
|
1825
|
+
verbose=verbose
|
|
1826
|
+
)
|
|
1827
|
+
relation = self._post_process_result(gen_text, task_payload)
|
|
1828
|
+
if relation:
|
|
1829
|
+
relations.append(relation)
|
|
1830
|
+
|
|
1831
|
+
if return_messages_log:
|
|
1832
|
+
task_payload['messages'].append({"role": "assistant", "content": gen_text})
|
|
1833
|
+
messages_log.append(task_payload['messages'])
|
|
1834
|
+
|
|
1835
|
+
return (relations, messages_log) if return_messages_log else relations
|
|
1836
|
+
|
|
1837
|
+
async def _extract_async(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent_batch_size: int = 32, return_messages_log: bool = False) -> Union[List[Dict], Tuple[List[Dict], List]]:
|
|
1838
|
+
pairs = list(itertools.combinations(doc.frames, 2))
|
|
1839
|
+
tasks_input = [self._get_task_if_possible(f1, f2, doc.text, buffer_size) for f1, f2 in pairs]
|
|
1840
|
+
# Filter out impossible pairs
|
|
1841
|
+
tasks_input = [task for task in tasks_input if task is not None]
|
|
1842
|
+
|
|
1843
|
+
relations = []
|
|
1844
|
+
messages_log = [] if return_messages_log else None
|
|
1845
|
+
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
1846
|
+
|
|
1847
|
+
async def semaphore_helper(task_payload: Dict):
|
|
1848
|
+
async with semaphore:
|
|
1849
|
+
gen_text = await self.inference_engine.chat_async(messages=task_payload['messages'])
|
|
1850
|
+
return gen_text, task_payload
|
|
1851
|
+
|
|
1852
|
+
tasks = [asyncio.create_task(semaphore_helper(payload)) for payload in tasks_input]
|
|
1853
|
+
results = await asyncio.gather(*tasks)
|
|
1854
|
+
|
|
1855
|
+
for gen_text, task_payload in results:
|
|
1856
|
+
relation = self._post_process_result(gen_text, task_payload)
|
|
1857
|
+
if relation:
|
|
1858
|
+
relations.append(relation)
|
|
1859
|
+
|
|
1860
|
+
if return_messages_log:
|
|
1861
|
+
task_payload['messages'].append({"role": "assistant", "content": gen_text})
|
|
1862
|
+
messages_log.append(task_payload['messages'])
|
|
1863
|
+
|
|
1864
|
+
return (relations, messages_log) if return_messages_log else relations
|
|
1865
|
+
|
|
1866
|
+
def extract_relations(self, doc: LLMInformationExtractionDocument, buffer_size: int = 128, concurrent: bool = False, concurrent_batch_size: int = 32, verbose: bool = False, return_messages_log: bool = False) -> List[Dict]:
|
|
1765
1867
|
if not doc.has_frame():
|
|
1766
1868
|
raise ValueError("Input document must have frames.")
|
|
1767
|
-
|
|
1768
1869
|
if doc.has_duplicate_frame_ids():
|
|
1769
1870
|
raise ValueError("All frame_ids in the input document must be unique.")
|
|
1770
1871
|
|
|
1771
1872
|
if concurrent:
|
|
1772
1873
|
if verbose:
|
|
1773
|
-
warnings.warn("
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
return asyncio.run(self.extract_async(doc=doc,
|
|
1777
|
-
buffer_size=buffer_size,
|
|
1778
|
-
concurrent_batch_size=concurrent_batch_size,
|
|
1779
|
-
return_messages_log=return_messages_log)
|
|
1780
|
-
)
|
|
1874
|
+
warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
|
|
1875
|
+
nest_asyncio.apply()
|
|
1876
|
+
return asyncio.run(self._extract_async(doc, buffer_size, concurrent_batch_size, return_messages_log))
|
|
1781
1877
|
else:
|
|
1782
|
-
return self.
|
|
1783
|
-
|
|
1784
|
-
|
|
1785
|
-
|
|
1878
|
+
return self._extract(doc, buffer_size, verbose, return_messages_log)
|
|
1879
|
+
|
|
1880
|
+
|
|
1881
|
+
class BinaryRelationExtractor(RelationExtractor):
|
|
1882
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str, possible_relation_func: Callable,
|
|
1883
|
+
system_prompt:str=None):
|
|
1884
|
+
"""
|
|
1885
|
+
This class extracts binary (yes/no) relations between two entities.
|
|
1886
|
+
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
1887
|
+
|
|
1888
|
+
Parameters
|
|
1889
|
+
----------
|
|
1890
|
+
inference_engine : InferenceEngine
|
|
1891
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
1892
|
+
prompt_template : str
|
|
1893
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
1894
|
+
possible_relation_func : Callable, Optional
|
|
1895
|
+
a function that inputs 2 frames and returns a bool indicating possible relations between them.
|
|
1896
|
+
system_prompt : str, Optional
|
|
1897
|
+
system prompt.
|
|
1898
|
+
"""
|
|
1899
|
+
super().__init__(inference_engine, prompt_template, system_prompt)
|
|
1900
|
+
if not callable(possible_relation_func):
|
|
1901
|
+
raise TypeError(f"Expect possible_relation_func as a function, received {type(possible_relation_func)} instead.")
|
|
1902
|
+
|
|
1903
|
+
sig = inspect.signature(possible_relation_func)
|
|
1904
|
+
if len(sig.parameters) != 2:
|
|
1905
|
+
raise ValueError("The possible_relation_func must have exactly two parameters.")
|
|
1906
|
+
|
|
1907
|
+
if sig.return_annotation not in {bool, inspect.Signature.empty}:
|
|
1908
|
+
warnings.warn(f"Expected possible_relation_func return annotation to be bool, but got {sig.return_annotation}.")
|
|
1909
|
+
|
|
1910
|
+
self.possible_relation_func = possible_relation_func
|
|
1911
|
+
|
|
1912
|
+
def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
|
|
1913
|
+
text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
|
|
1914
|
+
if self.possible_relation_func(frame_1, frame_2):
|
|
1915
|
+
roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size)
|
|
1916
|
+
messages = []
|
|
1917
|
+
if self.system_prompt:
|
|
1918
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1919
|
+
|
|
1920
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(
|
|
1921
|
+
text_content={"roi_text": roi_text, "frame_1": str(frame_1.to_dict()), "frame_2": str(frame_2.to_dict())}
|
|
1922
|
+
)})
|
|
1923
|
+
return {"frame_1": frame_1, "frame_2": frame_2, "messages": messages, "roi_text": roi_text}
|
|
1924
|
+
return None
|
|
1925
|
+
|
|
1926
|
+
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1927
|
+
rel_json = self._extract_json(gen_text)
|
|
1928
|
+
if len(rel_json) > 0 and "Relation" in rel_json[0]:
|
|
1929
|
+
rel = rel_json[0]["Relation"]
|
|
1930
|
+
if (isinstance(rel, bool) and rel) or (isinstance(rel, str) and rel.lower() == 'true'):
|
|
1931
|
+
return {'frame_1_id': pair_data['frame_1'].frame_id, 'frame_2_id': pair_data['frame_2'].frame_id}
|
|
1932
|
+
return None
|
|
1786
1933
|
|
|
1787
1934
|
|
|
1788
1935
|
class MultiClassRelationExtractor(RelationExtractor):
|
|
@@ -1828,223 +1975,25 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
1828
1975
|
self.possible_relation_types_func = possible_relation_types_func
|
|
1829
1976
|
|
|
1830
1977
|
|
|
1831
|
-
def
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1837
|
-
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
the relation type (str) or None for no relation.
|
|
1844
|
-
"""
|
|
1845
|
-
if len(rel_json) > 0:
|
|
1846
|
-
if "RelationType" in rel_json[0]:
|
|
1847
|
-
if rel_json[0]["RelationType"] in pos_rel_types:
|
|
1848
|
-
return rel_json[0]["RelationType"]
|
|
1849
|
-
else:
|
|
1850
|
-
warnings.warn('Extractor output JSON without "RelationType" key. Following default, relation = "No Relation".', RuntimeWarning)
|
|
1851
|
-
else:
|
|
1852
|
-
warnings.warn('Extractor did not output a JSON. Following default, relation = "No Relation".', RuntimeWarning)
|
|
1978
|
+
def _get_task_if_possible(self, frame_1: LLMInformationExtractionFrame, frame_2: LLMInformationExtractionFrame,
|
|
1979
|
+
text: str, buffer_size: int) -> Optional[Dict[str, Any]]:
|
|
1980
|
+
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1981
|
+
if pos_rel_types:
|
|
1982
|
+
roi_text = self._get_ROI(frame_1, frame_2, text, buffer_size)
|
|
1983
|
+
messages = []
|
|
1984
|
+
if self.system_prompt:
|
|
1985
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1986
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(
|
|
1987
|
+
text_content={"roi_text": roi_text, "frame_1": str(frame_1.to_dict()), "frame_2": str(frame_2.to_dict()), "pos_rel_types": str(pos_rel_types)}
|
|
1988
|
+
)})
|
|
1989
|
+
return {"frame_1": frame_1, "frame_2": frame_2, "messages": messages, "pos_rel_types": pos_rel_types, "roi_text": roi_text}
|
|
1853
1990
|
return None
|
|
1854
|
-
|
|
1855
1991
|
|
|
1856
|
-
def
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
buffer_size : int, Optional
|
|
1865
|
-
the number of characters before and after the two frames in the ROI text.
|
|
1866
|
-
max_new_tokens : str, Optional
|
|
1867
|
-
the max number of new tokens LLM should generate.
|
|
1868
|
-
temperature : float, Optional
|
|
1869
|
-
the temperature for token sampling.
|
|
1870
|
-
stream : bool, Optional
|
|
1871
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
1872
|
-
return_messages_log : bool, Optional
|
|
1873
|
-
if True, a list of messages will be returned.
|
|
1874
|
-
|
|
1875
|
-
Return : List[Dict]
|
|
1876
|
-
a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
|
|
1877
|
-
"""
|
|
1878
|
-
pairs = itertools.combinations(doc.frames, 2)
|
|
1879
|
-
|
|
1880
|
-
if return_messages_log:
|
|
1881
|
-
messages_log = []
|
|
1882
|
-
|
|
1883
|
-
output = []
|
|
1884
|
-
for frame_1, frame_2 in pairs:
|
|
1885
|
-
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1886
|
-
|
|
1887
|
-
if pos_rel_types:
|
|
1888
|
-
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1889
|
-
if verbose:
|
|
1890
|
-
print(f"\n\n{Fore.GREEN}ROI text:{Style.RESET_ALL} \n{roi_text}\n")
|
|
1891
|
-
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
1892
|
-
messages = []
|
|
1893
|
-
if self.system_prompt:
|
|
1894
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1895
|
-
|
|
1896
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1897
|
-
"frame_1": str(frame_1.to_dict()),
|
|
1898
|
-
"frame_2": str(frame_2.to_dict()),
|
|
1899
|
-
"pos_rel_types":str(pos_rel_types)}
|
|
1900
|
-
)})
|
|
1901
|
-
|
|
1902
|
-
gen_text = self.inference_engine.chat(
|
|
1903
|
-
messages=messages,
|
|
1904
|
-
stream=False,
|
|
1905
|
-
verbose=verbose
|
|
1906
|
-
)
|
|
1907
|
-
|
|
1908
|
-
if return_messages_log:
|
|
1909
|
-
messages.append({"role": "assistant", "content": gen_text})
|
|
1910
|
-
messages_log.append(messages)
|
|
1911
|
-
|
|
1912
|
-
rel_json = self._extract_json(gen_text)
|
|
1913
|
-
rel = self._post_process(rel_json, pos_rel_types)
|
|
1914
|
-
if rel:
|
|
1915
|
-
output.append({'frame_1_id':frame_1.frame_id, 'frame_2_id':frame_2.frame_id, 'relation':rel})
|
|
1916
|
-
|
|
1917
|
-
if return_messages_log:
|
|
1918
|
-
return output, messages_log
|
|
1919
|
-
return output
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
async def extract_async(self, doc:LLMInformationExtractionDocument, buffer_size:int=100,
|
|
1923
|
-
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[Dict]:
|
|
1924
|
-
"""
|
|
1925
|
-
This is the asynchronous version of the extract() method.
|
|
1926
|
-
|
|
1927
|
-
Parameters:
|
|
1928
|
-
-----------
|
|
1929
|
-
doc : LLMInformationExtractionDocument
|
|
1930
|
-
a document with frames.
|
|
1931
|
-
buffer_size : int, Optional
|
|
1932
|
-
the number of characters before and after the two frames in the ROI text.
|
|
1933
|
-
max_new_tokens : str, Optional
|
|
1934
|
-
the max number of new tokens LLM should generate.
|
|
1935
|
-
temperature : float, Optional
|
|
1936
|
-
the temperature for token sampling.
|
|
1937
|
-
concurrent_batch_size : int, Optional
|
|
1938
|
-
the number of frame pairs to process in concurrent.
|
|
1939
|
-
return_messages_log : bool, Optional
|
|
1940
|
-
if True, a list of messages will be returned.
|
|
1941
|
-
|
|
1942
|
-
Return : List[Dict]
|
|
1943
|
-
a list of dict with {"frame_1_id", "frame_2_id", "relation"} for all frame pairs.
|
|
1944
|
-
"""
|
|
1945
|
-
# Check if self.inference_engine.chat_async() is implemented
|
|
1946
|
-
if not hasattr(self.inference_engine, 'chat_async'):
|
|
1947
|
-
raise NotImplementedError(f"{self.inference_engine.__class__.__name__} does not have chat_async() method.")
|
|
1948
|
-
|
|
1949
|
-
pairs = itertools.combinations(doc.frames, 2)
|
|
1950
|
-
if return_messages_log:
|
|
1951
|
-
messages_log = []
|
|
1952
|
-
|
|
1953
|
-
n_frames = len(doc.frames)
|
|
1954
|
-
num_pairs = (n_frames * (n_frames-1)) // 2
|
|
1955
|
-
output = []
|
|
1956
|
-
for i in range(0, num_pairs, concurrent_batch_size):
|
|
1957
|
-
rel_pair_list = []
|
|
1958
|
-
tasks = []
|
|
1959
|
-
batch = list(itertools.islice(pairs, concurrent_batch_size))
|
|
1960
|
-
batch_messages = []
|
|
1961
|
-
for frame_1, frame_2 in batch:
|
|
1962
|
-
pos_rel_types = self.possible_relation_types_func(frame_1, frame_2)
|
|
1963
|
-
|
|
1964
|
-
if pos_rel_types:
|
|
1965
|
-
rel_pair_list.append({'frame_1':frame_1.frame_id, 'frame_2':frame_2.frame_id, 'pos_rel_types':pos_rel_types})
|
|
1966
|
-
roi_text = self._get_ROI(frame_1, frame_2, doc.text, buffer_size=buffer_size)
|
|
1967
|
-
messages = []
|
|
1968
|
-
if self.system_prompt:
|
|
1969
|
-
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
1970
|
-
|
|
1971
|
-
messages.append({'role': 'user', 'content': self._get_user_prompt(text_content={"roi_text":roi_text,
|
|
1972
|
-
"frame_1": str(frame_1.to_dict()),
|
|
1973
|
-
"frame_2": str(frame_2.to_dict()),
|
|
1974
|
-
"pos_rel_types":str(pos_rel_types)}
|
|
1975
|
-
)})
|
|
1976
|
-
task = asyncio.create_task(
|
|
1977
|
-
self.inference_engine.chat_async(
|
|
1978
|
-
messages=messages
|
|
1979
|
-
)
|
|
1980
|
-
)
|
|
1981
|
-
tasks.append(task)
|
|
1982
|
-
batch_messages.append(messages)
|
|
1983
|
-
|
|
1984
|
-
responses = await asyncio.gather(*tasks)
|
|
1985
|
-
|
|
1986
|
-
for d, response, messages in zip(rel_pair_list, responses, batch_messages):
|
|
1987
|
-
if return_messages_log:
|
|
1988
|
-
messages.append({"role": "assistant", "content": response})
|
|
1989
|
-
messages_log.append(messages)
|
|
1990
|
-
|
|
1991
|
-
rel_json = self._extract_json(response)
|
|
1992
|
-
rel = self._post_process(rel_json, d['pos_rel_types'])
|
|
1993
|
-
if rel:
|
|
1994
|
-
output.append({'frame_1_id':d['frame_1'], 'frame_2_id':d['frame_2'], 'relation':rel})
|
|
1995
|
-
|
|
1996
|
-
if return_messages_log:
|
|
1997
|
-
return output, messages_log
|
|
1998
|
-
return output
|
|
1999
|
-
|
|
2000
|
-
|
|
2001
|
-
def extract_relations(self, doc:LLMInformationExtractionDocument, buffer_size:int=100,
|
|
2002
|
-
concurrent:bool=False, concurrent_batch_size:int=32,
|
|
2003
|
-
verbose:bool=False, return_messages_log:bool=False, **kwrs) -> List[Dict]:
|
|
2004
|
-
"""
|
|
2005
|
-
This method considers all combinations of two frames. Use the possible_relation_types_func to filter impossible pairs.
|
|
2006
|
-
|
|
2007
|
-
Parameters:
|
|
2008
|
-
-----------
|
|
2009
|
-
doc : LLMInformationExtractionDocument
|
|
2010
|
-
a document with frames.
|
|
2011
|
-
buffer_size : int, Optional
|
|
2012
|
-
the number of characters before and after the two frames in the ROI text.
|
|
2013
|
-
max_new_tokens : str, Optional
|
|
2014
|
-
the max number of new tokens LLM should generate.
|
|
2015
|
-
temperature : float, Optional
|
|
2016
|
-
the temperature for token sampling.
|
|
2017
|
-
concurrent: bool, Optional
|
|
2018
|
-
if True, the extraction will be done in concurrent.
|
|
2019
|
-
concurrent_batch_size : int, Optional
|
|
2020
|
-
the number of frame pairs to process in concurrent.
|
|
2021
|
-
stream : bool, Optional
|
|
2022
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
2023
|
-
return_messages_log : bool, Optional
|
|
2024
|
-
if True, a list of messages will be returned.
|
|
2025
|
-
|
|
2026
|
-
Return : List[Dict]
|
|
2027
|
-
a list of dict with {"frame_1", "frame_2", "relation"} for all relations.
|
|
2028
|
-
"""
|
|
2029
|
-
if not doc.has_frame():
|
|
2030
|
-
raise ValueError("Input document must have frames.")
|
|
2031
|
-
|
|
2032
|
-
if doc.has_duplicate_frame_ids():
|
|
2033
|
-
raise ValueError("All frame_ids in the input document must be unique.")
|
|
2034
|
-
|
|
2035
|
-
if concurrent:
|
|
2036
|
-
if verbose:
|
|
2037
|
-
warnings.warn("stream=True is not supported in concurrent mode.", RuntimeWarning)
|
|
2038
|
-
|
|
2039
|
-
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
2040
|
-
return asyncio.run(self.extract_async(doc=doc,
|
|
2041
|
-
buffer_size=buffer_size,
|
|
2042
|
-
concurrent_batch_size=concurrent_batch_size,
|
|
2043
|
-
return_messages_log=return_messages_log)
|
|
2044
|
-
)
|
|
2045
|
-
else:
|
|
2046
|
-
return self.extract(doc=doc,
|
|
2047
|
-
buffer_size=buffer_size,
|
|
2048
|
-
verbose=verbose,
|
|
2049
|
-
return_messages_log=return_messages_log)
|
|
2050
|
-
|
|
1992
|
+
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1993
|
+
rel_json = self._extract_json(gen_text)
|
|
1994
|
+
pos_rel_types = pair_data['pos_rel_types']
|
|
1995
|
+
if len(rel_json) > 0 and "RelationType" in rel_json[0]:
|
|
1996
|
+
rel_type = rel_json[0]["RelationType"]
|
|
1997
|
+
if rel_type in pos_rel_types:
|
|
1998
|
+
return {'frame_1_id': pair_data['frame_1'].frame_id, 'frame_2_id': pair_data['frame_2'].frame_id, 'relation': rel_type}
|
|
1999
|
+
return None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: llm-ie
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: A comprehensive toolkit that provides building blocks for LLM-based named entity recognition, attribute extraction, and relation extraction pipelines.
|
|
5
5
|
License: MIT
|
|
6
6
|
Author: Enshuo (David) Hsu
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
llm_ie/__init__.py,sha256=
|
|
1
|
+
llm_ie/__init__.py,sha256=rLP01qXkIisX0WLzZOv6y494Braw89g5JLmA6ZyrGGA,1590
|
|
2
2
|
llm_ie/asset/PromptEditor_prompts/chat.txt,sha256=Fq62voV0JQ8xBRcxS1Nmdd7DkHs1fGYb-tmNwctZZK0,118
|
|
3
3
|
llm_ie/asset/PromptEditor_prompts/comment.txt,sha256=C_lxx-dlOlFJ__jkHKosZ8HsNAeV1aowh2B36nIipBY,159
|
|
4
4
|
llm_ie/asset/PromptEditor_prompts/rewrite.txt,sha256=JAwY9vm1jSmKf2qcLBYUvrSmME2EJH36bALmkwZDWYQ,178
|
|
@@ -9,6 +9,7 @@ llm_ie/asset/default_prompts/ReviewFrameExtractor_addition_review_prompt.txt,sha
|
|
|
9
9
|
llm_ie/asset/default_prompts/ReviewFrameExtractor_revision_review_prompt.txt,sha256=lGGjdeFpzZEc56w-EtQDMyYFs7A3DQAM32sT42Nf_08,293
|
|
10
10
|
llm_ie/asset/default_prompts/SentenceReviewFrameExtractor_addition_review_prompt.txt,sha256=Of11LFuXLB249oekFelzlIeoAB0cATReqWgFTvhNz_8,329
|
|
11
11
|
llm_ie/asset/default_prompts/SentenceReviewFrameExtractor_revision_review_prompt.txt,sha256=kNJQK7NdoCx13TXGY8HYGrW_v4SEaErK8j9qIzd70CM,291
|
|
12
|
+
llm_ie/asset/prompt_guide/AttributeExtractor_prompt_guide.txt,sha256=w2amKipinuJtCiyPsgWsjaJRwTpS1qOBDuPPtPCMeQA,2120
|
|
12
13
|
llm_ie/asset/prompt_guide/BasicFrameExtractor_prompt_guide.txt,sha256=-Cli7rwu4wM4vSmkG0nInNkpStUhRqKESQ3oqD38pbE,10395
|
|
13
14
|
llm_ie/asset/prompt_guide/BasicReviewFrameExtractor_prompt_guide.txt,sha256=-Cli7rwu4wM4vSmkG0nInNkpStUhRqKESQ3oqD38pbE,10395
|
|
14
15
|
llm_ie/asset/prompt_guide/BinaryRelationExtractor_prompt_guide.txt,sha256=Z6Yc2_QRqroWcJ13owNJbo78I0wpS4XXDsOjXFR-aPk,2166
|
|
@@ -20,8 +21,8 @@ llm_ie/asset/prompt_guide/SentenceReviewFrameExtractor_prompt_guide.txt,sha256=9
|
|
|
20
21
|
llm_ie/chunkers.py,sha256=24h9l-Ldyx3EgfYicFqGhV_b-XofUS3yovC1nBWdDoo,5143
|
|
21
22
|
llm_ie/data_types.py,sha256=72-3bzzYpo7KZpD9bjoroWT2eiM0zmWyDkBr2nHoBV0,18559
|
|
22
23
|
llm_ie/engines.py,sha256=uE5sag1YeKBYBFF4gY7rYZK9e1ttatf9T7bV_xSg9Pk,36075
|
|
23
|
-
llm_ie/extractors.py,sha256=
|
|
24
|
+
llm_ie/extractors.py,sha256=aCRqKhjSoKTAWZ3WhX_O6V-S_rIvYhPsk78nZLDpQw8,95149
|
|
24
25
|
llm_ie/prompt_editor.py,sha256=zh7Es5Ta2qSTgHtfF9Y9ZKXs4DMue6XlyRt9O6_Uk6c,10962
|
|
25
|
-
llm_ie-1.
|
|
26
|
-
llm_ie-1.
|
|
27
|
-
llm_ie-1.
|
|
26
|
+
llm_ie-1.2.0.dist-info/METADATA,sha256=X9zsMDwBAq1QzIkX8SSbmwLsEFiiAVeNeA0GTiNkAkQ,728
|
|
27
|
+
llm_ie-1.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
28
|
+
llm_ie-1.2.0.dist-info/RECORD,,
|
|
File without changes
|