hamtaa-texttools 1.1.20__py3-none-any.whl → 1.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.21.dist-info}/METADATA +8 -27
- hamtaa_texttools-1.1.21.dist-info/RECORD +32 -0
- texttools/batch/batch_config.py +14 -1
- texttools/batch/batch_runner.py +1 -1
- texttools/internals/async_operator.py +45 -79
- texttools/internals/models.py +74 -105
- texttools/internals/operator_utils.py +2 -26
- texttools/internals/prompt_loader.py +3 -20
- texttools/internals/sync_operator.py +44 -78
- texttools/prompts/README.md +2 -2
- texttools/prompts/categorize.yaml +35 -77
- texttools/prompts/check_fact.yaml +2 -2
- texttools/prompts/extract_entities.yaml +2 -2
- texttools/prompts/extract_keywords.yaml +6 -6
- texttools/prompts/is_question.yaml +2 -2
- texttools/prompts/merge_questions.yaml +4 -4
- texttools/prompts/propositionize.yaml +2 -2
- texttools/prompts/rewrite.yaml +6 -6
- texttools/prompts/run_custom.yaml +1 -1
- texttools/prompts/subject_to_question.yaml +2 -2
- texttools/prompts/summarize.yaml +2 -2
- texttools/prompts/text_to_question.yaml +2 -2
- texttools/prompts/translate.yaml +2 -2
- texttools/tools/async_tools.py +393 -485
- texttools/tools/sync_tools.py +394 -486
- hamtaa_texttools-1.1.20.dist-info/RECORD +0 -33
- texttools/batch/internals/utils.py +0 -13
- {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.21.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.21.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.21.dist-info}/top_level.txt +0 -0
- /texttools/batch/{internals/batch_manager.py → batch_manager.py} +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hamtaa-texttools
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.21
|
|
4
4
|
Summary: A high-level NLP toolkit built on top of modern LLMs.
|
|
5
5
|
Author-email: Tohidi <the.mohammad.tohidi@gmail.com>, Montazer <montazerh82@gmail.com>, Givechi <mohamad.m.givechi@gmail.com>, MoosaviNejad <erfanmoosavi84@gmail.com>, Zareshahi <a.zareshahi1377@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -50,7 +50,7 @@ It provides ready-to-use utilities for **translation, question detection, keywor
|
|
|
50
50
|
TextTools provides a rich collection of high-level NLP utilities,
|
|
51
51
|
Each tool is designed to work with structured outputs (JSON / Pydantic).
|
|
52
52
|
|
|
53
|
-
- **`categorize()`** - Classifies text into given categories
|
|
53
|
+
- **`categorize()`** - Classifies text into given categories
|
|
54
54
|
- **`extract_keywords()`** - Extracts keywords from text
|
|
55
55
|
- **`extract_entities()`** - Named Entity Recognition (NER) system
|
|
56
56
|
- **`is_question()`** - Binary detection of whether input is a question
|
|
@@ -61,7 +61,7 @@ Each tool is designed to work with structured outputs (JSON / Pydantic).
|
|
|
61
61
|
- **`summarize()`** - Text summarization
|
|
62
62
|
- **`translate()`** - Text translation between languages
|
|
63
63
|
- **`propositionize()`** - Convert text to atomic independence meaningful sentences
|
|
64
|
-
- **`check_fact()`** - Check a statement is relevant to source text
|
|
64
|
+
- **`check_fact()`** - Check whether a statement is relevant to the source text
|
|
65
65
|
- **`run_custom()`** - Allows users to define a custom tool with an arbitrary BaseModel
|
|
66
66
|
|
|
67
67
|
---
|
|
@@ -125,11 +125,12 @@ TextTools provides several optional flags to customize LLM behavior:
|
|
|
125
125
|
Every tool of `TextTools` returns a `ToolOutput` object which is a BaseModel with attributes:
|
|
126
126
|
- **`result: Any`** → The output of LLM
|
|
127
127
|
- **`analysis: str`** → The reasoning step before generating the final output
|
|
128
|
-
- **`logprobs: list`** → Token-level probabilities for the generated output
|
|
129
|
-
- **`process: str`** → The tool name which processed the input
|
|
130
|
-
- **`processed_at: datetime`** → The process time
|
|
131
|
-
- **`execution_time: float`** → The execution time (seconds)
|
|
128
|
+
- **`logprobs: list`** → Token-level probabilities for the generated output
|
|
132
129
|
- **`errors: list[str]`** → Any error that have occured during calling LLM
|
|
130
|
+
- **`ToolOutputMetadata`** →
|
|
131
|
+
- **`tool_name: str`** → The tool name which processed the input
|
|
132
|
+
- **`processed_at: datetime`** → The process time
|
|
133
|
+
- **`execution_time: float`** → The execution time (seconds)
|
|
133
134
|
|
|
134
135
|
**Note:** You can use `repr(ToolOutput)` to see details of your ToolOutput.
|
|
135
136
|
|
|
@@ -224,26 +225,6 @@ Use **TextTools** when you need to:
|
|
|
224
225
|
|
|
225
226
|
---
|
|
226
227
|
|
|
227
|
-
## 🔍 Logging
|
|
228
|
-
|
|
229
|
-
TextTools uses Python's standard `logging` module. The library's default logger level is `WARNING`, so if you want to modify it, follow instructions:
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
```python
|
|
233
|
-
import logging
|
|
234
|
-
|
|
235
|
-
# Default: warnings and errors only
|
|
236
|
-
logging.basicConfig(level=logging.WARNING)
|
|
237
|
-
|
|
238
|
-
# Debug everything (verbose)
|
|
239
|
-
logging.basicConfig(level=logging.DEBUG)
|
|
240
|
-
|
|
241
|
-
# Complete silence
|
|
242
|
-
logging.basicConfig(level=logging.CRITICAL)
|
|
243
|
-
```
|
|
244
|
-
|
|
245
|
-
---
|
|
246
|
-
|
|
247
228
|
## 📚 Batch Processing
|
|
248
229
|
|
|
249
230
|
Process large datasets efficiently using OpenAI's batch API.
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
hamtaa_texttools-1.1.21.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
|
|
2
|
+
texttools/__init__.py,sha256=CmCS9dEvO6061GiJ8A7gD3UAhCWHTkaID9q3Krlyq_o,311
|
|
3
|
+
texttools/batch/batch_config.py,sha256=scWYQBDuaTj8-b2x_a33Zu-zxm7eqEf5FFoquD-Sv94,1029
|
|
4
|
+
texttools/batch/batch_manager.py,sha256=6HfsexU0PHGGBH7HKReZ-CQxaQI9DXYKAPsFXxovb_I,8740
|
|
5
|
+
texttools/batch/batch_runner.py,sha256=fmoq7yxtEdvfLbEhcx95ma-lgrL-ZdI2EgxmEfVcKtE,10016
|
|
6
|
+
texttools/internals/async_operator.py,sha256=sKMYEy7jEcsXpwnBkA18PFubkM-TXZrBH3QwF7l-wSg,7054
|
|
7
|
+
texttools/internals/exceptions.py,sha256=h_yp_5i_5IfmqTBQ4S6ZOISrrliJBQ3HTEAjwJXrplk,495
|
|
8
|
+
texttools/internals/models.py,sha256=9uoCAe2TLrSzyS9lMJja5orPAYaCvVL1zoCb6FNdkfs,4541
|
|
9
|
+
texttools/internals/operator_utils.py,sha256=eLY2OjYQ3jT-50nx3I8gzuVzgGpMi52f5oB3cnFyxko,1864
|
|
10
|
+
texttools/internals/prompt_loader.py,sha256=yYXDD4YYG2zohGPAmvZwmv5f6xV_RSl5yOrObTh9w7I,3352
|
|
11
|
+
texttools/internals/sync_operator.py,sha256=IG3CXfGmv4PdFlAQ4AZcKuBAqPJdkIAK4mVw77zLbqI,6959
|
|
12
|
+
texttools/internals/text_to_chunks.py,sha256=vY3odhgCZK4E44k_SGlLoSiKkdN0ib6-lQAsPcplAHA,3843
|
|
13
|
+
texttools/prompts/README.md,sha256=ztajRJcmFLhyrUF0_qmOXaCwGsTGCFabfMjch2LAJG0,1375
|
|
14
|
+
texttools/prompts/categorize.yaml,sha256=016b1uGtbKXEwB8_2_bBgVuUelBlu_rgT85XK_c3Yv0,1219
|
|
15
|
+
texttools/prompts/check_fact.yaml,sha256=gQqacCXqUEx3u2FRwhFSZHvhyWGwsYuJd1nIJyhpu7Q,700
|
|
16
|
+
texttools/prompts/extract_entities.yaml,sha256=DN8lZjvzCjotODnHFkWIAxFvmVvoeSs-hDKdN1L6bec,608
|
|
17
|
+
texttools/prompts/extract_keywords.yaml,sha256=GoeApi9SUCLZgs18H2-2BxZiKQ3lHptMPesgq3cluqU,3171
|
|
18
|
+
texttools/prompts/is_question.yaml,sha256=w5qF-z05h62YVs-0x2b2ySlHDKIhukFC9pibnvNM0vc,469
|
|
19
|
+
texttools/prompts/merge_questions.yaml,sha256=f6bHEx54jJ8hnb8iDBUCxXeGdGwRFmuu7vOkVWdaIkM,1788
|
|
20
|
+
texttools/prompts/propositionize.yaml,sha256=agZKQY-NmeJD86DGjmd-paIuazf82bczIGadgzSP5Vs,1378
|
|
21
|
+
texttools/prompts/rewrite.yaml,sha256=h6x8aXcW8oRxEbp466eak0y-LCkUOKf-mJ-vNVp5j5M,5386
|
|
22
|
+
texttools/prompts/run_custom.yaml,sha256=IETY9H0wPGWIIzcnupfbwwKQblwZrbYAxB754W9MhgU,125
|
|
23
|
+
texttools/prompts/subject_to_question.yaml,sha256=TfVmZ6gDgaHRqJWCVkFlKpuJczpMvJTo4XLWPaq5zic,1145
|
|
24
|
+
texttools/prompts/summarize.yaml,sha256=CKx4vjhHbGus1TdjDz_oc0bNEQtq7zfHsZkV2WeYHDU,457
|
|
25
|
+
texttools/prompts/text_to_question.yaml,sha256=mnArBoYu7gpGHriaU2-Aw5SixB2ZIgoHMt99PnTPKD0,1003
|
|
26
|
+
texttools/prompts/translate.yaml,sha256=ew9RERAVSzg0cvxAinNwTSFIaOIjdwIsekbUsgAuNgo,632
|
|
27
|
+
texttools/tools/async_tools.py,sha256=VU3cqqCPILsyjRiG84w8kCw3iDSuFbI6S3VjExXZwFQ,44635
|
|
28
|
+
texttools/tools/sync_tools.py,sha256=2cqcosMYR6LHuYw32WFR-drvqQ-t7Q9_2rUBDOeYzho,44441
|
|
29
|
+
hamtaa_texttools-1.1.21.dist-info/METADATA,sha256=lExdE6uMFSs_wqUSElOyktjpHpZx4RY-cUH6azF-IYA,10183
|
|
30
|
+
hamtaa_texttools-1.1.21.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
31
|
+
hamtaa_texttools-1.1.21.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
|
|
32
|
+
hamtaa_texttools-1.1.21.dist-info/RECORD,,
|
texttools/batch/batch_config.py
CHANGED
|
@@ -1,7 +1,20 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
|
|
4
|
-
|
|
4
|
+
|
|
5
|
+
def export_data(data) -> list[dict[str, str]]:
|
|
6
|
+
"""
|
|
7
|
+
Produces a structure of the following form from an initial data structure:
|
|
8
|
+
[{"id": str, "text": str},...]
|
|
9
|
+
"""
|
|
10
|
+
return data
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def import_data(data) -> object:
|
|
14
|
+
"""
|
|
15
|
+
Takes the output and adds and aggregates it to the original structure.
|
|
16
|
+
"""
|
|
17
|
+
return data
|
|
5
18
|
|
|
6
19
|
|
|
7
20
|
@dataclass
|
texttools/batch/batch_runner.py
CHANGED
|
@@ -9,7 +9,7 @@ from dotenv import load_dotenv
|
|
|
9
9
|
from openai import OpenAI
|
|
10
10
|
from pydantic import BaseModel
|
|
11
11
|
|
|
12
|
-
from texttools.batch.
|
|
12
|
+
from texttools.batch.batch_manager import BatchManager
|
|
13
13
|
from texttools.batch.batch_config import BatchConfig
|
|
14
14
|
from texttools.internals.models import Str
|
|
15
15
|
from texttools.internals.exceptions import TextToolsError, ConfigurationError
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
from typing import TypeVar, Type
|
|
2
2
|
from collections.abc import Callable
|
|
3
|
-
import logging
|
|
4
3
|
|
|
5
4
|
from openai import AsyncOpenAI
|
|
6
5
|
from pydantic import BaseModel
|
|
7
6
|
|
|
8
|
-
from texttools.internals.models import
|
|
7
|
+
from texttools.internals.models import OperatorOutput
|
|
9
8
|
from texttools.internals.operator_utils import OperatorUtils
|
|
10
9
|
from texttools.internals.prompt_loader import PromptLoader
|
|
11
10
|
from texttools.internals.exceptions import (
|
|
@@ -18,35 +17,23 @@ from texttools.internals.exceptions import (
|
|
|
18
17
|
# Base Model type for output models
|
|
19
18
|
T = TypeVar("T", bound=BaseModel)
|
|
20
19
|
|
|
21
|
-
logger = logging.getLogger("texttools.async_operator")
|
|
22
|
-
|
|
23
20
|
|
|
24
21
|
class AsyncOperator:
|
|
25
22
|
"""
|
|
26
|
-
Core engine for running text-processing operations with an LLM
|
|
27
|
-
|
|
28
|
-
It wires together:
|
|
29
|
-
- `PromptLoader` → loads YAML prompt templates.
|
|
30
|
-
- `UserMergeFormatter` → applies formatting to messages (e.g., merging).
|
|
31
|
-
- AsyncOpenAI client → executes completions/parsed completions.
|
|
23
|
+
Core engine for running text-processing operations with an LLM.
|
|
32
24
|
"""
|
|
33
25
|
|
|
34
26
|
def __init__(self, client: AsyncOpenAI, model: str):
|
|
35
27
|
self._client = client
|
|
36
28
|
self._model = model
|
|
37
29
|
|
|
38
|
-
async def
|
|
39
|
-
"""
|
|
40
|
-
Calls OpenAI API for analysis using the configured prompt template.
|
|
41
|
-
Returns the analyzed content as a string.
|
|
42
|
-
"""
|
|
30
|
+
async def _analyze_completion(self, analyze_prompt: str, temperature: float) -> str:
|
|
43
31
|
try:
|
|
44
|
-
analyze_prompt = prompt_configs["analyze_template"]
|
|
45
|
-
|
|
46
32
|
if not analyze_prompt:
|
|
47
33
|
raise PromptError("Analyze template is empty")
|
|
48
34
|
|
|
49
|
-
analyze_message =
|
|
35
|
+
analyze_message = OperatorUtils.build_user_message(analyze_prompt)
|
|
36
|
+
|
|
50
37
|
completion = await self._client.chat.completions.create(
|
|
51
38
|
model=self._model,
|
|
52
39
|
messages=analyze_message,
|
|
@@ -61,7 +48,7 @@ class AsyncOperator:
|
|
|
61
48
|
if not analysis:
|
|
62
49
|
raise LLMError("Empty analysis response")
|
|
63
50
|
|
|
64
|
-
return analysis
|
|
51
|
+
return analysis
|
|
65
52
|
|
|
66
53
|
except Exception as e:
|
|
67
54
|
if isinstance(e, (PromptError, LLMError)):
|
|
@@ -70,21 +57,23 @@ class AsyncOperator:
|
|
|
70
57
|
|
|
71
58
|
async def _parse_completion(
|
|
72
59
|
self,
|
|
73
|
-
|
|
60
|
+
main_prompt: str,
|
|
74
61
|
output_model: Type[T],
|
|
75
62
|
temperature: float,
|
|
76
|
-
logprobs: bool
|
|
77
|
-
top_logprobs: int
|
|
78
|
-
priority: int
|
|
63
|
+
logprobs: bool,
|
|
64
|
+
top_logprobs: int,
|
|
65
|
+
priority: int,
|
|
79
66
|
) -> tuple[T, object]:
|
|
80
67
|
"""
|
|
81
68
|
Parses a chat completion using OpenAI's structured output format.
|
|
82
69
|
Returns both the parsed object and the raw completion for logprobs.
|
|
83
70
|
"""
|
|
84
71
|
try:
|
|
72
|
+
main_message = OperatorUtils.build_user_message(main_prompt)
|
|
73
|
+
|
|
85
74
|
request_kwargs = {
|
|
86
75
|
"model": self._model,
|
|
87
|
-
"messages":
|
|
76
|
+
"messages": main_message,
|
|
88
77
|
"response_format": output_model,
|
|
89
78
|
"temperature": temperature,
|
|
90
79
|
}
|
|
@@ -92,8 +81,10 @@ class AsyncOperator:
|
|
|
92
81
|
if logprobs:
|
|
93
82
|
request_kwargs["logprobs"] = True
|
|
94
83
|
request_kwargs["top_logprobs"] = top_logprobs
|
|
84
|
+
|
|
95
85
|
if priority:
|
|
96
86
|
request_kwargs["extra_body"] = {"priority": priority}
|
|
87
|
+
|
|
97
88
|
completion = await self._client.beta.chat.completions.parse(
|
|
98
89
|
**request_kwargs
|
|
99
90
|
)
|
|
@@ -122,24 +113,22 @@ class AsyncOperator:
|
|
|
122
113
|
user_prompt: str | None,
|
|
123
114
|
temperature: float,
|
|
124
115
|
logprobs: bool,
|
|
125
|
-
top_logprobs: int
|
|
116
|
+
top_logprobs: int,
|
|
126
117
|
validator: Callable[[object], bool] | None,
|
|
127
118
|
max_validation_retries: int | None,
|
|
119
|
+
priority: int,
|
|
128
120
|
# Internal parameters
|
|
129
121
|
prompt_file: str,
|
|
130
122
|
output_model: Type[T],
|
|
131
123
|
mode: str | None,
|
|
132
|
-
priority: int | None = 0,
|
|
133
124
|
**extra_kwargs,
|
|
134
|
-
) ->
|
|
125
|
+
) -> OperatorOutput:
|
|
135
126
|
"""
|
|
136
|
-
Execute the LLM pipeline with the given input text. (
|
|
127
|
+
Execute the LLM pipeline with the given input text. (Sync)
|
|
137
128
|
"""
|
|
138
129
|
try:
|
|
139
130
|
prompt_loader = PromptLoader()
|
|
140
|
-
output = ToolOutput()
|
|
141
131
|
|
|
142
|
-
# Prompt configs contain two keys: main_template and analyze template, both are string
|
|
143
132
|
prompt_configs = prompt_loader.load(
|
|
144
133
|
prompt_file=prompt_file,
|
|
145
134
|
text=text.strip(),
|
|
@@ -147,47 +136,32 @@ class AsyncOperator:
|
|
|
147
136
|
**extra_kwargs,
|
|
148
137
|
)
|
|
149
138
|
|
|
150
|
-
|
|
139
|
+
main_prompt = ""
|
|
140
|
+
analysis = ""
|
|
151
141
|
|
|
152
142
|
if with_analysis:
|
|
153
|
-
analysis = await self.
|
|
154
|
-
|
|
155
|
-
OperatorUtils.build_user_message(
|
|
156
|
-
f"Based on this analysis: {analysis}"
|
|
157
|
-
)
|
|
143
|
+
analysis = await self._analyze_completion(
|
|
144
|
+
prompt_configs["analyze_template"], temperature
|
|
158
145
|
)
|
|
146
|
+
main_prompt += f"Based on this analysis:\n{analysis}\n"
|
|
159
147
|
|
|
160
148
|
if output_lang:
|
|
161
|
-
|
|
162
|
-
OperatorUtils.build_user_message(
|
|
163
|
-
f"Respond only in the {output_lang} language."
|
|
164
|
-
)
|
|
165
|
-
)
|
|
149
|
+
main_prompt += f"Respond only in the {output_lang} language.\n"
|
|
166
150
|
|
|
167
151
|
if user_prompt:
|
|
168
|
-
|
|
169
|
-
OperatorUtils.build_user_message(
|
|
170
|
-
f"Consider this instruction {user_prompt}"
|
|
171
|
-
)
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
messages.append(
|
|
175
|
-
OperatorUtils.build_user_message(prompt_configs["main_template"])
|
|
176
|
-
)
|
|
152
|
+
main_prompt += f"Consider this instruction {user_prompt}\n"
|
|
177
153
|
|
|
178
|
-
|
|
154
|
+
main_prompt += prompt_configs["main_template"]
|
|
179
155
|
|
|
180
156
|
if logprobs and (not isinstance(top_logprobs, int) or top_logprobs < 2):
|
|
181
157
|
raise ValueError("top_logprobs should be an integer greater than 1")
|
|
182
158
|
|
|
183
159
|
parsed, completion = await self._parse_completion(
|
|
184
|
-
|
|
160
|
+
main_prompt, output_model, temperature, logprobs, top_logprobs, priority
|
|
185
161
|
)
|
|
186
162
|
|
|
187
|
-
output.result = parsed.result
|
|
188
|
-
|
|
189
163
|
# Retry logic if validation fails
|
|
190
|
-
if validator and not validator(
|
|
164
|
+
if validator and not validator(parsed.result):
|
|
191
165
|
if (
|
|
192
166
|
not isinstance(max_validation_retries, int)
|
|
193
167
|
or max_validation_retries < 1
|
|
@@ -197,17 +171,13 @@ class AsyncOperator:
|
|
|
197
171
|
)
|
|
198
172
|
|
|
199
173
|
succeeded = False
|
|
200
|
-
for
|
|
201
|
-
|
|
202
|
-
f"Validation failed, retrying for the {attempt + 1} time."
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
# Generate new temperature for retry
|
|
174
|
+
for _ in range(max_validation_retries):
|
|
175
|
+
# Generate a new temperature to retry
|
|
206
176
|
retry_temperature = OperatorUtils.get_retry_temp(temperature)
|
|
207
177
|
|
|
208
178
|
try:
|
|
209
179
|
parsed, completion = await self._parse_completion(
|
|
210
|
-
|
|
180
|
+
main_prompt,
|
|
211
181
|
output_model,
|
|
212
182
|
retry_temperature,
|
|
213
183
|
logprobs,
|
|
@@ -215,30 +185,26 @@ class AsyncOperator:
|
|
|
215
185
|
priority=priority,
|
|
216
186
|
)
|
|
217
187
|
|
|
218
|
-
output.result = parsed.result
|
|
219
|
-
|
|
220
188
|
# Check if retry was successful
|
|
221
|
-
if validator(
|
|
189
|
+
if validator(parsed.result):
|
|
222
190
|
succeeded = True
|
|
223
191
|
break
|
|
224
192
|
|
|
225
|
-
except LLMError
|
|
226
|
-
|
|
193
|
+
except LLMError:
|
|
194
|
+
pass
|
|
227
195
|
|
|
228
196
|
if not succeeded:
|
|
229
|
-
raise ValidationError(
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
output.process = prompt_file[:-5]
|
|
197
|
+
raise ValidationError("Validation failed after all retries")
|
|
198
|
+
|
|
199
|
+
operator_output = OperatorOutput(
|
|
200
|
+
result=parsed.result,
|
|
201
|
+
analysis=analysis if with_analysis else None,
|
|
202
|
+
logprobs=OperatorUtils.extract_logprobs(completion)
|
|
203
|
+
if logprobs
|
|
204
|
+
else None,
|
|
205
|
+
)
|
|
240
206
|
|
|
241
|
-
return
|
|
207
|
+
return operator_output
|
|
242
208
|
|
|
243
209
|
except (PromptError, LLMError, ValidationError):
|
|
244
210
|
raise
|
texttools/internals/models.py
CHANGED
|
@@ -1,25 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from datetime import datetime
|
|
2
|
-
from typing import Type, Literal
|
|
4
|
+
from typing import Type, Literal, Any
|
|
3
5
|
|
|
4
6
|
from pydantic import BaseModel, Field, create_model
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
class
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
analysis: str = ""
|
|
11
|
-
process: str | None = None
|
|
12
|
-
processed_at: datetime = datetime.now()
|
|
9
|
+
class ToolOutputMetadata(BaseModel):
|
|
10
|
+
tool_name: str
|
|
11
|
+
processed_at: datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
13
12
|
execution_time: float | None = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ToolOutput(BaseModel):
|
|
16
|
+
result: Any = None
|
|
17
|
+
analysis: str | None = None
|
|
18
|
+
logprobs: list[dict[str, Any]] | None = None
|
|
14
19
|
errors: list[str] = []
|
|
20
|
+
metadata: ToolOutputMetadata | None = None
|
|
15
21
|
|
|
16
22
|
def __repr__(self) -> str:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
+
base = f"""ToolOutput(result='{self.result}', result_type='{type(self.result)}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}'"""
|
|
24
|
+
|
|
25
|
+
if self.metadata:
|
|
26
|
+
base += f""", tool_name='{self.metadata.tool_name}',
|
|
27
|
+
processed_at='{self.metadata.processed_at}', execution_time='{self.metadata.execution_time}'
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
return base
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OperatorOutput(BaseModel):
|
|
34
|
+
result: Any
|
|
35
|
+
analysis: str | None
|
|
36
|
+
logprobs: list[dict[str, Any]] | None
|
|
23
37
|
|
|
24
38
|
|
|
25
39
|
class Str(BaseModel):
|
|
@@ -53,114 +67,69 @@ class ReasonListStr(BaseModel):
|
|
|
53
67
|
)
|
|
54
68
|
|
|
55
69
|
|
|
56
|
-
class Node
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
70
|
+
class Node:
|
|
71
|
+
def __init__(self, name: str, description: str, level: int, parent: Node | None):
|
|
72
|
+
self.name = name
|
|
73
|
+
self.description = description
|
|
74
|
+
self.level = level
|
|
75
|
+
self.parent = parent
|
|
76
|
+
self.children = {}
|
|
62
77
|
|
|
63
78
|
|
|
64
79
|
class CategoryTree:
|
|
65
|
-
def __init__(self
|
|
66
|
-
self._root = Node(
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
self._new_id = 1
|
|
71
|
-
|
|
72
|
-
def get_all_nodes(self) -> list[Node]:
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self._root = Node(name="root", description="root", level=0, parent=None)
|
|
82
|
+
self._all_nodes = {"root": self._root}
|
|
83
|
+
|
|
84
|
+
def get_all_nodes(self) -> dict[str, Node]:
|
|
73
85
|
return self._all_nodes
|
|
74
86
|
|
|
75
87
|
def get_level_count(self) -> int:
|
|
76
|
-
return max(
|
|
77
|
-
|
|
78
|
-
def get_node(self,
|
|
79
|
-
|
|
80
|
-
for node in self.get_all_nodes():
|
|
81
|
-
if node.name == identifier:
|
|
82
|
-
return node
|
|
83
|
-
return None
|
|
84
|
-
elif isinstance(identifier, int):
|
|
85
|
-
for node in self.get_all_nodes():
|
|
86
|
-
if node.node_id == identifier:
|
|
87
|
-
return node
|
|
88
|
-
return None
|
|
89
|
-
else:
|
|
90
|
-
return None
|
|
91
|
-
|
|
92
|
-
def get_children(self, parent_node: Node) -> list[Node] | None:
|
|
93
|
-
children = [
|
|
94
|
-
node
|
|
95
|
-
for node in self.get_all_nodes()
|
|
96
|
-
if parent_node.node_id == node.parent_id
|
|
97
|
-
]
|
|
98
|
-
return children if children else None
|
|
88
|
+
return max(node.level for node in self._all_nodes.values())
|
|
89
|
+
|
|
90
|
+
def get_node(self, name: str) -> Node | None:
|
|
91
|
+
return self._all_nodes.get(name)
|
|
99
92
|
|
|
100
93
|
def add_node(
|
|
101
94
|
self,
|
|
102
|
-
|
|
103
|
-
parent_name: str
|
|
95
|
+
name: str,
|
|
96
|
+
parent_name: str,
|
|
104
97
|
description: str | None = None,
|
|
105
98
|
) -> None:
|
|
106
|
-
if self.get_node(
|
|
107
|
-
raise ValueError(f"{
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
parent_id = parent_node.node_id
|
|
114
|
-
level = parent_node.level + 1
|
|
115
|
-
else:
|
|
116
|
-
level = 1
|
|
117
|
-
parent_id = 0
|
|
99
|
+
if self.get_node(name):
|
|
100
|
+
raise ValueError(f"Cannot add {name} category twice")
|
|
101
|
+
|
|
102
|
+
parent = self.get_node(parent_name)
|
|
103
|
+
|
|
104
|
+
if not parent:
|
|
105
|
+
raise ValueError(f"Parent category '{parent_name}' not found")
|
|
118
106
|
|
|
119
107
|
node_data = {
|
|
120
|
-
"
|
|
121
|
-
"name": node_name,
|
|
122
|
-
"level": level,
|
|
123
|
-
"parent_id": parent_id,
|
|
108
|
+
"name": name,
|
|
124
109
|
"description": description if description else "No description provided",
|
|
110
|
+
"level": parent.level + 1,
|
|
111
|
+
"parent": parent,
|
|
125
112
|
}
|
|
126
113
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
raise ValueError(f"Node with identifier: '{identifier}' not found.")
|
|
147
|
-
|
|
148
|
-
def dump_tree(self) -> dict:
|
|
149
|
-
def build_dict(node: Node) -> dict:
|
|
150
|
-
children = [
|
|
151
|
-
build_dict(child)
|
|
152
|
-
for child in self._all_nodes
|
|
153
|
-
if child.parent_id == node.node_id
|
|
154
|
-
]
|
|
155
|
-
return {
|
|
156
|
-
"node_id": node.node_id,
|
|
157
|
-
"name": node.name,
|
|
158
|
-
"level": node.level,
|
|
159
|
-
"parent_id": node.parent_id,
|
|
160
|
-
"children": children,
|
|
161
|
-
}
|
|
162
|
-
|
|
163
|
-
return {"category_tree": build_dict(self._root)["children"]}
|
|
114
|
+
new_node = Node(**node_data)
|
|
115
|
+
parent.children[name] = new_node
|
|
116
|
+
self._all_nodes[name] = new_node
|
|
117
|
+
|
|
118
|
+
def remove_node(self, name: str) -> None:
|
|
119
|
+
if name == "root":
|
|
120
|
+
raise ValueError("Cannot remove the root node")
|
|
121
|
+
|
|
122
|
+
node = self.get_node(name)
|
|
123
|
+
if not node:
|
|
124
|
+
raise ValueError(f"Category: '{name}' not found")
|
|
125
|
+
|
|
126
|
+
for child_name in list(node.children.keys()):
|
|
127
|
+
self.remove_node(child_name)
|
|
128
|
+
|
|
129
|
+
if node.parent:
|
|
130
|
+
del node.parent.children[name]
|
|
131
|
+
|
|
132
|
+
del self._all_nodes[name]
|
|
164
133
|
|
|
165
134
|
|
|
166
135
|
# This function is needed to create CategorizerOutput with dynamic categories
|
|
@@ -5,8 +5,8 @@ import random
|
|
|
5
5
|
|
|
6
6
|
class OperatorUtils:
|
|
7
7
|
@staticmethod
|
|
8
|
-
def build_user_message(prompt: str) -> dict[str, str]:
|
|
9
|
-
return {"role": "user", "content": prompt}
|
|
8
|
+
def build_user_message(prompt: str) -> list[dict[str, str]]:
|
|
9
|
+
return [{"role": "user", "content": prompt}]
|
|
10
10
|
|
|
11
11
|
@staticmethod
|
|
12
12
|
def extract_logprobs(completion: dict) -> list[dict]:
|
|
@@ -52,27 +52,3 @@ class OperatorUtils:
|
|
|
52
52
|
new_temp = base_temp + delta_temp
|
|
53
53
|
|
|
54
54
|
return max(0.0, min(new_temp, 1.5))
|
|
55
|
-
|
|
56
|
-
@staticmethod
|
|
57
|
-
def user_merge_format(messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
|
58
|
-
"""
|
|
59
|
-
Merges consecutive user messages into a single message, separated by newlines.
|
|
60
|
-
|
|
61
|
-
This is useful for condensing a multi-turn user input into a single
|
|
62
|
-
message for the LLM. Assistant and system messages are left unchanged and
|
|
63
|
-
act as separators between user message groups.
|
|
64
|
-
"""
|
|
65
|
-
merged = []
|
|
66
|
-
|
|
67
|
-
for message in messages:
|
|
68
|
-
role, content = message["role"], message["content"].strip()
|
|
69
|
-
|
|
70
|
-
# Merge with previous user turn
|
|
71
|
-
if merged and role == "user" and merged[-1]["role"] == "user":
|
|
72
|
-
merged[-1]["content"] += "\n" + content
|
|
73
|
-
|
|
74
|
-
# Otherwise, start a new turn
|
|
75
|
-
else:
|
|
76
|
-
merged.append({"role": role, "content": content})
|
|
77
|
-
|
|
78
|
-
return merged
|