llm-ie 1.2.3__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/extractors.py CHANGED
@@ -1,7 +1,5 @@
1
1
  import abc
2
2
  import re
3
- import json
4
- import json_repair
5
3
  import inspect
6
4
  import importlib.resources
7
5
  import warnings
@@ -10,6 +8,7 @@ import asyncio
10
8
  import nest_asyncio
11
9
  from concurrent.futures import ThreadPoolExecutor
12
10
  from typing import Any, Set, List, Dict, Tuple, Union, Callable, Generator, Optional, AsyncGenerator
11
+ from llm_ie.utils import extract_json, apply_prompt_template
13
12
  from llm_ie.data_types import FrameExtractionUnit, LLMInformationExtractionFrame, LLMInformationExtractionDocument
14
13
  from llm_ie.chunkers import UnitChunker, WholeDocumentUnitChunker, SentenceUnitChunker
15
14
  from llm_ie.chunkers import ContextChunker, NoContextChunker, WholeDocumentContextChunker, SlideWindowContextChunker
@@ -96,79 +95,428 @@ class Extractor:
96
95
  Returns : str
97
96
  a user prompt.
98
97
  """
99
- pattern = re.compile(r'{{(.*?)}}')
98
+ return apply_prompt_template(self.prompt_template, text_content)
99
+
100
+
101
+ class StructExtractor(Extractor):
102
+ def __init__(self, inference_engine:InferenceEngine, unit_chunker:UnitChunker, prompt_template:str,
103
+ system_prompt:str=None, context_chunker:ContextChunker=None, aggregation_func:Callable=None):
104
+ """
105
+ This class is for unanchored structured information extraction.
106
+ Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
107
+
108
+ Parameters:
109
+ ----------
110
+ inference_engine : InferenceEngine
111
+ the LLM inferencing engine object. Must implements the chat() method.
112
+ unit_chunker : UnitChunker
113
+ the unit chunker object that determines how to chunk the document text into units.
114
+ prompt_template : str
115
+ prompt template with "{{<placeholder name>}}" placeholder.
116
+ system_prompt : str, Optional
117
+ system prompt.
118
+ context_chunker : ContextChunker
119
+ the context chunker object that determines how to get context for each unit.
120
+ aggregation_func : Callable
121
+ a function that inputs a list of structured information (dict)
122
+ and outputs an aggregated structured information (dict).
123
+ if not specified, the default is to merge all dicts by updating keys and overwriting values sequentially.
124
+ """
125
+ super().__init__(inference_engine=inference_engine,
126
+ prompt_template=prompt_template,
127
+ system_prompt=system_prompt)
128
+
129
+ self.unit_chunker = unit_chunker
130
+ self.context_chunker = context_chunker
131
+ self.aggregation_func = aggregation_func
132
+
133
+
134
+ def extract(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
135
+ verbose:bool=False, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
136
+ """
137
+ This method inputs text content and outputs a string generated by LLM
138
+
139
+ Parameters:
140
+ ----------
141
+ text_content : Union[str, Dict[str,str]]
142
+ the input text content to put in prompt template.
143
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
144
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
145
+ return_messages_log : bool, Optional
146
+ if True, a list of messages will be returned.
147
+
148
+ Return : List[FrameExtractionUnit]
149
+ the output from LLM. Need post-processing.
150
+ """
151
+ # unit chunking
100
152
  if isinstance(text_content, str):
101
- matches = pattern.findall(self.prompt_template)
102
- if len(matches) != 1:
103
- raise ValueError("When text_content is str, the prompt template must has exactly 1 placeholder {{<placeholder name>}}.")
104
- text = re.sub(r'\\', r'\\\\', text_content)
105
- prompt = pattern.sub(text, self.prompt_template)
153
+ doc_text = text_content
106
154
 
107
155
  elif isinstance(text_content, dict):
108
- # Check if all values are str
109
- if not all([isinstance(v, str) for v in text_content.values()]):
110
- raise ValueError("All values in text_content must be str.")
111
- # Check if all keys are in the prompt template
112
- placeholders = pattern.findall(self.prompt_template)
113
- if len(placeholders) != len(text_content):
114
- raise ValueError(f"Expect text_content ({len(text_content)}) and prompt template placeholder ({len(placeholders)}) to have equal size.")
115
- if not all([k in placeholders for k, _ in text_content.items()]):
116
- raise ValueError(f"All keys in text_content ({text_content.keys()}) must match placeholders in prompt template ({placeholders}).")
117
-
118
- prompt = pattern.sub(lambda match: re.sub(r'\\', r'\\\\', text_content[match.group(1)]), self.prompt_template)
119
-
120
- return prompt
156
+ if document_key is None:
157
+ raise ValueError("document_key must be provided when text_content is dict.")
158
+ doc_text = text_content[document_key]
159
+
160
+ units = self.unit_chunker.chunk(doc_text)
161
+ # context chunker init
162
+ self.context_chunker.fit(doc_text, units)
163
+
164
+ # messages log
165
+ messages_logger = MessagesLogger() if return_messages_log else None
166
+
167
+ # generate unit by unit
168
+ for i, unit in enumerate(units):
169
+ try:
170
+ # construct chat messages
171
+ messages = []
172
+ if self.system_prompt:
173
+ messages.append({'role': 'system', 'content': self.system_prompt})
174
+
175
+ context = self.context_chunker.chunk(unit)
176
+
177
+ if context == "":
178
+ # no context, just place unit in user prompt
179
+ if isinstance(text_content, str):
180
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
181
+ else:
182
+ unit_content = text_content.copy()
183
+ unit_content[document_key] = unit.text
184
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
185
+ else:
186
+ # insert context to user prompt
187
+ if isinstance(text_content, str):
188
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
189
+ else:
190
+ context_content = text_content.copy()
191
+ context_content[document_key] = context
192
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
193
+ # simulate conversation where assistant confirms
194
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
195
+ # place unit of interest
196
+ messages.append({'role': 'user', 'content': unit.text})
197
+
198
+ if verbose:
199
+ print(f"\n\n{Fore.GREEN}Unit {i + 1}/{len(units)}:{Style.RESET_ALL}\n{unit.text}\n")
200
+ if context != "":
201
+ print(f"{Fore.YELLOW}Context:{Style.RESET_ALL}\n{context}\n")
202
+
203
+ print(f"{Fore.BLUE}Extraction:{Style.RESET_ALL}")
204
+
205
+
206
+ gen_text = self.inference_engine.chat(
207
+ messages=messages,
208
+ verbose=verbose,
209
+ stream=False,
210
+ messages_logger=messages_logger
211
+ )
212
+
213
+ # add generated text to unit
214
+ unit.set_generated_text(gen_text["response"])
215
+ unit.set_status("success")
216
+ except Exception as e:
217
+ unit.set_status("fail")
218
+ warnings.warn(f"LLM inference failed for unit {i} ({unit.start}, {unit.end}): {e}", RuntimeWarning)
219
+
220
+ if return_messages_log:
221
+ return units, messages_logger.get_messages_log()
222
+
223
+ return units
121
224
 
122
- def _find_dict_strings(self, text: str) -> List[str]:
225
+ def stream(self, text_content: Union[str, Dict[str, str]],
226
+ document_key: str = None) -> Generator[Dict[str, Any], None, List[FrameExtractionUnit]]:
123
227
  """
