hamtaa-texttools 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.2.dist-info}/METADATA +8 -38
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.2.dist-info}/RECORD +11 -14
- texttools/__init__.py +1 -3
- texttools/core/engine.py +21 -23
- texttools/core/internal_models.py +7 -3
- texttools/core/operators/async_operator.py +1 -3
- texttools/core/operators/sync_operator.py +1 -3
- texttools/batch/config.py +0 -40
- texttools/batch/manager.py +0 -228
- texttools/batch/runner.py +0 -228
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.2.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.2.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.2.dist-info}/top_level.txt +0 -0
- /texttools/{batch → core/operators}/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hamtaa-texttools
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.2
|
|
4
4
|
Summary: A high-level NLP toolkit built on top of modern LLMs.
|
|
5
5
|
Author-email: Tohidi <the.mohammad.tohidi@gmail.com>, Erfan Moosavi <erfanmoosavi84@gmail.com>, Montazer <montazerh82@gmail.com>, Givechi <mohamad.m.givechi@gmail.com>, Zareshahi <a.zareshahi1377@gmail.com>
|
|
6
6
|
Maintainer-email: Erfan Moosavi <erfanmoosavi84@gmail.com>, Tohidi <the.mohammad.tohidi@gmail.com>
|
|
@@ -21,6 +21,9 @@ Dynamic: license-file
|
|
|
21
21
|
|
|
22
22
|
# TextTools
|
|
23
23
|
|
|
24
|
+

|
|
25
|
+

