hamtaa-texttools 1.1.16__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hamtaa_texttools-1.2.0.dist-info/METADATA +212 -0
- hamtaa_texttools-1.2.0.dist-info/RECORD +34 -0
- texttools/__init__.py +5 -5
- texttools/batch/__init__.py +0 -0
- texttools/batch/{batch_config.py → config.py} +16 -2
- texttools/batch/{internals/batch_manager.py → manager.py} +2 -2
- texttools/batch/{batch_runner.py → runner.py} +80 -69
- texttools/core/__init__.py +0 -0
- texttools/core/engine.py +254 -0
- texttools/core/exceptions.py +22 -0
- texttools/core/internal_models.py +58 -0
- texttools/core/operators/async_operator.py +194 -0
- texttools/core/operators/sync_operator.py +192 -0
- texttools/models.py +88 -0
- texttools/prompts/categorize.yaml +36 -77
- texttools/prompts/check_fact.yaml +24 -0
- texttools/prompts/extract_entities.yaml +7 -3
- texttools/prompts/extract_keywords.yaml +21 -9
- texttools/prompts/is_question.yaml +6 -2
- texttools/prompts/merge_questions.yaml +12 -5
- texttools/prompts/propositionize.yaml +24 -0
- texttools/prompts/rewrite.yaml +9 -10
- texttools/prompts/run_custom.yaml +2 -2
- texttools/prompts/subject_to_question.yaml +7 -3
- texttools/prompts/summarize.yaml +6 -2
- texttools/prompts/text_to_question.yaml +12 -6
- texttools/prompts/translate.yaml +7 -2
- texttools/py.typed +0 -0
- texttools/tools/__init__.py +0 -0
- texttools/tools/async_tools.py +778 -489
- texttools/tools/sync_tools.py +775 -487
- hamtaa_texttools-1.1.16.dist-info/METADATA +0 -255
- hamtaa_texttools-1.1.16.dist-info/RECORD +0 -31
- texttools/batch/internals/utils.py +0 -16
- texttools/prompts/README.md +0 -35
- texttools/prompts/detect_entity.yaml +0 -22
- texttools/tools/internals/async_operator.py +0 -200
- texttools/tools/internals/formatters.py +0 -24
- texttools/tools/internals/models.py +0 -183
- texttools/tools/internals/operator_utils.py +0 -54
- texttools/tools/internals/prompt_loader.py +0 -56
- texttools/tools/internals/sync_operator.py +0 -201
- {hamtaa_texttools-1.1.16.dist-info → hamtaa_texttools-1.2.0.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.16.dist-info → hamtaa_texttools-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.16.dist-info → hamtaa_texttools-1.2.0.dist-info}/top_level.txt +0 -0
texttools/core/engine.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import random
|
|
3
|
+
import re
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
from .exceptions import PromptError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PromptLoader:
|
|
13
|
+
"""
|
|
14
|
+
Utility for loading and formatting YAML prompt templates.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
MAIN_TEMPLATE = "main_template"
|
|
18
|
+
ANALYZE_TEMPLATE = "analyze_template"
|
|
19
|
+
|
|
20
|
+
@lru_cache(maxsize=32)
|
|
21
|
+
def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
|
|
22
|
+
"""
|
|
23
|
+
Loads prompt templates from YAML file with optional mode selection.
|
|
24
|
+
"""
|
|
25
|
+
try:
|
|
26
|
+
base_dir = Path(__file__).parent.parent / Path("prompts")
|
|
27
|
+
prompt_path = base_dir / prompt_file
|
|
28
|
+
|
|
29
|
+
if not prompt_path.exists():
|
|
30
|
+
raise PromptError(f"Prompt file not found: {prompt_file}")
|
|
31
|
+
|
|
32
|
+
data = yaml.safe_load(prompt_path.read_text(encoding="utf-8"))
|
|
33
|
+
|
|
34
|
+
if self.MAIN_TEMPLATE not in data:
|
|
35
|
+
raise PromptError(f"Missing 'main_template' in {prompt_file}")
|
|
36
|
+
|
|
37
|
+
if self.ANALYZE_TEMPLATE not in data:
|
|
38
|
+
raise PromptError(f"Missing 'analyze_template' in {prompt_file}")
|
|
39
|
+
|
|
40
|
+
if mode and mode not in data.get(self.MAIN_TEMPLATE, {}):
|
|
41
|
+
raise PromptError(f"Mode '{mode}' not found in {prompt_file}")
|
|
42
|
+
|
|
43
|
+
main_template = (
|
|
44
|
+
data[self.MAIN_TEMPLATE][mode]
|
|
45
|
+
if mode and isinstance(data[self.MAIN_TEMPLATE], dict)
|
|
46
|
+
else data[self.MAIN_TEMPLATE]
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
analyze_template = (
|
|
50
|
+
data[self.ANALYZE_TEMPLATE][mode]
|
|
51
|
+
if mode and isinstance(data[self.ANALYZE_TEMPLATE], dict)
|
|
52
|
+
else data[self.ANALYZE_TEMPLATE]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if not main_template or not main_template.strip():
|
|
56
|
+
raise PromptError(
|
|
57
|
+
f"Empty main_template in {prompt_file}"
|
|
58
|
+
+ (f" for mode '{mode}'" if mode else "")
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return {
|
|
62
|
+
self.MAIN_TEMPLATE: main_template,
|
|
63
|
+
self.ANALYZE_TEMPLATE: analyze_template,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
except yaml.YAMLError as e:
|
|
67
|
+
raise PromptError(f"Invalid YAML in {prompt_file}: {e}")
|
|
68
|
+
except Exception as e:
|
|
69
|
+
raise PromptError(f"Failed to load prompt {prompt_file}: {e}")
|
|
70
|
+
|
|
71
|
+
def load(
|
|
72
|
+
self, prompt_file: str, text: str, mode: str, **extra_kwargs
|
|
73
|
+
) -> dict[str, str]:
|
|
74
|
+
try:
|
|
75
|
+
template_configs = self._load_templates(prompt_file, mode)
|
|
76
|
+
format_args = {"text": text}
|
|
77
|
+
format_args.update(extra_kwargs)
|
|
78
|
+
|
|
79
|
+
# Inject variables inside each template
|
|
80
|
+
for key in template_configs.keys():
|
|
81
|
+
template_configs[key] = template_configs[key].format(**format_args)
|
|
82
|
+
|
|
83
|
+
return template_configs
|
|
84
|
+
|
|
85
|
+
except KeyError as e:
|
|
86
|
+
raise PromptError(f"Missing template variable: {e}")
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise PromptError(f"Failed to format prompt: {e}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class OperatorUtils:
|
|
92
|
+
@staticmethod
|
|
93
|
+
def build_main_prompt(
|
|
94
|
+
main_template: str,
|
|
95
|
+
analysis: str | None,
|
|
96
|
+
output_lang: str | None,
|
|
97
|
+
user_prompt: str | None,
|
|
98
|
+
) -> str:
|
|
99
|
+
main_prompt = ""
|
|
100
|
+
|
|
101
|
+
if analysis:
|
|
102
|
+
main_prompt += f"Based on this analysis:\n{analysis}\n"
|
|
103
|
+
|
|
104
|
+
if output_lang:
|
|
105
|
+
main_prompt += f"Respond only in the {output_lang} language.\n"
|
|
106
|
+
|
|
107
|
+
if user_prompt:
|
|
108
|
+
main_prompt += f"Consider this instruction {user_prompt}\n"
|
|
109
|
+
|
|
110
|
+
main_prompt += main_template
|
|
111
|
+
|
|
112
|
+
return main_prompt
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def build_message(prompt: str) -> list[dict[str, str]]:
|
|
116
|
+
return [{"role": "user", "content": prompt}]
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def extract_logprobs(completion: dict) -> list[dict]:
|
|
120
|
+
"""
|
|
121
|
+
Extracts and filters token probabilities from completion logprobs.
|
|
122
|
+
Skips punctuation and structural tokens, returns cleaned probability data.
|
|
123
|
+
"""
|
|
124
|
+
logprobs_data = []
|
|
125
|
+
|
|
126
|
+
ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
|
|
127
|
+
|
|
128
|
+
for choice in completion.choices:
|
|
129
|
+
if not getattr(choice, "logprobs", None):
|
|
130
|
+
raise ValueError("Your model does not support logprobs")
|
|
131
|
+
|
|
132
|
+
for logprob_item in choice.logprobs.content:
|
|
133
|
+
if ignore_pattern.match(logprob_item.token):
|
|
134
|
+
continue
|
|
135
|
+
token_entry = {
|
|
136
|
+
"token": logprob_item.token,
|
|
137
|
+
"prob": round(math.exp(logprob_item.logprob), 8),
|
|
138
|
+
"top_alternatives": [],
|
|
139
|
+
}
|
|
140
|
+
for alt in logprob_item.top_logprobs:
|
|
141
|
+
if ignore_pattern.match(alt.token):
|
|
142
|
+
continue
|
|
143
|
+
token_entry["top_alternatives"].append(
|
|
144
|
+
{
|
|
145
|
+
"token": alt.token,
|
|
146
|
+
"prob": round(math.exp(alt.logprob), 8),
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
logprobs_data.append(token_entry)
|
|
150
|
+
|
|
151
|
+
return logprobs_data
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def get_retry_temp(base_temp: float) -> float:
|
|
155
|
+
delta_temp = random.choice([-1, 1]) * random.uniform(0.1, 0.9)
|
|
156
|
+
new_temp = base_temp + delta_temp
|
|
157
|
+
|
|
158
|
+
return max(0.0, min(new_temp, 1.5))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def text_to_chunks(text: str, size: int, overlap: int) -> list[str]:
|
|
162
|
+
separators = ["\n\n", "\n", " ", ""]
|
|
163
|
+
is_separator_regex = False
|
|
164
|
+
keep_separator = True # Equivalent to 'start'
|
|
165
|
+
length_function = len
|
|
166
|
+
strip_whitespace = True
|
|
167
|
+
chunk_size = size
|
|
168
|
+
chunk_overlap = overlap
|
|
169
|
+
|
|
170
|
+
def _split_text_with_regex(
|
|
171
|
+
text: str, separator: str, keep_separator: bool
|
|
172
|
+
) -> list[str]:
|
|
173
|
+
if not separator:
|
|
174
|
+
return [text]
|
|
175
|
+
if not keep_separator:
|
|
176
|
+
return re.split(separator, text)
|
|
177
|
+
_splits = re.split(f"({separator})", text)
|
|
178
|
+
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
|
179
|
+
if len(_splits) % 2 == 0:
|
|
180
|
+
splits += [_splits[-1]]
|
|
181
|
+
return [_splits[0]] + splits if _splits[0] else splits
|
|
182
|
+
|
|
183
|
+
def _join_docs(docs: list[str], separator: str) -> str | None:
|
|
184
|
+
text = separator.join(docs)
|
|
185
|
+
if strip_whitespace:
|
|
186
|
+
text = text.strip()
|
|
187
|
+
return text if text else None
|
|
188
|
+
|
|
189
|
+
def _merge_splits(splits: list[str], separator: str) -> list[str]:
|
|
190
|
+
separator_len = length_function(separator)
|
|
191
|
+
docs = []
|
|
192
|
+
current_doc = []
|
|
193
|
+
total = 0
|
|
194
|
+
for d in splits:
|
|
195
|
+
len_ = length_function(d)
|
|
196
|
+
if total + len_ + (separator_len if current_doc else 0) > chunk_size:
|
|
197
|
+
if total > chunk_size:
|
|
198
|
+
pass
|
|
199
|
+
if current_doc:
|
|
200
|
+
doc = _join_docs(current_doc, separator)
|
|
201
|
+
if doc is not None:
|
|
202
|
+
docs.append(doc)
|
|
203
|
+
while total > chunk_overlap or (
|
|
204
|
+
total + len_ + (separator_len if current_doc else 0)
|
|
205
|
+
> chunk_size
|
|
206
|
+
and total > 0
|
|
207
|
+
):
|
|
208
|
+
total -= length_function(current_doc[0]) + (
|
|
209
|
+
separator_len if len(current_doc) > 1 else 0
|
|
210
|
+
)
|
|
211
|
+
current_doc = current_doc[1:]
|
|
212
|
+
current_doc.append(d)
|
|
213
|
+
total += len_ + (separator_len if len(current_doc) > 1 else 0)
|
|
214
|
+
doc = _join_docs(current_doc, separator)
|
|
215
|
+
if doc is not None:
|
|
216
|
+
docs.append(doc)
|
|
217
|
+
return docs
|
|
218
|
+
|
|
219
|
+
def _split_text(text: str, separators: list[str]) -> list[str]:
|
|
220
|
+
final_chunks = []
|
|
221
|
+
separator = separators[-1]
|
|
222
|
+
new_separators = []
|
|
223
|
+
for i, _s in enumerate(separators):
|
|
224
|
+
separator_ = _s if is_separator_regex else re.escape(_s)
|
|
225
|
+
if not _s:
|
|
226
|
+
separator = _s
|
|
227
|
+
break
|
|
228
|
+
if re.search(separator_, text):
|
|
229
|
+
separator = _s
|
|
230
|
+
new_separators = separators[i + 1 :]
|
|
231
|
+
break
|
|
232
|
+
separator_ = separator if is_separator_regex else re.escape(separator)
|
|
233
|
+
splits = _split_text_with_regex(text, separator_, keep_separator)
|
|
234
|
+
_separator = "" if keep_separator else separator
|
|
235
|
+
good_splits = []
|
|
236
|
+
for s in splits:
|
|
237
|
+
if length_function(s) < chunk_size:
|
|
238
|
+
good_splits.append(s)
|
|
239
|
+
else:
|
|
240
|
+
if good_splits:
|
|
241
|
+
merged_text = _merge_splits(good_splits, _separator)
|
|
242
|
+
final_chunks.extend(merged_text)
|
|
243
|
+
good_splits = []
|
|
244
|
+
if not new_separators:
|
|
245
|
+
final_chunks.append(s)
|
|
246
|
+
else:
|
|
247
|
+
other_info = _split_text(s, new_separators)
|
|
248
|
+
final_chunks.extend(other_info)
|
|
249
|
+
if good_splits:
|
|
250
|
+
merged_text = _merge_splits(good_splits, _separator)
|
|
251
|
+
final_chunks.extend(merged_text)
|
|
252
|
+
return final_chunks
|
|
253
|
+
|
|
254
|
+
return _split_text(text, separators)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
class TextToolsError(Exception):
|
|
2
|
+
"""Base exception for all TextTools errors."""
|
|
3
|
+
|
|
4
|
+
pass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PromptError(TextToolsError):
|
|
8
|
+
"""Errors related to prompt loading and formatting."""
|
|
9
|
+
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LLMError(TextToolsError):
|
|
14
|
+
"""Errors from LLM API calls."""
|
|
15
|
+
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ValidationError(TextToolsError):
|
|
20
|
+
"""Errors from output validation."""
|
|
21
|
+
|
|
22
|
+
pass
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Any, Literal, Type
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, create_model
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OperatorOutput(BaseModel):
|
|
7
|
+
result: Any
|
|
8
|
+
analysis: str | None
|
|
9
|
+
logprobs: list[dict[str, Any]] | None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Str(BaseModel):
|
|
13
|
+
result: str = Field(..., description="The output string", example="text")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Bool(BaseModel):
|
|
17
|
+
result: bool = Field(
|
|
18
|
+
..., description="Boolean indicating the output state", example=True
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ListStr(BaseModel):
|
|
23
|
+
result: list[str] = Field(
|
|
24
|
+
..., description="The output list of strings", example=["text_1", "text_2"]
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ListDictStrStr(BaseModel):
|
|
29
|
+
result: list[dict[str, str]] = Field(
|
|
30
|
+
...,
|
|
31
|
+
description="List of dictionaries containing string key-value pairs",
|
|
32
|
+
example=[{"text": "Mohammad", "type": "PER"}, {"text": "Iran", "type": "LOC"}],
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ReasonListStr(BaseModel):
|
|
37
|
+
reason: str = Field(..., description="Thinking process that led to the output")
|
|
38
|
+
result: list[str] = Field(
|
|
39
|
+
..., description="The output list of strings", example=["text_1", "text_2"]
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# This function is needed to create CategorizerOutput with dynamic categories
|
|
44
|
+
def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
|
|
45
|
+
literal_type = Literal[*allowed_values]
|
|
46
|
+
|
|
47
|
+
CategorizerOutput = create_model(
|
|
48
|
+
"CategorizerOutput",
|
|
49
|
+
reason=(
|
|
50
|
+
str,
|
|
51
|
+
Field(
|
|
52
|
+
..., description="Explanation of why the input belongs to the category"
|
|
53
|
+
),
|
|
54
|
+
),
|
|
55
|
+
result=(literal_type, Field(..., description="Predicted category label")),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
return CategorizerOutput
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any, Type, TypeVar
|
|
3
|
+
|
|
4
|
+
from openai import AsyncOpenAI
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from ..engine import OperatorUtils, PromptLoader
|
|
8
|
+
from ..exceptions import LLMError, PromptError, TextToolsError, ValidationError
|
|
9
|
+
from ..internal_models import OperatorOutput
|
|
10
|
+
|
|
11
|
+
# Base Model type for output models
|
|
12
|
+
T = TypeVar("T", bound=BaseModel)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AsyncOperator:
|
|
16
|
+
"""
|
|
17
|
+
Core engine for running text-processing operations with an LLM.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, client: AsyncOpenAI, model: str):
|
|
21
|
+
self._client = client
|
|
22
|
+
self._model = model
|
|
23
|
+
|
|
24
|
+
async def _analyze_completion(self, analyze_message: list[dict[str, str]]) -> str:
|
|
25
|
+
try:
|
|
26
|
+
completion = await self._client.chat.completions.create(
|
|
27
|
+
model=self._model,
|
|
28
|
+
messages=analyze_message,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if not completion.choices:
|
|
32
|
+
raise LLMError("No choices returned from LLM")
|
|
33
|
+
|
|
34
|
+
analysis = completion.choices[0].message.content.strip()
|
|
35
|
+
|
|
36
|
+
if not analysis:
|
|
37
|
+
raise LLMError("Empty analysis response")
|
|
38
|
+
|
|
39
|
+
return analysis
|
|
40
|
+
|
|
41
|
+
except Exception as e:
|
|
42
|
+
if isinstance(e, (PromptError, LLMError)):
|
|
43
|
+
raise
|
|
44
|
+
raise LLMError(f"Analysis failed: {e}")
|
|
45
|
+
|
|
46
|
+
async def _parse_completion(
|
|
47
|
+
self,
|
|
48
|
+
main_message: list[dict[str, str]],
|
|
49
|
+
output_model: Type[T],
|
|
50
|
+
temperature: float,
|
|
51
|
+
logprobs: bool,
|
|
52
|
+
top_logprobs: int,
|
|
53
|
+
priority: int | None,
|
|
54
|
+
) -> tuple[T, Any]:
|
|
55
|
+
"""
|
|
56
|
+
Parses a chat completion using OpenAI's structured output format.
|
|
57
|
+
Returns both the parsed Any and the raw completion for logprobs.
|
|
58
|
+
"""
|
|
59
|
+
try:
|
|
60
|
+
request_kwargs = {
|
|
61
|
+
"model": self._model,
|
|
62
|
+
"messages": main_message,
|
|
63
|
+
"response_format": output_model,
|
|
64
|
+
"temperature": temperature,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
if logprobs:
|
|
68
|
+
request_kwargs["logprobs"] = True
|
|
69
|
+
request_kwargs["top_logprobs"] = top_logprobs
|
|
70
|
+
|
|
71
|
+
if priority is not None:
|
|
72
|
+
request_kwargs["extra_body"] = {"priority": priority}
|
|
73
|
+
|
|
74
|
+
completion = await self._client.beta.chat.completions.parse(
|
|
75
|
+
**request_kwargs
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if not completion.choices:
|
|
79
|
+
raise LLMError("No choices returned from LLM")
|
|
80
|
+
|
|
81
|
+
parsed = completion.choices[0].message.parsed
|
|
82
|
+
|
|
83
|
+
if not parsed:
|
|
84
|
+
raise LLMError("Failed to parse LLM response")
|
|
85
|
+
|
|
86
|
+
return parsed, completion
|
|
87
|
+
|
|
88
|
+
except Exception as e:
|
|
89
|
+
if isinstance(e, LLMError):
|
|
90
|
+
raise
|
|
91
|
+
raise LLMError(f"Completion failed: {e}")
|
|
92
|
+
|
|
93
|
+
async def run(
|
|
94
|
+
self,
|
|
95
|
+
# User parameters
|
|
96
|
+
text: str,
|
|
97
|
+
with_analysis: bool,
|
|
98
|
+
output_lang: str | None,
|
|
99
|
+
user_prompt: str | None,
|
|
100
|
+
temperature: float,
|
|
101
|
+
logprobs: bool,
|
|
102
|
+
top_logprobs: int,
|
|
103
|
+
validator: Callable[[Any], bool] | None,
|
|
104
|
+
max_validation_retries: int | None,
|
|
105
|
+
priority: int | None,
|
|
106
|
+
# Internal parameters
|
|
107
|
+
tool_name: str,
|
|
108
|
+
output_model: Type[T],
|
|
109
|
+
mode: str | None,
|
|
110
|
+
**extra_kwargs,
|
|
111
|
+
) -> OperatorOutput:
|
|
112
|
+
"""
|
|
113
|
+
Execute the LLM pipeline with the given input text.
|
|
114
|
+
"""
|
|
115
|
+
try:
|
|
116
|
+
prompt_loader = PromptLoader()
|
|
117
|
+
prompt_configs = prompt_loader.load(
|
|
118
|
+
prompt_file=tool_name + ".yaml",
|
|
119
|
+
text=text.strip(),
|
|
120
|
+
mode=mode,
|
|
121
|
+
**extra_kwargs,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
analysis: str | None = None
|
|
125
|
+
|
|
126
|
+
if with_analysis:
|
|
127
|
+
analyze_message = OperatorUtils.build_message(
|
|
128
|
+
prompt_configs["analyze_template"]
|
|
129
|
+
)
|
|
130
|
+
analysis = await self._analyze_completion(analyze_message)
|
|
131
|
+
|
|
132
|
+
main_message = OperatorUtils.build_message(
|
|
133
|
+
OperatorUtils.build_main_prompt(
|
|
134
|
+
prompt_configs["main_template"], analysis, output_lang, user_prompt
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
parsed, completion = await self._parse_completion(
|
|
139
|
+
main_message,
|
|
140
|
+
output_model,
|
|
141
|
+
temperature,
|
|
142
|
+
logprobs,
|
|
143
|
+
top_logprobs,
|
|
144
|
+
priority,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Retry logic if validation fails
|
|
148
|
+
if validator and not validator(parsed.result):
|
|
149
|
+
if (
|
|
150
|
+
not isinstance(max_validation_retries, int)
|
|
151
|
+
or max_validation_retries < 1
|
|
152
|
+
):
|
|
153
|
+
raise ValueError("max_validation_retries should be a positive int")
|
|
154
|
+
|
|
155
|
+
succeeded = False
|
|
156
|
+
for _ in range(max_validation_retries):
|
|
157
|
+
# Generate a new temperature to retry
|
|
158
|
+
retry_temperature = OperatorUtils.get_retry_temp(temperature)
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
parsed, completion = await self._parse_completion(
|
|
162
|
+
main_message,
|
|
163
|
+
output_model,
|
|
164
|
+
retry_temperature,
|
|
165
|
+
logprobs,
|
|
166
|
+
top_logprobs,
|
|
167
|
+
priority=priority,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Check if retry was successful
|
|
171
|
+
if validator(parsed.result):
|
|
172
|
+
succeeded = True
|
|
173
|
+
break
|
|
174
|
+
|
|
175
|
+
except LLMError:
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
if not succeeded:
|
|
179
|
+
raise ValidationError("Validation failed after all retries")
|
|
180
|
+
|
|
181
|
+
operator_output = OperatorOutput(
|
|
182
|
+
result=parsed.result,
|
|
183
|
+
analysis=analysis if with_analysis else None,
|
|
184
|
+
logprobs=OperatorUtils.extract_logprobs(completion)
|
|
185
|
+
if logprobs
|
|
186
|
+
else None,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return operator_output
|
|
190
|
+
|
|
191
|
+
except (PromptError, LLMError, ValidationError):
|
|
192
|
+
raise
|
|
193
|
+
except Exception as e:
|
|
194
|
+
raise TextToolsError(f"Unexpected error in operator: {e}")
|