124
- Extracts balanced JSON-like dictionaries from a string, even if nested.
228
+ Streams LLM responses per unit with structured event types,
229
+ and returns collected data for post-processing.
230
+
231
+ Yields:
232
+ -------
233
+ Dict[str, Any]: (type, data)
234
+ - {"type": "info", "data": str_message}: General informational messages.
235
+ - {"type": "unit", "data": dict_unit_info}: Signals start of a new unit. dict_unit_info contains {'id', 'text', 'start', 'end'}
236
+ - {"type": "context", "data": str_context}: Context string for the current unit.
237
+ - {"type": "reasoning", "data": str_chunk}: A reasoning model thinking chunk from the LLM.
238
+ - {"type": "response", "data": str_chunk}: A response/answer chunk from the LLM.
239
+
240
+ Returns:
241
+ --------
242
+ List[FrameExtractionUnit]:
243
+ A list of FrameExtractionUnit objects, each containing the
244
+ original unit details and the fully accumulated 'gen_text' from the LLM.
245
+ """
246
+ if isinstance(text_content, str):
247
+ doc_text = text_content
248
+ elif isinstance(text_content, dict):
249
+ if document_key is None:
250
+ raise ValueError("document_key must be provided when text_content is dict.")
251
+ if document_key not in text_content:
252
+ raise ValueError(f"document_key '{document_key}' not found in text_content.")
253
+ doc_text = text_content[document_key]
254
+ else:
255
+ raise TypeError("text_content must be a string or a dictionary.")
256
+
257
+ units: List[FrameExtractionUnit] = self.unit_chunker.chunk(doc_text)
258
+ self.context_chunker.fit(doc_text, units)
259
+
260
+ yield {"type": "info", "data": f"Starting LLM processing for {len(units)} units."}
261
+
262
+ for i, unit in enumerate(units):
263
+ unit_info_payload = {"id": i, "text": unit.text, "start": unit.start, "end": unit.end}
264
+ yield {"type": "unit", "data": unit_info_payload}
265
+
266
+ messages = []
267
+ if self.system_prompt:
268
+ messages.append({'role': 'system', 'content': self.system_prompt})
269
+
270
+ context_str = self.context_chunker.chunk(unit)
271
+
272
+ # Construct prompt input based on whether text_content was str or dict
273
+ if context_str:
274
+ yield {"type": "context", "data": context_str}
275
+ prompt_input_for_context = context_str
276
+ if isinstance(text_content, dict):
277
+ context_content_dict = text_content.copy()
278
+ context_content_dict[document_key] = context_str
279
+ prompt_input_for_context = context_content_dict
280
+ messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_context)})
281
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
282
+ messages.append({'role': 'user', 'content': unit.text})
283
+ else: # No context
284
+ prompt_input_for_unit = unit.text
285
+ if isinstance(text_content, dict):
286
+ unit_content_dict = text_content.copy()
287
+ unit_content_dict[document_key] = unit.text
288
+ prompt_input_for_unit = unit_content_dict
289
+ messages.append({'role': 'user', 'content': self._get_user_prompt(prompt_input_for_unit)})
290
+
291
+ current_gen_text = ""
292
+
293
+ response_stream = self.inference_engine.chat(
294
+ messages=messages,
295
+ stream=True
296
+ )
297
+ for chunk in response_stream:
298
+ yield chunk
299
+ if chunk["type"] == "response":
300
+ current_gen_text += chunk["data"]
301
+
302
+ # Store the result for this unit
303
+ unit.set_generated_text(current_gen_text)
304
+ unit.set_status("success")
305
+
306
+ yield {"type": "info", "data": "All units processed by LLM."}
307
+ return units
308
+
309
+ async def extract_async(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
310
+ concurrent_batch_size:int=32, return_messages_log:bool=False) -> List[FrameExtractionUnit]:
311
+ """
312
+ This is the asynchronous version of the extract() method.
125
313
 