|
|
26
|
+
|
|
24
27
|
## 📌 Overview
|
|
25
28
|
|
|
26
29
|
**TextTools** is a high-level **NLP toolkit** built on top of **LLMs**.
|
|
@@ -44,11 +47,11 @@ Each tool is designed to work with structured outputs.
|
|
|
44
47
|
- **`is_question()`** - Binary question detection
|
|
45
48
|
- **`text_to_question()`** - Generates questions from text
|
|
46
49
|
- **`merge_questions()`** - Merges multiple questions into one
|
|
47
|
-
- **`rewrite()`** - Rewrites text in a
|
|
48
|
-
- **`subject_to_question()`** - Generates questions about a
|
|
50
|
+
- **`rewrite()`** - Rewrites text in a different way
|
|
51
|
+
- **`subject_to_question()`** - Generates questions about a given subject
|
|
49
52
|
- **`summarize()`** - Text summarization
|
|
50
53
|
- **`translate()`** - Text translation
|
|
51
|
-
- **`propositionize()`** - Convert text to atomic
|
|
54
|
+
- **`propositionize()`** - Convert text to atomic independent meaningful sentences
|
|
52
55
|
- **`check_fact()`** - Check whether a statement is relevant to the source text
|
|
53
56
|
- **`run_custom()`** - Allows users to define a custom tool with an arbitrary BaseModel
|
|
54
57
|
|
|
@@ -66,7 +69,7 @@ pip install -U hamtaa-texttools
|
|
|
66
69
|
|
|
67
70
|
## 📊 Tool Quality Tiers
|
|
68
71
|
|
|
69
|
-
| Status | Meaning | Tools |
|
|
72
|
+
| Status | Meaning | Tools | Safe for Production? |
|
|
70
73
|
|--------|---------|----------|-------------------|
|
|
71
74
|
| **✅ Production** | Evaluated, tested, stable. | `categorize()` (list mode), `extract_keywords()`, `extract_entities()`, `is_question()`, `text_to_question()`, `merge_questions()`, `rewrite()`, `subject_to_question()`, `summarize()`, `run_custom()` | **Yes** - ready for reliable use. |
|
|
72
75
|
| **🧪 Experimental** | Added to the package but **not fully evaluated**. Functional, but quality may vary. | `categorize()` (tree mode), `translate()`, `propositionize()`, `check_fact()` | **Use with caution** - outputs not yet validated. |
|
|
@@ -177,40 +180,7 @@ Use **TextTools** when you need to:
|
|
|
177
180
|
|
|
178
181
|
---
|
|
179
182
|
|
|
180
|
-
## 📚 Batch Processing
|
|
181
|
-
|
|
182
|
-
Process large datasets efficiently using OpenAI's batch API.
|
|
183
|
-
|
|
184
|
-
## ⚡ Quick Start (Batch Runner)
|
|
185
|
-
|
|
186
|
-
```python
|
|
187
|
-
from pydantic import BaseModel
|
|
188
|
-
from texttools import BatchRunner, BatchConfig
|
|
189
|
-
|
|
190
|
-
config = BatchConfig(
|
|
191
|
-
system_prompt="Extract entities from the text",
|
|
192
|
-
job_name="entity_extraction",
|
|
193
|
-
input_data_path="data.json",
|
|
194
|
-
output_data_filename="results.json",
|
|
195
|
-
model="gpt-4o-mini"
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
class Output(BaseModel):
|
|
199
|
-
entities: list[str]
|
|
200
|
-
|
|
201
|
-
runner = BatchRunner(config, output_model=Output)
|
|
202
|
-
runner.run()
|
|
203
|
-
```
|
|
204
|
-
|
|
205
|
-
---
|
|
206
|
-
|
|
207
183
|
## 🤝 Contributing
|
|
208
184
|
|
|
209
185
|
Contributions are welcome!
|
|
210
186
|
Feel free to **open issues, suggest new features, or submit pull requests**.
|
|
211
|
-
|
|
212
|
-
---
|
|
213
|
-
|
|
214
|
-
## 🌿 License
|
|
215
|
-
|
|
216
|
-
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -1,17 +1,14 @@
|
|
|
1
|
-
hamtaa_texttools-1.3.
|
|
2
|
-
texttools/__init__.py,sha256=
|
|
1
|
+
hamtaa_texttools-1.3.2.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
|
|
2
|
+
texttools/__init__.py,sha256=RK1GAU6pq2lGwFtHdrCX5JkPRHmOLGcmGH67hd_7VAQ,175
|
|
3
3
|
texttools/models.py,sha256=5eT2cSrFq8Xa38kANznV7gbi7lwB2PoDxciLKTpsd6c,2516
|
|
4
4
|
texttools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
texttools/batch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
-
texttools/batch/config.py,sha256=GDDXuhRZ_bOGVwSIlU4tWP247tx1_A7qzLJn7VqDyLU,1050
|
|
7
|
-
texttools/batch/manager.py,sha256=XZtf8UkdClfQlnRKne4nWEcFvdSKE67EamEePKy7jwI,8730
|
|
8
|
-
texttools/batch/runner.py,sha256=9qxXIMfYRXW5SXDqqKtRr61rnQdYZkbCGqKImhSrY6I,9923
|
|
9
5
|
texttools/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
texttools/core/engine.py,sha256=
|
|
6
|
+
texttools/core/engine.py,sha256=AjifrcJl6PeRu1W6nu9zcxySn-1439Ef2La4d7GpNKY,9481
|
|
11
7
|
texttools/core/exceptions.py,sha256=6SDjUL1rmd3ngzD3ytF4LyTRj3bQMSFR9ECrLoqXXHw,395
|
|
12
|
-
texttools/core/internal_models.py,sha256=
|
|
13
|
-
texttools/core/operators/
|
|
14
|
-
texttools/core/operators/
|
|
8
|
+
texttools/core/internal_models.py,sha256=J1qGEO8V0OoX6_-1yxbSmZSR79tJF0ExAIG1QuvH0L0,1734
|
|
9
|
+
texttools/core/operators/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
texttools/core/operators/async_operator.py,sha256=-72YQEGFkbk2uYW6PHkLT4wGxhj2p6Uqy3sJtVa9-rk,6386
|
|
11
|
+
texttools/core/operators/sync_operator.py,sha256=mfXtEOlIAhHo4SHaHRKjGb0Z1T894clv-toUzUcbfpo,6291
|
|
15
12
|
texttools/prompts/categorize.yaml,sha256=42Rp3SgVHaDLKrJ27_uK788LiQud0pOXJthz4r0a40Y,1214
|
|
16
13
|
texttools/prompts/check_fact.yaml,sha256=zWFQDRhEE1ij9wSeeenS9YSTM-bY5zzUaG390zUgmcs,714
|
|
17
14
|
texttools/prompts/extract_entities.yaml,sha256=_zYKHNJDIzVDI_-TnwFCKyMs-XLM5igvmWhvSTc3INQ,637
|
|
@@ -28,7 +25,7 @@ texttools/prompts/translate.yaml,sha256=Dd5bs3O8SI-FlVSwHMYGeEjMmdOWeRlcfBHkhixC
|
|
|
28
25
|
texttools/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
26
|
texttools/tools/async_tools.py,sha256=2suwx8N0aRnowaSOpV6C57AqPlmQe5Z0Yx4E5QIMkmU,46939
|
|
30
27
|
texttools/tools/sync_tools.py,sha256=mEuL-nlbxVW30dPE3hGkAUnYXbul-3gN2Le4CMVFCgU,42528
|
|
31
|
-
hamtaa_texttools-1.3.
|
|
32
|
-
hamtaa_texttools-1.3.
|
|
33
|
-
hamtaa_texttools-1.3.
|
|
34
|
-
hamtaa_texttools-1.3.
|
|
28
|
+
hamtaa_texttools-1.3.2.dist-info/METADATA,sha256=LjhXLwovneW5Ii1DvAYhFT4JR64ar23UyptCvCO6Hpc,7448
|
|
29
|
+
hamtaa_texttools-1.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
+
hamtaa_texttools-1.3.2.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
|
|
31
|
+
hamtaa_texttools-1.3.2.dist-info/RECORD,,
|
texttools/__init__.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
|
-
from .batch.config import BatchConfig
|
|
2
|
-
from .batch.runner import BatchRunner
|
|
3
1
|
from .models import CategoryTree
|
|
4
2
|
from .tools.async_tools import AsyncTheTool
|
|
5
3
|
from .tools.sync_tools import TheTool
|
|
6
4
|
|
|
7
|
-
__all__ = ["TheTool", "AsyncTheTool", "CategoryTree"
|
|
5
|
+
__all__ = ["TheTool", "AsyncTheTool", "CategoryTree"]
|
texttools/core/engine.py
CHANGED
|
@@ -4,6 +4,7 @@ import random
|
|
|
4
4
|
import re
|
|
5
5
|
from functools import lru_cache
|
|
6
6
|
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import yaml
|
|
9
10
|
|
|
@@ -20,9 +21,6 @@ class PromptLoader:
|
|
|
20
21
|
|
|
21
22
|
@lru_cache(maxsize=32)
|
|
22
23
|
def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
|
|
23
|
-
"""
|
|
24
|
-
Loads prompt templates from YAML file with optional mode selection.
|
|
25
|
-
"""
|
|
26
24
|
try:
|
|
27
25
|
base_dir = Path(__file__).parent.parent / Path("prompts")
|
|
28
26
|
prompt_path = base_dir / prompt_file
|
|
@@ -73,13 +71,12 @@ class PromptLoader:
|
|
|
73
71
|
self, prompt_file: str, text: str, mode: str, **extra_kwargs
|
|
74
72
|
) -> dict[str, str]:
|
|
75
73
|
try:
|
|
76
|
-
template_configs = self._load_templates(prompt_file, mode)
|
|
77
74
|
format_args = {"text": text}
|
|
78
75
|
format_args.update(extra_kwargs)
|
|
79
76
|
|
|
80
|
-
|
|
81
|
-
for key in template_configs.
|
|
82
|
-
template_configs[key] =
|
|
77
|
+
template_configs = self._load_templates(prompt_file, mode)
|
|
78
|
+
for key, value in template_configs.items():
|
|
79
|
+
template_configs[key] = value.format(**format_args)
|
|
83
80
|
|
|
84
81
|
return template_configs
|
|
85
82
|
|
|
@@ -97,30 +94,27 @@ class OperatorUtils:
|
|
|
97
94
|
output_lang: str | None,
|
|
98
95
|
user_prompt: str | None,
|
|
99
96
|
) -> str:
|
|
100
|
-
|
|
97
|
+
parts = []
|
|
101
98
|
|
|
102
99
|
if analysis:
|
|
103
|
-
|
|
104
|
-
|
|
100
|
+
parts.append(f"Based on this analysis: {analysis}")
|
|
105
101
|
if output_lang:
|
|
106
|
-
|
|
107
|
-
|
|
102
|
+
parts.append(f"Respond only in the {output_lang} language.")
|
|
108
103
|
if user_prompt:
|
|
109
|
-
|
|
104
|
+
parts.append(f"Consider this instruction: {user_prompt}")
|
|
110
105
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
return main_prompt
|
|
106
|
+
parts.append(main_template)
|
|
107
|
+
return "\n".join(parts)
|
|
114
108
|
|
|
115
109
|
@staticmethod
|
|
116
110
|
def build_message(prompt: str) -> list[dict[str, str]]:
|
|
117
111
|
return [{"role": "user", "content": prompt}]
|
|
118
112
|
|
|
119
113
|
@staticmethod
|
|
120
|
-
def extract_logprobs(completion:
|
|
114
|
+
def extract_logprobs(completion: Any) -> list[dict]:
|
|
121
115
|
"""
|
|
122
|
-
Extracts and filters
|
|
123
|
-
Skips punctuation and structural tokens
|
|
116
|
+
Extracts and filters logprobs from completion.
|
|
117
|
+
Skips punctuation and structural tokens.
|
|
124
118
|
"""
|
|
125
119
|
logprobs_data = []
|
|
126
120
|
|
|
@@ -153,16 +147,17 @@ class OperatorUtils:
|
|
|
153
147
|
|
|
154
148
|
@staticmethod
|
|
155
149
|
def get_retry_temp(base_temp: float) -> float:
|
|
156
|
-
|
|
157
|
-
new_temp = base_temp + delta_temp
|
|
158
|
-
|
|
150
|
+
new_temp = base_temp + random.choice([-1, 1]) * random.uniform(0.1, 0.9)
|
|
159
151
|
return max(0.0, min(new_temp, 1.5))
|
|
160
152
|
|
|
161
153
|
|
|
162
154
|
def text_to_chunks(text: str, size: int, overlap: int) -> list[str]:
|
|
155
|
+
"""
|
|
156
|
+
Utility for chunking large texts. Used for translation tool
|
|
157
|
+
"""
|
|
163
158
|
separators = ["\n\n", "\n", " ", ""]
|
|
164
159
|
is_separator_regex = False
|
|
165
|
-
keep_separator = True
|
|
160
|
+
keep_separator = True
|
|
166
161
|
length_function = len
|
|
167
162
|
strip_whitespace = True
|
|
168
163
|
chunk_size = size
|
|
@@ -256,6 +251,9 @@ def text_to_chunks(text: str, size: int, overlap: int) -> list[str]:
|
|
|
256
251
|
|
|
257
252
|
|
|
258
253
|
async def run_with_timeout(coro, timeout: float | None):
|
|
254
|
+
"""
|
|
255
|
+
Utility for timeout logic defined in AsyncTheTool
|
|
256
|
+
"""
|
|
259
257
|
if timeout is None:
|
|
260
258
|
return await coro
|
|
261
259
|
try:
|
|
@@ -21,7 +21,9 @@ class Bool(BaseModel):
|
|
|
21
21
|
|
|
22
22
|
class ListStr(BaseModel):
|
|
23
23
|
result: list[str] = Field(
|
|
24
|
-
...,
|
|
24
|
+
...,
|
|
25
|
+
description="The output list of strings",
|
|
26
|
+
example=["text_1", "text_2", "text_3"],
|
|
25
27
|
)
|
|
26
28
|
|
|
27
29
|
|
|
@@ -36,11 +38,13 @@ class ListDictStrStr(BaseModel):
|
|
|
36
38
|
class ReasonListStr(BaseModel):
|
|
37
39
|
reason: str = Field(..., description="Thinking process that led to the output")
|
|
38
40
|
result: list[str] = Field(
|
|
39
|
-
...,
|
|
41
|
+
...,
|
|
42
|
+
description="The output list of strings",
|
|
43
|
+
example=["text_1", "text_2", "text_3"],
|
|
40
44
|
)
|
|
41
45
|
|
|
42
46
|
|
|
43
|
-
#
|
|
47
|
+
# Create CategorizerOutput with dynamic categories
|
|
44
48
|
def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
|
|
45
49
|
literal_type = Literal[*allowed_values]
|
|
46
50
|
|
|
@@ -54,7 +54,7 @@ class AsyncOperator:
|
|
|
54
54
|
) -> tuple[T, Any]:
|
|
55
55
|
"""
|
|
56
56
|
Parses a chat completion using OpenAI's structured output format.
|
|
57
|
-
Returns both the parsed
|
|
57
|
+
Returns both the parsed and the completion for logprobs.
|
|
58
58
|
"""
|
|
59
59
|
try:
|
|
60
60
|
request_kwargs = {
|
|
@@ -92,7 +92,6 @@ class AsyncOperator:
|
|
|
92
92
|
|
|
93
93
|
async def run(
|
|
94
94
|
self,
|
|
95
|
-
# User parameters
|
|
96
95
|
text: str,
|
|
97
96
|
with_analysis: bool,
|
|
98
97
|
output_lang: str | None,
|
|
@@ -103,7 +102,6 @@ class AsyncOperator:
|
|
|
103
102
|
validator: Callable[[Any], bool] | None,
|
|
104
103
|
max_validation_retries: int | None,
|
|
105
104
|
priority: int | None,
|
|
106
|
-
# Internal parameters
|
|
107
105
|
tool_name: str,
|
|
108
106
|
output_model: Type[T],
|
|
109
107
|
mode: str | None,
|
|
@@ -54,7 +54,7 @@ class Operator:
|
|
|
54
54
|
) -> tuple[T, Any]:
|
|
55
55
|
"""
|
|
56
56
|
Parses a chat completion using OpenAI's structured output format.
|
|
57
|
-
Returns both the parsed
|
|
57
|
+
Returns both the parsed and the completion for logprobs.
|
|
58
58
|
"""
|
|
59
59
|
try:
|
|
60
60
|
request_kwargs = {
|
|
@@ -90,7 +90,6 @@ class Operator:
|
|
|
90
90
|
|
|
91
91
|
def run(
|
|
92
92
|
self,
|
|
93
|
-
# User parameters
|
|
94
93
|
text: str,
|
|
95
94
|
with_analysis: bool,
|
|
96
95
|
output_lang: str | None,
|
|
@@ -101,7 +100,6 @@ class Operator:
|
|
|
101
100
|
validator: Callable[[Any], bool] | None,
|
|
102
101
|
max_validation_retries: int | None,
|
|
103
102
|
priority: int | None,
|
|
104
|
-
# Internal parameters
|
|
105
103
|
tool_name: str,
|
|
106
104
|
output_model: Type[T],
|
|
107
105
|
mode: str | None,
|
texttools/batch/config.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def export_data(data) -> list[dict[str, str]]:
|
|
7
|
-
"""
|
|
8
|
-
Produces a structure of the following form from an initial data structure:
|
|
9
|
-
[{"id": str, "text": str},...]
|
|
10
|
-
"""
|
|
11
|
-
return data
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def import_data(data) -> Any:
|
|
15
|
-
"""
|
|
16
|
-
Takes the output and adds and aggregates it to the original structure.
|
|
17
|
-
"""
|
|
18
|
-
return data
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@dataclass
|
|
22
|
-
class BatchConfig:
|
|
23
|
-
"""
|
|
24
|
-
Configuration for batch job runner.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
system_prompt: str = ""
|
|
28
|
-
job_name: str = ""
|
|
29
|
-
input_data_path: str = ""
|
|
30
|
-
output_data_filename: str = ""
|
|
31
|
-
model: str = "gpt-4.1-mini"
|
|
32
|
-
MAX_BATCH_SIZE: int = 100
|
|
33
|
-
MAX_TOTAL_TOKENS: int = 2_000_000
|
|
34
|
-
CHARS_PER_TOKEN: float = 2.7
|
|
35
|
-
PROMPT_TOKEN_MULTIPLIER: int = 1_000
|
|
36
|
-
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
37
|
-
import_function: Callable = import_data
|
|
38
|
-
export_function: Callable = export_data
|
|
39
|
-
poll_interval_seconds: int = 30
|
|
40
|
-
max_retries: int = 3
|
texttools/batch/manager.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import uuid
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Any, Type, TypeVar
|
|
6
|
-
|
|
7
|
-
from openai import OpenAI
|
|
8
|
-
from openai.lib._pydantic import to_strict_json_schema
|
|
9
|
-
from pydantic import BaseModel
|
|
10
|
-
|
|
11
|
-
# Base Model type for output models
|
|
12
|
-
T = TypeVar("T", bound=BaseModel)
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger("texttools.batch_manager")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class BatchManager:
|
|
18
|
-
"""
|
|
19
|
-
Manages batch processing jobs for OpenAI's chat completions with structured outputs.
|
|
20
|
-
|
|
21
|
-
Handles the full lifecycle of a batch job: creating tasks from input texts,
|
|
22
|
-
starting the job, monitoring status, and fetching results. Results are automatically
|
|
23
|
-
parsed into the specified Pydantic output model. Job state is persisted to disk.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
client: OpenAI,
|
|
29
|
-
model: str,
|
|
30
|
-
output_model: Type[T],
|
|
31
|
-
prompt_template: str,
|
|
32
|
-
state_dir: Path = Path(".batch_jobs"),
|
|
33
|
-
custom_json_schema_obj_str: dict | None = None,
|
|
34
|
-
**client_kwargs: Any,
|
|
35
|
-
):
|
|
36
|
-
self._client = client
|
|
37
|
-
self._model = model
|
|
38
|
-
self._output_model = output_model
|
|
39
|
-
self._prompt_template = prompt_template
|
|
40
|
-
self._state_dir = state_dir
|
|
41
|
-
self._custom_json_schema_obj_str = custom_json_schema_obj_str
|
|
42
|
-
self._client_kwargs = client_kwargs
|
|
43
|
-
self._dict_input = False
|
|
44
|
-
self._state_dir.mkdir(parents=True, exist_ok=True)
|
|
45
|
-
|
|
46
|
-
if custom_json_schema_obj_str and not isinstance(
|
|
47
|
-
custom_json_schema_obj_str, dict
|
|
48
|
-
):
|
|
49
|
-
raise ValueError("Schema should be a dict")
|
|
50
|
-
|
|
51
|
-
def _state_file(self, job_name: str) -> Path:
|
|
52
|
-
return self._state_dir / f"{job_name}.json"
|
|
53
|
-
|
|
54
|
-
def _load_state(self, job_name: str) -> list[dict[str, Any]]:
|
|
55
|
-
"""
|
|
56
|
-
Loads the state (job information) from the state file for the given job name.
|
|
57
|
-
Returns an empty list if the state file does not exist.
|
|
58
|
-
"""
|
|
59
|
-
path = self._state_file(job_name)
|
|
60
|
-
if path.exists():
|
|
61
|
-
with open(path, "r", encoding="utf-8") as f:
|
|
62
|
-
return json.load(f)
|
|
63
|
-
return []
|
|
64
|
-
|
|
65
|
-
def _save_state(self, job_name: str, jobs: list[dict[str, Any]]) -> None:
|
|
66
|
-
"""
|
|
67
|
-
Saves the job state to the state file for the given job name.
|
|
68
|
-
"""
|
|
69
|
-
with open(self._state_file(job_name), "w", encoding="utf-8") as f:
|
|
70
|
-
json.dump(jobs, f)
|
|
71
|
-
|
|
72
|
-
def _clear_state(self, job_name: str) -> None:
|
|
73
|
-
"""
|
|
74
|
-
Deletes the state file for the given job name if it exists.
|
|
75
|
-
"""
|
|
76
|
-
path = self._state_file(job_name)
|
|
77
|
-
if path.exists():
|
|
78
|
-
path.unlink()
|
|
79
|
-
|
|
80
|
-
def _build_task(self, text: str, idx: str) -> dict[str, Any]:
|
|
81
|
-
"""
|
|
82
|
-
Builds a single task dictionary for the batch job, including the prompt, model, and response format configuration.
|
|
83
|
-
"""
|
|
84
|
-
response_format_config: dict[str, Any]
|
|
85
|
-
|
|
86
|
-
if self._custom_json_schema_obj_str:
|
|
87
|
-
response_format_config = {
|
|
88
|
-
"type": "json_schema",
|
|
89
|
-
"json_schema": self._custom_json_schema_obj_str,
|
|
90
|
-
}
|
|
91
|
-
else:
|
|
92
|
-
raw_schema = to_strict_json_schema(self._output_model)
|
|
93
|
-
response_format_config = {
|
|
94
|
-
"type": "json_schema",
|
|
95
|
-
"json_schema": {
|
|
96
|
-
"name": self._output_model.__name__,
|
|
97
|
-
"schema": raw_schema,
|
|
98
|
-
},
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
return {
|
|
102
|
-
"custom_id": str(idx),
|
|
103
|
-
"method": "POST",
|
|
104
|
-
"url": "/v1/chat/completions",
|
|
105
|
-
"body": {
|
|
106
|
-
"model": self.model,
|
|
107
|
-
"messages": [
|
|
108
|
-
{"role": "system", "content": self._prompt_template},
|
|
109
|
-
{"role": "user", "content": text},
|
|
110
|
-
],
|
|
111
|
-
"response_format": response_format_config,
|
|
112
|
-
**self._client_kwargs,
|
|
113
|
-
},
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
def _prepare_file(self, payload: list[str] | list[dict[str, str]]) -> Path:
|
|
117
|
-
"""
|
|
118
|
-
Prepares a JSONL file containing all tasks for the batch job, based on the input payload.
|
|
119
|
-
Returns the path to the created file.
|
|
120
|
-
"""
|
|
121
|
-
if not payload:
|
|
122
|
-
raise ValueError("Payload must not be empty")
|
|
123
|
-
if isinstance(payload[0], str):
|
|
124
|
-
tasks = [self._build_task(text, uuid.uuid4().hex) for text in payload]
|
|
125
|
-
elif isinstance(payload[0], dict):
|
|
126
|
-
tasks = [self._build_task(dic["text"], dic["id"]) for dic in payload]
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
raise TypeError(
|
|
130
|
-
"The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}"
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
file_path = self._state_dir / f"batch_{uuid.uuid4().hex}.jsonl"
|
|
134
|
-
with open(file_path, "w", encoding="utf-8") as f:
|
|
135
|
-
for task in tasks:
|
|
136
|
-
f.write(json.dumps(task) + "\n")
|
|
137
|
-
return file_path
|
|
138
|
-
|
|
139
|
-
def start(self, payload: list[str | dict[str, str]], job_name: str):
|
|
140
|
-
"""
|
|
141
|
-
Starts a new batch job by uploading the prepared file and creating a batch job on the server.
|
|
142
|
-
If a job with the same name already exists, it does nothing.
|
|
143
|
-
"""
|
|
144
|
-
if self._load_state(job_name):
|
|
145
|
-
return
|
|
146
|
-
|
|
147
|
-
path = self._prepare_file(payload)
|
|
148
|
-
upload = self._client.files.create(file=open(path, "rb"), purpose="batch")
|
|
149
|
-
job = self._client.batches.create(
|
|
150
|
-
input_file_id=upload.id,
|
|
151
|
-
endpoint="/v1/chat/completions",
|
|
152
|
-
completion_window="24h",
|
|
153
|
-
).to_dict()
|
|
154
|
-
self._save_state(job_name, [job])
|
|
155
|
-
|
|
156
|
-
def check_status(self, job_name: str) -> str:
|
|
157
|
-
"""
|
|
158
|
-
Checks and returns the current status of the batch job with the given job name.
|
|
159
|
-
Updates the job state with the latest information from the server.
|
|
160
|
-
"""
|
|
161
|
-
job = self._load_state(job_name)[0]
|
|
162
|
-
if not job:
|
|
163
|
-
return "completed"
|
|
164
|
-
|
|
165
|
-
info = self._client.batches.retrieve(job["id"])
|
|
166
|
-
job = info.to_dict()
|
|
167
|
-
self._save_state(job_name, [job])
|
|
168
|
-
logger.info("Batch job status: %s", job)
|
|
169
|
-
return job["status"]
|
|
170
|
-
|
|
171
|
-
def fetch_results(
|
|
172
|
-
self, job_name: str, remove_cache: bool = True
|
|
173
|
-
) -> tuple[dict[str, str], list]:
|
|
174
|
-
"""
|
|
175
|
-
Fetches the results of a completed batch job. Optionally saves the results to a file and/or removes the job cache.
|
|
176
|
-
Returns a tuple containing the parsed results and a log of errors (if any).
|
|
177
|
-
"""
|
|
178
|
-
job = self._load_state(job_name)[0]
|
|
179
|
-
if not job:
|
|
180
|
-
return {}
|
|
181
|
-
batch_id = job["id"]
|
|
182
|
-
|
|
183
|
-
info = self._client.batches.retrieve(batch_id)
|
|
184
|
-
out_file_id = info.output_file_id
|
|
185
|
-
if not out_file_id:
|
|
186
|
-
error_file_id = info.error_file_id
|
|
187
|
-
if error_file_id:
|
|
188
|
-
err_content = (
|
|
189
|
-
self._client.files.content(error_file_id).read().decode("utf-8")
|
|
190
|
-
)
|
|
191
|
-
logger.error("Error file content:", err_content)
|
|
192
|
-
return {}
|
|
193
|
-
|
|
194
|
-
content = self._client.files.content(out_file_id).read().decode("utf-8")
|
|
195
|
-
lines = content.splitlines()
|
|
196
|
-
results = {}
|
|
197
|
-
log = []
|
|
198
|
-
for line in lines:
|
|
199
|
-
result = json.loads(line)
|
|
200
|
-
custom_id = result["custom_id"]
|
|
201
|
-
if result["response"]["status_code"] == 200:
|
|
202
|
-
content = result["response"]["body"]["choices"][0]["message"]["content"]
|
|
203
|
-
try:
|
|
204
|
-
parsed_content = json.loads(content)
|
|
205
|
-
model_instance = self._output_model(**parsed_content)
|
|
206
|
-
results[custom_id] = model_instance.model_dump(mode="json")
|
|
207
|
-
except json.JSONDecodeError:
|
|
208
|
-
results[custom_id] = {"error": "Failed to parse content as JSON"}
|
|
209
|
-
error_d = {custom_id: results[custom_id]}
|
|
210
|
-
log.append(error_d)
|
|
211
|
-
except Exception as e:
|
|
212
|
-
results[custom_id] = {"error": str(e)}
|
|
213
|
-
error_d = {custom_id: results[custom_id]}
|
|
214
|
-
log.append(error_d)
|
|
215
|
-
else:
|
|
216
|
-
error_message = (
|
|
217
|
-
result["response"]["body"]
|
|
218
|
-
.get("error", {})
|
|
219
|
-
.get("message", "Unknown error")
|
|
220
|
-
)
|
|
221
|
-
results[custom_id] = {"error": error_message}
|
|
222
|
-
error_d = {custom_id: results[custom_id]}
|
|
223
|
-
log.append(error_d)
|
|
224
|
-
|
|
225
|
-
if remove_cache:
|
|
226
|
-
self._clear_state(job_name)
|
|
227
|
-
|
|
228
|
-
return results, log
|
texttools/batch/runner.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import os
|
|
4
|
-
import time
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any, Type, TypeVar
|
|
7
|
-
|
|
8
|
-
from dotenv import load_dotenv
|
|
9
|
-
from openai import OpenAI
|
|
10
|
-
from pydantic import BaseModel
|
|
11
|
-
|
|
12
|
-
from ..core.exceptions import TextToolsError
|
|
13
|
-
from ..core.internal_models import Str
|
|
14
|
-
from .config import BatchConfig
|
|
15
|
-
from .manager import BatchManager
|
|
16
|
-
|
|
17
|
-
# Base Model type for output models
|
|
18
|
-
T = TypeVar("T", bound=BaseModel)
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger("texttools.batch_runner")
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class BatchRunner:
|
|
24
|
-
"""
|
|
25
|
-
Handles running batch jobs using a batch manager and configuration.
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
self, config: BatchConfig = BatchConfig(), output_model: Type[T] = Str
|
|
30
|
-
):
|
|
31
|
-
try:
|
|
32
|
-
self._config = config
|
|
33
|
-
self._system_prompt = config.system_prompt
|
|
34
|
-
self._job_name = config.job_name
|
|
35
|
-
self._input_data_path = config.input_data_path
|
|
36
|
-
self._output_data_filename = config.output_data_filename
|
|
37
|
-
self._model = config.model
|
|
38
|
-
self._output_model = output_model
|
|
39
|
-
self._manager = self._init_manager()
|
|
40
|
-
self._data = self._load_data()
|
|
41
|
-
self._parts: list[list[dict[str, Any]]] = []
|
|
42
|
-
# Map part index to job name
|
|
43
|
-
self._part_idx_to_job_name: dict[int, str] = {}
|
|
44
|
-
# Track retry attempts per part
|
|
45
|
-
self._part_attempts: dict[int, int] = {}
|
|
46
|
-
self._partition_data()
|
|
47
|
-
Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
48
|
-
|
|
49
|
-
except Exception as e:
|
|
50
|
-
raise TextToolsError(f"Batch runner initialization failed: {e}")
|
|
51
|
-
|
|
52
|
-
def _init_manager(self) -> BatchManager:
|
|
53
|
-
load_dotenv()
|
|
54
|
-
api_key = os.getenv("OPENAI_API_KEY")
|
|
55
|
-
client = OpenAI(api_key=api_key)
|
|
56
|
-
return BatchManager(
|
|
57
|
-
client=client,
|
|
58
|
-
model=self._model,
|
|
59
|
-
prompt_template=self._system_prompt,
|
|
60
|
-
output_model=self._output_model,
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
def _load_data(self):
|
|
64
|
-
with open(self._input_data_path, "r", encoding="utf-8") as f:
|
|
65
|
-
data = json.load(f)
|
|
66
|
-
data = self._config.export_function(data)
|
|
67
|
-
|
|
68
|
-
# Ensure data is a list of dicts with 'id' and 'content' as strings
|
|
69
|
-
if not isinstance(data, list):
|
|
70
|
-
raise ValueError(
|
|
71
|
-
"Exported data must be a list of dicts with 'id' and 'content' keys"
|
|
72
|
-
)
|
|
73
|
-
for item in data:
|
|
74
|
-
if not (isinstance(item, dict) and "id" in item and "content" in item):
|
|
75
|
-
raise ValueError(
|
|
76
|
-
f"Item must be a dict with 'id' and 'content' keys. Got: {type(item)}"
|
|
77
|
-
)
|
|
78
|
-
if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
|
|
79
|
-
raise ValueError("'id' and 'content' must be strings.")
|
|
80
|
-
return data
|
|
81
|
-
|
|
82
|
-
def _partition_data(self):
|
|
83
|
-
total_length = sum(len(item["content"]) for item in self._data)
|
|
84
|
-
prompt_length = len(self._system_prompt)
|
|
85
|
-
total = total_length + (prompt_length * len(self._data))
|
|
86
|
-
calculation = total / self._config.CHARS_PER_TOKEN
|
|
87
|
-
logger.info(
|
|
88
|
-
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
89
|
-
)
|
|
90
|
-
if calculation < self._config.MAX_TOTAL_TOKENS:
|
|
91
|
-
self._parts = [self._data]
|
|
92
|
-
else:
|
|
93
|
-
# Partition into chunks of MAX_BATCH_SIZE
|
|
94
|
-
self._parts = [
|
|
95
|
-
self._data[i : i + self._config.MAX_BATCH_SIZE]
|
|
96
|
-
for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
|
|
97
|
-
]
|
|
98
|
-
logger.info(f"Data split into {len(self._parts)} part(s)")
|
|
99
|
-
|
|
100
|
-
def _submit_all_jobs(self) -> None:
|
|
101
|
-
for idx, part in enumerate(self._parts):
|
|
102
|
-
if self._result_exists(idx):
|
|
103
|
-
logger.info(f"Skipping part {idx + 1}: result already exists.")
|
|
104
|
-
continue
|
|
105
|
-
part_job_name = (
|
|
106
|
-
f"{self._job_name}_part_{idx + 1}"
|
|
107
|
-
if len(self._parts) > 1
|
|
108
|
-
else self._job_name
|
|
109
|
-
)
|
|
110
|
-
# If a job with this name already exists, register and skip submitting
|
|
111
|
-
existing_job = self._manager._load_state(part_job_name)
|
|
112
|
-
if existing_job:
|
|
113
|
-
logger.info(
|
|
114
|
-
f"Skipping part {idx + 1}: job already exists ({part_job_name})."
|
|
115
|
-
)
|
|
116
|
-
self._part_idx_to_job_name[idx] = part_job_name
|
|
117
|
-
self._part_attempts.setdefault(idx, 0)
|
|
118
|
-
continue
|
|
119
|
-
|
|
120
|
-
payload = part
|
|
121
|
-
logger.info(
|
|
122
|
-
f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
|
|
123
|
-
)
|
|
124
|
-
self._manager.start(payload, job_name=part_job_name)
|
|
125
|
-
self._part_idx_to_job_name[idx] = part_job_name
|
|
126
|
-
self._part_attempts.setdefault(idx, 0)
|
|
127
|
-
# This is added for letting file get uploaded, before starting the next part.
|
|
128
|
-
logger.info("Uploading...")
|
|
129
|
-
time.sleep(30)
|
|
130
|
-
|
|
131
|
-
def _save_results(
|
|
132
|
-
self,
|
|
133
|
-
output_data: list[dict[str, Any]] | dict[str, Any],
|
|
134
|
-
log: list[Any],
|
|
135
|
-
part_idx: int,
|
|
136
|
-
):
|
|
137
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
138
|
-
result_path = (
|
|
139
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
140
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
141
|
-
)
|
|
142
|
-
if not output_data:
|
|
143
|
-
logger.info("No output data to save. Skipping this part.")
|
|
144
|
-
return
|
|
145
|
-
else:
|
|
146
|
-
with open(result_path, "w", encoding="utf-8") as f:
|
|
147
|
-
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
148
|
-
if log:
|
|
149
|
-
log_path = (
|
|
150
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
151
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
|
|
152
|
-
)
|
|
153
|
-
with open(log_path, "w", encoding="utf-8") as f:
|
|
154
|
-
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
155
|
-
|
|
156
|
-
def _result_exists(self, part_idx: int) -> bool:
|
|
157
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
158
|
-
result_path = (
|
|
159
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
160
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
161
|
-
)
|
|
162
|
-
return result_path.exists()
|
|
163
|
-
|
|
164
|
-
def run(self):
|
|
165
|
-
"""
|
|
166
|
-
Execute the batch job processing pipeline.
|
|
167
|
-
|
|
168
|
-
Submits jobs, monitors progress, handles retries, and saves results.
|
|
169
|
-
"""
|
|
170
|
-
try:
|
|
171
|
-
# Submit all jobs up-front for concurrent execution
|
|
172
|
-
self._submit_all_jobs()
|
|
173
|
-
pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
|
|
174
|
-
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
175
|
-
# Polling loop
|
|
176
|
-
while pending_parts:
|
|
177
|
-
finished_this_round: list[int] = []
|
|
178
|
-
for part_idx in list(pending_parts):
|
|
179
|
-
job_name = self._part_idx_to_job_name[part_idx]
|
|
180
|
-
status = self._manager.check_status(job_name=job_name)
|
|
181
|
-
logger.info(f"Status for {job_name}: {status}")
|
|
182
|
-
if status == "completed":
|
|
183
|
-
logger.info(
|
|
184
|
-
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
185
|
-
)
|
|
186
|
-
output_data, log = self._manager.fetch_results(
|
|
187
|
-
job_name=job_name, remove_cache=False
|
|
188
|
-
)
|
|
189
|
-
output_data = self._config.import_function(output_data)
|
|
190
|
-
self._save_results(output_data, log, part_idx)
|
|
191
|
-
logger.info(
|
|
192
|
-
f"Fetched and saved results for part {part_idx + 1}."
|
|
193
|
-
)
|
|
194
|
-
finished_this_round.append(part_idx)
|
|
195
|
-
elif status == "failed":
|
|
196
|
-
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
197
|
-
self._part_attempts[part_idx] = attempt
|
|
198
|
-
if attempt <= self._config.max_retries:
|
|
199
|
-
logger.info(
|
|
200
|
-
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
201
|
-
)
|
|
202
|
-
self._manager._clear_state(job_name)
|
|
203
|
-
time.sleep(10)
|
|
204
|
-
payload = self._to_manager_payload(self._parts[part_idx])
|
|
205
|
-
new_job_name = (
|
|
206
|
-
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
207
|
-
)
|
|
208
|
-
self._manager.start(payload, job_name=new_job_name)
|
|
209
|
-
self._part_idx_to_job_name[part_idx] = new_job_name
|
|
210
|
-
else:
|
|
211
|
-
logger.info(
|
|
212
|
-
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
213
|
-
)
|
|
214
|
-
finished_this_round.append(part_idx)
|
|
215
|
-
else:
|
|
216
|
-
# Still running or queued
|
|
217
|
-
continue
|
|
218
|
-
# Remove finished parts
|
|
219
|
-
for part_idx in finished_this_round:
|
|
220
|
-
pending_parts.discard(part_idx)
|
|
221
|
-
if pending_parts:
|
|
222
|
-
logger.info(
|
|
223
|
-
f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
224
|
-
)
|
|
225
|
-
time.sleep(self._config.poll_interval_seconds)
|
|
226
|
-
|
|
227
|
-
except Exception as e:
|
|
228
|
-
raise TextToolsError(f"Batch job execution failed: {e}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|