hamtaa-texttools 1.3.0__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/METADATA +1 -28
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/RECORD +6 -10
- texttools/__init__.py +1 -3
- texttools/batch/__init__.py +0 -0
- texttools/batch/config.py +0 -40
- texttools/batch/manager.py +0 -228
- texttools/batch/runner.py +0 -228
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.3.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hamtaa-texttools
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.1
|
|
4
4
|
Summary: A high-level NLP toolkit built on top of modern LLMs.
|
|
5
5
|
Author-email: Tohidi <the.mohammad.tohidi@gmail.com>, Erfan Moosavi <erfanmoosavi84@gmail.com>, Montazer <montazerh82@gmail.com>, Givechi <mohamad.m.givechi@gmail.com>, Zareshahi <a.zareshahi1377@gmail.com>
|
|
6
6
|
Maintainer-email: Erfan Moosavi <erfanmoosavi84@gmail.com>, Tohidi <the.mohammad.tohidi@gmail.com>
|
|
@@ -177,33 +177,6 @@ Use **TextTools** when you need to:
|
|
|
177
177
|
|
|
178
178
|
---
|
|
179
179
|
|
|
180
|
-
## 📚 Batch Processing
|
|
181
|
-
|
|
182
|
-
Process large datasets efficiently using OpenAI's batch API.
|
|
183
|
-
|
|
184
|
-
## ⚡ Quick Start (Batch Runner)
|
|
185
|
-
|
|
186
|
-
```python
|
|
187
|
-
from pydantic import BaseModel
|
|
188
|
-
from texttools import BatchRunner, BatchConfig
|
|
189
|
-
|
|
190
|
-
config = BatchConfig(
|
|
191
|
-
system_prompt="Extract entities from the text",
|
|
192
|
-
job_name="entity_extraction",
|
|
193
|
-
input_data_path="data.json",
|
|
194
|
-
output_data_filename="results.json",
|
|
195
|
-
model="gpt-4o-mini"
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
class Output(BaseModel):
|
|
199
|
-
entities: list[str]
|
|
200
|
-
|
|
201
|
-
runner = BatchRunner(config, output_model=Output)
|
|
202
|
-
runner.run()
|
|
203
|
-
```
|
|
204
|
-
|
|
205
|
-
---
|
|
206
|
-
|
|
207
180
|
## 🤝 Contributing
|
|
208
181
|
|
|
209
182
|
Contributions are welcome!
|
|
@@ -1,11 +1,7 @@
|
|
|
1
|
-
hamtaa_texttools-1.3.
|
|
2
|
-
texttools/__init__.py,sha256=
|
|
1
|
+
hamtaa_texttools-1.3.1.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
|
|
2
|
+
texttools/__init__.py,sha256=RK1GAU6pq2lGwFtHdrCX5JkPRHmOLGcmGH67hd_7VAQ,175
|
|
3
3
|
texttools/models.py,sha256=5eT2cSrFq8Xa38kANznV7gbi7lwB2PoDxciLKTpsd6c,2516
|
|
4
4
|
texttools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
texttools/batch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
-
texttools/batch/config.py,sha256=GDDXuhRZ_bOGVwSIlU4tWP247tx1_A7qzLJn7VqDyLU,1050
|
|
7
|
-
texttools/batch/manager.py,sha256=XZtf8UkdClfQlnRKne4nWEcFvdSKE67EamEePKy7jwI,8730
|
|
8
|
-
texttools/batch/runner.py,sha256=9qxXIMfYRXW5SXDqqKtRr61rnQdYZkbCGqKImhSrY6I,9923
|
|
9
5
|
texttools/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
6
|
texttools/core/engine.py,sha256=iRHdlIOPuUwIN6_72HNyTQQE7h_7xUZhC-WO-fDA5k8,9597
|
|
11
7
|
texttools/core/exceptions.py,sha256=6SDjUL1rmd3ngzD3ytF4LyTRj3bQMSFR9ECrLoqXXHw,395
|
|
@@ -28,7 +24,7 @@ texttools/prompts/translate.yaml,sha256=Dd5bs3O8SI-FlVSwHMYGeEjMmdOWeRlcfBHkhixC
|
|
|
28
24
|
texttools/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
25
|
texttools/tools/async_tools.py,sha256=2suwx8N0aRnowaSOpV6C57AqPlmQe5Z0Yx4E5QIMkmU,46939
|
|
30
26
|
texttools/tools/sync_tools.py,sha256=mEuL-nlbxVW30dPE3hGkAUnYXbul-3gN2Le4CMVFCgU,42528
|
|
31
|
-
hamtaa_texttools-1.3.
|
|
32
|
-
hamtaa_texttools-1.3.
|
|
33
|
-
hamtaa_texttools-1.3.
|
|
34
|
-
hamtaa_texttools-1.3.
|
|
27
|
+
hamtaa_texttools-1.3.1.dist-info/METADATA,sha256=6wLYAaPVOFpzUz8tN7lfzbAGhEr10JBXgRHcZZvrt5s,7453
|
|
28
|
+
hamtaa_texttools-1.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
29
|
+
hamtaa_texttools-1.3.1.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
|
|
30
|
+
hamtaa_texttools-1.3.1.dist-info/RECORD,,
|
texttools/__init__.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
|
-
from .batch.config import BatchConfig
|
|
2
|
-
from .batch.runner import BatchRunner
|
|
3
1
|
from .models import CategoryTree
|
|
4
2
|
from .tools.async_tools import AsyncTheTool
|
|
5
3
|
from .tools.sync_tools import TheTool
|
|
6
4
|
|
|
7
|
-
__all__ = ["TheTool", "AsyncTheTool", "CategoryTree"
|
|
5
|
+
__all__ = ["TheTool", "AsyncTheTool", "CategoryTree"]
|
texttools/batch/__init__.py
DELETED
|
File without changes
|
texttools/batch/config.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def export_data(data) -> list[dict[str, str]]:
|
|
7
|
-
"""
|
|
8
|
-
Produces a structure of the following form from an initial data structure:
|
|
9
|
-
[{"id": str, "text": str},...]
|
|
10
|
-
"""
|
|
11
|
-
return data
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def import_data(data) -> Any:
|
|
15
|
-
"""
|
|
16
|
-
Takes the output and adds and aggregates it to the original structure.
|
|
17
|
-
"""
|
|
18
|
-
return data
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@dataclass
|
|
22
|
-
class BatchConfig:
|
|
23
|
-
"""
|
|
24
|
-
Configuration for batch job runner.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
system_prompt: str = ""
|
|
28
|
-
job_name: str = ""
|
|
29
|
-
input_data_path: str = ""
|
|
30
|
-
output_data_filename: str = ""
|
|
31
|
-
model: str = "gpt-4.1-mini"
|
|
32
|
-
MAX_BATCH_SIZE: int = 100
|
|
33
|
-
MAX_TOTAL_TOKENS: int = 2_000_000
|
|
34
|
-
CHARS_PER_TOKEN: float = 2.7
|
|
35
|
-
PROMPT_TOKEN_MULTIPLIER: int = 1_000
|
|
36
|
-
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
37
|
-
import_function: Callable = import_data
|
|
38
|
-
export_function: Callable = export_data
|
|
39
|
-
poll_interval_seconds: int = 30
|
|
40
|
-
max_retries: int = 3
|
texttools/batch/manager.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import uuid
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Any, Type, TypeVar
|
|
6
|
-
|
|
7
|
-
from openai import OpenAI
|
|
8
|
-
from openai.lib._pydantic import to_strict_json_schema
|
|
9
|
-
from pydantic import BaseModel
|
|
10
|
-
|
|
11
|
-
# Base Model type for output models
|
|
12
|
-
T = TypeVar("T", bound=BaseModel)
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger("texttools.batch_manager")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class BatchManager:
|
|
18
|
-
"""
|
|
19
|
-
Manages batch processing jobs for OpenAI's chat completions with structured outputs.
|
|
20
|
-
|
|
21
|
-
Handles the full lifecycle of a batch job: creating tasks from input texts,
|
|
22
|
-
starting the job, monitoring status, and fetching results. Results are automatically
|
|
23
|
-
parsed into the specified Pydantic output model. Job state is persisted to disk.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
client: OpenAI,
|
|
29
|
-
model: str,
|
|
30
|
-
output_model: Type[T],
|
|
31
|
-
prompt_template: str,
|
|
32
|
-
state_dir: Path = Path(".batch_jobs"),
|
|
33
|
-
custom_json_schema_obj_str: dict | None = None,
|
|
34
|
-
**client_kwargs: Any,
|
|
35
|
-
):
|
|
36
|
-
self._client = client
|
|
37
|
-
self._model = model
|
|
38
|
-
self._output_model = output_model
|
|
39
|
-
self._prompt_template = prompt_template
|
|
40
|
-
self._state_dir = state_dir
|
|
41
|
-
self._custom_json_schema_obj_str = custom_json_schema_obj_str
|
|
42
|
-
self._client_kwargs = client_kwargs
|
|
43
|
-
self._dict_input = False
|
|
44
|
-
self._state_dir.mkdir(parents=True, exist_ok=True)
|
|
45
|
-
|
|
46
|
-
if custom_json_schema_obj_str and not isinstance(
|
|
47
|
-
custom_json_schema_obj_str, dict
|
|
48
|
-
):
|
|
49
|
-
raise ValueError("Schema should be a dict")
|
|
50
|
-
|
|
51
|
-
def _state_file(self, job_name: str) -> Path:
|
|
52
|
-
return self._state_dir / f"{job_name}.json"
|
|
53
|
-
|
|
54
|
-
def _load_state(self, job_name: str) -> list[dict[str, Any]]:
|
|
55
|
-
"""
|
|
56
|
-
Loads the state (job information) from the state file for the given job name.
|
|
57
|
-
Returns an empty list if the state file does not exist.
|
|
58
|
-
"""
|
|
59
|
-
path = self._state_file(job_name)
|
|
60
|
-
if path.exists():
|
|
61
|
-
with open(path, "r", encoding="utf-8") as f:
|
|
62
|
-
return json.load(f)
|
|
63
|
-
return []
|
|
64
|
-
|
|
65
|
-
def _save_state(self, job_name: str, jobs: list[dict[str, Any]]) -> None:
|
|
66
|
-
"""
|
|
67
|
-
Saves the job state to the state file for the given job name.
|
|
68
|
-
"""
|
|
69
|
-
with open(self._state_file(job_name), "w", encoding="utf-8") as f:
|
|
70
|
-
json.dump(jobs, f)
|
|
71
|
-
|
|
72
|
-
def _clear_state(self, job_name: str) -> None:
|
|
73
|
-
"""
|
|
74
|
-
Deletes the state file for the given job name if it exists.
|
|
75
|
-
"""
|
|
76
|
-
path = self._state_file(job_name)
|
|
77
|
-
if path.exists():
|
|
78
|
-
path.unlink()
|
|
79
|
-
|
|
80
|
-
def _build_task(self, text: str, idx: str) -> dict[str, Any]:
|
|
81
|
-
"""
|
|
82
|
-
Builds a single task dictionary for the batch job, including the prompt, model, and response format configuration.
|
|
83
|
-
"""
|
|
84
|
-
response_format_config: dict[str, Any]
|
|
85
|
-
|
|
86
|
-
if self._custom_json_schema_obj_str:
|
|
87
|
-
response_format_config = {
|
|
88
|
-
"type": "json_schema",
|
|
89
|
-
"json_schema": self._custom_json_schema_obj_str,
|
|
90
|
-
}
|
|
91
|
-
else:
|
|
92
|
-
raw_schema = to_strict_json_schema(self._output_model)
|
|
93
|
-
response_format_config = {
|
|
94
|
-
"type": "json_schema",
|
|
95
|
-
"json_schema": {
|
|
96
|
-
"name": self._output_model.__name__,
|
|
97
|
-
"schema": raw_schema,
|
|
98
|
-
},
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
return {
|
|
102
|
-
"custom_id": str(idx),
|
|
103
|
-
"method": "POST",
|
|
104
|
-
"url": "/v1/chat/completions",
|
|
105
|
-
"body": {
|
|
106
|
-
"model": self.model,
|
|
107
|
-
"messages": [
|
|
108
|
-
{"role": "system", "content": self._prompt_template},
|
|
109
|
-
{"role": "user", "content": text},
|
|
110
|
-
],
|
|
111
|
-
"response_format": response_format_config,
|
|
112
|
-
**self._client_kwargs,
|
|
113
|
-
},
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
def _prepare_file(self, payload: list[str] | list[dict[str, str]]) -> Path:
|
|
117
|
-
"""
|
|
118
|
-
Prepares a JSONL file containing all tasks for the batch job, based on the input payload.
|
|
119
|
-
Returns the path to the created file.
|
|
120
|
-
"""
|
|
121
|
-
if not payload:
|
|
122
|
-
raise ValueError("Payload must not be empty")
|
|
123
|
-
if isinstance(payload[0], str):
|
|
124
|
-
tasks = [self._build_task(text, uuid.uuid4().hex) for text in payload]
|
|
125
|
-
elif isinstance(payload[0], dict):
|
|
126
|
-
tasks = [self._build_task(dic["text"], dic["id"]) for dic in payload]
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
raise TypeError(
|
|
130
|
-
"The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}"
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
file_path = self._state_dir / f"batch_{uuid.uuid4().hex}.jsonl"
|
|
134
|
-
with open(file_path, "w", encoding="utf-8") as f:
|
|
135
|
-
for task in tasks:
|
|
136
|
-
f.write(json.dumps(task) + "\n")
|
|
137
|
-
return file_path
|
|
138
|
-
|
|
139
|
-
def start(self, payload: list[str | dict[str, str]], job_name: str):
|
|
140
|
-
"""
|
|
141
|
-
Starts a new batch job by uploading the prepared file and creating a batch job on the server.
|
|
142
|
-
If a job with the same name already exists, it does nothing.
|
|
143
|
-
"""
|
|
144
|
-
if self._load_state(job_name):
|
|
145
|
-
return
|
|
146
|
-
|
|
147
|
-
path = self._prepare_file(payload)
|
|
148
|
-
upload = self._client.files.create(file=open(path, "rb"), purpose="batch")
|
|
149
|
-
job = self._client.batches.create(
|
|
150
|
-
input_file_id=upload.id,
|
|
151
|
-
endpoint="/v1/chat/completions",
|
|
152
|
-
completion_window="24h",
|
|
153
|
-
).to_dict()
|
|
154
|
-
self._save_state(job_name, [job])
|
|
155
|
-
|
|
156
|
-
def check_status(self, job_name: str) -> str:
|
|
157
|
-
"""
|
|
158
|
-
Checks and returns the current status of the batch job with the given job name.
|
|
159
|
-
Updates the job state with the latest information from the server.
|
|
160
|
-
"""
|
|
161
|
-
job = self._load_state(job_name)[0]
|
|
162
|
-
if not job:
|
|
163
|
-
return "completed"
|
|
164
|
-
|
|
165
|
-
info = self._client.batches.retrieve(job["id"])
|
|
166
|
-
job = info.to_dict()
|
|
167
|
-
self._save_state(job_name, [job])
|
|
168
|
-
logger.info("Batch job status: %s", job)
|
|
169
|
-
return job["status"]
|
|
170
|
-
|
|
171
|
-
def fetch_results(
|
|
172
|
-
self, job_name: str, remove_cache: bool = True
|
|
173
|
-
) -> tuple[dict[str, str], list]:
|
|
174
|
-
"""
|
|
175
|
-
Fetches the results of a completed batch job. Optionally saves the results to a file and/or removes the job cache.
|
|
176
|
-
Returns a tuple containing the parsed results and a log of errors (if any).
|
|
177
|
-
"""
|
|
178
|
-
job = self._load_state(job_name)[0]
|
|
179
|
-
if not job:
|
|
180
|
-
return {}
|
|
181
|
-
batch_id = job["id"]
|
|
182
|
-
|
|
183
|
-
info = self._client.batches.retrieve(batch_id)
|
|
184
|
-
out_file_id = info.output_file_id
|
|
185
|
-
if not out_file_id:
|
|
186
|
-
error_file_id = info.error_file_id
|
|
187
|
-
if error_file_id:
|
|
188
|
-
err_content = (
|
|
189
|
-
self._client.files.content(error_file_id).read().decode("utf-8")
|
|
190
|
-
)
|
|
191
|
-
logger.error("Error file content:", err_content)
|
|
192
|
-
return {}
|
|
193
|
-
|
|
194
|
-
content = self._client.files.content(out_file_id).read().decode("utf-8")
|
|
195
|
-
lines = content.splitlines()
|
|
196
|
-
results = {}
|
|
197
|
-
log = []
|
|
198
|
-
for line in lines:
|
|
199
|
-
result = json.loads(line)
|
|
200
|
-
custom_id = result["custom_id"]
|
|
201
|
-
if result["response"]["status_code"] == 200:
|
|
202
|
-
content = result["response"]["body"]["choices"][0]["message"]["content"]
|
|
203
|
-
try:
|
|
204
|
-
parsed_content = json.loads(content)
|
|
205
|
-
model_instance = self._output_model(**parsed_content)
|
|
206
|
-
results[custom_id] = model_instance.model_dump(mode="json")
|
|
207
|
-
except json.JSONDecodeError:
|
|
208
|
-
results[custom_id] = {"error": "Failed to parse content as JSON"}
|
|
209
|
-
error_d = {custom_id: results[custom_id]}
|
|
210
|
-
log.append(error_d)
|
|
211
|
-
except Exception as e:
|
|
212
|
-
results[custom_id] = {"error": str(e)}
|
|
213
|
-
error_d = {custom_id: results[custom_id]}
|
|
214
|
-
log.append(error_d)
|
|
215
|
-
else:
|
|
216
|
-
error_message = (
|
|
217
|
-
result["response"]["body"]
|
|
218
|
-
.get("error", {})
|
|
219
|
-
.get("message", "Unknown error")
|
|
220
|
-
)
|
|
221
|
-
results[custom_id] = {"error": error_message}
|
|
222
|
-
error_d = {custom_id: results[custom_id]}
|
|
223
|
-
log.append(error_d)
|
|
224
|
-
|
|
225
|
-
if remove_cache:
|
|
226
|
-
self._clear_state(job_name)
|
|
227
|
-
|
|
228
|
-
return results, log
|
texttools/batch/runner.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import os
|
|
4
|
-
import time
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any, Type, TypeVar
|
|
7
|
-
|
|
8
|
-
from dotenv import load_dotenv
|
|
9
|
-
from openai import OpenAI
|
|
10
|
-
from pydantic import BaseModel
|
|
11
|
-
|
|
12
|
-
from ..core.exceptions import TextToolsError
|
|
13
|
-
from ..core.internal_models import Str
|
|
14
|
-
from .config import BatchConfig
|
|
15
|
-
from .manager import BatchManager
|
|
16
|
-
|
|
17
|
-
# Base Model type for output models
|
|
18
|
-
T = TypeVar("T", bound=BaseModel)
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger("texttools.batch_runner")
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class BatchRunner:
|
|
24
|
-
"""
|
|
25
|
-
Handles running batch jobs using a batch manager and configuration.
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
self, config: BatchConfig = BatchConfig(), output_model: Type[T] = Str
|
|
30
|
-
):
|
|
31
|
-
try:
|
|
32
|
-
self._config = config
|
|
33
|
-
self._system_prompt = config.system_prompt
|
|
34
|
-
self._job_name = config.job_name
|
|
35
|
-
self._input_data_path = config.input_data_path
|
|
36
|
-
self._output_data_filename = config.output_data_filename
|
|
37
|
-
self._model = config.model
|
|
38
|
-
self._output_model = output_model
|
|
39
|
-
self._manager = self._init_manager()
|
|
40
|
-
self._data = self._load_data()
|
|
41
|
-
self._parts: list[list[dict[str, Any]]] = []
|
|
42
|
-
# Map part index to job name
|
|
43
|
-
self._part_idx_to_job_name: dict[int, str] = {}
|
|
44
|
-
# Track retry attempts per part
|
|
45
|
-
self._part_attempts: dict[int, int] = {}
|
|
46
|
-
self._partition_data()
|
|
47
|
-
Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
48
|
-
|
|
49
|
-
except Exception as e:
|
|
50
|
-
raise TextToolsError(f"Batch runner initialization failed: {e}")
|
|
51
|
-
|
|
52
|
-
def _init_manager(self) -> BatchManager:
|
|
53
|
-
load_dotenv()
|
|
54
|
-
api_key = os.getenv("OPENAI_API_KEY")
|
|
55
|
-
client = OpenAI(api_key=api_key)
|
|
56
|
-
return BatchManager(
|
|
57
|
-
client=client,
|
|
58
|
-
model=self._model,
|
|
59
|
-
prompt_template=self._system_prompt,
|
|
60
|
-
output_model=self._output_model,
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
def _load_data(self):
|
|
64
|
-
with open(self._input_data_path, "r", encoding="utf-8") as f:
|
|
65
|
-
data = json.load(f)
|
|
66
|
-
data = self._config.export_function(data)
|
|
67
|
-
|
|
68
|
-
# Ensure data is a list of dicts with 'id' and 'content' as strings
|
|
69
|
-
if not isinstance(data, list):
|
|
70
|
-
raise ValueError(
|
|
71
|
-
"Exported data must be a list of dicts with 'id' and 'content' keys"
|
|
72
|
-
)
|
|
73
|
-
for item in data:
|
|
74
|
-
if not (isinstance(item, dict) and "id" in item and "content" in item):
|
|
75
|
-
raise ValueError(
|
|
76
|
-
f"Item must be a dict with 'id' and 'content' keys. Got: {type(item)}"
|
|
77
|
-
)
|
|
78
|
-
if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
|
|
79
|
-
raise ValueError("'id' and 'content' must be strings.")
|
|
80
|
-
return data
|
|
81
|
-
|
|
82
|
-
def _partition_data(self):
|
|
83
|
-
total_length = sum(len(item["content"]) for item in self._data)
|
|
84
|
-
prompt_length = len(self._system_prompt)
|
|
85
|
-
total = total_length + (prompt_length * len(self._data))
|
|
86
|
-
calculation = total / self._config.CHARS_PER_TOKEN
|
|
87
|
-
logger.info(
|
|
88
|
-
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
89
|
-
)
|
|
90
|
-
if calculation < self._config.MAX_TOTAL_TOKENS:
|
|
91
|
-
self._parts = [self._data]
|
|
92
|
-
else:
|
|
93
|
-
# Partition into chunks of MAX_BATCH_SIZE
|
|
94
|
-
self._parts = [
|
|
95
|
-
self._data[i : i + self._config.MAX_BATCH_SIZE]
|
|
96
|
-
for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
|
|
97
|
-
]
|
|
98
|
-
logger.info(f"Data split into {len(self._parts)} part(s)")
|
|
99
|
-
|
|
100
|
-
def _submit_all_jobs(self) -> None:
|
|
101
|
-
for idx, part in enumerate(self._parts):
|
|
102
|
-
if self._result_exists(idx):
|
|
103
|
-
logger.info(f"Skipping part {idx + 1}: result already exists.")
|
|
104
|
-
continue
|
|
105
|
-
part_job_name = (
|
|
106
|
-
f"{self._job_name}_part_{idx + 1}"
|
|
107
|
-
if len(self._parts) > 1
|
|
108
|
-
else self._job_name
|
|
109
|
-
)
|
|
110
|
-
# If a job with this name already exists, register and skip submitting
|
|
111
|
-
existing_job = self._manager._load_state(part_job_name)
|
|
112
|
-
if existing_job:
|
|
113
|
-
logger.info(
|
|
114
|
-
f"Skipping part {idx + 1}: job already exists ({part_job_name})."
|
|
115
|
-
)
|
|
116
|
-
self._part_idx_to_job_name[idx] = part_job_name
|
|
117
|
-
self._part_attempts.setdefault(idx, 0)
|
|
118
|
-
continue
|
|
119
|
-
|
|
120
|
-
payload = part
|
|
121
|
-
logger.info(
|
|
122
|
-
f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
|
|
123
|
-
)
|
|
124
|
-
self._manager.start(payload, job_name=part_job_name)
|
|
125
|
-
self._part_idx_to_job_name[idx] = part_job_name
|
|
126
|
-
self._part_attempts.setdefault(idx, 0)
|
|
127
|
-
# This is added for letting file get uploaded, before starting the next part.
|
|
128
|
-
logger.info("Uploading...")
|
|
129
|
-
time.sleep(30)
|
|
130
|
-
|
|
131
|
-
def _save_results(
|
|
132
|
-
self,
|
|
133
|
-
output_data: list[dict[str, Any]] | dict[str, Any],
|
|
134
|
-
log: list[Any],
|
|
135
|
-
part_idx: int,
|
|
136
|
-
):
|
|
137
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
138
|
-
result_path = (
|
|
139
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
140
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
141
|
-
)
|
|
142
|
-
if not output_data:
|
|
143
|
-
logger.info("No output data to save. Skipping this part.")
|
|
144
|
-
return
|
|
145
|
-
else:
|
|
146
|
-
with open(result_path, "w", encoding="utf-8") as f:
|
|
147
|
-
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
148
|
-
if log:
|
|
149
|
-
log_path = (
|
|
150
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
151
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
|
|
152
|
-
)
|
|
153
|
-
with open(log_path, "w", encoding="utf-8") as f:
|
|
154
|
-
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
155
|
-
|
|
156
|
-
def _result_exists(self, part_idx: int) -> bool:
|
|
157
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
158
|
-
result_path = (
|
|
159
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
160
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
161
|
-
)
|
|
162
|
-
return result_path.exists()
|
|
163
|
-
|
|
164
|
-
def run(self):
|
|
165
|
-
"""
|
|
166
|
-
Execute the batch job processing pipeline.
|
|
167
|
-
|
|
168
|
-
Submits jobs, monitors progress, handles retries, and saves results.
|
|
169
|
-
"""
|
|
170
|
-
try:
|
|
171
|
-
# Submit all jobs up-front for concurrent execution
|
|
172
|
-
self._submit_all_jobs()
|
|
173
|
-
pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
|
|
174
|
-
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
175
|
-
# Polling loop
|
|
176
|
-
while pending_parts:
|
|
177
|
-
finished_this_round: list[int] = []
|
|
178
|
-
for part_idx in list(pending_parts):
|
|
179
|
-
job_name = self._part_idx_to_job_name[part_idx]
|
|
180
|
-
status = self._manager.check_status(job_name=job_name)
|
|
181
|
-
logger.info(f"Status for {job_name}: {status}")
|
|
182
|
-
if status == "completed":
|
|
183
|
-
logger.info(
|
|
184
|
-
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
185
|
-
)
|
|
186
|
-
output_data, log = self._manager.fetch_results(
|
|
187
|
-
job_name=job_name, remove_cache=False
|
|
188
|
-
)
|
|
189
|
-
output_data = self._config.import_function(output_data)
|
|
190
|
-
self._save_results(output_data, log, part_idx)
|
|
191
|
-
logger.info(
|
|
192
|
-
f"Fetched and saved results for part {part_idx + 1}."
|
|
193
|
-
)
|
|
194
|
-
finished_this_round.append(part_idx)
|
|
195
|
-
elif status == "failed":
|
|
196
|
-
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
197
|
-
self._part_attempts[part_idx] = attempt
|
|
198
|
-
if attempt <= self._config.max_retries:
|
|
199
|
-
logger.info(
|
|
200
|
-
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
201
|
-
)
|
|
202
|
-
self._manager._clear_state(job_name)
|
|
203
|
-
time.sleep(10)
|
|
204
|
-
payload = self._to_manager_payload(self._parts[part_idx])
|
|
205
|
-
new_job_name = (
|
|
206
|
-
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
207
|
-
)
|
|
208
|
-
self._manager.start(payload, job_name=new_job_name)
|
|
209
|
-
self._part_idx_to_job_name[part_idx] = new_job_name
|
|
210
|
-
else:
|
|
211
|
-
logger.info(
|
|
212
|
-
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
213
|
-
)
|
|
214
|
-
finished_this_round.append(part_idx)
|
|
215
|
-
else:
|
|
216
|
-
# Still running or queued
|
|
217
|
-
continue
|
|
218
|
-
# Remove finished parts
|
|
219
|
-
for part_idx in finished_this_round:
|
|
220
|
-
pending_parts.discard(part_idx)
|
|
221
|
-
if pending_parts:
|
|
222
|
-
logger.info(
|
|
223
|
-
f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
224
|
-
)
|
|
225
|
-
time.sleep(self._config.poll_interval_seconds)
|
|
226
|
-
|
|
227
|
-
except Exception as e:
|
|
228
|
-
raise TextToolsError(f"Batch job execution failed: {e}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|