hamtaa-texttools 1.0.1__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hamtaa-texttools might be problematic. Click here for more details.
- hamtaa_texttools-1.1.7.dist-info/METADATA +228 -0
- hamtaa_texttools-1.1.7.dist-info/RECORD +30 -0
- {hamtaa_texttools-1.0.1.dist-info → hamtaa_texttools-1.1.7.dist-info}/licenses/LICENSE +20 -20
- {hamtaa_texttools-1.0.1.dist-info → hamtaa_texttools-1.1.7.dist-info}/top_level.txt +0 -0
- texttools/__init__.py +4 -9
- texttools/batch/__init__.py +3 -0
- texttools/{utils/batch_manager → batch}/batch_manager.py +226 -240
- texttools/batch/batch_runner.py +254 -0
- texttools/prompts/README.md +35 -0
- texttools/prompts/categorizer.yaml +28 -0
- texttools/prompts/extract_entities.yaml +20 -0
- texttools/prompts/extract_keywords.yaml +18 -0
- texttools/prompts/is_question.yaml +14 -0
- texttools/prompts/merge_questions.yaml +46 -0
- texttools/prompts/rewrite.yaml +111 -0
- texttools/prompts/run_custom.yaml +7 -0
- texttools/prompts/subject_to_question.yaml +22 -0
- texttools/prompts/summarize.yaml +14 -0
- texttools/prompts/text_to_question.yaml +20 -0
- texttools/prompts/translate.yaml +15 -0
- texttools/tools/__init__.py +4 -3
- texttools/tools/async_the_tool.py +435 -0
- texttools/tools/internals/async_operator.py +242 -0
- texttools/tools/internals/base_operator.py +100 -0
- texttools/tools/internals/formatters.py +24 -0
- texttools/tools/internals/operator.py +242 -0
- texttools/tools/internals/output_models.py +62 -0
- texttools/tools/internals/prompt_loader.py +60 -0
- texttools/tools/the_tool.py +433 -291
- hamtaa_texttools-1.0.1.dist-info/METADATA +0 -129
- hamtaa_texttools-1.0.1.dist-info/RECORD +0 -18
- texttools/formatters/base_formatter.py +0 -33
- texttools/formatters/user_merge_formatter/user_merge_formatter.py +0 -47
- texttools/prompts/__init__.py +0 -0
- texttools/tools/operator.py +0 -236
- texttools/tools/output_models.py +0 -54
- texttools/tools/prompt_loader.py +0 -84
- texttools/utils/__init__.py +0 -4
- texttools/utils/batch_manager/__init__.py +0 -4
- texttools/utils/batch_manager/batch_runner.py +0 -212
- {hamtaa_texttools-1.0.1.dist-info → hamtaa_texttools-1.1.7.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
from typing import Literal, Any, Callable
|
|
2
|
+
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
|
|
5
|
+
from texttools.tools.internals.async_operator import AsyncOperator
|
|
6
|
+
import texttools.tools.internals.output_models as OutputModels
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AsyncTheTool:
|
|
10
|
+
"""
|
|
11
|
+
Async counterpart to TheTool.
|
|
12
|
+
|
|
13
|
+
Each method configures the async operator with a specific YAML prompt,
|
|
14
|
+
output schema, and flags, then delegates execution to `operator.run()`.
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
async_client = AsyncOpenAI(...)
|
|
18
|
+
tool = TheToolAsync(async_client, model="model-name")
|
|
19
|
+
result = await tool.categorize("text ...", with_analysis=True)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
client: AsyncOpenAI,
|
|
25
|
+
model: str,
|
|
26
|
+
):
|
|
27
|
+
self.operator = AsyncOperator(client=client, model=model)
|
|
28
|
+
|
|
29
|
+
async def categorize(
|
|
30
|
+
self,
|
|
31
|
+
text: str,
|
|
32
|
+
with_analysis: bool = False,
|
|
33
|
+
user_prompt: str | None = None,
|
|
34
|
+
temperature: float | None = 0.0,
|
|
35
|
+
logprobs: bool = False,
|
|
36
|
+
top_logprobs: int | None = None,
|
|
37
|
+
validator: Callable[[Any], bool] | None = None,
|
|
38
|
+
) -> OutputModels.ToolOutput:
|
|
39
|
+
"""
|
|
40
|
+
Categorize a text into a single Islamic studies domain category.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
ToolOutput: Object containing:
|
|
44
|
+
- result (str): The assigned Islamic studies category
|
|
45
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
46
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
47
|
+
"""
|
|
48
|
+
return await self.operator.run(
|
|
49
|
+
# User parameters
|
|
50
|
+
text=text,
|
|
51
|
+
with_analysis=with_analysis,
|
|
52
|
+
user_prompt=user_prompt,
|
|
53
|
+
temperature=temperature,
|
|
54
|
+
logprobs=logprobs,
|
|
55
|
+
top_logprobs=top_logprobs,
|
|
56
|
+
validator=validator,
|
|
57
|
+
# Internal parameters
|
|
58
|
+
prompt_file="categorizer.yaml",
|
|
59
|
+
output_model=OutputModels.CategorizerOutput,
|
|
60
|
+
resp_format="parse",
|
|
61
|
+
mode=None,
|
|
62
|
+
output_lang=None,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
async def extract_keywords(
|
|
66
|
+
self,
|
|
67
|
+
text: str,
|
|
68
|
+
with_analysis: bool = False,
|
|
69
|
+
output_lang: str | None = None,
|
|
70
|
+
user_prompt: str | None = None,
|
|
71
|
+
temperature: float | None = 0.0,
|
|
72
|
+
logprobs: bool = False,
|
|
73
|
+
top_logprobs: int | None = None,
|
|
74
|
+
validator: Callable[[Any], bool] | None = None,
|
|
75
|
+
) -> OutputModels.ToolOutput:
|
|
76
|
+
"""
|
|
77
|
+
Extract salient keywords from text.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
ToolOutput: Object containing:
|
|
81
|
+
- result (list[str]): List of extracted keywords
|
|
82
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
83
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
84
|
+
"""
|
|
85
|
+
return await self.operator.run(
|
|
86
|
+
# User parameters
|
|
87
|
+
text=text,
|
|
88
|
+
with_analysis=with_analysis,
|
|
89
|
+
output_lang=output_lang,
|
|
90
|
+
user_prompt=user_prompt,
|
|
91
|
+
temperature=temperature,
|
|
92
|
+
logprobs=logprobs,
|
|
93
|
+
top_logprobs=top_logprobs,
|
|
94
|
+
validator=validator,
|
|
95
|
+
# Internal parameters
|
|
96
|
+
prompt_file="extract_keywords.yaml",
|
|
97
|
+
output_model=OutputModels.ListStrOutput,
|
|
98
|
+
resp_format="parse",
|
|
99
|
+
mode=None,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
async def extract_entities(
|
|
103
|
+
self,
|
|
104
|
+
text: str,
|
|
105
|
+
with_analysis: bool = False,
|
|
106
|
+
output_lang: str | None = None,
|
|
107
|
+
user_prompt: str | None = None,
|
|
108
|
+
temperature: float | None = 0.0,
|
|
109
|
+
logprobs: bool = False,
|
|
110
|
+
top_logprobs: int | None = None,
|
|
111
|
+
validator: Callable[[Any], bool] | None = None,
|
|
112
|
+
) -> OutputModels.ToolOutput:
|
|
113
|
+
"""
|
|
114
|
+
Perform Named Entity Recognition (NER) over the input text.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
ToolOutput: Object containing:
|
|
118
|
+
- result (list[dict]): List of entities with 'text' and 'type' keys
|
|
119
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
120
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
121
|
+
"""
|
|
122
|
+
return await self.operator.run(
|
|
123
|
+
# User parameters
|
|
124
|
+
text=text,
|
|
125
|
+
with_analysis=with_analysis,
|
|
126
|
+
output_lang=output_lang,
|
|
127
|
+
user_prompt=user_prompt,
|
|
128
|
+
temperature=temperature,
|
|
129
|
+
logprobs=logprobs,
|
|
130
|
+
top_logprobs=top_logprobs,
|
|
131
|
+
validator=validator,
|
|
132
|
+
# Internal parameters
|
|
133
|
+
prompt_file="extract_entities.yaml",
|
|
134
|
+
output_model=OutputModels.ListDictStrStrOutput,
|
|
135
|
+
resp_format="parse",
|
|
136
|
+
mode=None,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
async def is_question(
|
|
140
|
+
self,
|
|
141
|
+
text: str,
|
|
142
|
+
with_analysis: bool = False,
|
|
143
|
+
user_prompt: str | None = None,
|
|
144
|
+
temperature: float | None = 0.0,
|
|
145
|
+
logprobs: bool = False,
|
|
146
|
+
top_logprobs: int | None = None,
|
|
147
|
+
validator: Callable[[Any], bool] | None = None,
|
|
148
|
+
) -> OutputModels.ToolOutput:
|
|
149
|
+
"""
|
|
150
|
+
Detect if the input is phrased as a question.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
ToolOutput: Object containing:
|
|
154
|
+
- result (bool): True if text is a question, False otherwise
|
|
155
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
156
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
157
|
+
"""
|
|
158
|
+
return await self.operator.run(
|
|
159
|
+
# User parameters
|
|
160
|
+
text=text,
|
|
161
|
+
with_analysis=with_analysis,
|
|
162
|
+
user_prompt=user_prompt,
|
|
163
|
+
temperature=temperature,
|
|
164
|
+
logprobs=logprobs,
|
|
165
|
+
top_logprobs=top_logprobs,
|
|
166
|
+
validator=validator,
|
|
167
|
+
# Internal parameters
|
|
168
|
+
prompt_file="is_question.yaml",
|
|
169
|
+
output_model=OutputModels.BoolOutput,
|
|
170
|
+
resp_format="parse",
|
|
171
|
+
mode=None,
|
|
172
|
+
output_lang=None,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
async def text_to_question(
|
|
176
|
+
self,
|
|
177
|
+
text: str,
|
|
178
|
+
with_analysis: bool = False,
|
|
179
|
+
output_lang: str | None = None,
|
|
180
|
+
user_prompt: str | None = None,
|
|
181
|
+
temperature: float | None = 0.0,
|
|
182
|
+
logprobs: bool = False,
|
|
183
|
+
top_logprobs: int | None = None,
|
|
184
|
+
validator: Callable[[Any], bool] | None = None,
|
|
185
|
+
) -> OutputModels.ToolOutput:
|
|
186
|
+
"""
|
|
187
|
+
Generate a single question from the given text.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
ToolOutput: Object containing:
|
|
191
|
+
- result (str): The generated question
|
|
192
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
193
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
194
|
+
"""
|
|
195
|
+
return await self.operator.run(
|
|
196
|
+
# User parameters
|
|
197
|
+
text=text,
|
|
198
|
+
with_analysis=with_analysis,
|
|
199
|
+
output_lang=output_lang,
|
|
200
|
+
user_prompt=user_prompt,
|
|
201
|
+
temperature=temperature,
|
|
202
|
+
logprobs=logprobs,
|
|
203
|
+
top_logprobs=top_logprobs,
|
|
204
|
+
validator=validator,
|
|
205
|
+
# Internal parameters
|
|
206
|
+
prompt_file="text_to_question.yaml",
|
|
207
|
+
output_model=OutputModels.StrOutput,
|
|
208
|
+
resp_format="parse",
|
|
209
|
+
mode=None,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
async def merge_questions(
|
|
213
|
+
self,
|
|
214
|
+
text: list[str],
|
|
215
|
+
with_analysis: bool = False,
|
|
216
|
+
output_lang: str | None = None,
|
|
217
|
+
user_prompt: str | None = None,
|
|
218
|
+
temperature: float | None = 0.0,
|
|
219
|
+
logprobs: bool = False,
|
|
220
|
+
top_logprobs: int | None = None,
|
|
221
|
+
mode: Literal["default", "reason"] = "default",
|
|
222
|
+
validator: Callable[[Any], bool] | None = None,
|
|
223
|
+
) -> OutputModels.ToolOutput:
|
|
224
|
+
"""
|
|
225
|
+
Merge multiple questions into a single unified question.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
ToolOutput: Object containing:
|
|
229
|
+
- result (str): The merged question
|
|
230
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
231
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
232
|
+
"""
|
|
233
|
+
text = ", ".join(text)
|
|
234
|
+
return await self.operator.run(
|
|
235
|
+
# User parameters
|
|
236
|
+
text=text,
|
|
237
|
+
with_analysis=with_analysis,
|
|
238
|
+
output_lang=output_lang,
|
|
239
|
+
user_prompt=user_prompt,
|
|
240
|
+
temperature=temperature,
|
|
241
|
+
logprobs=logprobs,
|
|
242
|
+
top_logprobs=top_logprobs,
|
|
243
|
+
validator=validator,
|
|
244
|
+
# Internal parameters
|
|
245
|
+
prompt_file="merge_questions.yaml",
|
|
246
|
+
output_model=OutputModels.StrOutput,
|
|
247
|
+
resp_format="parse",
|
|
248
|
+
mode=mode,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
async def rewrite(
|
|
252
|
+
self,
|
|
253
|
+
text: str,
|
|
254
|
+
with_analysis: bool = False,
|
|
255
|
+
output_lang: str | None = None,
|
|
256
|
+
user_prompt: str | None = None,
|
|
257
|
+
temperature: float | None = 0.0,
|
|
258
|
+
logprobs: bool = False,
|
|
259
|
+
top_logprobs: int | None = None,
|
|
260
|
+
mode: Literal["positive", "negative", "hard_negative"] = "positive",
|
|
261
|
+
validator: Callable[[Any], bool] | None = None,
|
|
262
|
+
) -> OutputModels.ToolOutput:
|
|
263
|
+
"""
|
|
264
|
+
Rewrite a text with different modes.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
ToolOutput: Object containing:
|
|
268
|
+
- result (str): The rewritten text
|
|
269
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
270
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
271
|
+
"""
|
|
272
|
+
return await self.operator.run(
|
|
273
|
+
# User parameters
|
|
274
|
+
text=text,
|
|
275
|
+
with_analysis=with_analysis,
|
|
276
|
+
output_lang=output_lang,
|
|
277
|
+
user_prompt=user_prompt,
|
|
278
|
+
temperature=temperature,
|
|
279
|
+
logprobs=logprobs,
|
|
280
|
+
top_logprobs=top_logprobs,
|
|
281
|
+
validator=validator,
|
|
282
|
+
# Internal parameters
|
|
283
|
+
prompt_file="rewrite.yaml",
|
|
284
|
+
output_model=OutputModels.StrOutput,
|
|
285
|
+
resp_format="parse",
|
|
286
|
+
mode=mode,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
async def subject_to_question(
|
|
290
|
+
self,
|
|
291
|
+
text: str,
|
|
292
|
+
number_of_questions: int,
|
|
293
|
+
with_analysis: bool = False,
|
|
294
|
+
output_lang: str | None = None,
|
|
295
|
+
user_prompt: str | None = None,
|
|
296
|
+
temperature: float | None = 0.0,
|
|
297
|
+
logprobs: bool = False,
|
|
298
|
+
top_logprobs: int | None = None,
|
|
299
|
+
validator: Callable[[Any], bool] | None = None,
|
|
300
|
+
) -> OutputModels.ToolOutput:
|
|
301
|
+
"""
|
|
302
|
+
Generate a list of questions about a subject.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
ToolOutput: Object containing:
|
|
306
|
+
- result (list[str]): List of generated questions
|
|
307
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
308
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
309
|
+
"""
|
|
310
|
+
return await self.operator.run(
|
|
311
|
+
# User parameters
|
|
312
|
+
text=text,
|
|
313
|
+
number_of_questions=number_of_questions,
|
|
314
|
+
with_analysis=with_analysis,
|
|
315
|
+
output_lang=output_lang,
|
|
316
|
+
user_prompt=user_prompt,
|
|
317
|
+
temperature=temperature,
|
|
318
|
+
logprobs=logprobs,
|
|
319
|
+
top_logprobs=top_logprobs,
|
|
320
|
+
validator=validator,
|
|
321
|
+
# Internal parameters
|
|
322
|
+
prompt_file="subject_to_question.yaml",
|
|
323
|
+
output_model=OutputModels.ReasonListStrOutput,
|
|
324
|
+
resp_format="parse",
|
|
325
|
+
mode=None,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
async def summarize(
|
|
329
|
+
self,
|
|
330
|
+
text: str,
|
|
331
|
+
with_analysis: bool = False,
|
|
332
|
+
output_lang: str | None = None,
|
|
333
|
+
user_prompt: str | None = None,
|
|
334
|
+
temperature: float | None = 0.0,
|
|
335
|
+
logprobs: bool = False,
|
|
336
|
+
top_logprobs: int | None = None,
|
|
337
|
+
validator: Callable[[Any], bool] | None = None,
|
|
338
|
+
) -> OutputModels.ToolOutput:
|
|
339
|
+
"""
|
|
340
|
+
Summarize the given subject text.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
ToolOutput: Object containing:
|
|
344
|
+
- result (str): The summary text
|
|
345
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
346
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
347
|
+
"""
|
|
348
|
+
return await self.operator.run(
|
|
349
|
+
# User parameters
|
|
350
|
+
text=text,
|
|
351
|
+
with_analysis=with_analysis,
|
|
352
|
+
output_lang=output_lang,
|
|
353
|
+
user_prompt=user_prompt,
|
|
354
|
+
temperature=temperature,
|
|
355
|
+
logprobs=logprobs,
|
|
356
|
+
top_logprobs=top_logprobs,
|
|
357
|
+
validator=validator,
|
|
358
|
+
# Internal parameters
|
|
359
|
+
prompt_file="summarize.yaml",
|
|
360
|
+
output_model=OutputModels.StrOutput,
|
|
361
|
+
resp_format="parse",
|
|
362
|
+
mode=None,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
async def translate(
|
|
366
|
+
self,
|
|
367
|
+
text: str,
|
|
368
|
+
target_language: str,
|
|
369
|
+
with_analysis: bool = False,
|
|
370
|
+
user_prompt: str | None = None,
|
|
371
|
+
temperature: float | None = 0.0,
|
|
372
|
+
logprobs: bool = False,
|
|
373
|
+
top_logprobs: int | None = None,
|
|
374
|
+
validator: Callable[[Any], bool] | None = None,
|
|
375
|
+
) -> OutputModels.ToolOutput:
|
|
376
|
+
"""
|
|
377
|
+
Translate text between languages.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
ToolOutput: Object containing:
|
|
381
|
+
- result (str): The translated text
|
|
382
|
+
- logprobs (list | None): Probability data if logprobs enabled
|
|
383
|
+
- analysis (str | None): Detailed reasoning if with_analysis enabled
|
|
384
|
+
"""
|
|
385
|
+
return await self.operator.run(
|
|
386
|
+
# User parameters
|
|
387
|
+
text=text,
|
|
388
|
+
target_language=target_language,
|
|
389
|
+
with_analysis=with_analysis,
|
|
390
|
+
user_prompt=user_prompt,
|
|
391
|
+
temperature=temperature,
|
|
392
|
+
logprobs=logprobs,
|
|
393
|
+
top_logprobs=top_logprobs,
|
|
394
|
+
validator=validator,
|
|
395
|
+
# Internal parameters
|
|
396
|
+
prompt_file="translate.yaml",
|
|
397
|
+
output_model=OutputModels.StrOutput,
|
|
398
|
+
resp_format="parse",
|
|
399
|
+
mode=None,
|
|
400
|
+
output_lang=None,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
async def run_custom(
|
|
404
|
+
self,
|
|
405
|
+
prompt: str,
|
|
406
|
+
output_model: Any,
|
|
407
|
+
output_lang: str | None = None,
|
|
408
|
+
temperature: float | None = None,
|
|
409
|
+
logprobs: bool | None = None,
|
|
410
|
+
top_logprobs: int | None = None,
|
|
411
|
+
) -> OutputModels.ToolOutput:
|
|
412
|
+
"""
|
|
413
|
+
Custom tool that can do almost anything!
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
ToolOutput: Object with fields:
|
|
417
|
+
- result (str): The output result
|
|
418
|
+
"""
|
|
419
|
+
return await self.operator.run(
|
|
420
|
+
# User paramaeters
|
|
421
|
+
text=prompt,
|
|
422
|
+
output_model=output_model,
|
|
423
|
+
output_model_str=output_model.model_json_schema(),
|
|
424
|
+
output_lang=output_lang,
|
|
425
|
+
temperature=temperature,
|
|
426
|
+
logprobs=logprobs,
|
|
427
|
+
top_logprobs=top_logprobs,
|
|
428
|
+
# Internal parameters
|
|
429
|
+
prompt_file="run_custom.yaml",
|
|
430
|
+
resp_format="parse",
|
|
431
|
+
user_prompt=None,
|
|
432
|
+
with_analysis=False,
|
|
433
|
+
mode=None,
|
|
434
|
+
validator=None,
|
|
435
|
+
)
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
from typing import Any, TypeVar, Type, Literal, Callable
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from openai import AsyncOpenAI
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from texttools.tools.internals.output_models import ToolOutput
|
|
8
|
+
from texttools.tools.internals.base_operator import BaseOperator
|
|
9
|
+
from texttools.tools.internals.formatters import Formatter
|
|
10
|
+
from texttools.tools.internals.prompt_loader import PromptLoader
|
|
11
|
+
|
|
12
|
+
# Base Model type for output models
|
|
13
|
+
T = TypeVar("T", bound=BaseModel)
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("texttools.async_operator")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AsyncOperator(BaseOperator):
|
|
19
|
+
"""
|
|
20
|
+
Core engine for running text-processing operations with an LLM (Async).
|
|
21
|
+
|
|
22
|
+
It wires together:
|
|
23
|
+
- `PromptLoader` → loads YAML prompt templates.
|
|
24
|
+
- `UserMergeFormatter` → applies formatting to messages (e.g., merging).
|
|
25
|
+
- AsyncOpenAI client → executes completions/parsed completions.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, client: AsyncOpenAI, model: str):
|
|
29
|
+
self.client = client
|
|
30
|
+
self.model = model
|
|
31
|
+
|
|
32
|
+
async def _analyze(self, prompt_configs: dict[str, str], temperature: float) -> str:
|
|
33
|
+
"""
|
|
34
|
+
Calls OpenAI API for analysis using the configured prompt template.
|
|
35
|
+
Returns the analyzed content as a string.
|
|
36
|
+
"""
|
|
37
|
+
analyze_prompt = prompt_configs["analyze_template"]
|
|
38
|
+
analyze_message = [self._build_user_message(analyze_prompt)]
|
|
39
|
+
completion = await self.client.chat.completions.create(
|
|
40
|
+
model=self.model,
|
|
41
|
+
messages=analyze_message,
|
|
42
|
+
temperature=temperature,
|
|
43
|
+
)
|
|
44
|
+
analysis = completion.choices[0].message.content.strip()
|
|
45
|
+
return analysis
|
|
46
|
+
|
|
47
|
+
async def _parse_completion(
|
|
48
|
+
self,
|
|
49
|
+
message: list[dict[str, str]],
|
|
50
|
+
output_model: Type[T],
|
|
51
|
+
temperature: float,
|
|
52
|
+
logprobs: bool = False,
|
|
53
|
+
top_logprobs: int = 3,
|
|
54
|
+
) -> tuple[Type[T], Any]:
|
|
55
|
+
"""
|
|
56
|
+
Parses a chat completion using OpenAI's structured output format.
|
|
57
|
+
Returns both the parsed object and the raw completion for logging.
|
|
58
|
+
"""
|
|
59
|
+
request_kwargs = {
|
|
60
|
+
"model": self.model,
|
|
61
|
+
"messages": message,
|
|
62
|
+
"response_format": output_model,
|
|
63
|
+
"temperature": temperature,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if logprobs:
|
|
67
|
+
request_kwargs["logprobs"] = True
|
|
68
|
+
request_kwargs["top_logprobs"] = top_logprobs
|
|
69
|
+
|
|
70
|
+
completion = await self.client.beta.chat.completions.parse(**request_kwargs)
|
|
71
|
+
parsed = completion.choices[0].message.parsed
|
|
72
|
+
return parsed, completion
|
|
73
|
+
|
|
74
|
+
async def _vllm_completion(
|
|
75
|
+
self,
|
|
76
|
+
message: list[dict[str, str]],
|
|
77
|
+
output_model: Type[T],
|
|
78
|
+
temperature: float,
|
|
79
|
+
logprobs: bool = False,
|
|
80
|
+
top_logprobs: int = 3,
|
|
81
|
+
) -> tuple[Type[T], Any]:
|
|
82
|
+
"""
|
|
83
|
+
Generates a completion using vLLM with JSON schema guidance.
|
|
84
|
+
Returns the parsed output model and raw completion.
|
|
85
|
+
"""
|
|
86
|
+
json_schema = output_model.model_json_schema()
|
|
87
|
+
|
|
88
|
+
# Build kwargs dynamically
|
|
89
|
+
request_kwargs = {
|
|
90
|
+
"model": self.model,
|
|
91
|
+
"messages": message,
|
|
92
|
+
"extra_body": {"guided_json": json_schema},
|
|
93
|
+
"temperature": temperature,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
if logprobs:
|
|
97
|
+
request_kwargs["logprobs"] = True
|
|
98
|
+
request_kwargs["top_logprobs"] = top_logprobs
|
|
99
|
+
|
|
100
|
+
completion = await self.client.chat.completions.create(**request_kwargs)
|
|
101
|
+
response = completion.choices[0].message.content
|
|
102
|
+
|
|
103
|
+
# Convert the string response to output model
|
|
104
|
+
parsed = self._convert_to_output_model(response, output_model)
|
|
105
|
+
return parsed, completion
|
|
106
|
+
|
|
107
|
+
async def run(
|
|
108
|
+
self,
|
|
109
|
+
# User parameters
|
|
110
|
+
text: str,
|
|
111
|
+
with_analysis: bool,
|
|
112
|
+
output_lang: str | None,
|
|
113
|
+
user_prompt: str | None,
|
|
114
|
+
temperature: float,
|
|
115
|
+
logprobs: bool,
|
|
116
|
+
top_logprobs: int | None,
|
|
117
|
+
validator: Callable[[Any], bool] | None,
|
|
118
|
+
# Internal parameters
|
|
119
|
+
prompt_file: str,
|
|
120
|
+
output_model: Type[T],
|
|
121
|
+
resp_format: Literal["vllm", "parse"],
|
|
122
|
+
mode: str | None,
|
|
123
|
+
**extra_kwargs,
|
|
124
|
+
) -> ToolOutput:
|
|
125
|
+
"""
|
|
126
|
+
Execute the async LLM pipeline with the given input text. (Async)
|
|
127
|
+
"""
|
|
128
|
+
prompt_loader = PromptLoader()
|
|
129
|
+
formatter = Formatter()
|
|
130
|
+
output = ToolOutput()
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
# Prompt configs contain two keys: main_template and analyze template, both are string
|
|
134
|
+
prompt_configs = prompt_loader.load(
|
|
135
|
+
prompt_file=prompt_file,
|
|
136
|
+
text=text.strip(),
|
|
137
|
+
mode=mode,
|
|
138
|
+
**extra_kwargs,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
messages: list[dict[str, str]] = []
|
|
142
|
+
|
|
143
|
+
if with_analysis:
|
|
144
|
+
analysis = await self._analyze(prompt_configs, temperature)
|
|
145
|
+
messages.append(
|
|
146
|
+
self._build_user_message(f"Based on this analysis: {analysis}")
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if output_lang:
|
|
150
|
+
messages.append(
|
|
151
|
+
self._build_user_message(
|
|
152
|
+
f"Respond only in the {output_lang} language."
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if user_prompt:
|
|
157
|
+
messages.append(
|
|
158
|
+
self._build_user_message(f"Consider this instruction {user_prompt}")
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
messages.append(self._build_user_message(prompt_configs["main_template"]))
|
|
162
|
+
messages = formatter.user_merge_format(messages)
|
|
163
|
+
|
|
164
|
+
if resp_format == "vllm":
|
|
165
|
+
parsed, completion = await self._vllm_completion(
|
|
166
|
+
messages, output_model, temperature, logprobs, top_logprobs
|
|
167
|
+
)
|
|
168
|
+
elif resp_format == "parse":
|
|
169
|
+
parsed, completion = await self._parse_completion(
|
|
170
|
+
messages, output_model, temperature, logprobs, top_logprobs
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Ensure output_model has a `result` field
|
|
174
|
+
if not hasattr(parsed, "result"):
|
|
175
|
+
error = "The provided output_model must define a field named 'result'"
|
|
176
|
+
logger.error(error)
|
|
177
|
+
output.errors.append(error)
|
|
178
|
+
return output
|
|
179
|
+
|
|
180
|
+
output.result = parsed.result
|
|
181
|
+
|
|
182
|
+
# Retry logic if validation fails
|
|
183
|
+
if validator and not validator(output.result):
|
|
184
|
+
max_retries = 3
|
|
185
|
+
for attempt in range(max_retries):
|
|
186
|
+
logger.warning(
|
|
187
|
+
f"Validation failed, retrying for the {attempt + 1} time."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Generate new temperature for retry
|
|
191
|
+
retry_temperature = self._get_retry_temp(temperature)
|
|
192
|
+
try:
|
|
193
|
+
if resp_format == "vllm":
|
|
194
|
+
parsed, completion = await self._vllm_completion(
|
|
195
|
+
messages,
|
|
196
|
+
output_model,
|
|
197
|
+
retry_temperature,
|
|
198
|
+
logprobs,
|
|
199
|
+
top_logprobs,
|
|
200
|
+
)
|
|
201
|
+
elif resp_format == "parse":
|
|
202
|
+
parsed, completion = await self._parse_completion(
|
|
203
|
+
messages,
|
|
204
|
+
output_model,
|
|
205
|
+
retry_temperature,
|
|
206
|
+
logprobs,
|
|
207
|
+
top_logprobs,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
output.result = parsed.result
|
|
211
|
+
|
|
212
|
+
# Check if retry was successful
|
|
213
|
+
if validator(output.result):
|
|
214
|
+
logger.info(
|
|
215
|
+
f"Validation passed on retry attempt {attempt + 1}"
|
|
216
|
+
)
|
|
217
|
+
break
|
|
218
|
+
else:
|
|
219
|
+
logger.warning(
|
|
220
|
+
f"Validation still failing after retry attempt {attempt + 1}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Retry attempt {attempt + 1} failed: {e}")
|
|
225
|
+
# Continue to next retry attempt if this one fails
|
|
226
|
+
|
|
227
|
+
# Final check after all retries
|
|
228
|
+
if validator and not validator(output.result):
|
|
229
|
+
output.errors.append("Validation failed after all retry attempts")
|
|
230
|
+
|
|
231
|
+
if logprobs:
|
|
232
|
+
output.logprobs = self._extract_logprobs(completion)
|
|
233
|
+
|
|
234
|
+
if with_analysis:
|
|
235
|
+
output.analysis = analysis
|
|
236
|
+
|
|
237
|
+
return output
|
|
238
|
+
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logger.error(f"AsyncTheTool failed: {e}")
|
|
241
|
+
output.errors.append(str(e))
|
|
242
|
+
return output
|