126
314
  Parameters:
127
- -----------
128
- text : str
129
- the input text containing JSON-like structures.
130
-
131
- Returns : List[str]
132
- A list of valid JSON-like strings representing dictionaries.
133
- """
134
- open_brace = 0
135
- start = -1
136
- json_objects = []
137
-
138
- for i, char in enumerate(text):
139
- if char == '{':
140
- if open_brace == 0:
141
- # start of a new JSON object
142
- start = i
143
- open_brace += 1
144
- elif char == '}':
145
- open_brace -= 1
146
- if open_brace == 0 and start != -1:
147
- json_objects.append(text[start:i + 1])
148
- start = -1
149
-
150
- return json_objects
151
-
152
-
153
- def _extract_json(self, gen_text:str) -> List[Dict[str, str]]:
154
- """
155
- This method inputs a generated text and output a JSON of information tuples
315
+ ----------
316
+ text_content : Union[str, Dict[str,str]]
317
+ the input text content to put in prompt template.
318
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
319
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
320
+ document_key : str, Optional
321
+ specify the key in text_content where document text is.
322
+ If text_content is str, this parameter will be ignored.
323
+ concurrent_batch_size : int, Optional
324
+ the batch size for concurrent processing.
325
+ return_messages_log : bool, Optional
326
+ if True, a list of messages will be returned.
327
+
328
+ Return : List[FrameExtractionUnit]
329
+ the output from LLM for each unit. Contains the start, end, text, and generated text.
156
330
  """
157
- out = []
158
- dict_str_list = self._find_dict_strings(gen_text)
159
- for dict_str in dict_str_list:
160
- try:
161
- dict_obj = json.loads(dict_str)
162
- out.append(dict_obj)
163
- except json.JSONDecodeError:
164
- dict_obj = json_repair.repair_json(dict_str, skip_json_loads=True, return_objects=True)
165
- if dict_obj:
166
- warnings.warn(f'JSONDecodeError detected, fixed with repair_json:\n{dict_str}', RuntimeWarning)
167
- out.append(dict_obj)
331
+ if isinstance(text_content, str):
332
+ doc_text = text_content
333
+ elif isinstance(text_content, dict):
334
+ if document_key is None:
335
+ raise ValueError("document_key must be provided when text_content is dict.")
336
+ if document_key not in text_content:
337
+ raise ValueError(f"document_key '{document_key}' not found in text_content dictionary.")
338
+ doc_text = text_content[document_key]
339
+ else:
340
+ raise TypeError("text_content must be a string or a dictionary.")
341
+
342
+ units = self.unit_chunker.chunk(doc_text)
343
+
344
+ # context chunker init
345
+ self.context_chunker.fit(doc_text, units)
346
+
347
+ # messages logger init
348
+ messages_logger = MessagesLogger() if return_messages_log else None
349
+
350
+ # Prepare inputs for all units first
351
+ tasks_input = []
352
+ for i, unit in enumerate(units):
353
+ # construct chat messages
354
+ messages = []
355
+ if self.system_prompt:
356
+ messages.append({'role': 'system', 'content': self.system_prompt})
357
+
358
+ context = self.context_chunker.chunk(unit)
359
+
360
+ if context == "":
361
+ # no context, just place unit in user prompt
362
+ if isinstance(text_content, str):
363
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit.text)})
364
+ else:
365
+ unit_content = text_content.copy()
366
+ unit_content[document_key] = unit.text
367
+ messages.append({'role': 'user', 'content': self._get_user_prompt(unit_content)})
368
+ else:
369
+ # insert context to user prompt
370
+ if isinstance(text_content, str):
371
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context)})
168
372
  else:
169
- warnings.warn(f'JSONDecodeError could not be fixed:\n{dict_str}', RuntimeWarning)
170
- return out
373
+ context_content = text_content.copy()
374
+ context_content[document_key] = context
375
+ messages.append({'role': 'user', 'content': self._get_user_prompt(context_content)})
376
+ # simulate conversation where assistant confirms
377
+ messages.append({'role': 'assistant', 'content': 'Sure, please provide the unit text (e.g., sentence, line, chunk) of interest.'})
378
+ # place unit of interest
379
+ messages.append({'role': 'user', 'content': unit.text})
380
+
381
+ # Store unit and messages together for the task
382
+ tasks_input.append({"unit": unit, "messages": messages, "original_index": i})
383
+
384
+ # Process units concurrently with asyncio.Semaphore
385
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
386
+
387
+ async def semaphore_helper(task_data: Dict, **kwrs):
388
+ unit = task_data["unit"]
389
+ messages = task_data["messages"]
390
+
391
+ async with semaphore:
392
+ gen_text = await self.inference_engine.chat_async(
393
+ messages=messages,
394
+ messages_logger=messages_logger
395
+ )
396
+
397
+ unit.set_generated_text(gen_text["response"])
398
+ unit.set_status("success")
399
+
400
+ # Create and gather tasks
401
+ tasks = []
402
+ for task_inp in tasks_input:
403
+ task = asyncio.create_task(semaphore_helper(
404
+ task_inp
405
+ ))
406
+ tasks.append(task)
407
+
408
+ await asyncio.gather(*tasks)
409
+
410
+ # Return units
411
+ if return_messages_log:
412
+ return units, messages_logger.get_messages_log()
413
+ else:
414
+ return units
415
+
416
+ def _default_struct_aggregate(self, structs: List[Dict[str, Any]]) -> Dict[str, Any]:
417
+ """
418
+ Given a list of structured information (dict), aggregate them into a single dict by seqentially updating keys
419
+ and overwriting values.
420
+ """
421
+ aggregated_struct = {}
422
+ for struct in structs:
423
+ aggregated_struct.update(struct)
424
+ return aggregated_struct
425
+
171
426
 
427
+ def extract_struct(self, text_content:Union[str, Dict[str,str]], document_key:str=None,
428
+ verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32,
429
+ return_messages_log:bool=False) -> List[Dict[str, Any]]:
430
+ """
431
+ This method inputs a document text and outputs a list of LLMInformationExtractionFrame
432
+ It use the extract() method and post-process outputs into frames.
433
+
434
+ Parameters:
435
+ ----------
436
+ text_content : Union[str, Dict[str,str]]
437
+ the input text content to put in prompt template.
438
+ If str, the prompt template must has only 1 placeholder {{<placeholder name>}}, regardless of placeholder name.
439
+ If dict, all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
440
+ document_key : str, Optional
441
+ specify the key in text_content where document text is.
442
+ If text_content is str, this parameter will be ignored.
443
+ verbose : bool, Optional
444
+ if True, LLM generated text will be printed in terminal in real-time.
445
+ concurrent : bool, Optional
446
+ if True, the sentences will be extracted in concurrent.
447
+ concurrent_batch_size : int, Optional
448
+ the number of sentences to process in concurrent. Only used when `concurrent` is True.
449
+ return_messages_log : bool, Optional
450
+ if True, a list of messages will be returned.
451
+
452
+ Return : List[Dict[str, Any]]
453
+ a list of unanchored structured information.
454
+ """
455
+ if concurrent:
456
+ if verbose:
457
+ warnings.warn("verbose=True is not supported in concurrent mode.", RuntimeWarning)
458
+
459
+ nest_asyncio.apply() # For Jupyter notebook. Terminal does not need this.
460
+ extraction_results = asyncio.run(self.extract_async(text_content=text_content,
461
+ document_key=document_key,
462
+ concurrent_batch_size=concurrent_batch_size,
463
+ return_messages_log=return_messages_log)
464
+ )
465
+ else:
466
+ extraction_results = self.extract(text_content=text_content,
467
+ document_key=document_key,
468
+ verbose=verbose,
469
+ return_messages_log=return_messages_log)
470
+
471
+ units, messages_log = extraction_results if return_messages_log else (extraction_results, None)
472
+
473
+ struct_json = []
474
+ for unit in units:
475
+ if unit.status != "success":
476
+ continue
477
+ try:
478
+ unit_struct_json = extract_json(unit.get_generated_text())
479
+ struct_json.extend(unit_struct_json)
480
+ except Exception as e:
481
+ unit.set_status("fail")
482
+ warnings.warn(f"Struct extraction failed for unit ({unit.start}, {unit.end}): {e}", RuntimeWarning)
483
+
484
+ if self.aggregation_func is None:
485
+ struct = self._default_struct_aggregate(struct_json)
486
+ else:
487
+ struct = self.aggregation_func(struct_json)
488
+
489
+ if return_messages_log:
490
+ return struct, messages_log
491
+ return struct
492
+
493
+
494
+ class BasicStructExtractor(StructExtractor):
495
+ def __init__(self, inference_engine:InferenceEngine, prompt_template:str,
496
+ system_prompt:str=None, aggregation_func:Callable=None):
497
+ """
498
+ This class prompts the LLM with the whole document at once for structured information extraction.
499
+ Input LLM inference engine, system prompt (optional), prompt template (with instruction, few-shot examples).
500
+
501
+ Parameters:
502
+ ----------
503
+ inference_engine : InferenceEngine
504
+ the LLM inferencing engine object. Must implements the chat() method.
505
+ prompt_template : str
506
+ prompt template with "{{<placeholder name>}}" placeholder.
507
+ system_prompt : str, Optional
508
+ system prompt.
509
+ aggregation_func : Callable
510
+ a function that inputs a list of structured information (dict)
511
+ and outputs an aggregated structured information (dict).
512
+ if not specified, the default is to merge all dicts by updating keys and overwriting values sequentially.
513
+ """
514
+ super().__init__(inference_engine=inference_engine,
515
+ unit_chunker=WholeDocumentUnitChunker(),
516
+ prompt_template=prompt_template,
517
+ system_prompt=system_prompt,
518
+ context_chunker=WholeDocumentContextChunker())
519
+
172
520
 
173
521
  class FrameExtractor(Extractor):
174
522
  from nltk.tokenize import RegexpTokenizer
@@ -372,7 +720,7 @@ class FrameExtractor(Extractor):
372
720
  return_messages_log : bool, Optional
373
721
  if True, a list of messages will be returned.
374
722
 
375
- Return : str
723
+ Return : List[LLMInformationExtractionFrame]
376
724
  a list of frames.
377
725
  """
