llm-ie 1.2.3__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +6 -6
- llm_ie/asset/default_prompts/LLMUnitChunker_user_prompt.txt +129 -0
- llm_ie/asset/prompt_guide/AttributeExtractor_prompt_guide.txt +2 -2
- llm_ie/asset/prompt_guide/StructExtractor_prompt_guide.txt +53 -0
- llm_ie/chunkers.py +104 -4
- llm_ie/data_types.py +72 -44
- llm_ie/engines.py +44 -0
- llm_ie/extractors.py +421 -73
- llm_ie/prompt_editor.py +9 -32
- llm_ie/utils.py +95 -0
- {llm_ie-1.2.3.dist-info → llm_ie-1.3.0.dist-info}/METADATA +1 -1
- {llm_ie-1.2.3.dist-info → llm_ie-1.3.0.dist-info}/RECORD +13 -10
- {llm_ie-1.2.3.dist-info → llm_ie-1.3.0.dist-info}/WHEEL +0 -0
llm_ie/extractors.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import re
|
|
3
|
-
import json
|
|
4
|
-
import json_repair
|
|
5
3
|
import inspect
|
|
6
4
|
import importlib.resources
|
|
7
5
|
import warnings
|
|
@@ -10,6 +8,7 @@ import asyncio
|
|
|
10
8
|
import nest_asyncio
|
|
11
9
|
from concurrent.futures import ThreadPoolExecutor
|
|
12
10
|
from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
|
|
11
|
+
from llm_ie.utils import extract_json, apply_prompt_template
|
|
13
12
|
from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
|
|
14
13
|
from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
|
|
15
14
|
from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
|
|
@@ -96,79 +95,428 @@ class Extractor:
|
|
|
96
95
|
Returns : str
|
|
97
96
|
a user prompt.
|
|
98
97
|
"""
|
|
99
|
-
|
|
98
|
+
return apply_prompt_template(self.prompt_template, text_content)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class StructExtractor(Extractor):
|
|
102
|
+
def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker, prompt_template:str,
|
|
103
|
+
system_prompt:str=None, context_chunker:ContextChunker=None, aggregation_func:Callable=None):
|
|
104
|
+
"""
|
|
105
|
+
This class is for unanchored structured information extraction.
|
|
106
|
+
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
107
|
+
|
|
108
|
+
Parameters:
|
|
109
|
+
----------
|
|
110
|
+
inference_engine : InferenceEngine
|
|
111
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
112
|
+
unit_chunker : UnitChunker
|
|
113
|
+
the unit chunker object that determines how to chunk the document text into units.
|
|
114
|
+
prompt_template : str
|
|
115
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
116
|
+
system_prompt : str, Optional
|
|
117
|
+
system prompt.
|
|
118
|
+
context_chunker : ContextChunker
|
|
119
|
+
the context chunker object that determines how to get context for each unit.
|
|
120
|
+
aggregation_func : Callable
|
|
121
|
+
a function that inputs a list of structured information (dict)
|
|
122
|
+
and outputs an aggregated structured information (dict).
|
|
123
|
+
if not specified, the default is to merge all dicts by updating keys and overwriting values sequentially.
|
|
124
|
+
"""
|
|
125
|
+
super().__init__(inference_engine=inference_engine,
|
|
126
|
+
prompt_template=prompt_template,
|
|
127
|
+
system_prompt=system_prompt)
|
|
128
|
+
|
|
129
|
+
self.unit_chunker = unit_chunker
|
|
130
|
+
self.context_chunker = context_chunker
|
|
131
|
+
self.aggregation_func = aggregation_func
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
135
|
+
verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
136
|
+
"""
|
|
137
|
+
This method inputs text content and outputs a string generated by LLM
|
|
138
|
+
|
|
139
|
+
Parameters:
|
|
140
|
+
----------
|
|
141
|
+
text_content : Union[str, Dict[str,str]]
|
|
142
|
+
the input text content to put in prompt template.
|
|
143
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
144
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
145
|
+
return_messages_log : bool, Optional
|
|
146
|
+
if True, a list of messages will be returned.
|
|
147
|
+
|
|
148
|
+
Return : List[FrameExtractionUnit]
|
|
149
|
+
the output from LLM. Need post-processing.
|
|
150
|
+
"""
|
|
151
|
+
# unit chunking
|
|
100
152
|
if isinstance(text_content, str):
|
|
101
|
-
|
|
102
|
-
if len(matches) != 1:
|
|
103
|
-
raise ValueError("When text_content is str, the prompt template must has exactly 1 placeholder {{<placeholder name>}}.")
|
|
104
|
-
text = re.sub(r'\\', r'\\\\', text_content)
|
|
105
|
-
prompt = pattern.sub(text, self.prompt_template)
|
|
153
|
+
doc_text = text_content
|
|
106
154
|
|
|
107
155
|
elif isinstance(text_content, dict):
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
156
|
+
if document_key is None:
|
|
157
|
+
raise ValueError("document_key must be provided when text_content is dict.")
|
|
158
|
+
doc_text = text_content[document_key]
|
|
159
|
+
|
|
160
|
+
units = self.unit_chunker.chunk(doc_text)
|
|
161
|
+
# context chunker init
|
|
162
|
+
self.context_chunker.fit(doc_text, units)
|
|
163
|
+
|
|
164
|
+
# messages log
|
|
165
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
166
|
+
|
|
167
|
+
# generate unit by unit
|
|
168
|
+
for i, unit in enumerate(units):
|
|
169
|
+
try:
|
|
170
|
+
# construct chat messages
|
|
171
|
+
messages = []
|
|
172
|
+
if self.system_prompt:
|
|
173
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
174
|
+
|
|
175
|
+
context = self.context_chunker.chunk(unit)
|
|
176
|
+
|
|
177
|
+
if context == "":
|
|
178
|
+
# no context, just place unit in user prompt
|
|
179
|
+
if isinstance(text_content, str):
|
|
180
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
181
|
+
else:
|
|
182
|
+
unit_content = text_content.copy()
|
|
183
|
+
unit_content[document_key] = unit.text
|
|
184
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
185
|
+
else:
|
|
186
|
+
# insert context to user prompt
|
|
187
|
+
if isinstance(text_content, str):
|
|
188
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
189
|
+
else:
|
|
190
|
+
context_content = text_content.copy()
|
|
191
|
+
context_content[document_key] = context
|
|
192
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
193
|
+
# simulate conversation where assistant confirms
|
|
194
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
195
|
+
# place unit of interest
|
|
196
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
197
|
+
|
|
198
|
+
if verbose:
|
|
199
|
+
print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
|
|
200
|
+
if context != "":
|
|
201
|
+
print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
|
|
202
|
+
|
|
203
|
+
print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
gen_text = self.inference_engine.chat(
|
|
207
|
+
messages=messages,
|
|
208
|
+
verbose=verbose,
|
|
209
|
+
stream=False,
|
|
210
|
+
messages_logger=messages_logger
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# add generated text to unit
|
|
214
|
+
unit.set_generated_text(gen_text["response"])
|
|
215
|
+
unit.set_status("success")
|
|
216
|
+
except Exception as e:
|
|
217
|
+
unit.set_status("fail")
|
|
218
|
+
warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
|
|
219
|
+
|
|
220
|
+
if return_messages_log:
|
|
221
|
+
return units, messages_logger.get_messages_log()
|
|
222
|
+
|
|
223
|
+
return units
|
|
121
224
|
|
|
122
|
-
def
|
|
225
|
+
def stream(self, text_content: Union[str, Dict[str, str]],
|
|
226
|
+
document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
|
|
123
227
|
"""
|
|
124
|
-
|
|
228
|
+
Streams LLM responses per unit with structured event types,
|
|
229
|
+
and returns collected data for post-processing.
|
|
230
|
+
|
|
231
|
+
Yields:
|
|
232
|
+
-------
|
|
233
|
+
Dict[str, Any]: (type, data)
|
|
234
|
+
- {"type": "info", "data": str_message}: General informational messages.
|
|
235
|
+
- {"type": "unit", "data": dict_unit_info}: Signals start of a new unit. dict_unit_info contains {'id', 'text', 'start', 'end'}
|
|
236
|
+
- {"type": "context", "data": str_context}: Context string for the current unit.
|
|
237
|
+
- {"type": "reasoning", "data": str_chunk}: A reasoning model thinking chunk from the LLM.
|
|
238
|
+
- {"type": "response", "data": str_chunk}: A response/answer chunk from the LLM.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
--------
|
|
242
|
+
List[FrameExtractionUnit]:
|
|
243
|
+
A list of FrameExtractionUnit objects, each containing the
|
|
244
|
+
original unit details and the fully accumulated 'gen_text' from the LLM.
|
|
245
|
+
"""
|
|
246
|
+
if isinstance(text_content, str):
|
|
247
|
+
doc_text = text_content
|
|
248
|
+
elif isinstance(text_content, dict):
|
|
249
|
+
if document_key is None:
|
|
250
|
+
raise ValueError("document_key must be provided when text_content is dict.")
|
|
251
|
+
if document_key not in text_content:
|
|
252
|
+
raise ValueError(f"document_key '{document_key}' not found in text_content.")
|
|
253
|
+
doc_text = text_content[document_key]
|
|
254
|
+
else:
|
|
255
|
+
raise TypeError("text_content must be a string or a dictionary.")
|
|
256
|
+
|
|
257
|
+
units: List[FrameExtractionUnit] = self.unit_chunker.chunk(doc_text)
|
|
258
|
+
self.context_chunker.fit(doc_text, units)
|
|
259
|
+
|
|
260
|
+
yield {"type": "info", "data": f"Starting LLM processing for {len(units)} units."}
|
|
261
|
+
|
|
262
|
+
for i, unit in enumerate(units):
|
|
263
|
+
unit_info_payload = {"id": i, "text": unit.text, "start": unit.start, "end": unit.end}
|
|
264
|
+
yield {"type": "unit", "data": unit_info_payload}
|
|
265
|
+
|
|
266
|
+
messages = []
|
|
267
|
+
if self.system_prompt:
|
|
268
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
269
|
+
|
|
270
|
+
context_str = self.context_chunker.chunk(unit)
|
|
271
|
+
|
|
272
|
+
# Construct prompt input based on whether text_content was str or dict
|
|
273
|
+
if context_str:
|
|
274
|
+
yield {"type": "context", "data": context_str}
|
|
275
|
+
prompt_input_for_context = context_str
|
|
276
|
+
if isinstance(text_content, dict):
|
|
277
|
+
context_content_dict = text_content.copy()
|
|
278
|
+
context_content_dict[document_key] = context_str
|
|
279
|
+
prompt_input_for_context = context_content_dict
|
|
280
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_context)})
|
|
281
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
282
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
283
|
+
else: # No context
|
|
284
|
+
prompt_input_for_unit = unit.text
|
|
285
|
+
if isinstance(text_content, dict):
|
|
286
|
+
unit_content_dict = text_content.copy()
|
|
287
|
+
unit_content_dict[document_key] = unit.text
|
|
288
|
+
prompt_input_for_unit = unit_content_dict
|
|
289
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_unit)})
|
|
290
|
+
|
|
291
|
+
current_gen_text = ""
|
|
292
|
+
|
|
293
|
+
response_stream = self.inference_engine.chat(
|
|
294
|
+
messages=messages,
|
|
295
|
+
stream=True
|
|
296
|
+
)
|
|
297
|
+
for chunk in response_stream:
|
|
298
|
+
yield chunk
|
|
299
|
+
if chunk["type"] == "response":
|
|
300
|
+
current_gen_text += chunk["data"]
|
|
301
|
+
|
|
302
|
+
# Store the result for this unit
|
|
303
|
+
unit.set_generated_text(current_gen_text)
|
|
304
|
+
unit.set_status("success")
|
|
305
|
+
|
|
306
|
+
yield {"type": "info", "data": "All units processed by LLM."}
|
|
307
|
+
return units
|
|
308
|
+
|
|
309
|
+
async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
310
|
+
concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
|
|
311
|
+
"""
|
|
312
|
+
This is the asynchronous version of the extract() method.
|
|
125
313
|
|
|
126
314
|
Parameters:
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
the input text
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
start = i
|
|
143
|
-
open_brace += 1
|
|
144
|
-
elif char == '}':
|
|
145
|
-
open_brace -= 1
|
|
146
|
-
if open_brace == 0 and start != -1:
|
|
147
|
-
json_objects.append(text[start:i + 1])
|
|
148
|
-
start = -1
|
|
149
|
-
|
|
150
|
-
return json_objects
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
def _extract_json(self, gen_text:str) -> List[Dict[str, str]]:
|
|
154
|
-
"""
|
|
155
|
-
This method inputs a generated text and output a JSON of information tuples
|
|
315
|
+
----------
|
|
316
|
+
text_content : Union[str, Dict[str,str]]
|
|
317
|
+
the input text content to put in prompt template.
|
|
318
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
319
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
320
|
+
document_key : str, Optional
|
|
321
|
+
specify the key in text_content where document text is.
|
|
322
|
+
If text_content is str, this parameter will be ignored.
|
|
323
|
+
concurrent_batch_size : int, Optional
|
|
324
|
+
the batch size for concurrent processing.
|
|
325
|
+
return_messages_log : bool, Optional
|
|
326
|
+
if True, a list of messages will be returned.
|
|
327
|
+
|
|
328
|
+
Return : List[FrameExtractionUnit]
|
|
329
|
+
the output from LLM for each unit. Contains the start, end, text, and generated text.
|
|
156
330
|
"""
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
331
|
+
if isinstance(text_content, str):
|
|
332
|
+
doc_text = text_content
|
|
333
|
+
elif isinstance(text_content, dict):
|
|
334
|
+
if document_key is None:
|
|
335
|
+
raise ValueError("document_key must be provided when text_content is dict.")
|
|
336
|
+
if document_key not in text_content:
|
|
337
|
+
raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
|
|
338
|
+
doc_text = text_content[document_key]
|
|
339
|
+
else:
|
|
340
|
+
raise TypeError("text_content must be a string or a dictionary.")
|
|
341
|
+
|
|
342
|
+
units = self.unit_chunker.chunk(doc_text)
|
|
343
|
+
|
|
344
|
+
# context chunker init
|
|
345
|
+
self.context_chunker.fit(doc_text, units)
|
|
346
|
+
|
|
347
|
+
# messages logger init
|
|
348
|
+
messages_logger = MessagesLogger() if return_messages_log else None
|
|
349
|
+
|
|
350
|
+
# Prepare inputs for all units first
|
|
351
|
+
tasks_input = []
|
|
352
|
+
for i, unit in enumerate(units):
|
|
353
|
+
# construct chat messages
|
|
354
|
+
messages = []
|
|
355
|
+
if self.system_prompt:
|
|
356
|
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
|
357
|
+
|
|
358
|
+
context = self.context_chunker.chunk(unit)
|
|
359
|
+
|
|
360
|
+
if context == "":
|
|
361
|
+
# no context, just place unit in user prompt
|
|
362
|
+
if isinstance(text_content, str):
|
|
363
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
|
|
364
|
+
else:
|
|
365
|
+
unit_content = text_content.copy()
|
|
366
|
+
unit_content[document_key] = unit.text
|
|
367
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
|
|
368
|
+
else:
|
|
369
|
+
# insert context to user prompt
|
|
370
|
+
if isinstance(text_content, str):
|
|
371
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
|
|
168
372
|
else:
|
|
169
|
-
|
|
170
|
-
|
|
373
|
+
context_content = text_content.copy()
|
|
374
|
+
context_content[document_key] = context
|
|
375
|
+
messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
|
|
376
|
+
# simulate conversation where assistant confirms
|
|
377
|
+
messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
|
|
378
|
+
# place unit of interest
|
|
379
|
+
messages.append({'role': 'user', 'content': unit.text})
|
|
380
|
+
|
|
381
|
+
# Store unit and messages together for the task
|
|
382
|
+
tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
|
|
383
|
+
|
|
384
|
+
# Process units concurrently with asyncio.Semaphore
|
|
385
|
+
semaphore = asyncio.Semaphore(concurrent_batch_size)
|
|
386
|
+
|
|
387
|
+
async def semaphore_helper(task_data: Dict, **kwrs):
|
|
388
|
+
unit = task_data["unit"]
|
|
389
|
+
messages = task_data["messages"]
|
|
390
|
+
|
|
391
|
+
async with semaphore:
|
|
392
|
+
gen_text = await self.inference_engine.chat_async(
|
|
393
|
+
messages=messages,
|
|
394
|
+
messages_logger=messages_logger
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
unit.set_generated_text(gen_text["response"])
|
|
398
|
+
unit.set_status("success")
|
|
399
|
+
|
|
400
|
+
# Create and gather tasks
|
|
401
|
+
tasks = []
|
|
402
|
+
for task_inp in tasks_input:
|
|
403
|
+
task = asyncio.create_task(semaphore_helper(
|
|
404
|
+
task_inp
|
|
405
|
+
))
|
|
406
|
+
tasks.append(task)
|
|
407
|
+
|
|
408
|
+
await asyncio.gather(*tasks)
|
|
409
|
+
|
|
410
|
+
# Return units
|
|
411
|
+
if return_messages_log:
|
|
412
|
+
return units, messages_logger.get_messages_log()
|
|
413
|
+
else:
|
|
414
|
+
return units
|
|
415
|
+
|
|
416
|
+
def _default_struct_aggregate(self, structs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
417
|
+
"""
|
|
418
|
+
Given a list of structured information (dict), aggregate them into a single dict by seqentially updating keys
|
|
419
|
+
and overwriting values.
|
|
420
|
+
"""
|
|
421
|
+
aggregated_struct = {}
|
|
422
|
+
for struct in structs:
|
|
423
|
+
aggregated_struct.update(struct)
|
|
424
|
+
return aggregated_struct
|
|
425
|
+
|
|
171
426
|
|
|
427
|
+
def extract_struct(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
|
|
428
|
+
verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32,
|
|
429
|
+
return_messages_log:bool=False) -> List[Dict[str, Any]]:
|
|
430
|
+
"""
|
|
431
|
+
This method inputs a document text and outputs a list of LLMInformationExtractionFrame
|
|
432
|
+
It use the extract() method and post-process outputs into frames.
|
|
433
|
+
|
|
434
|
+
Parameters:
|
|
435
|
+
----------
|
|
436
|
+
text_content : Union[str, Dict[str,str]]
|
|
437
|
+
the input text content to put in prompt template.
|
|
438
|
+
If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
|
|
439
|
+
If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
440
|
+
document_key : str, Optional
|
|
441
|
+
specify the key in text_content where document text is.
|
|
442
|
+
If text_content is str, this parameter will be ignored.
|
|
443
|
+
verbose : bool, Optional
|
|
444
|
+
if True, LLM generated text will be printed in terminal in real-time.
|
|
445
|
+
concurrent : bool, Optional
|
|
446
|
+
if True, the sentences will be extracted in concurrent.
|
|
447
|
+
concurrent_batch_size : int, Optional
|
|
448
|
+
the number of sentences to process in concurrent. Only used when `concurrent` is True.
|
|
449
|
+
return_messages_log : bool, Optional
|
|
450
|
+
if True, a list of messages will be returned.
|
|
451
|
+
|
|
452
|
+
Return : List[Dict[str, Any]]
|
|
453
|
+
a list of unanchored structured information.
|
|
454
|
+
"""
|
|
455
|
+
if concurrent:
|
|
456
|
+
if verbose:
|
|
457
|
+
warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
|
|
458
|
+
|
|
459
|
+
nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
|
|
460
|
+
extraction_results = asyncio.run(self.extract_async(text_content=text_content,
|
|
461
|
+
document_key=document_key,
|
|
462
|
+
concurrent_batch_size=concurrent_batch_size,
|
|
463
|
+
return_messages_log=return_messages_log)
|
|
464
|
+
)
|
|
465
|
+
else:
|
|
466
|
+
extraction_results = self.extract(text_content=text_content,
|
|
467
|
+
document_key=document_key,
|
|
468
|
+
verbose=verbose,
|
|
469
|
+
return_messages_log=return_messages_log)
|
|
470
|
+
|
|
471
|
+
units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
|
|
472
|
+
|
|
473
|
+
struct_json = []
|
|
474
|
+
for unit in units:
|
|
475
|
+
if unit.status != "success":
|
|
476
|
+
continue
|
|
477
|
+
try:
|
|
478
|
+
unit_struct_json = extract_json(unit.get_generated_text())
|
|
479
|
+
struct_json.extend(unit_struct_json)
|
|
480
|
+
except Exception as e:
|
|
481
|
+
unit.set_status("fail")
|
|
482
|
+
warnings.warn(f"Struct extraction failed for unit ({unit.start}, {unit.end}): {e}", RuntimeWarning)
|
|
483
|
+
|
|
484
|
+
if self.aggregation_func is None:
|
|
485
|
+
struct = self._default_struct_aggregate(struct_json)
|
|
486
|
+
else:
|
|
487
|
+
struct = self.aggregation_func(struct_json)
|
|
488
|
+
|
|
489
|
+
if return_messages_log:
|
|
490
|
+
return struct, messages_log
|
|
491
|
+
return struct
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
class BasicStructExtractor(StructExtractor):
|
|
495
|
+
def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
|
|
496
|
+
system_prompt:str=None, aggregation_func:Callable=None):
|
|
497
|
+
"""
|
|
498
|
+
This class prompts the LLM with the whole document at once for structured information extraction.
|
|
499
|
+
Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
|
|
500
|
+
|
|
501
|
+
Parameters:
|
|
502
|
+
----------
|
|
503
|
+
inference_engine : InferenceEngine
|
|
504
|
+
the LLM inferencing engine object. Must implements the chat() method.
|
|
505
|
+
prompt_template : str
|
|
506
|
+
prompt template with "{{<placeholder name>}}" placeholder.
|
|
507
|
+
system_prompt : str, Optional
|
|
508
|
+
system prompt.
|
|
509
|
+
aggregation_func : Callable
|
|
510
|
+
a function that inputs a list of structured information (dict)
|
|
511
|
+
and outputs an aggregated structured information (dict).
|
|
512
|
+
if not specified, the default is to merge all dicts by updating keys and overwriting values sequentially.
|
|
513
|
+
"""
|
|
514
|
+
super().__init__(inference_engine=inference_engine,
|
|
515
|
+
unit_chunker=WholeDocumentUnitChunker(),
|
|
516
|
+
prompt_template=prompt_template,
|
|
517
|
+
system_prompt=system_prompt,
|
|
518
|
+
context_chunker=WholeDocumentContextChunker())
|
|
519
|
+
|
|
172
520
|
|
|
173
521
|
class FrameExtractor(Extractor):
|
|
174
522
|
from nltk.tokenize import RegexpTokenizer
|
|
@@ -372,7 +720,7 @@ class FrameExtractor(Extractor):
|
|
|
372
720
|
return_messages_log : bool, Optional
|
|
373
721
|
if True, a list of messages will be returned.
|
|
374
722
|
|
|
375
|
-
Return :
|
|
723
|
+
Return : List[LLMInformationExtractionFrame]
|
|
376
724
|
a list of frames.
|
|
377
725
|
"""
|
|
378
726
|
return NotImplemented
|
|
@@ -731,7 +1079,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
731
1079
|
return_messages_log : bool, Optional
|
|
732
1080
|
if True, a list of messages will be returned.
|
|
733
1081
|
|
|
734
|
-
Return :
|
|
1082
|
+
Return : List[LLMInformationExtractionFrame]
|
|
735
1083
|
a list of frames.
|
|
736
1084
|
"""
|
|
737
1085
|
ENTITY_KEY = "entity_text"
|
|
@@ -759,7 +1107,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
759
1107
|
if unit.status != "success":
|
|
760
1108
|
warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
|
|
761
1109
|
continue
|
|
762
|
-
for entity in
|
|
1110
|
+
for entity in extract_json(gen_text=unit.gen_text):
|
|
763
1111
|
if ENTITY_KEY in entity:
|
|
764
1112
|
entity_json.append(entity)
|
|
765
1113
|
else:
|
|
@@ -963,7 +1311,7 @@ class DirectFrameExtractor(FrameExtractor):
|
|
|
963
1311
|
frame_list = []
|
|
964
1312
|
for res in sorted(doc_results['units'], key=lambda r: r.start):
|
|
965
1313
|
entity_json = []
|
|
966
|
-
for entity in
|
|
1314
|
+
for entity in extract_json(gen_text=res.gen_text):
|
|
967
1315
|
if ENTITY_KEY in entity:
|
|
968
1316
|
entity_json.append(entity)
|
|
969
1317
|
else:
|
|
@@ -1712,7 +2060,7 @@ class AttributeExtractor(Extractor):
|
|
|
1712
2060
|
messages_logger=messages_logger
|
|
1713
2061
|
)
|
|
1714
2062
|
|
|
1715
|
-
attribute_list =
|
|
2063
|
+
attribute_list = extract_json(gen_text=gen_text["response"])
|
|
1716
2064
|
if isinstance(attribute_list, list) and len(attribute_list) > 0:
|
|
1717
2065
|
attributes = attribute_list[0]
|
|
1718
2066
|
if return_messages_log:
|
|
@@ -1822,7 +2170,7 @@ class AttributeExtractor(Extractor):
|
|
|
1822
2170
|
messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
|
|
1823
2171
|
|
|
1824
2172
|
gen_text = await self.inference_engine.chat_async(messages=messages, messages_logger=messages_logger)
|
|
1825
|
-
attribute_list =
|
|
2173
|
+
attribute_list = extract_json(gen_text=gen_text["response"])
|
|
1826
2174
|
attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
|
|
1827
2175
|
return {"frame": frame, "attributes": attributes, "messages": messages}
|
|
1828
2176
|
|
|
@@ -2075,7 +2423,7 @@ class BinaryRelationExtractor(RelationExtractor):
|
|
|
2075
2423
|
return None
|
|
2076
2424
|
|
|
2077
2425
|
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
2078
|
-
rel_json =
|
|
2426
|
+
rel_json = extract_json(gen_text)
|
|
2079
2427
|
if len(rel_json) > 0 and "Relation" in rel_json[0]:
|
|
2080
2428
|
rel = rel_json[0]["Relation"]
|
|
2081
2429
|
if (isinstance(rel, bool) and rel) or (isinstance(rel, str) and rel.lower() == 'true'):
|
|
@@ -2141,7 +2489,7 @@ class MultiClassRelationExtractor(RelationExtractor):
|
|
|
2141
2489
|
return None
|
|
2142
2490
|
|
|
2143
2491
|
def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
2144
|
-
rel_json =
|
|
2492
|
+
rel_json = extract_json(gen_text)
|
|
2145
2493
|
pos_rel_types = pair_data['pos_rel_types']
|
|
2146
2494
|
if len(rel_json) > 0 and "RelationType" in rel_json[0]:
|
|
2147
2495
|
rel_type = rel_json[0]["RelationType"]
|
llm_ie/prompt_editor.py
CHANGED
|
@@ -2,6 +2,7 @@ import sys
|
|
|
2
2
|
import warnings
|
|
3
3
|
from typing import List, Dict, Generator
|
|
4
4
|
import importlib.resources
|
|
5
|
+
from llm_ie.utils import apply_prompt_template
|
|
5
6
|
from llm_ie.engines import InferenceEngine
|
|
6
7
|
from llm_ie.extractors import FrameExtractor
|
|
7
8
|
import re
|
|
@@ -45,30 +46,6 @@ class PromptEditor:
|
|
|
45
46
|
|
|
46
47
|
# internal memory (history messages) for the `chat` method
|
|
47
48
|
self.messages = []
|
|
48
|
-
|
|
49
|
-
def _apply_prompt_template(self, text_content:Dict[str,str], prompt_template:str) -> str:
|
|
50
|
-
"""
|
|
51
|
-
This method applies text_content to prompt_template and returns a prompt.
|
|
52
|
-
|
|
53
|
-
Parameters
|
|
54
|
-
----------
|
|
55
|
-
text_content : Dict[str,str]
|
|
56
|
-
the input text content to put in prompt template.
|
|
57
|
-
all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
|
|
58
|
-
|
|
59
|
-
Returns : str
|
|
60
|
-
a prompt.
|
|
61
|
-
"""
|
|
62
|
-
pattern = re.compile(r'{{(.*?)}}')
|
|
63
|
-
placeholders = pattern.findall(prompt_template)
|
|
64
|
-
if len(placeholders) != len(text_content):
|
|
65
|
-
raise ValueError(f"Expect text_content ({len(text_content)}) and prompt template placeholder ({len(placeholders)}) to have equal size.")
|
|
66
|
-
if not all([k in placeholders for k, _ in text_content.items()]):
|
|
67
|
-
raise ValueError(f"All keys in text_content ({text_content.keys()}) must match placeholders in prompt template ({placeholders}).")
|
|
68
|
-
|
|
69
|
-
prompt = pattern.sub(lambda match: re.sub(r'\\', r'\\\\', text_content[match.group(1)]), prompt_template)
|
|
70
|
-
|
|
71
|
-
return prompt
|
|
72
49
|
|
|
73
50
|
|
|
74
51
|
def rewrite(self, draft:str) -> str:
|
|
@@ -80,8 +57,8 @@ class PromptEditor:
|
|
|
80
57
|
with open(file_path, 'r') as f:
|
|
81
58
|
rewrite_prompt_template = f.read()
|
|
82
59
|
|
|
83
|
-
prompt =
|
|
84
|
-
|
|
60
|
+
prompt = apply_prompt_template(prompt_template=rewrite_prompt_template,
|
|
61
|
+
text_content={"draft": draft, "prompt_guideline": self.prompt_guide})
|
|
85
62
|
messages = [{"role": "system", "content": self.system_prompt},
|
|
86
63
|
{"role": "user", "content": prompt}]
|
|
87
64
|
res = self.inference_engine.chat(messages, verbose=True)
|
|
@@ -96,8 +73,8 @@ class PromptEditor:
|
|
|
96
73
|
with open(file_path, 'r') as f:
|
|
97
74
|
comment_prompt_template = f.read()
|
|
98
75
|
|
|
99
|
-
prompt =
|
|
100
|
-
|
|
76
|
+
prompt = apply_prompt_template(prompt_template=comment_prompt_template,
|
|
77
|
+
text_content={"draft": draft, "prompt_guideline": self.prompt_guide})
|
|
101
78
|
messages = [{"role": "system", "content": self.system_prompt},
|
|
102
79
|
{"role": "user", "content": prompt}]
|
|
103
80
|
res = self.inference_engine.chat(messages, verbose=True)
|
|
@@ -254,8 +231,8 @@ class PromptEditor:
|
|
|
254
231
|
with open(file_path, 'r') as f:
|
|
255
232
|
chat_prompt_template = f.read()
|
|
256
233
|
|
|
257
|
-
guideline =
|
|
258
|
-
|
|
234
|
+
guideline = apply_prompt_template(prompt_template=chat_prompt_template,
|
|
235
|
+
text_content={"prompt_guideline": self.prompt_guide})
|
|
259
236
|
|
|
260
237
|
self.messages = [{"role": "system", "content": self.system_prompt + guideline}]
|
|
261
238
|
|
|
@@ -288,8 +265,8 @@ class PromptEditor:
|
|
|
288
265
|
with open(file_path, 'r') as f:
|
|
289
266
|
chat_prompt_template = f.read()
|
|
290
267
|
|
|
291
|
-
guideline =
|
|
292
|
-
|
|
268
|
+
guideline = apply_prompt_template(prompt_template=chat_prompt_template,
|
|
269
|
+
text_content={"prompt_guideline": self.prompt_guide})
|
|
293
270
|
|
|
294
271
|
messages = [{"role": "system", "content": self.system_prompt + guideline}] + messages
|
|
295
272
|
|