hamtaa-texttools 1.1.17__py3-none-any.whl → 1.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.1.17.dist-info → hamtaa_texttools-1.1.19.dist-info}/METADATA +31 -1
- hamtaa_texttools-1.1.19.dist-info/RECORD +33 -0
- texttools/__init__.py +1 -1
- texttools/batch/batch_runner.py +75 -64
- texttools/{tools/internals → internals}/async_operator.py +96 -48
- texttools/internals/exceptions.py +28 -0
- texttools/{tools/internals → internals}/models.py +2 -2
- texttools/internals/prompt_loader.py +108 -0
- texttools/{tools/internals → internals}/sync_operator.py +92 -47
- texttools/prompts/check_fact.yaml +19 -0
- texttools/prompts/propositionize.yaml +13 -6
- texttools/prompts/run_custom.yaml +1 -1
- texttools/tools/async_tools.py +576 -348
- texttools/tools/sync_tools.py +573 -346
- hamtaa_texttools-1.1.17.dist-info/RECORD +0 -32
- texttools/prompts/detect_entity.yaml +0 -22
- texttools/tools/internals/prompt_loader.py +0 -56
- {hamtaa_texttools-1.1.17.dist-info → hamtaa_texttools-1.1.19.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.17.dist-info → hamtaa_texttools-1.1.19.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.17.dist-info → hamtaa_texttools-1.1.19.dist-info}/top_level.txt +0 -0
- /texttools/{tools/internals → internals}/formatters.py +0 -0
- /texttools/{tools/internals → internals}/operator_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hamtaa-texttools
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.19
|
|
4
4
|
Summary: A high-level NLP toolkit built on top of modern LLMs.
|
|
5
5
|
Author-email: Tohidi <the.mohammad.tohidi@gmail.com>, Montazer <montazerh82@gmail.com>, Givechi <mohamad.m.givechi@gmail.com>, MoosaviNejad <erfanmoosavi84@gmail.com>, Zareshahi <a.zareshahi1377@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -61,10 +61,40 @@ Each tool is designed to work with structured outputs (JSON / Pydantic).
|
|
|
61
61
|
- **`summarize()`** - Text summarization
|
|
62
62
|
- **`translate()`** - Text translation between languages
|
|
63
63
|
- **`propositionize()`** - Convert text to atomic independence meaningful sentences
|
|
64
|
+
- **`check_fact()`** - Check a statement is relevant to source text or not
|
|
64
65
|
- **`run_custom()`** - Allows users to define a custom tool with an arbitrary BaseModel
|
|
65
66
|
|
|
66
67
|
---
|
|
67
68
|
|
|
69
|
+
## 📊 Tool Quality Tiers
|
|
70
|
+
|
|
71
|
+
| Status | Meaning | Use in Production? |
|
|
72
|
+
|--------|---------|-------------------|
|
|
73
|
+
| **✅ Production** | Evaluated, tested, stable. | **Yes** - ready for reliable use. |
|
|
74
|
+
| **🧪 Experimental** | Added to the package but **not fully evaluated**. Functional, but quality may vary. | **Use with caution** - outputs not yet validated. |
|
|
75
|
+
|
|
76
|
+
### Current Status
|
|
77
|
+
**Production Tools:**
|
|
78
|
+
- `categorize()` (list mode)
|
|
79
|
+
- `extract_keywords()`
|
|
80
|
+
- `extract_entities()`
|
|
81
|
+
- `is_question()`
|
|
82
|
+
- `text_to_question()`
|
|
83
|
+
- `merge_questions()`
|
|
84
|
+
- `rewrite()`
|
|
85
|
+
- `subject_to_question()`
|
|
86
|
+
- `summarize()`
|
|
87
|
+
- `run_custom()` (fine in most cases)
|
|
88
|
+
|
|
89
|
+
**Experimental Tools:**
|
|
90
|
+
- `categorize()` (tree mode)
|
|
91
|
+
- `translate()`
|
|
92
|
+
- `propositionize()`
|
|
93
|
+
- `check_fact()`
|
|
94
|
+
- `run_custom()` (not evaluated in all scenarios)
|
|
95
|
+
|
|
96
|
+
---
|
|
97
|
+
|
|
68
98
|
## ⚙️ `with_analysis`, `logprobs`, `output_lang`, `user_prompt`, `temperature`, `validator` and `priority` parameters
|
|
69
99
|
|
|
70
100
|
TextTools provides several optional flags to customize LLM behavior:
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
hamtaa_texttools-1.1.19.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
|
|
2
|
+
texttools/__init__.py,sha256=CmCS9dEvO6061GiJ8A7gD3UAhCWHTkaID9q3Krlyq_o,311
|
|
3
|
+
texttools/batch/batch_config.py,sha256=m1UgILVKjNdWE6laNbfbG4vgi4o2fEegGZbeoam6pnY,749
|
|
4
|
+
texttools/batch/batch_runner.py,sha256=Tz-jec27UZBSZAXc0sxitc5XycDfzvOYl47Yqzq6Myw,10031
|
|
5
|
+
texttools/batch/internals/batch_manager.py,sha256=UoBe76vmFG72qrSaGKDZf4HzkykFBkkkbL9TLfV8TuQ,8730
|
|
6
|
+
texttools/batch/internals/utils.py,sha256=F1_7YlVFKhjUROAFX4m0SaP8KiZVZyHRMIIB87VUGQc,373
|
|
7
|
+
texttools/internals/async_operator.py,sha256=_RfYSm_66RJ6nppzorJ4r3BHdhr8xr404QjeVvsvX4Q,8485
|
|
8
|
+
texttools/internals/exceptions.py,sha256=h_yp_5i_5IfmqTBQ4S6ZOISrrliJBQ3HTEAjwJXrplk,495
|
|
9
|
+
texttools/internals/formatters.py,sha256=tACNLP6PeoqaRpNudVxBaHA25zyWqWYPZQuYysIu88g,941
|
|
10
|
+
texttools/internals/models.py,sha256=zmgdFhMCNyfc-5dtSE4jwulhltVgxYzITZRMDJBUF0A,5977
|
|
11
|
+
texttools/internals/operator_utils.py,sha256=w1k0RJ_W_CRbVc_J2w337VuL-opHpHiCxfhEOwtyuOo,1856
|
|
12
|
+
texttools/internals/prompt_loader.py,sha256=i4OxcVJTjHFKPSoC-DWZUM3Vf8ye_vbD7b6t3N2qB08,3972
|
|
13
|
+
texttools/internals/sync_operator.py,sha256=7SdsNoFQxgmMrSZbUUw7SJVqyO5Xhu8dui9lm64RKsk,8382
|
|
14
|
+
texttools/prompts/README.md,sha256=-5YO93CN93QLifqZpUeUnCOCBbDiOTV-cFQeJ7Gg0I4,1377
|
|
15
|
+
texttools/prompts/categorize.yaml,sha256=F7VezB25B_sT5yoC25ezODBddkuDD5lUHKetSpx9FKI,2743
|
|
16
|
+
texttools/prompts/check_fact.yaml,sha256=5kpBjmfZxgp81Owc8-Pd0U8-cZowFGRdYlGTFQLYQ9o,702
|
|
17
|
+
texttools/prompts/extract_entities.yaml,sha256=KiKjeDpHaeh3JVtZ6q1pa3k4DYucUIU9WnEcRTCA-SE,651
|
|
18
|
+
texttools/prompts/extract_keywords.yaml,sha256=Vj4Tt3vT6LtpOo_iBZPo9oWI50oVdPGXe5i8yDR8ex4,3177
|
|
19
|
+
texttools/prompts/is_question.yaml,sha256=d0-vKRbXWkxvO64ikvxRjEmpAXGpCYIPGhgexvPPjws,471
|
|
20
|
+
texttools/prompts/merge_questions.yaml,sha256=0J85GvTirZB4ELwH3sk8ub_WcqqpYf6PrMKr3djlZeo,1792
|
|
21
|
+
texttools/prompts/propositionize.yaml,sha256=kdj-UxPOYcLSTLF7cWARDxxTxSFB0qRBaRujdThPDxw,1380
|
|
22
|
+
texttools/prompts/rewrite.yaml,sha256=LO7He_IA3MZKz8a-LxH9DHJpOjpYwaYN1pbjp1Y0tFo,5392
|
|
23
|
+
texttools/prompts/run_custom.yaml,sha256=6oiMYOo_WctVbOmE01wZzI1ra7nFDMJzceTTtnGdmOA,126
|
|
24
|
+
texttools/prompts/subject_to_question.yaml,sha256=C7x7rNNm6U_ZG9HOn6zuzYOtvJUZ2skuWbL1-aYdd3E,1147
|
|
25
|
+
texttools/prompts/summarize.yaml,sha256=o6rxGPfWtZd61Duvm8NVvCJqfq73b-wAuMSKR6UYUqY,459
|
|
26
|
+
texttools/prompts/text_to_question.yaml,sha256=UheKYpDn6iyKI8NxunHZtFpNyfCLZZe5cvkuXpurUJY,783
|
|
27
|
+
texttools/prompts/translate.yaml,sha256=mGT2uBCei6uucWqVbs4silk-UV060v3G0jnt0P6sr50,634
|
|
28
|
+
texttools/tools/async_tools.py,sha256=eNsKJqpTNL1AIM_enHvqUJYxov1Mb5ErnShZuX7oqRQ,49532
|
|
29
|
+
texttools/tools/sync_tools.py,sha256=DbY7smzYnEgd1H0r-5sVW-NExJwhL23TtQ_n5ACqbBc,49344
|
|
30
|
+
hamtaa_texttools-1.1.19.dist-info/METADATA,sha256=egHahU5ec3bSpRB0DK0CrT21kfmhbM6-xxJjoMp1eDU,10587
|
|
31
|
+
hamtaa_texttools-1.1.19.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
32
|
+
hamtaa_texttools-1.1.19.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
|
|
33
|
+
hamtaa_texttools-1.1.19.dist-info/RECORD,,
|
texttools/__init__.py
CHANGED
|
@@ -2,6 +2,6 @@ from .batch.batch_runner import BatchJobRunner
|
|
|
2
2
|
from .batch.batch_config import BatchConfig
|
|
3
3
|
from .tools.sync_tools import TheTool
|
|
4
4
|
from .tools.async_tools import AsyncTheTool
|
|
5
|
-
from .
|
|
5
|
+
from .internals.models import CategoryTree
|
|
6
6
|
|
|
7
7
|
__all__ = ["TheTool", "AsyncTheTool", "BatchJobRunner", "BatchConfig", "CategoryTree"]
|
texttools/batch/batch_runner.py
CHANGED
|
@@ -11,7 +11,8 @@ from pydantic import BaseModel
|
|
|
11
11
|
|
|
12
12
|
from texttools.batch.internals.batch_manager import BatchManager
|
|
13
13
|
from texttools.batch.batch_config import BatchConfig
|
|
14
|
-
from texttools.
|
|
14
|
+
from texttools.internals.models import StrOutput
|
|
15
|
+
from texttools.internals.exceptions import TextToolsError, ConfigurationError
|
|
15
16
|
|
|
16
17
|
# Base Model type for output models
|
|
17
18
|
T = TypeVar("T", bound=BaseModel)
|
|
@@ -27,22 +28,26 @@ class BatchJobRunner:
|
|
|
27
28
|
def __init__(
|
|
28
29
|
self, config: BatchConfig = BatchConfig(), output_model: Type[T] = StrOutput
|
|
29
30
|
):
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
31
|
+
try:
|
|
32
|
+
self._config = config
|
|
33
|
+
self._system_prompt = config.system_prompt
|
|
34
|
+
self._job_name = config.job_name
|
|
35
|
+
self._input_data_path = config.input_data_path
|
|
36
|
+
self._output_data_filename = config.output_data_filename
|
|
37
|
+
self._model = config.model
|
|
38
|
+
self._output_model = output_model
|
|
39
|
+
self._manager = self._init_manager()
|
|
40
|
+
self._data = self._load_data()
|
|
41
|
+
self._parts: list[list[dict[str, Any]]] = []
|
|
42
|
+
# Map part index to job name
|
|
43
|
+
self._part_idx_to_job_name: dict[int, str] = {}
|
|
44
|
+
# Track retry attempts per part
|
|
45
|
+
self._part_attempts: dict[int, int] = {}
|
|
46
|
+
self._partition_data()
|
|
47
|
+
Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
48
|
+
|
|
49
|
+
except Exception as e:
|
|
50
|
+
raise ConfigurationError(f"Batch runner initialization failed: {e}")
|
|
46
51
|
|
|
47
52
|
def _init_manager(self) -> BatchManager:
|
|
48
53
|
load_dotenv()
|
|
@@ -162,56 +167,62 @@ class BatchJobRunner:
|
|
|
162
167
|
|
|
163
168
|
Submits jobs, monitors progress, handles retries, and saves results.
|
|
164
169
|
"""
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
179
|
-
)
|
|
180
|
-
output_data, log = self._manager.fetch_results(
|
|
181
|
-
job_name=job_name, remove_cache=False
|
|
182
|
-
)
|
|
183
|
-
output_data = self._config.import_function(output_data)
|
|
184
|
-
self._save_results(output_data, log, part_idx)
|
|
185
|
-
logger.info(f"Fetched and saved results for part {part_idx + 1}.")
|
|
186
|
-
finished_this_round.append(part_idx)
|
|
187
|
-
elif status == "failed":
|
|
188
|
-
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
189
|
-
self._part_attempts[part_idx] = attempt
|
|
190
|
-
if attempt <= self._config.max_retries:
|
|
170
|
+
try:
|
|
171
|
+
# Submit all jobs up-front for concurrent execution
|
|
172
|
+
self._submit_all_jobs()
|
|
173
|
+
pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
|
|
174
|
+
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
175
|
+
# Polling loop
|
|
176
|
+
while pending_parts:
|
|
177
|
+
finished_this_round: list[int] = []
|
|
178
|
+
for part_idx in list(pending_parts):
|
|
179
|
+
job_name = self._part_idx_to_job_name[part_idx]
|
|
180
|
+
status = self._manager.check_status(job_name=job_name)
|
|
181
|
+
logger.info(f"Status for {job_name}: {status}")
|
|
182
|
+
if status == "completed":
|
|
191
183
|
logger.info(
|
|
192
|
-
f"Job
|
|
184
|
+
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
193
185
|
)
|
|
194
|
-
self._manager.
|
|
195
|
-
|
|
196
|
-
payload = self._to_manager_payload(self._parts[part_idx])
|
|
197
|
-
new_job_name = (
|
|
198
|
-
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
186
|
+
output_data, log = self._manager.fetch_results(
|
|
187
|
+
job_name=job_name, remove_cache=False
|
|
199
188
|
)
|
|
200
|
-
self.
|
|
201
|
-
self.
|
|
202
|
-
else:
|
|
189
|
+
output_data = self._config.import_function(output_data)
|
|
190
|
+
self._save_results(output_data, log, part_idx)
|
|
203
191
|
logger.info(
|
|
204
|
-
f"
|
|
192
|
+
f"Fetched and saved results for part {part_idx + 1}."
|
|
205
193
|
)
|
|
206
194
|
finished_this_round.append(part_idx)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
195
|
+
elif status == "failed":
|
|
196
|
+
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
197
|
+
self._part_attempts[part_idx] = attempt
|
|
198
|
+
if attempt <= self._config.max_retries:
|
|
199
|
+
logger.info(
|
|
200
|
+
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
201
|
+
)
|
|
202
|
+
self._manager._clear_state(job_name)
|
|
203
|
+
time.sleep(10)
|
|
204
|
+
payload = self._to_manager_payload(self._parts[part_idx])
|
|
205
|
+
new_job_name = (
|
|
206
|
+
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
207
|
+
)
|
|
208
|
+
self._manager.start(payload, job_name=new_job_name)
|
|
209
|
+
self._part_idx_to_job_name[part_idx] = new_job_name
|
|
210
|
+
else:
|
|
211
|
+
logger.info(
|
|
212
|
+
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
213
|
+
)
|
|
214
|
+
finished_this_round.append(part_idx)
|
|
215
|
+
else:
|
|
216
|
+
# Still running or queued
|
|
217
|
+
continue
|
|
218
|
+
# Remove finished parts
|
|
219
|
+
for part_idx in finished_this_round:
|
|
220
|
+
pending_parts.discard(part_idx)
|
|
221
|
+
if pending_parts:
|
|
222
|
+
logger.info(
|
|
223
|
+
f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
224
|
+
)
|
|
225
|
+
time.sleep(self._config.poll_interval_seconds)
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
raise TextToolsError(f"Batch job execution failed: {e}")
|
|
@@ -5,10 +5,16 @@ import logging
|
|
|
5
5
|
from openai import AsyncOpenAI
|
|
6
6
|
from pydantic import BaseModel
|
|
7
7
|
|
|
8
|
-
from texttools.
|
|
9
|
-
from texttools.
|
|
10
|
-
from texttools.
|
|
11
|
-
from texttools.
|
|
8
|
+
from texttools.internals.models import ToolOutput
|
|
9
|
+
from texttools.internals.operator_utils import OperatorUtils
|
|
10
|
+
from texttools.internals.formatters import Formatter
|
|
11
|
+
from texttools.internals.prompt_loader import PromptLoader
|
|
12
|
+
from texttools.internals.exceptions import (
|
|
13
|
+
TextToolsError,
|
|
14
|
+
LLMError,
|
|
15
|
+
ValidationError,
|
|
16
|
+
PromptError,
|
|
17
|
+
)
|
|
12
18
|
|
|
13
19
|
# Base Model type for output models
|
|
14
20
|
T = TypeVar("T", bound=BaseModel)
|
|
@@ -35,15 +41,33 @@ class AsyncOperator:
|
|
|
35
41
|
Calls OpenAI API for analysis using the configured prompt template.
|
|
36
42
|
Returns the analyzed content as a string.
|
|
37
43
|
"""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
44
|
+
try:
|
|
45
|
+
analyze_prompt = prompt_configs["analyze_template"]
|
|
46
|
+
|
|
47
|
+
if not analyze_prompt:
|
|
48
|
+
raise PromptError("Analyze template is empty")
|
|
49
|
+
|
|
50
|
+
analyze_message = [OperatorUtils.build_user_message(analyze_prompt)]
|
|
51
|
+
completion = await self._client.chat.completions.create(
|
|
52
|
+
model=self._model,
|
|
53
|
+
messages=analyze_message,
|
|
54
|
+
temperature=temperature,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if not completion.choices:
|
|
58
|
+
raise LLMError("No choices returned from LLM")
|
|
59
|
+
|
|
60
|
+
analysis = completion.choices[0].message.content.strip()
|
|
61
|
+
|
|
62
|
+
if not analysis:
|
|
63
|
+
raise LLMError("Empty analysis response")
|
|
64
|
+
|
|
65
|
+
return analysis.strip()
|
|
66
|
+
|
|
67
|
+
except Exception as e:
|
|
68
|
+
if isinstance(e, (PromptError, LLMError)):
|
|
69
|
+
raise
|
|
70
|
+
raise LLMError(f"Analysis failed: {e}")
|
|
47
71
|
|
|
48
72
|
async def _parse_completion(
|
|
49
73
|
self,
|
|
@@ -58,21 +82,37 @@ class AsyncOperator:
|
|
|
58
82
|
Parses a chat completion using OpenAI's structured output format.
|
|
59
83
|
Returns both the parsed object and the raw completion for logprobs.
|
|
60
84
|
"""
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
85
|
+
try:
|
|
86
|
+
request_kwargs = {
|
|
87
|
+
"model": self._model,
|
|
88
|
+
"messages": message,
|
|
89
|
+
"response_format": output_model,
|
|
90
|
+
"temperature": temperature,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
if logprobs:
|
|
94
|
+
request_kwargs["logprobs"] = True
|
|
95
|
+
request_kwargs["top_logprobs"] = top_logprobs
|
|
96
|
+
if priority:
|
|
97
|
+
request_kwargs["extra_body"] = {"priority": priority}
|
|
98
|
+
completion = await self._client.beta.chat.completions.parse(
|
|
99
|
+
**request_kwargs
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if not completion.choices:
|
|
103
|
+
raise LLMError("No choices returned from LLM")
|
|
104
|
+
|
|
105
|
+
parsed = completion.choices[0].message.parsed
|
|
106
|
+
|
|
107
|
+
if not parsed:
|
|
108
|
+
raise LLMError("Failed to parse LLM response")
|
|
109
|
+
|
|
110
|
+
return parsed, completion
|
|
111
|
+
|
|
112
|
+
except Exception as e:
|
|
113
|
+
if isinstance(e, LLMError):
|
|
114
|
+
raise
|
|
115
|
+
raise LLMError(f"Completion failed: {e}")
|
|
76
116
|
|
|
77
117
|
async def run(
|
|
78
118
|
self,
|
|
@@ -94,13 +134,13 @@ class AsyncOperator:
|
|
|
94
134
|
**extra_kwargs,
|
|
95
135
|
) -> ToolOutput:
|
|
96
136
|
"""
|
|
97
|
-
Execute the
|
|
137
|
+
Execute the LLM pipeline with the given input text. (Async)
|
|
98
138
|
"""
|
|
99
|
-
prompt_loader = PromptLoader()
|
|
100
|
-
formatter = Formatter()
|
|
101
|
-
output = ToolOutput()
|
|
102
|
-
|
|
103
139
|
try:
|
|
140
|
+
prompt_loader = PromptLoader()
|
|
141
|
+
formatter = Formatter()
|
|
142
|
+
output = ToolOutput()
|
|
143
|
+
|
|
104
144
|
# Prompt configs contain two keys: main_template and analyze template, both are string
|
|
105
145
|
prompt_configs = prompt_loader.load(
|
|
106
146
|
prompt_file=prompt_file,
|
|
@@ -139,6 +179,9 @@ class AsyncOperator:
|
|
|
139
179
|
|
|
140
180
|
messages = formatter.user_merge_format(messages)
|
|
141
181
|
|
|
182
|
+
if logprobs and (not isinstance(top_logprobs, int) or top_logprobs < 2):
|
|
183
|
+
raise ValueError("top_logprobs should be an integer greater than 1")
|
|
184
|
+
|
|
142
185
|
parsed, completion = await self._parse_completion(
|
|
143
186
|
messages, output_model, temperature, logprobs, top_logprobs, priority
|
|
144
187
|
)
|
|
@@ -147,6 +190,15 @@ class AsyncOperator:
|
|
|
147
190
|
|
|
148
191
|
# Retry logic if validation fails
|
|
149
192
|
if validator and not validator(output.result):
|
|
193
|
+
if (
|
|
194
|
+
not isinstance(max_validation_retries, int)
|
|
195
|
+
or max_validation_retries < 1
|
|
196
|
+
):
|
|
197
|
+
raise ValueError(
|
|
198
|
+
"max_validation_retries should be a positive integer"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
succeeded = False
|
|
150
202
|
for attempt in range(max_validation_retries):
|
|
151
203
|
logger.warning(
|
|
152
204
|
f"Validation failed, retrying for the {attempt + 1} time."
|
|
@@ -154,6 +206,7 @@ class AsyncOperator:
|
|
|
154
206
|
|
|
155
207
|
# Generate new temperature for retry
|
|
156
208
|
retry_temperature = OperatorUtils.get_retry_temp(temperature)
|
|
209
|
+
|
|
157
210
|
try:
|
|
158
211
|
parsed, completion = await self._parse_completion(
|
|
159
212
|
messages,
|
|
@@ -161,28 +214,23 @@ class AsyncOperator:
|
|
|
161
214
|
retry_temperature,
|
|
162
215
|
logprobs,
|
|
163
216
|
top_logprobs,
|
|
217
|
+
priority=priority,
|
|
164
218
|
)
|
|
165
219
|
|
|
166
220
|
output.result = parsed.result
|
|
167
221
|
|
|
168
222
|
# Check if retry was successful
|
|
169
223
|
if validator(output.result):
|
|
170
|
-
|
|
171
|
-
f"Validation passed on retry attempt {attempt + 1}"
|
|
172
|
-
)
|
|
224
|
+
succeeded = True
|
|
173
225
|
break
|
|
174
|
-
else:
|
|
175
|
-
logger.warning(
|
|
176
|
-
f"Validation still failing after retry attempt {attempt + 1}"
|
|
177
|
-
)
|
|
178
226
|
|
|
179
|
-
except
|
|
227
|
+
except LLMError as e:
|
|
180
228
|
logger.error(f"Retry attempt {attempt + 1} failed: {e}")
|
|
181
|
-
# Continue to next retry attempt if this one fails
|
|
182
229
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
230
|
+
if not succeeded:
|
|
231
|
+
raise ValidationError(
|
|
232
|
+
f"Validation failed after {max_validation_retries} retries"
|
|
233
|
+
)
|
|
186
234
|
|
|
187
235
|
if logprobs:
|
|
188
236
|
output.logprobs = OperatorUtils.extract_logprobs(completion)
|
|
@@ -194,7 +242,7 @@ class AsyncOperator:
|
|
|
194
242
|
|
|
195
243
|
return output
|
|
196
244
|
|
|
245
|
+
except (PromptError, LLMError, ValidationError):
|
|
246
|
+
raise
|
|
197
247
|
except Exception as e:
|
|
198
|
-
|
|
199
|
-
output.errors.append(str(e))
|
|
200
|
-
return output
|
|
248
|
+
raise TextToolsError(f"Unexpected error in operator: {e}")
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
class TextToolsError(Exception):
|
|
2
|
+
"""Base exception for all TextTools errors."""
|
|
3
|
+
|
|
4
|
+
pass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PromptError(TextToolsError):
|
|
8
|
+
"""Errors related to prompt loading and formatting."""
|
|
9
|
+
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LLMError(TextToolsError):
|
|
14
|
+
"""Errors from LLM API calls."""
|
|
15
|
+
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ValidationError(TextToolsError):
|
|
20
|
+
"""Errors from output validation."""
|
|
21
|
+
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ConfigurationError(TextToolsError):
|
|
26
|
+
"""Errors from misconfiguration."""
|
|
27
|
+
|
|
28
|
+
pass
|
|
@@ -8,9 +8,9 @@ class ToolOutput(BaseModel):
|
|
|
8
8
|
result: Any = None
|
|
9
9
|
logprobs: list[dict[str, Any]] = []
|
|
10
10
|
analysis: str = ""
|
|
11
|
-
process: str =
|
|
11
|
+
process: str | None = None
|
|
12
12
|
processed_at: datetime = datetime.now()
|
|
13
|
-
execution_time: float =
|
|
13
|
+
execution_time: float | None = None
|
|
14
14
|
errors: list[str] = []
|
|
15
15
|
|
|
16
16
|
def __repr__(self) -> str:
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
import yaml
|
|
4
|
+
|
|
5
|
+
from texttools.internals.exceptions import PromptError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PromptLoader:
|
|
9
|
+
"""
|
|
10
|
+
Utility for loading and formatting YAML prompt templates.
|
|
11
|
+
|
|
12
|
+
Responsibilities:
|
|
13
|
+
- Load and parse YAML prompt definitions.
|
|
14
|
+
- Select the right template (by mode, if applicable).
|
|
15
|
+
- Inject variables (`{input}`, plus any extra kwargs) into the templates.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
MAIN_TEMPLATE = "main_template"
|
|
19
|
+
ANALYZE_TEMPLATE = "analyze_template"
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def _build_format_args(text: str, **extra_kwargs) -> dict[str, str]:
|
|
23
|
+
# Base formatting args
|
|
24
|
+
format_args = {"input": text}
|
|
25
|
+
# Merge extras
|
|
26
|
+
format_args.update(extra_kwargs)
|
|
27
|
+
return format_args
|
|
28
|
+
|
|
29
|
+
# Use lru_cache to load each file once
|
|
30
|
+
@lru_cache(maxsize=32)
|
|
31
|
+
def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
|
|
32
|
+
"""
|
|
33
|
+
Loads prompt templates from YAML file with optional mode selection.
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
base_dir = Path(__file__).parent.parent / Path("prompts")
|
|
37
|
+
prompt_path = base_dir / prompt_file
|
|
38
|
+
|
|
39
|
+
if not prompt_path.exists():
|
|
40
|
+
raise PromptError(f"Prompt file not found: {prompt_file}")
|
|
41
|
+
|
|
42
|
+
data = yaml.safe_load(prompt_path.read_text(encoding="utf-8"))
|
|
43
|
+
|
|
44
|
+
if self.MAIN_TEMPLATE not in data:
|
|
45
|
+
raise PromptError(f"Missing 'main_template' in {prompt_file}")
|
|
46
|
+
|
|
47
|
+
if self.ANALYZE_TEMPLATE not in data:
|
|
48
|
+
raise PromptError(f"Missing 'analyze_template' in {prompt_file}")
|
|
49
|
+
|
|
50
|
+
if mode and mode not in data.get(self.MAIN_TEMPLATE, {}):
|
|
51
|
+
raise PromptError(f"Mode '{mode}' not found in {prompt_file}")
|
|
52
|
+
|
|
53
|
+
# Extract templates based on mode
|
|
54
|
+
main_template = (
|
|
55
|
+
data[self.MAIN_TEMPLATE][mode]
|
|
56
|
+
if mode and isinstance(data[self.MAIN_TEMPLATE], dict)
|
|
57
|
+
else data[self.MAIN_TEMPLATE]
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
analyze_template = (
|
|
61
|
+
data[self.ANALYZE_TEMPLATE][mode]
|
|
62
|
+
if mode and isinstance(data[self.ANALYZE_TEMPLATE], dict)
|
|
63
|
+
else data[self.ANALYZE_TEMPLATE]
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if not main_template or not main_template.strip():
|
|
67
|
+
raise PromptError(
|
|
68
|
+
f"Empty main_template in {prompt_file}"
|
|
69
|
+
+ (f" for mode '{mode}'" if mode else "")
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
if (
|
|
73
|
+
not analyze_template
|
|
74
|
+
or not analyze_template.strip()
|
|
75
|
+
or analyze_template.strip() in ["{analyze_template}", "{}"]
|
|
76
|
+
):
|
|
77
|
+
raise PromptError(
|
|
78
|
+
"analyze_template cannot be empty"
|
|
79
|
+
+ (f" for mode '{mode}'" if mode else "")
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return {
|
|
83
|
+
self.MAIN_TEMPLATE: main_template,
|
|
84
|
+
self.ANALYZE_TEMPLATE: analyze_template,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
except yaml.YAMLError as e:
|
|
88
|
+
raise PromptError(f"Invalid YAML in {prompt_file}: {e}")
|
|
89
|
+
except Exception as e:
|
|
90
|
+
raise PromptError(f"Failed to load prompt {prompt_file}: {e}")
|
|
91
|
+
|
|
92
|
+
def load(
|
|
93
|
+
self, prompt_file: str, text: str, mode: str, **extra_kwargs
|
|
94
|
+
) -> dict[str, str]:
|
|
95
|
+
try:
|
|
96
|
+
template_configs = self._load_templates(prompt_file, mode)
|
|
97
|
+
format_args = self._build_format_args(text, **extra_kwargs)
|
|
98
|
+
|
|
99
|
+
# Inject variables inside each template
|
|
100
|
+
for key in template_configs.keys():
|
|
101
|
+
template_configs[key] = template_configs[key].format(**format_args)
|
|
102
|
+
|
|
103
|
+
return template_configs
|
|
104
|
+
|
|
105
|
+
except KeyError as e:
|
|
106
|
+
raise PromptError(f"Missing template variable: {e}")
|
|
107
|
+
except Exception as e:
|
|
108
|
+
raise PromptError(f"Failed to format prompt: {e}")
|