378
726
  return NotImplemented
@@ -731,7 +1079,7 @@ class DirectFrameExtractor(FrameExtractor):
731
1079
  return_messages_log : bool, Optional
732
1080
  if True, a list of messages will be returned.
733
1081
 
734
- Return : str
1082
+ Return : List[LLMInformationExtractionFrame]
735
1083
  a list of frames.
736
1084
  """
737
1085
  ENTITY_KEY = "entity_text"
@@ -759,7 +1107,7 @@ class DirectFrameExtractor(FrameExtractor):
759
1107
  if unit.status != "success":
760
1108
  warnings.warn(f"Skipping failed unit ({unit.start}, {unit.end}): {unit.text}", RuntimeWarning)
761
1109
  continue
762
- for entity in self._extract_json(gen_text=unit.gen_text):
1110
+ for entity in extract_json(gen_text=unit.gen_text):
763
1111
  if ENTITY_KEY in entity:
764
1112
  entity_json.append(entity)
765
1113
  else:
@@ -963,7 +1311,7 @@ class DirectFrameExtractor(FrameExtractor):
963
1311
  frame_list = []
964
1312
  for res in sorted(doc_results['units'], key=lambda r: r.start):
965
1313
  entity_json = []
966
- for entity in self._extract_json(gen_text=res.gen_text):
1314
+ for entity in extract_json(gen_text=res.gen_text):
967
1315
  if ENTITY_KEY in entity:
968
1316
  entity_json.append(entity)
969
1317
  else:
@@ -1712,7 +2060,7 @@ class AttributeExtractor(Extractor):
1712
2060
  messages_logger=messages_logger
1713
2061
  )
1714
2062
 
1715
- attribute_list = self._extract_json(gen_text=gen_text["response"])
2063
+ attribute_list = extract_json(gen_text=gen_text["response"])
1716
2064
  if isinstance(attribute_list, list) and len(attribute_list) > 0:
1717
2065
  attributes = attribute_list[0]
1718
2066
  if return_messages_log:
@@ -1822,7 +2170,7 @@ class AttributeExtractor(Extractor):
1822
2170
  messages.append({'role': 'user', 'content': self._get_user_prompt({"context": context, "frame": str(frame.to_dict())})})
1823
2171
 
1824
2172
  gen_text = await self.inference_engine.chat_async(messages=messages, messages_logger=messages_logger)
1825
- attribute_list = self._extract_json(gen_text=gen_text["response"])
2173
+ attribute_list = extract_json(gen_text=gen_text["response"])
1826
2174
  attributes = attribute_list[0] if isinstance(attribute_list, list) and len(attribute_list) > 0 else {}
1827
2175
  return {"frame": frame, "attributes": attributes, "messages": messages}
1828
2176
 
@@ -2075,7 +2423,7 @@ class BinaryRelationExtractor(RelationExtractor):
2075
2423
  return None
2076
2424
 
2077
2425
  def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
2078
- rel_json = self._extract_json(gen_text)
2426
+ rel_json = extract_json(gen_text)
2079
2427
  if len(rel_json) > 0 and "Relation" in rel_json[0]:
2080
2428
  rel = rel_json[0]["Relation"]
2081
2429
  if (isinstance(rel, bool) and rel) or (isinstance(rel, str) and rel.lower() == 'true'):
@@ -2141,7 +2489,7 @@ class MultiClassRelationExtractor(RelationExtractor):
2141
2489
  return None
2142
2490
 
2143
2491
  def _post_process_result(self, gen_text: str, pair_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
2144
- rel_json = self._extract_json(gen_text)
2492
+ rel_json = extract_json(gen_text)
2145
2493
  pos_rel_types = pair_data['pos_rel_types']
2146
2494
  if len(rel_json) > 0 and "RelationType" in rel_json[0]:
2147
2495
  rel_type = rel_json[0]["RelationType"]
llm_ie/prompt_editor.py CHANGED
@@ -2,6 +2,7 @@ import sys
2
2
  import warnings
3
3
  from typing import List, Dict, Generator
4
4
  import importlib.resources
5
+ from llm_ie.utils import apply_prompt_template
5
6
  from llm_ie.engines import InferenceEngine
6
7
  from llm_ie.extractors import FrameExtractor
7
8
  import re
@@ -45,30 +46,6 @@ class PromptEditor:
45
46
 
46
47
  # internal memory (history messages) for the `chat` method
47
48
  self.messages = []
48
-
49
- def _apply_prompt_template(self, text_content:Dict[str,str], prompt_template:str) -> str:
50
- """
51
- This method applies text_content to prompt_template and returns a prompt.
52
-
53
- Parameters
54
- ----------
55
- text_content : Dict[str,str]
56
- the input text content to put in prompt template.
57
- all the keys must be included in the prompt template placeholder {{<placeholder name>}}.
58
-
59
- Returns : str
60
- a prompt.
61
- """
62
- pattern = re.compile(r'{{(.*?)}}')
63
- placeholders = pattern.findall(prompt_template)
64
- if len(placeholders) != len(text_content):
65
- raise ValueError(f"Expect text_content ({len(text_content)}) and prompt template placeholder ({len(placeholders)}) to have equal size.")
66
- if not all([k in placeholders for k, _ in text_content.items()]):
67
- raise ValueError(f"All keys in text_content ({text_content.keys()}) must match placeholders in prompt template ({placeholders}).")
68
-
69
- prompt = pattern.sub(lambda match: re.sub(r'\\', r'\\\\', text_content[match.group(1)]), prompt_template)
70
-
71
- return prompt
72
49
 
73
50
 
74
51
  def rewrite(self, draft:str) -> str:
@@ -80,8 +57,8 @@ class PromptEditor:
80
57
  with open(file_path, 'r') as f:
81
58
  rewrite_prompt_template = f.read()
82
59
 
83
- prompt = self._apply_prompt_template(text_content={"draft": draft, "prompt_guideline": self.prompt_guide},
84
- prompt_template=rewrite_prompt_template)
60
+ prompt = apply_prompt_template(prompt_template=rewrite_prompt_template,
61
+ text_content={"draft": draft, "prompt_guideline": self.prompt_guide})
85
62
  messages = [{"role": "system", "content": self.system_prompt},
86
63
  {"role": "user", "content": prompt}]
87
64
  res = self.inference_engine.chat(messages, verbose=True)
@@ -96,8 +73,8 @@ class PromptEditor:
96
73
  with open(file_path, 'r') as f:
97
74
  comment_prompt_template = f.read()
98
75
 
99
- prompt = self._apply_prompt_template(text_content={"draft": draft, "prompt_guideline": self.prompt_guide},
100
- prompt_template=comment_prompt_template)
76
+ prompt = apply_prompt_template(prompt_template=comment_prompt_template,
77
+ text_content={"draft": draft, "prompt_guideline": self.prompt_guide})
101
78
  messages = [{"role": "system", "content": self.system_prompt},
102
79
  {"role": "user", "content": prompt}]
103
80
  res = self.inference_engine.chat(messages, verbose=True)
@@ -254,8 +231,8 @@ class PromptEditor:
254
231
  with open(file_path, 'r') as f:
255
232
  chat_prompt_template = f.read()
256
233
 
257
- guideline = self._apply_prompt_template(text_content={"prompt_guideline": self.prompt_guide},
258
- prompt_template=chat_prompt_template)
234
+ guideline = apply_prompt_template(prompt_template=chat_prompt_template,
235
+ text_content={"prompt_guideline": self.prompt_guide})
259
236
 
260
237
  self.messages = [{"role": "system", "content": self.system_prompt + guideline}]
261
238
 
@@ -288,8 +265,8 @@ class PromptEditor:
288
265
  with open(file_path, 'r') as f:
289
266
  chat_prompt_template = f.read()
290
267
 
291
- guideline = self._apply_prompt_template(text_content={"prompt_guideline": self.prompt_guide},
292
- prompt_template=chat_prompt_template)
268
+ guideline = apply_prompt_template(prompt_template=chat_prompt_template,
269
+ text_content={"prompt_guideline": self.prompt_guide})
293
270
 
294
271
  messages = [{"role": "system", "content": self.system_prompt + guideline}] + messages
295
272