hamtaa-texttools 1.0.5__py3-none-any.whl → 1.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hamtaa_texttools-1.1.16.dist-info/METADATA +255 -0
- hamtaa_texttools-1.1.16.dist-info/RECORD +31 -0
- texttools/__init__.py +6 -8
- texttools/batch/batch_config.py +26 -0
- texttools/batch/batch_runner.py +144 -139
- texttools/batch/{batch_manager.py → internals/batch_manager.py} +42 -54
- texttools/batch/internals/utils.py +16 -0
- texttools/prompts/README.md +8 -4
- texttools/prompts/categorize.yaml +77 -0
- texttools/prompts/detect_entity.yaml +22 -0
- texttools/prompts/extract_keywords.yaml +68 -0
- texttools/prompts/{question_merger.yaml → merge_questions.yaml} +5 -5
- texttools/tools/async_tools.py +804 -0
- texttools/tools/internals/async_operator.py +139 -236
- texttools/tools/internals/formatters.py +24 -0
- texttools/tools/internals/models.py +183 -0
- texttools/tools/internals/operator_utils.py +54 -0
- texttools/tools/internals/prompt_loader.py +23 -43
- texttools/tools/internals/sync_operator.py +201 -0
- texttools/tools/sync_tools.py +804 -0
- hamtaa_texttools-1.0.5.dist-info/METADATA +0 -192
- hamtaa_texttools-1.0.5.dist-info/RECORD +0 -30
- texttools/batch/__init__.py +0 -4
- texttools/formatters/base_formatter.py +0 -33
- texttools/formatters/user_merge_formatter.py +0 -30
- texttools/prompts/categorizer.yaml +0 -28
- texttools/prompts/keyword_extractor.yaml +0 -18
- texttools/tools/__init__.py +0 -4
- texttools/tools/async_the_tool.py +0 -277
- texttools/tools/internals/operator.py +0 -295
- texttools/tools/internals/output_models.py +0 -52
- texttools/tools/the_tool.py +0 -501
- {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/top_level.txt +0 -0
- /texttools/prompts/{ner_extractor.yaml → extract_entities.yaml} +0 -0
- /texttools/prompts/{question_detector.yaml → is_question.yaml} +0 -0
- /texttools/prompts/{rewriter.yaml → rewrite.yaml} +0 -0
- /texttools/prompts/{custom_tool.yaml → run_custom.yaml} +0 -0
- /texttools/prompts/{subject_question_generator.yaml → subject_to_question.yaml} +0 -0
- /texttools/prompts/{summarizer.yaml → summarize.yaml} +0 -0
- /texttools/prompts/{question_generator.yaml → text_to_question.yaml} +0 -0
- /texttools/prompts/{translator.yaml → translate.yaml} +0 -0
|
@@ -1,297 +1,200 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
import
|
|
4
|
-
import math
|
|
5
|
-
import re
|
|
6
|
-
from typing import Any, Literal, Optional, TypeVar
|
|
1
|
+
from typing import Any, TypeVar, Type
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
import logging
|
|
7
4
|
|
|
8
5
|
from openai import AsyncOpenAI
|
|
9
6
|
from pydantic import BaseModel
|
|
10
7
|
|
|
11
|
-
from texttools.
|
|
12
|
-
|
|
13
|
-
|
|
8
|
+
from texttools.tools.internals.models import ToolOutput
|
|
9
|
+
from texttools.tools.internals.operator_utils import OperatorUtils
|
|
10
|
+
from texttools.tools.internals.formatters import Formatter
|
|
14
11
|
from texttools.tools.internals.prompt_loader import PromptLoader
|
|
15
12
|
|
|
16
13
|
# Base Model type for output models
|
|
17
14
|
T = TypeVar("T", bound=BaseModel)
|
|
18
15
|
|
|
16
|
+
logger = logging.getLogger("texttools.async_operator")
|
|
17
|
+
|
|
19
18
|
|
|
20
19
|
class AsyncOperator:
|
|
21
20
|
"""
|
|
22
|
-
|
|
21
|
+
Core engine for running text-processing operations with an LLM (Async).
|
|
23
22
|
|
|
24
|
-
|
|
23
|
+
It wires together:
|
|
24
|
+
- `PromptLoader` → loads YAML prompt templates.
|
|
25
|
+
- `UserMergeFormatter` → applies formatting to messages (e.g., merging).
|
|
26
|
+
- AsyncOpenAI client → executes completions/parsed completions.
|
|
25
27
|
"""
|
|
26
28
|
|
|
27
|
-
def __init__(
|
|
28
|
-
self
|
|
29
|
-
|
|
30
|
-
*,
|
|
31
|
-
model: str,
|
|
32
|
-
temperature: float = 0.0,
|
|
33
|
-
**client_kwargs: Any,
|
|
34
|
-
):
|
|
35
|
-
self.client: AsyncOpenAI = client
|
|
36
|
-
self.model = model
|
|
37
|
-
self.temperature = temperature
|
|
38
|
-
self.client_kwargs = client_kwargs
|
|
39
|
-
|
|
40
|
-
def _build_user_message(self, prompt: str) -> dict[str, str]:
|
|
41
|
-
return {"role": "user", "content": prompt}
|
|
42
|
-
|
|
43
|
-
async def _analysis_completion(self, analyze_message: list[dict[str, str]]) -> str:
|
|
44
|
-
try:
|
|
45
|
-
completion = await self.client.chat.completions.create(
|
|
46
|
-
model=self.model,
|
|
47
|
-
messages=analyze_message,
|
|
48
|
-
temperature=self.temperature,
|
|
49
|
-
**self.client_kwargs,
|
|
50
|
-
)
|
|
51
|
-
analysis = completion.choices[0].message.content.strip()
|
|
52
|
-
return analysis
|
|
29
|
+
def __init__(self, client: AsyncOpenAI, model: str):
|
|
30
|
+
self._client = client
|
|
31
|
+
self._model = model
|
|
53
32
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
33
|
+
async def _analyze(self, prompt_configs: dict[str, str], temperature: float) -> str:
|
|
34
|
+
"""
|
|
35
|
+
Calls OpenAI API for analysis using the configured prompt template.
|
|
36
|
+
Returns the analyzed content as a string.
|
|
37
|
+
"""
|
|
59
38
|
analyze_prompt = prompt_configs["analyze_template"]
|
|
60
|
-
analyze_message = [
|
|
61
|
-
|
|
62
|
-
|
|
39
|
+
analyze_message = [OperatorUtils.build_user_message(analyze_prompt)]
|
|
40
|
+
completion = await self._client.chat.completions.create(
|
|
41
|
+
model=self._model,
|
|
42
|
+
messages=analyze_message,
|
|
43
|
+
temperature=temperature,
|
|
44
|
+
)
|
|
45
|
+
analysis = completion.choices[0].message.content.strip()
|
|
63
46
|
return analysis
|
|
64
47
|
|
|
65
48
|
async def _parse_completion(
|
|
66
49
|
self,
|
|
67
50
|
message: list[dict[str, str]],
|
|
68
|
-
output_model: T,
|
|
51
|
+
output_model: Type[T],
|
|
52
|
+
temperature: float,
|
|
69
53
|
logprobs: bool = False,
|
|
70
54
|
top_logprobs: int = 3,
|
|
71
|
-
|
|
55
|
+
priority: int | None = 0,
|
|
72
56
|
) -> tuple[T, Any]:
|
|
73
|
-
try:
|
|
74
|
-
request_kwargs = {
|
|
75
|
-
"model": self.model,
|
|
76
|
-
"messages": message,
|
|
77
|
-
"response_format": output_model,
|
|
78
|
-
"temperature": self.temperature,
|
|
79
|
-
**self.client_kwargs,
|
|
80
|
-
}
|
|
81
|
-
|
|
82
|
-
if max_tokens is not None:
|
|
83
|
-
request_kwargs["max_tokens"] = max_tokens
|
|
84
|
-
|
|
85
|
-
if logprobs:
|
|
86
|
-
request_kwargs["logprobs"] = True
|
|
87
|
-
request_kwargs["top_logprobs"] = top_logprobs
|
|
88
|
-
|
|
89
|
-
completion = await self.client.beta.chat.completions.parse(**request_kwargs)
|
|
90
|
-
parsed = completion.choices[0].message.parsed
|
|
91
|
-
return parsed, completion
|
|
92
|
-
|
|
93
|
-
except Exception as e:
|
|
94
|
-
print(f"[ERROR] Failed to parse completion: {e}")
|
|
95
|
-
raise
|
|
96
|
-
|
|
97
|
-
def _clean_json_response(self, response: str) -> str:
|
|
98
57
|
"""
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
- ```json{"result": "value"}```
|
|
58
|
+
Parses a chat completion using OpenAI's structured output format.
|
|
59
|
+
Returns both the parsed object and the raw completion for logprobs.
|
|
102
60
|
"""
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
if
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
response_string: The JSON string (may contain code block markers)
|
|
121
|
-
output_model: Your Pydantic output model class (e.g., StrOutput, ListStrOutput)
|
|
122
|
-
|
|
123
|
-
Returns:
|
|
124
|
-
Instance of your output model
|
|
125
|
-
"""
|
|
126
|
-
try:
|
|
127
|
-
# Clean the response string
|
|
128
|
-
cleaned_json = self._clean_json_response(response_string)
|
|
129
|
-
|
|
130
|
-
# Fix Python-style booleans
|
|
131
|
-
cleaned_json = cleaned_json.replace("False", "false").replace(
|
|
132
|
-
"True", "true"
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
# Convert string to Python dictionary
|
|
136
|
-
response_dict = json.loads(cleaned_json)
|
|
137
|
-
|
|
138
|
-
# Convert dictionary to output model
|
|
139
|
-
return output_model(**response_dict)
|
|
140
|
-
|
|
141
|
-
except json.JSONDecodeError as e:
|
|
142
|
-
raise ValueError(
|
|
143
|
-
f"Failed to parse JSON response: {e}\nResponse: {response_string}"
|
|
144
|
-
)
|
|
145
|
-
except Exception as e:
|
|
146
|
-
raise ValueError(f"Failed to convert to output model: {e}")
|
|
147
|
-
|
|
148
|
-
async def _vllm_completion(
|
|
149
|
-
self,
|
|
150
|
-
message: list[dict[str, str]],
|
|
151
|
-
output_model: T,
|
|
152
|
-
logprobs: bool = False,
|
|
153
|
-
top_logprobs: int = 3,
|
|
154
|
-
max_tokens: int | None = None,
|
|
155
|
-
) -> tuple[T, Any]:
|
|
156
|
-
try:
|
|
157
|
-
json_schema = output_model.model_json_schema()
|
|
158
|
-
|
|
159
|
-
# Build kwargs dynamically
|
|
160
|
-
request_kwargs = {
|
|
161
|
-
"model": self.model,
|
|
162
|
-
"messages": message,
|
|
163
|
-
"extra_body": {"guided_json": json_schema},
|
|
164
|
-
"temperature": self.temperature,
|
|
165
|
-
**self.client_kwargs,
|
|
166
|
-
}
|
|
167
|
-
|
|
168
|
-
if max_tokens is not None:
|
|
169
|
-
request_kwargs["max_tokens"] = max_tokens
|
|
170
|
-
|
|
171
|
-
if logprobs:
|
|
172
|
-
request_kwargs["logprobs"] = True
|
|
173
|
-
request_kwargs["top_logprobs"] = top_logprobs
|
|
174
|
-
|
|
175
|
-
completion = await self.client.chat.completions.create(**request_kwargs)
|
|
176
|
-
response = completion.choices[0].message.content
|
|
177
|
-
|
|
178
|
-
# Convert the string response to output model
|
|
179
|
-
parsed = self._convert_to_output_model(response, output_model)
|
|
180
|
-
|
|
181
|
-
return parsed, completion
|
|
182
|
-
|
|
183
|
-
except Exception as e:
|
|
184
|
-
print(f"[ERROR] Failed to get vLLM structured output: {e}")
|
|
185
|
-
raise
|
|
186
|
-
|
|
187
|
-
def _extract_logprobs(self, completion: dict):
|
|
188
|
-
logprobs_data = []
|
|
189
|
-
ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
|
|
190
|
-
|
|
191
|
-
for choice in completion.choices:
|
|
192
|
-
if not getattr(choice, "logprobs", None):
|
|
193
|
-
continue
|
|
194
|
-
|
|
195
|
-
for logprob_item in choice.logprobs.content:
|
|
196
|
-
if ignore_pattern.match(logprob_item.token):
|
|
197
|
-
continue
|
|
198
|
-
token_entry = {
|
|
199
|
-
"token": logprob_item.token,
|
|
200
|
-
"prob": round(math.exp(logprob_item.logprob), 8),
|
|
201
|
-
"top_alternatives": [],
|
|
202
|
-
}
|
|
203
|
-
for alt in logprob_item.top_logprobs:
|
|
204
|
-
if ignore_pattern.match(alt.token):
|
|
205
|
-
continue
|
|
206
|
-
token_entry["top_alternatives"].append(
|
|
207
|
-
{
|
|
208
|
-
"token": alt.token,
|
|
209
|
-
"prob": round(math.exp(alt.logprob), 8),
|
|
210
|
-
}
|
|
211
|
-
)
|
|
212
|
-
logprobs_data.append(token_entry)
|
|
213
|
-
|
|
214
|
-
return logprobs_data
|
|
61
|
+
request_kwargs = {
|
|
62
|
+
"model": self._model,
|
|
63
|
+
"messages": message,
|
|
64
|
+
"response_format": output_model,
|
|
65
|
+
"temperature": temperature,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
if logprobs:
|
|
69
|
+
request_kwargs["logprobs"] = True
|
|
70
|
+
request_kwargs["top_logprobs"] = top_logprobs
|
|
71
|
+
if priority:
|
|
72
|
+
request_kwargs["extra_body"] = {"priority": priority}
|
|
73
|
+
completion = await self._client.beta.chat.completions.parse(**request_kwargs)
|
|
74
|
+
parsed = completion.choices[0].message.parsed
|
|
75
|
+
return parsed, completion
|
|
215
76
|
|
|
216
77
|
async def run(
|
|
217
78
|
self,
|
|
218
|
-
|
|
79
|
+
# User parameters
|
|
80
|
+
text: str,
|
|
81
|
+
with_analysis: bool,
|
|
82
|
+
output_lang: str | None,
|
|
83
|
+
user_prompt: str | None,
|
|
84
|
+
temperature: float,
|
|
85
|
+
logprobs: bool,
|
|
86
|
+
top_logprobs: int | None,
|
|
87
|
+
validator: Callable[[Any], bool] | None,
|
|
88
|
+
max_validation_retries: int | None,
|
|
89
|
+
# Internal parameters
|
|
219
90
|
prompt_file: str,
|
|
220
|
-
output_model: T,
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
mode: str = "",
|
|
224
|
-
resp_format: Literal["vllm", "parse"] = "parse",
|
|
225
|
-
output_lang: str | None = None,
|
|
226
|
-
logprobs: bool = False,
|
|
227
|
-
top_logprobs: int = 3,
|
|
228
|
-
max_tokens: int | None = None,
|
|
91
|
+
output_model: Type[T],
|
|
92
|
+
mode: str | None,
|
|
93
|
+
priority: int | None = 0,
|
|
229
94
|
**extra_kwargs,
|
|
230
|
-
) ->
|
|
95
|
+
) -> ToolOutput:
|
|
231
96
|
"""
|
|
232
|
-
Execute the async LLM pipeline with the given input text.
|
|
97
|
+
Execute the async LLM pipeline with the given input text. (Async)
|
|
233
98
|
"""
|
|
234
99
|
prompt_loader = PromptLoader()
|
|
235
|
-
formatter =
|
|
100
|
+
formatter = Formatter()
|
|
101
|
+
output = ToolOutput()
|
|
236
102
|
|
|
237
103
|
try:
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
# FIXED: Correct parameter order for load
|
|
104
|
+
# Prompt configs contain two keys: main_template and analyze template, both are string
|
|
241
105
|
prompt_configs = prompt_loader.load(
|
|
242
|
-
prompt_file=prompt_file,
|
|
243
|
-
text=
|
|
244
|
-
mode=mode
|
|
106
|
+
prompt_file=prompt_file,
|
|
107
|
+
text=text.strip(),
|
|
108
|
+
mode=mode,
|
|
245
109
|
**extra_kwargs,
|
|
246
110
|
)
|
|
247
111
|
|
|
248
|
-
messages
|
|
112
|
+
messages = []
|
|
249
113
|
|
|
250
114
|
if with_analysis:
|
|
251
|
-
analysis = await self._analyze(prompt_configs)
|
|
115
|
+
analysis = await self._analyze(prompt_configs, temperature)
|
|
252
116
|
messages.append(
|
|
253
|
-
|
|
117
|
+
OperatorUtils.build_user_message(
|
|
118
|
+
f"Based on this analysis: {analysis}"
|
|
119
|
+
)
|
|
254
120
|
)
|
|
255
121
|
|
|
256
122
|
if output_lang:
|
|
257
123
|
messages.append(
|
|
258
|
-
|
|
124
|
+
OperatorUtils.build_user_message(
|
|
259
125
|
f"Respond only in the {output_lang} language."
|
|
260
126
|
)
|
|
261
127
|
)
|
|
262
128
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
messages,
|
|
269
|
-
output_model,
|
|
270
|
-
logprobs,
|
|
271
|
-
top_logprobs,
|
|
272
|
-
max_tokens, # Pass max_tokens
|
|
273
|
-
)
|
|
274
|
-
elif resp_format == "parse":
|
|
275
|
-
parsed, completion = await self._parse_completion(
|
|
276
|
-
messages,
|
|
277
|
-
output_model,
|
|
278
|
-
logprobs,
|
|
279
|
-
top_logprobs,
|
|
280
|
-
max_tokens, # Pass max_tokens
|
|
129
|
+
if user_prompt:
|
|
130
|
+
messages.append(
|
|
131
|
+
OperatorUtils.build_user_message(
|
|
132
|
+
f"Consider this instruction {user_prompt}"
|
|
133
|
+
)
|
|
281
134
|
)
|
|
282
|
-
else:
|
|
283
|
-
raise ValueError(f"Unknown resp_format: {resp_format}")
|
|
284
135
|
|
|
285
|
-
|
|
136
|
+
messages.append(
|
|
137
|
+
OperatorUtils.build_user_message(prompt_configs["main_template"])
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
messages = formatter.user_merge_format(messages)
|
|
141
|
+
|
|
142
|
+
parsed, completion = await self._parse_completion(
|
|
143
|
+
messages, output_model, temperature, logprobs, top_logprobs, priority
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
output.result = parsed.result
|
|
147
|
+
|
|
148
|
+
# Retry logic if validation fails
|
|
149
|
+
if validator and not validator(output.result):
|
|
150
|
+
for attempt in range(max_validation_retries):
|
|
151
|
+
logger.warning(
|
|
152
|
+
f"Validation failed, retrying for the {attempt + 1} time."
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Generate new temperature for retry
|
|
156
|
+
retry_temperature = OperatorUtils.get_retry_temp(temperature)
|
|
157
|
+
try:
|
|
158
|
+
parsed, completion = await self._parse_completion(
|
|
159
|
+
messages,
|
|
160
|
+
output_model,
|
|
161
|
+
retry_temperature,
|
|
162
|
+
logprobs,
|
|
163
|
+
top_logprobs,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
output.result = parsed.result
|
|
167
|
+
|
|
168
|
+
# Check if retry was successful
|
|
169
|
+
if validator(output.result):
|
|
170
|
+
logger.info(
|
|
171
|
+
f"Validation passed on retry attempt {attempt + 1}"
|
|
172
|
+
)
|
|
173
|
+
break
|
|
174
|
+
else:
|
|
175
|
+
logger.warning(
|
|
176
|
+
f"Validation still failing after retry attempt {attempt + 1}"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.error(f"Retry attempt {attempt + 1} failed: {e}")
|
|
181
|
+
# Continue to next retry attempt if this one fails
|
|
182
|
+
|
|
183
|
+
# Final check after all retries
|
|
184
|
+
if validator and not validator(output.result):
|
|
185
|
+
output.errors.append("Validation failed after all retry attempts")
|
|
286
186
|
|
|
287
187
|
if logprobs:
|
|
288
|
-
|
|
188
|
+
output.logprobs = OperatorUtils.extract_logprobs(completion)
|
|
289
189
|
|
|
290
190
|
if with_analysis:
|
|
291
|
-
|
|
191
|
+
output.analysis = analysis
|
|
192
|
+
|
|
193
|
+
output.process = prompt_file[:-5]
|
|
292
194
|
|
|
293
|
-
return
|
|
195
|
+
return output
|
|
294
196
|
|
|
295
197
|
except Exception as e:
|
|
296
|
-
|
|
297
|
-
|
|
198
|
+
logger.error(f"AsyncTheTool failed: {e}")
|
|
199
|
+
output.errors.append(str(e))
|
|
200
|
+
return output
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
class Formatter:
|
|
2
|
+
@staticmethod
|
|
3
|
+
def user_merge_format(messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
|
4
|
+
"""
|
|
5
|
+
Merges consecutive user messages into a single message, separated by newlines.
|
|
6
|
+
|
|
7
|
+
This is useful for condensing a multi-turn user input into a single
|
|
8
|
+
message for the LLM. Assistant and system messages are left unchanged and
|
|
9
|
+
act as separators between user message groups.
|
|
10
|
+
"""
|
|
11
|
+
merged: list[dict[str, str]] = []
|
|
12
|
+
|
|
13
|
+
for message in messages:
|
|
14
|
+
role, content = message["role"], message["content"].strip()
|
|
15
|
+
|
|
16
|
+
# Merge with previous user turn
|
|
17
|
+
if merged and role == "user" and merged[-1]["role"] == "user":
|
|
18
|
+
merged[-1]["content"] += "\n" + content
|
|
19
|
+
|
|
20
|
+
# Otherwise, start a new turn
|
|
21
|
+
else:
|
|
22
|
+
merged.append({"role": role, "content": content})
|
|
23
|
+
|
|
24
|
+
return merged
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Type, Any, Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, create_model
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ToolOutput(BaseModel):
|
|
8
|
+
result: Any = None
|
|
9
|
+
analysis: str = ""
|
|
10
|
+
logprobs: list[dict[str, Any]] = []
|
|
11
|
+
process: str = ""
|
|
12
|
+
processed_at: datetime = datetime.now()
|
|
13
|
+
execution_time: float = -1.0
|
|
14
|
+
errors: list[str] = []
|
|
15
|
+
|
|
16
|
+
def __repr__(self) -> str:
|
|
17
|
+
return f"ToolOutput(process='{self.process}', result_type='{type(self.result)}', result='{self.result}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}', processed_at='{self.processed_at}', execution_time='{self.execution_time}'"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StrOutput(BaseModel):
|
|
21
|
+
result: str = Field(..., description="The output string")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BoolOutput(BaseModel):
|
|
25
|
+
result: bool = Field(
|
|
26
|
+
..., description="Boolean indicating the output state", example=True
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ListStrOutput(BaseModel):
|
|
31
|
+
result: list[str] = Field(
|
|
32
|
+
..., description="The output list of strings", example=["text_1", "text_2"]
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ListDictStrStrOutput(BaseModel):
|
|
37
|
+
result: list[dict[str, str]] = Field(
|
|
38
|
+
...,
|
|
39
|
+
description="List of dictionaries containing string key-value pairs",
|
|
40
|
+
example=[{"text": "Mohammad", "type": "PER"}],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ReasonListStrOutput(BaseModel):
|
|
45
|
+
reason: str = Field(..., description="Thinking process that led to the output")
|
|
46
|
+
result: list[str] = Field(..., description="The output list of strings")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Node(BaseModel):
|
|
50
|
+
node_id: int
|
|
51
|
+
name: str
|
|
52
|
+
level: int
|
|
53
|
+
parent_id: int | None
|
|
54
|
+
description: str = "No description provided"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CategoryTree:
|
|
58
|
+
def __init__(self, tree_name):
|
|
59
|
+
self.root = Node(node_id=0, name=tree_name, level=0, parent_id=None)
|
|
60
|
+
self.all_nodes: list[Node] = [self.root]
|
|
61
|
+
self.new_id = 1
|
|
62
|
+
|
|
63
|
+
def add_node(
|
|
64
|
+
self,
|
|
65
|
+
node_name: str,
|
|
66
|
+
parent_name: str | None = None,
|
|
67
|
+
description: str | None = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
if self.find_node(node_name):
|
|
70
|
+
raise ValueError(f"{node_name} has been chosen for another category before")
|
|
71
|
+
|
|
72
|
+
if parent_name:
|
|
73
|
+
parent_node = self.find_node(parent_name)
|
|
74
|
+
if parent_node is None:
|
|
75
|
+
raise ValueError(f"Parent category '{parent_name}' not found")
|
|
76
|
+
parent_id = parent_node.node_id
|
|
77
|
+
level = parent_node.level + 1
|
|
78
|
+
else:
|
|
79
|
+
level = 1
|
|
80
|
+
parent_id = 0
|
|
81
|
+
|
|
82
|
+
node_data = {
|
|
83
|
+
"node_id": self.new_id,
|
|
84
|
+
"name": node_name,
|
|
85
|
+
"level": level,
|
|
86
|
+
"parent_id": parent_id,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
if description is not None:
|
|
90
|
+
node_data["description"] = description
|
|
91
|
+
|
|
92
|
+
self.all_nodes.append(Node(**node_data))
|
|
93
|
+
self.new_id += 1
|
|
94
|
+
|
|
95
|
+
def get_nodes(self) -> list[Node]:
|
|
96
|
+
return self.all_nodes
|
|
97
|
+
|
|
98
|
+
def get_level_count(self) -> int:
|
|
99
|
+
return max([item.level for item in self.all_nodes])
|
|
100
|
+
|
|
101
|
+
def find_node(self, identifier: int | str) -> Node | None:
|
|
102
|
+
if isinstance(identifier, str):
|
|
103
|
+
for node in self.get_nodes():
|
|
104
|
+
if node.name == identifier:
|
|
105
|
+
return node
|
|
106
|
+
return None
|
|
107
|
+
elif isinstance(identifier, int):
|
|
108
|
+
for node in self.get_nodes():
|
|
109
|
+
if node.node_id == identifier:
|
|
110
|
+
return node
|
|
111
|
+
return None
|
|
112
|
+
else:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
def find_children(self, parent_node: Node) -> list[Node] | None:
|
|
116
|
+
children = [
|
|
117
|
+
node for node in self.get_nodes() if parent_node.node_id == node.parent_id
|
|
118
|
+
]
|
|
119
|
+
return children if children else None
|
|
120
|
+
|
|
121
|
+
def remove_node(self, identifier: int | str) -> None:
|
|
122
|
+
node = self.find_node(identifier)
|
|
123
|
+
|
|
124
|
+
if node is not None:
|
|
125
|
+
# Remove node's children recursively
|
|
126
|
+
children = self.find_children(node)
|
|
127
|
+
|
|
128
|
+
# Ending condition
|
|
129
|
+
if children is None:
|
|
130
|
+
self.all_nodes.remove(node)
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
for child in children:
|
|
134
|
+
self.remove_node(child.name)
|
|
135
|
+
|
|
136
|
+
# Remove the node from tree
|
|
137
|
+
self.all_nodes.remove(node)
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(f"Node with identifier: '{identifier}' not found.")
|
|
140
|
+
|
|
141
|
+
def dump_tree(self) -> dict:
|
|
142
|
+
def build_dict(node: Node) -> dict:
|
|
143
|
+
children = [
|
|
144
|
+
build_dict(child)
|
|
145
|
+
for child in self.all_nodes
|
|
146
|
+
if child.parent_id == node.node_id
|
|
147
|
+
]
|
|
148
|
+
return {
|
|
149
|
+
"node_id": node.node_id,
|
|
150
|
+
"name": node.name,
|
|
151
|
+
"level": node.level,
|
|
152
|
+
"parent_id": node.parent_id,
|
|
153
|
+
"children": children,
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
return {"category_tree": build_dict(self.root)["children"]}
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# This function is needed to create CategorizerOutput with dynamic categories
|
|
160
|
+
def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
|
|
161
|
+
literal_type = Literal[*allowed_values]
|
|
162
|
+
|
|
163
|
+
CategorizerOutput = create_model(
|
|
164
|
+
"CategorizerOutput",
|
|
165
|
+
reason=(
|
|
166
|
+
str,
|
|
167
|
+
Field(
|
|
168
|
+
..., description="Explanation of why the input belongs to the category"
|
|
169
|
+
),
|
|
170
|
+
),
|
|
171
|
+
result=(literal_type, Field(..., description="Predicted category label")),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return CategorizerOutput
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class Entity(BaseModel):
|
|
178
|
+
text: str = Field(description="The exact text of the entity")
|
|
179
|
+
entity_type: str = Field(description="The type of the entity")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class EntityDetectorOutput(BaseModel):
|
|
183
|
+
result: list[Entity] = Field(description="List of all extracted entities")
|