hamtaa-texttools 1.1.1__py3-none-any.whl → 1.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/METADATA +98 -26
- hamtaa_texttools-1.1.16.dist-info/RECORD +31 -0
- texttools/__init__.py +6 -8
- texttools/batch/batch_config.py +26 -0
- texttools/batch/batch_runner.py +105 -151
- texttools/batch/{batch_manager.py → internals/batch_manager.py} +39 -40
- texttools/batch/internals/utils.py +16 -0
- texttools/prompts/README.md +4 -4
- texttools/prompts/categorize.yaml +77 -0
- texttools/prompts/detect_entity.yaml +22 -0
- texttools/prompts/extract_keywords.yaml +68 -18
- texttools/tools/async_tools.py +804 -0
- texttools/tools/internals/async_operator.py +90 -69
- texttools/tools/internals/models.py +183 -0
- texttools/tools/internals/operator_utils.py +54 -0
- texttools/tools/internals/prompt_loader.py +13 -14
- texttools/tools/internals/sync_operator.py +201 -0
- texttools/tools/sync_tools.py +804 -0
- hamtaa_texttools-1.1.1.dist-info/RECORD +0 -30
- texttools/batch/__init__.py +0 -4
- texttools/prompts/categorizer.yaml +0 -28
- texttools/tools/__init__.py +0 -4
- texttools/tools/async_the_tool.py +0 -414
- texttools/tools/internals/base_operator.py +0 -91
- texttools/tools/internals/operator.py +0 -179
- texttools/tools/internals/output_models.py +0 -59
- texttools/tools/the_tool.py +0 -412
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hamtaa-texttools
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.16
|
|
4
4
|
Summary: A high-level NLP toolkit built on top of modern LLMs.
|
|
5
5
|
Author-email: Tohidi <the.mohammad.tohidi@gmail.com>, Montazer <montazerh82@gmail.com>, Givechi <mohamad.m.givechi@gmail.com>, MoosaviNejad <erfanmoosavi84@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -29,6 +29,7 @@ Requires-Python: >=3.8
|
|
|
29
29
|
Description-Content-Type: text/markdown
|
|
30
30
|
License-File: LICENSE
|
|
31
31
|
Requires-Dist: openai==1.97.1
|
|
32
|
+
Requires-Dist: pydantic>=2.0.0
|
|
32
33
|
Requires-Dist: pyyaml>=6.0
|
|
33
34
|
Dynamic: license-file
|
|
34
35
|
|
|
@@ -40,50 +41,69 @@ Dynamic: license-file
|
|
|
40
41
|
|
|
41
42
|
It provides both **sync (`TheTool`)** and **async (`AsyncTheTool`)** APIs for maximum flexibility.
|
|
42
43
|
|
|
43
|
-
It provides ready-to-use utilities for **translation, question detection, keyword extraction, categorization, NER
|
|
44
|
+
It provides ready-to-use utilities for **translation, question detection, keyword extraction, categorization, NER extraction, and more** - designed to help you integrate AI-powered text processing into your applications with minimal effort.
|
|
44
45
|
|
|
45
46
|
---
|
|
46
47
|
|
|
47
48
|
## ✨ Features
|
|
48
49
|
|
|
49
|
-
TextTools provides a rich collection of high-level NLP utilities
|
|
50
|
-
Each tool is designed to work
|
|
50
|
+
TextTools provides a rich collection of high-level NLP utilities,
|
|
51
|
+
Each tool is designed to work with structured outputs (JSON / Pydantic).
|
|
51
52
|
|
|
52
|
-
- **`categorize()`** - Classifies text into
|
|
53
|
-
- **`is_question()`** - Binary detection of whether input is a question
|
|
53
|
+
- **`categorize()`** - Classifies text into given categories (You have to create a category tree)
|
|
54
54
|
- **`extract_keywords()`** - Extracts keywords from text
|
|
55
55
|
- **`extract_entities()`** - Named Entity Recognition (NER) system
|
|
56
|
-
- **`
|
|
56
|
+
- **`is_question()`** - Binary detection of whether input is a question
|
|
57
57
|
- **`text_to_question()`** - Generates questions from text
|
|
58
58
|
- **`merge_questions()`** - Merges multiple questions with different modes
|
|
59
59
|
- **`rewrite()`** - Rewrites text with different wording/meaning
|
|
60
60
|
- **`subject_to_question()`** - Generates questions about a specific subject
|
|
61
|
+
- **`summarize()`** - Text summarization
|
|
61
62
|
- **`translate()`** - Text translation between languages
|
|
62
|
-
- **`run_custom()`** - Allows users to define a custom tool with arbitrary BaseModel
|
|
63
|
+
- **`run_custom()`** - Allows users to define a custom tool with an arbitrary BaseModel
|
|
63
64
|
|
|
64
65
|
---
|
|
65
66
|
|
|
66
|
-
## ⚙️ `with_analysis`, `logprobs`, `output_lang`, `user_prompt` and `
|
|
67
|
+
## ⚙️ `with_analysis`, `logprobs`, `output_lang`, `user_prompt`, `temperature`, `validator` and `priority` parameters
|
|
67
68
|
|
|
68
69
|
TextTools provides several optional flags to customize LLM behavior:
|
|
69
70
|
|
|
70
|
-
- **`with_analysis
|
|
71
|
-
Note
|
|
71
|
+
- **`with_analysis (bool)`** → Adds a reasoning step before generating the final output.
|
|
72
|
+
**Note:** This doubles token usage per call because it triggers an additional LLM request.
|
|
72
73
|
|
|
73
|
-
- **`logprobs
|
|
74
|
+
- **`logprobs (bool)`** → Returns token-level probabilities for the generated output. You can also specify `top_logprobs=<N>` to get the top N alternative tokens and their probabilities.
|
|
75
|
+
**Note:** This feature works if it's supported by the model.
|
|
74
76
|
|
|
75
|
-
- **`output_lang
|
|
77
|
+
- **`output_lang (str)`** → Forces the model to respond in a specific language. The model will ignore other instructions about language and respond strictly in the requested language.
|
|
76
78
|
|
|
77
|
-
- **`user_prompt
|
|
79
|
+
- **`user_prompt (str)`** → Allows you to inject a custom instruction or prompt into the model alongside the main template. This gives you fine-grained control over how the model interprets or modifies the input text.
|
|
78
80
|
|
|
79
|
-
- **`temperature
|
|
81
|
+
- **`temperature (float)`** → Determines how creative the model should respond. Takes a float number from `0.0` to `2.0`.
|
|
80
82
|
|
|
81
|
-
|
|
83
|
+
- **`validator (Callable)`** → Forces TheTool to validate the output result based on your custom validator. Validator should return a bool (True if there were no problem, False if the validation fails.) If the validator fails, TheTool will retry to get another output by modifying `temperature`. You can specify `max_validation_retries=<N>` to change the number of retries.
|
|
84
|
+
|
|
85
|
+
- **`priority (int)`** → Task execution priority level. Higher values = higher priority. Affects processing order in queues.
|
|
86
|
+
**Note:** This feature works if it's supported by the model and vLLM.
|
|
82
87
|
|
|
83
88
|
**Note:** There might be some tools that don't support some of the parameters above.
|
|
84
89
|
|
|
85
90
|
---
|
|
86
91
|
|
|
92
|
+
## 🧩 ToolOutput
|
|
93
|
+
|
|
94
|
+
Every tool of `TextTools` returns a `ToolOutput` object which is a BaseModel with attributes:
|
|
95
|
+
- **`result: Any`** → The output of LLM
|
|
96
|
+
- **`analysis: str`** → The reasoning step before generating the final output
|
|
97
|
+
- **`logprobs: list`** → Token-level probabilities for the generated output
|
|
98
|
+
- **`process: str`** → The tool name which processed the input
|
|
99
|
+
- **`processed_at: datetime`** → The process time
|
|
100
|
+
- **`execution_time: float`** → The execution time (seconds)
|
|
101
|
+
- **`errors: list[str]`** → Any error that have occured during calling LLM
|
|
102
|
+
|
|
103
|
+
**Note:** You can use `repr(ToolOutput)` to see details of your ToolOutput.
|
|
104
|
+
|
|
105
|
+
---
|
|
106
|
+
|
|
87
107
|
## 🚀 Installation
|
|
88
108
|
|
|
89
109
|
Install the latest release via PyPI:
|
|
@@ -94,7 +114,7 @@ pip install -U hamtaa-texttools
|
|
|
94
114
|
|
|
95
115
|
---
|
|
96
116
|
|
|
97
|
-
## Sync vs Async
|
|
117
|
+
## 🧨 Sync vs Async
|
|
98
118
|
| Tool | Style | Use case |
|
|
99
119
|
|--------------|---------|---------------------------------------------|
|
|
100
120
|
| `TheTool` | Sync | Simple scripts, sequential workflows |
|
|
@@ -121,13 +141,13 @@ the_tool = TheTool(client=client, model=model)
|
|
|
121
141
|
detection = the_tool.is_question("Is this project open source?", logprobs=True, top_logprobs=2)
|
|
122
142
|
print(detection.result)
|
|
123
143
|
print(detection.logprobs)
|
|
124
|
-
# Output: True
|
|
144
|
+
# Output: True + logprobs
|
|
125
145
|
|
|
126
146
|
# Example: Translation
|
|
127
147
|
translation = the_tool.translate("سلام، حالت چطوره؟" target_language="English", with_analysis=True)
|
|
128
148
|
print(translation.result)
|
|
129
149
|
print(translation.analysis)
|
|
130
|
-
# Output: "Hi! How are you?"
|
|
150
|
+
# Output: "Hi! How are you?" + analysis
|
|
131
151
|
```
|
|
132
152
|
|
|
133
153
|
---
|
|
@@ -147,19 +167,22 @@ async def main():
|
|
|
147
167
|
model = "gpt-4o-mini"
|
|
148
168
|
|
|
149
169
|
# Create an instance of AsyncTheTool
|
|
150
|
-
|
|
170
|
+
async_the_tool = AsyncTheTool(client=async_client, model=model)
|
|
171
|
+
|
|
172
|
+
# Example: Async Translation and Keyword Extraction
|
|
173
|
+
translation_task = async_the_tool.translate("سلام، حالت چطوره؟", target_language="English")
|
|
174
|
+
keywords_task = async_the_tool.extract_keywords("Tomorrow, we will be dead by the car crash")
|
|
151
175
|
|
|
152
|
-
|
|
153
|
-
translation = await the_tool.translate("سلام، حالت چطوره؟", target_language="English")
|
|
176
|
+
(translation, keywords) = await asyncio.gather(translation_task, keywords_task)
|
|
154
177
|
print(translation.result)
|
|
155
|
-
|
|
178
|
+
print(keywords.result)
|
|
156
179
|
|
|
157
180
|
asyncio.run(main())
|
|
158
181
|
```
|
|
159
182
|
|
|
160
183
|
---
|
|
161
184
|
|
|
162
|
-
##
|
|
185
|
+
## 👍 Use Cases
|
|
163
186
|
|
|
164
187
|
Use **TextTools** when you need to:
|
|
165
188
|
|
|
@@ -167,7 +190,56 @@ Use **TextTools** when you need to:
|
|
|
167
190
|
- 🌍 **Translate** and process multilingual corpora with ease
|
|
168
191
|
- 🧩 **Integrate** LLMs into production pipelines (structured outputs)
|
|
169
192
|
- 📊 **Analyze** large text collections using embeddings and categorization
|
|
170
|
-
|
|
193
|
+
|
|
194
|
+
---
|
|
195
|
+
|
|
196
|
+
## 🔍 Logging
|
|
197
|
+
|
|
198
|
+
TextTools uses Python's standard `logging` module. The library's default logger level is `WARNING`, so if you want to modify it, follow instructions:
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
```python
|
|
202
|
+
import logging
|
|
203
|
+
|
|
204
|
+
# Default: warnings and errors only
|
|
205
|
+
logging.basicConfig(level=logging.WARNING)
|
|
206
|
+
|
|
207
|
+
# Debug everything (verbose)
|
|
208
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
209
|
+
|
|
210
|
+
# Complete silence
|
|
211
|
+
logging.basicConfig(level=logging.CRITICAL)
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
---
|
|
215
|
+
|
|
216
|
+
## 📚 Batch Processing
|
|
217
|
+
|
|
218
|
+
Process large datasets efficiently using OpenAI's batch API.
|
|
219
|
+
|
|
220
|
+
## ⚡ Quick Start (Batch)
|
|
221
|
+
|
|
222
|
+
```python
|
|
223
|
+
from pydantic import BaseModel
|
|
224
|
+
from texttools import BatchJobRunner, BatchConfig
|
|
225
|
+
|
|
226
|
+
# Configure your batch job
|
|
227
|
+
config = BatchConfig(
|
|
228
|
+
system_prompt="Extract entities from the text",
|
|
229
|
+
job_name="entity_extraction",
|
|
230
|
+
input_data_path="data.json",
|
|
231
|
+
output_data_filename="results.json",
|
|
232
|
+
model="gpt-4o-mini"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Define your output schema
|
|
236
|
+
class Output(BaseModel):
|
|
237
|
+
entities: list[str]
|
|
238
|
+
|
|
239
|
+
# Run the batch job
|
|
240
|
+
runner = BatchJobRunner(config, output_model=Output)
|
|
241
|
+
runner.run()
|
|
242
|
+
```
|
|
171
243
|
|
|
172
244
|
---
|
|
173
245
|
|
|
@@ -178,6 +250,6 @@ Feel free to **open issues, suggest new features, or submit pull requests**.
|
|
|
178
250
|
|
|
179
251
|
---
|
|
180
252
|
|
|
181
|
-
## License
|
|
253
|
+
## 🌿 License
|
|
182
254
|
|
|
183
255
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
hamtaa_texttools-1.1.16.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
|
|
2
|
+
texttools/__init__.py,sha256=dc81lXGWP29k7oVvq2BMoMotz6lgiwX4PO2jHHBe2S8,317
|
|
3
|
+
texttools/batch/batch_config.py,sha256=m1UgILVKjNdWE6laNbfbG4vgi4o2fEegGZbeoam6pnY,749
|
|
4
|
+
texttools/batch/batch_runner.py,sha256=9e4SPLlvLHHs3U7bHkuuMVw8TFNwsGUzRjkAMKN4_ik,9378
|
|
5
|
+
texttools/batch/internals/batch_manager.py,sha256=UoBe76vmFG72qrSaGKDZf4HzkykFBkkkbL9TLfV8TuQ,8730
|
|
6
|
+
texttools/batch/internals/utils.py,sha256=F1_7YlVFKhjUROAFX4m0SaP8KiZVZyHRMIIB87VUGQc,373
|
|
7
|
+
texttools/prompts/README.md,sha256=-5YO93CN93QLifqZpUeUnCOCBbDiOTV-cFQeJ7Gg0I4,1377
|
|
8
|
+
texttools/prompts/categorize.yaml,sha256=F7VezB25B_sT5yoC25ezODBddkuDD5lUHKetSpx9FKI,2743
|
|
9
|
+
texttools/prompts/detect_entity.yaml,sha256=1rhMkJOjxSQcT4j_c5SRcIm77AUdeG-rUmeidb6VOFc,981
|
|
10
|
+
texttools/prompts/extract_entities.yaml,sha256=KiKjeDpHaeh3JVtZ6q1pa3k4DYucUIU9WnEcRTCA-SE,651
|
|
11
|
+
texttools/prompts/extract_keywords.yaml,sha256=Vj4Tt3vT6LtpOo_iBZPo9oWI50oVdPGXe5i8yDR8ex4,3177
|
|
12
|
+
texttools/prompts/is_question.yaml,sha256=d0-vKRbXWkxvO64ikvxRjEmpAXGpCYIPGhgexvPPjws,471
|
|
13
|
+
texttools/prompts/merge_questions.yaml,sha256=0J85GvTirZB4ELwH3sk8ub_WcqqpYf6PrMKr3djlZeo,1792
|
|
14
|
+
texttools/prompts/rewrite.yaml,sha256=LO7He_IA3MZKz8a-LxH9DHJpOjpYwaYN1pbjp1Y0tFo,5392
|
|
15
|
+
texttools/prompts/run_custom.yaml,sha256=38OkCoVITbuuS9c08UZSP1jZW4WjSmRIi8fR0RAiPu4,108
|
|
16
|
+
texttools/prompts/subject_to_question.yaml,sha256=C7x7rNNm6U_ZG9HOn6zuzYOtvJUZ2skuWbL1-aYdd3E,1147
|
|
17
|
+
texttools/prompts/summarize.yaml,sha256=o6rxGPfWtZd61Duvm8NVvCJqfq73b-wAuMSKR6UYUqY,459
|
|
18
|
+
texttools/prompts/text_to_question.yaml,sha256=UheKYpDn6iyKI8NxunHZtFpNyfCLZZe5cvkuXpurUJY,783
|
|
19
|
+
texttools/prompts/translate.yaml,sha256=mGT2uBCei6uucWqVbs4silk-UV060v3G0jnt0P6sr50,634
|
|
20
|
+
texttools/tools/async_tools.py,sha256=vNAg0gxwUZPsMS4q8JCv7RlYymS8l_5FsFI5adEYT7w,34376
|
|
21
|
+
texttools/tools/sync_tools.py,sha256=hFifFa9YatvSeGif2E_bIG006eMdIBr6SV9HsZ_dAlg,34187
|
|
22
|
+
texttools/tools/internals/async_operator.py,sha256=1TMr8e1qbE9GSz8jl0q3MKdM8lIYE-1ZuSxHjYPqKHI,7198
|
|
23
|
+
texttools/tools/internals/formatters.py,sha256=tACNLP6PeoqaRpNudVxBaHA25zyWqWYPZQuYysIu88g,941
|
|
24
|
+
texttools/tools/internals/models.py,sha256=2QnvMiijuSqOqpCl026848rJy_pHNbRoDESlQvcdHlk,5839
|
|
25
|
+
texttools/tools/internals/operator_utils.py,sha256=w1k0RJ_W_CRbVc_J2w337VuL-opHpHiCxfhEOwtyuOo,1856
|
|
26
|
+
texttools/tools/internals/prompt_loader.py,sha256=4g6-U8kqrGN7VpNaRcrBcnF-h03PXjUDBP0lL0_4EZY,1953
|
|
27
|
+
texttools/tools/internals/sync_operator.py,sha256=4-V__o55Q8w29lWxkhG4St-exZLZTfBbiW76knOXbc0,7106
|
|
28
|
+
hamtaa_texttools-1.1.16.dist-info/METADATA,sha256=DL-cjlGMv7bft8QVd-pn5E_tNDuPgQHkTKGl4YTosGw,9555
|
|
29
|
+
hamtaa_texttools-1.1.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
+
hamtaa_texttools-1.1.16.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
|
|
31
|
+
hamtaa_texttools-1.1.16.dist-info/RECORD,,
|
texttools/__init__.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
from .batch import BatchJobRunner
|
|
2
|
-
from .
|
|
1
|
+
from .batch.batch_runner import BatchJobRunner
|
|
2
|
+
from .batch.batch_config import BatchConfig
|
|
3
|
+
from .tools.sync_tools import TheTool
|
|
4
|
+
from .tools.async_tools import AsyncTheTool
|
|
5
|
+
from .tools.internals.models import CategoryTree
|
|
3
6
|
|
|
4
|
-
__all__ = [
|
|
5
|
-
"TheTool",
|
|
6
|
-
"AsyncTheTool",
|
|
7
|
-
"SimpleBatchManager",
|
|
8
|
-
"BatchJobRunner",
|
|
9
|
-
]
|
|
7
|
+
__all__ = ["TheTool", "AsyncTheTool", "BatchJobRunner", "BatchConfig", "CategoryTree"]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
from texttools.batch.internals.utils import import_data, export_data
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class BatchConfig:
|
|
9
|
+
"""
|
|
10
|
+
Configuration for batch job runner.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
system_prompt: str = ""
|
|
14
|
+
job_name: str = ""
|
|
15
|
+
input_data_path: str = ""
|
|
16
|
+
output_data_filename: str = ""
|
|
17
|
+
model: str = "gpt-4.1-mini"
|
|
18
|
+
MAX_BATCH_SIZE: int = 100
|
|
19
|
+
MAX_TOTAL_TOKENS: int = 2_000_000
|
|
20
|
+
CHARS_PER_TOKEN: float = 2.7
|
|
21
|
+
PROMPT_TOKEN_MULTIPLIER: int = 1_000
|
|
22
|
+
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
23
|
+
import_function: Callable = import_data
|
|
24
|
+
export_function: Callable = export_data
|
|
25
|
+
poll_interval_seconds: int = 30
|
|
26
|
+
max_retries: int = 3
|
texttools/batch/batch_runner.py
CHANGED
|
@@ -1,61 +1,22 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
import time
|
|
4
|
-
from dataclasses import dataclass
|
|
5
4
|
from pathlib import Path
|
|
6
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Type, TypeVar
|
|
7
6
|
import logging
|
|
8
7
|
|
|
9
8
|
from dotenv import load_dotenv
|
|
10
9
|
from openai import OpenAI
|
|
11
10
|
from pydantic import BaseModel
|
|
12
11
|
|
|
13
|
-
from texttools.batch import
|
|
12
|
+
from texttools.batch.internals.batch_manager import BatchManager
|
|
13
|
+
from texttools.batch.batch_config import BatchConfig
|
|
14
|
+
from texttools.tools.internals.models import StrOutput
|
|
14
15
|
|
|
15
|
-
#
|
|
16
|
-
|
|
17
|
-
logger.setLevel(logging.INFO)
|
|
16
|
+
# Base Model type for output models
|
|
17
|
+
T = TypeVar("T", bound=BaseModel)
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
class OutputModel(BaseModel):
|
|
21
|
-
desired_output: str
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def export_data(data):
|
|
25
|
-
"""
|
|
26
|
-
Produces a structure of the following form from an initial data structure:
|
|
27
|
-
[{"id": str, "text": str},...]
|
|
28
|
-
"""
|
|
29
|
-
return data
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def import_data(data):
|
|
33
|
-
"""
|
|
34
|
-
Takes the output and adds and aggregates it to the original structure.
|
|
35
|
-
"""
|
|
36
|
-
return data
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@dataclass
|
|
40
|
-
class BatchConfig:
|
|
41
|
-
"""
|
|
42
|
-
Configuration for batch job runner.
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
system_prompt: str = ""
|
|
46
|
-
job_name: str = ""
|
|
47
|
-
input_data_path: str = ""
|
|
48
|
-
output_data_filename: str = ""
|
|
49
|
-
model: str = "gpt-4.1-mini"
|
|
50
|
-
MAX_BATCH_SIZE: int = 100
|
|
51
|
-
MAX_TOTAL_TOKENS: int = 2000000
|
|
52
|
-
CHARS_PER_TOKEN: float = 2.7
|
|
53
|
-
PROMPT_TOKEN_MULTIPLIER: int = 1000
|
|
54
|
-
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
55
|
-
import_function: Callable = import_data
|
|
56
|
-
export_function: Callable = export_data
|
|
57
|
-
poll_interval_seconds: int = 30
|
|
58
|
-
max_retries: int = 3
|
|
19
|
+
logger = logging.getLogger("texttools.batch_runner")
|
|
59
20
|
|
|
60
21
|
|
|
61
22
|
class BatchJobRunner:
|
|
@@ -64,142 +25,180 @@ class BatchJobRunner:
|
|
|
64
25
|
"""
|
|
65
26
|
|
|
66
27
|
def __init__(
|
|
67
|
-
self, config: BatchConfig = BatchConfig(), output_model:
|
|
28
|
+
self, config: BatchConfig = BatchConfig(), output_model: Type[T] = StrOutput
|
|
68
29
|
):
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
77
|
-
self.
|
|
78
|
-
self.
|
|
79
|
-
self._partition_data()
|
|
80
|
-
Path(self.config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
30
|
+
self._config = config
|
|
31
|
+
self._system_prompt = config.system_prompt
|
|
32
|
+
self._job_name = config.job_name
|
|
33
|
+
self._input_data_path = config.input_data_path
|
|
34
|
+
self._output_data_filename = config.output_data_filename
|
|
35
|
+
self._model = config.model
|
|
36
|
+
self._output_model = output_model
|
|
37
|
+
self._manager = self._init_manager()
|
|
38
|
+
self._data = self._load_data()
|
|
39
|
+
self._parts: list[list[dict[str, Any]]] = []
|
|
81
40
|
# Map part index to job name
|
|
82
|
-
self.
|
|
41
|
+
self._part_idx_to_job_name: dict[int, str] = {}
|
|
83
42
|
# Track retry attempts per part
|
|
84
|
-
self.
|
|
43
|
+
self._part_attempts: dict[int, int] = {}
|
|
44
|
+
self._partition_data()
|
|
45
|
+
Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
85
46
|
|
|
86
|
-
def _init_manager(self) ->
|
|
47
|
+
def _init_manager(self) -> BatchManager:
|
|
87
48
|
load_dotenv()
|
|
88
49
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
89
50
|
client = OpenAI(api_key=api_key)
|
|
90
|
-
return
|
|
51
|
+
return BatchManager(
|
|
91
52
|
client=client,
|
|
92
|
-
model=self.
|
|
93
|
-
prompt_template=self.
|
|
94
|
-
output_model=self.
|
|
53
|
+
model=self._model,
|
|
54
|
+
prompt_template=self._system_prompt,
|
|
55
|
+
output_model=self._output_model,
|
|
95
56
|
)
|
|
96
57
|
|
|
97
58
|
def _load_data(self):
|
|
98
|
-
with open(self.
|
|
59
|
+
with open(self._input_data_path, "r", encoding="utf-8") as f:
|
|
99
60
|
data = json.load(f)
|
|
100
|
-
data = self.
|
|
61
|
+
data = self._config.export_function(data)
|
|
101
62
|
|
|
102
63
|
# Ensure data is a list of dicts with 'id' and 'content' as strings
|
|
103
64
|
if not isinstance(data, list):
|
|
104
65
|
raise ValueError(
|
|
105
|
-
|
|
66
|
+
"Exported data must be a list of dicts with 'id' and 'content' keys"
|
|
106
67
|
)
|
|
107
68
|
for item in data:
|
|
108
69
|
if not (isinstance(item, dict) and "id" in item and "content" in item):
|
|
109
70
|
raise ValueError(
|
|
110
|
-
"
|
|
71
|
+
f"Item must be a dict with 'id' and 'content' keys. Got: {type(item)}"
|
|
111
72
|
)
|
|
112
73
|
if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
|
|
113
74
|
raise ValueError("'id' and 'content' must be strings.")
|
|
114
75
|
return data
|
|
115
76
|
|
|
116
77
|
def _partition_data(self):
|
|
117
|
-
total_length = sum(len(item["content"]) for item in self.
|
|
118
|
-
prompt_length = len(self.
|
|
119
|
-
total = total_length + (prompt_length * len(self.
|
|
120
|
-
calculation = total / self.
|
|
78
|
+
total_length = sum(len(item["content"]) for item in self._data)
|
|
79
|
+
prompt_length = len(self._system_prompt)
|
|
80
|
+
total = total_length + (prompt_length * len(self._data))
|
|
81
|
+
calculation = total / self._config.CHARS_PER_TOKEN
|
|
121
82
|
logger.info(
|
|
122
83
|
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
123
84
|
)
|
|
124
|
-
if calculation < self.
|
|
125
|
-
self.
|
|
85
|
+
if calculation < self._config.MAX_TOTAL_TOKENS:
|
|
86
|
+
self._parts = [self._data]
|
|
126
87
|
else:
|
|
127
88
|
# Partition into chunks of MAX_BATCH_SIZE
|
|
128
|
-
self.
|
|
129
|
-
self.
|
|
130
|
-
for i in range(0, len(self.
|
|
89
|
+
self._parts = [
|
|
90
|
+
self._data[i : i + self._config.MAX_BATCH_SIZE]
|
|
91
|
+
for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
|
|
131
92
|
]
|
|
132
|
-
logger.info(f"Data split into {len(self.
|
|
93
|
+
logger.info(f"Data split into {len(self._parts)} part(s)")
|
|
133
94
|
|
|
134
95
|
def _submit_all_jobs(self) -> None:
|
|
135
|
-
for idx, part in enumerate(self.
|
|
96
|
+
for idx, part in enumerate(self._parts):
|
|
136
97
|
if self._result_exists(idx):
|
|
137
98
|
logger.info(f"Skipping part {idx + 1}: result already exists.")
|
|
138
99
|
continue
|
|
139
100
|
part_job_name = (
|
|
140
|
-
f"{self.
|
|
141
|
-
if len(self.
|
|
142
|
-
else self.
|
|
101
|
+
f"{self._job_name}_part_{idx + 1}"
|
|
102
|
+
if len(self._parts) > 1
|
|
103
|
+
else self._job_name
|
|
143
104
|
)
|
|
144
105
|
# If a job with this name already exists, register and skip submitting
|
|
145
|
-
existing_job = self.
|
|
106
|
+
existing_job = self._manager._load_state(part_job_name)
|
|
146
107
|
if existing_job:
|
|
147
108
|
logger.info(
|
|
148
109
|
f"Skipping part {idx + 1}: job already exists ({part_job_name})."
|
|
149
110
|
)
|
|
150
|
-
self.
|
|
151
|
-
self.
|
|
111
|
+
self._part_idx_to_job_name[idx] = part_job_name
|
|
112
|
+
self._part_attempts.setdefault(idx, 0)
|
|
152
113
|
continue
|
|
153
114
|
|
|
154
115
|
payload = part
|
|
155
116
|
logger.info(
|
|
156
|
-
f"Submitting job for part {idx + 1}/{len(self.
|
|
117
|
+
f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
|
|
157
118
|
)
|
|
158
|
-
self.
|
|
159
|
-
self.
|
|
160
|
-
self.
|
|
119
|
+
self._manager.start(payload, job_name=part_job_name)
|
|
120
|
+
self._part_idx_to_job_name[idx] = part_job_name
|
|
121
|
+
self._part_attempts.setdefault(idx, 0)
|
|
161
122
|
# This is added for letting file get uploaded, before starting the next part.
|
|
162
123
|
logger.info("Uploading...")
|
|
163
124
|
time.sleep(30)
|
|
164
125
|
|
|
126
|
+
def _save_results(
|
|
127
|
+
self,
|
|
128
|
+
output_data: list[dict[str, Any]] | dict[str, Any],
|
|
129
|
+
log: list[Any],
|
|
130
|
+
part_idx: int,
|
|
131
|
+
):
|
|
132
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
133
|
+
result_path = (
|
|
134
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
135
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
136
|
+
)
|
|
137
|
+
if not output_data:
|
|
138
|
+
logger.info("No output data to save. Skipping this part.")
|
|
139
|
+
return
|
|
140
|
+
else:
|
|
141
|
+
with open(result_path, "w", encoding="utf-8") as f:
|
|
142
|
+
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
143
|
+
if log:
|
|
144
|
+
log_path = (
|
|
145
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
146
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
|
|
147
|
+
)
|
|
148
|
+
with open(log_path, "w", encoding="utf-8") as f:
|
|
149
|
+
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
150
|
+
|
|
151
|
+
def _result_exists(self, part_idx: int) -> bool:
|
|
152
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
153
|
+
result_path = (
|
|
154
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
155
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
156
|
+
)
|
|
157
|
+
return result_path.exists()
|
|
158
|
+
|
|
165
159
|
def run(self):
|
|
160
|
+
"""
|
|
161
|
+
Execute the batch job processing pipeline.
|
|
162
|
+
|
|
163
|
+
Submits jobs, monitors progress, handles retries, and saves results.
|
|
164
|
+
"""
|
|
166
165
|
# Submit all jobs up-front for concurrent execution
|
|
167
166
|
self._submit_all_jobs()
|
|
168
|
-
pending_parts: set[int] = set(self.
|
|
167
|
+
pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
|
|
169
168
|
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
170
169
|
# Polling loop
|
|
171
170
|
while pending_parts:
|
|
172
171
|
finished_this_round: list[int] = []
|
|
173
172
|
for part_idx in list(pending_parts):
|
|
174
|
-
job_name = self.
|
|
175
|
-
status = self.
|
|
173
|
+
job_name = self._part_idx_to_job_name[part_idx]
|
|
174
|
+
status = self._manager.check_status(job_name=job_name)
|
|
176
175
|
logger.info(f"Status for {job_name}: {status}")
|
|
177
176
|
if status == "completed":
|
|
178
177
|
logger.info(
|
|
179
178
|
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
180
179
|
)
|
|
181
|
-
output_data, log = self.
|
|
180
|
+
output_data, log = self._manager.fetch_results(
|
|
182
181
|
job_name=job_name, remove_cache=False
|
|
183
182
|
)
|
|
184
|
-
output_data = self.
|
|
183
|
+
output_data = self._config.import_function(output_data)
|
|
185
184
|
self._save_results(output_data, log, part_idx)
|
|
186
185
|
logger.info(f"Fetched and saved results for part {part_idx + 1}.")
|
|
187
186
|
finished_this_round.append(part_idx)
|
|
188
187
|
elif status == "failed":
|
|
189
|
-
attempt = self.
|
|
190
|
-
self.
|
|
191
|
-
if attempt <= self.
|
|
188
|
+
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
189
|
+
self._part_attempts[part_idx] = attempt
|
|
190
|
+
if attempt <= self._config.max_retries:
|
|
192
191
|
logger.info(
|
|
193
192
|
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
194
193
|
)
|
|
195
|
-
self.
|
|
194
|
+
self._manager._clear_state(job_name)
|
|
196
195
|
time.sleep(10)
|
|
197
|
-
payload = self._to_manager_payload(self.
|
|
196
|
+
payload = self._to_manager_payload(self._parts[part_idx])
|
|
198
197
|
new_job_name = (
|
|
199
|
-
f"{self.
|
|
198
|
+
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
200
199
|
)
|
|
201
|
-
self.
|
|
202
|
-
self.
|
|
200
|
+
self._manager.start(payload, job_name=new_job_name)
|
|
201
|
+
self._part_idx_to_job_name[part_idx] = new_job_name
|
|
203
202
|
else:
|
|
204
203
|
logger.info(
|
|
205
204
|
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
@@ -213,51 +212,6 @@ class BatchJobRunner:
|
|
|
213
212
|
pending_parts.discard(part_idx)
|
|
214
213
|
if pending_parts:
|
|
215
214
|
logger.info(
|
|
216
|
-
f"Waiting {self.
|
|
215
|
+
f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
217
216
|
)
|
|
218
|
-
time.sleep(self.
|
|
219
|
-
|
|
220
|
-
def _save_results(
|
|
221
|
-
self,
|
|
222
|
-
output_data: list[dict[str, Any]] | dict[str, Any],
|
|
223
|
-
log: list[Any],
|
|
224
|
-
part_idx: int,
|
|
225
|
-
):
|
|
226
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self.parts) > 1 else ""
|
|
227
|
-
result_path = (
|
|
228
|
-
Path(self.config.BASE_OUTPUT_DIR)
|
|
229
|
-
/ f"{Path(self.output_data_filename).stem}{part_suffix}.json"
|
|
230
|
-
)
|
|
231
|
-
if not output_data:
|
|
232
|
-
logger.info("No output data to save. Skipping this part.")
|
|
233
|
-
return
|
|
234
|
-
else:
|
|
235
|
-
with open(result_path, "w", encoding="utf-8") as f:
|
|
236
|
-
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
237
|
-
if log:
|
|
238
|
-
log_path = (
|
|
239
|
-
Path(self.config.BASE_OUTPUT_DIR)
|
|
240
|
-
/ f"{Path(self.output_data_filename).stem}{part_suffix}_log.json"
|
|
241
|
-
)
|
|
242
|
-
with open(log_path, "w", encoding="utf-8") as f:
|
|
243
|
-
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
244
|
-
|
|
245
|
-
def _result_exists(self, part_idx: int) -> bool:
|
|
246
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self.parts) > 1 else ""
|
|
247
|
-
result_path = (
|
|
248
|
-
Path(self.config.BASE_OUTPUT_DIR)
|
|
249
|
-
/ f"{Path(self.output_data_filename).stem}{part_suffix}.json"
|
|
250
|
-
)
|
|
251
|
-
return result_path.exists()
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
if __name__ == "__main__":
|
|
255
|
-
logger.info("=== Batch Job Runner ===")
|
|
256
|
-
config = BatchConfig(
|
|
257
|
-
system_prompt="",
|
|
258
|
-
job_name="job_name",
|
|
259
|
-
input_data_path="Data.json",
|
|
260
|
-
output_data_filename="output",
|
|
261
|
-
)
|
|
262
|
-
runner = BatchJobRunner(config)
|
|
263
|
-
runner.run()
|
|
217
|
+
time.sleep(self._config.poll_interval_seconds)
|