hamtaa-texttools 1.2.0__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.2.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/METADATA +6 -29
- {hamtaa_texttools-1.2.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/RECORD +9 -13
- texttools/__init__.py +1 -3
- texttools/core/engine.py +10 -0
- texttools/tools/async_tools.py +338 -269
- texttools/tools/sync_tools.py +16 -17
- texttools/batch/__init__.py +0 -0
- texttools/batch/config.py +0 -40
- texttools/batch/manager.py +0 -228
- texttools/batch/runner.py +0 -228
- {hamtaa_texttools-1.2.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.2.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.2.0.dist-info → hamtaa_texttools-1.3.1.dist-info}/top_level.txt +0 -0
texttools/tools/sync_tools.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import sys
|
|
2
1
|
from collections.abc import Callable
|
|
3
2
|
from time import perf_counter
|
|
4
3
|
from typing import Any, Literal
|
|
@@ -66,7 +65,7 @@ class TheTool:
|
|
|
66
65
|
ToolOutput
|
|
67
66
|
|
|
68
67
|
"""
|
|
69
|
-
tool_name =
|
|
68
|
+
tool_name = "categorize"
|
|
70
69
|
start = perf_counter()
|
|
71
70
|
|
|
72
71
|
try:
|
|
@@ -198,7 +197,7 @@ class TheTool:
|
|
|
198
197
|
Returns:
|
|
199
198
|
ToolOutput
|
|
200
199
|
"""
|
|
201
|
-
tool_name =
|
|
200
|
+
tool_name = "extract_keywords"
|
|
202
201
|
start = perf_counter()
|
|
203
202
|
|
|
204
203
|
try:
|
|
@@ -211,7 +210,6 @@ class TheTool:
|
|
|
211
210
|
temperature=temperature,
|
|
212
211
|
logprobs=logprobs,
|
|
213
212
|
top_logprobs=top_logprobs,
|
|
214
|
-
mode=mode,
|
|
215
213
|
number_of_keywords=number_of_keywords,
|
|
216
214
|
validator=validator,
|
|
217
215
|
max_validation_retries=max_validation_retries,
|
|
@@ -219,6 +217,7 @@ class TheTool:
|
|
|
219
217
|
# Internal parameters
|
|
220
218
|
tool_name=tool_name,
|
|
221
219
|
output_model=ListStr,
|
|
220
|
+
mode=mode,
|
|
222
221
|
)
|
|
223
222
|
|
|
224
223
|
metadata = ToolOutputMetadata(
|
|
@@ -272,7 +271,7 @@ class TheTool:
|
|
|
272
271
|
Returns:
|
|
273
272
|
ToolOutput
|
|
274
273
|
"""
|
|
275
|
-
tool_name =
|
|
274
|
+
tool_name = "extract_entities"
|
|
276
275
|
start = perf_counter()
|
|
277
276
|
|
|
278
277
|
try:
|
|
@@ -343,7 +342,7 @@ class TheTool:
|
|
|
343
342
|
Returns:
|
|
344
343
|
ToolOutput
|
|
345
344
|
"""
|
|
346
|
-
tool_name =
|
|
345
|
+
tool_name = "is_question"
|
|
347
346
|
start = perf_counter()
|
|
348
347
|
|
|
349
348
|
try:
|
|
@@ -416,7 +415,7 @@ class TheTool:
|
|
|
416
415
|
Returns:
|
|
417
416
|
ToolOutput
|
|
418
417
|
"""
|
|
419
|
-
tool_name =
|
|
418
|
+
tool_name = "text_to_question"
|
|
420
419
|
start = perf_counter()
|
|
421
420
|
|
|
422
421
|
try:
|
|
@@ -489,7 +488,7 @@ class TheTool:
|
|
|
489
488
|
Returns:
|
|
490
489
|
ToolOutput
|
|
491
490
|
"""
|
|
492
|
-
tool_name =
|
|
491
|
+
tool_name = "merge_questions"
|
|
493
492
|
start = perf_counter()
|
|
494
493
|
|
|
495
494
|
try:
|
|
@@ -562,7 +561,7 @@ class TheTool:
|
|
|
562
561
|
Returns:
|
|
563
562
|
ToolOutput
|
|
564
563
|
"""
|
|
565
|
-
tool_name =
|
|
564
|
+
tool_name = "rewrite"
|
|
566
565
|
start = perf_counter()
|
|
567
566
|
|
|
568
567
|
try:
|
|
@@ -635,7 +634,7 @@ class TheTool:
|
|
|
635
634
|
Returns:
|
|
636
635
|
ToolOutput
|
|
637
636
|
"""
|
|
638
|
-
tool_name =
|
|
637
|
+
tool_name = "subject_to_question"
|
|
639
638
|
start = perf_counter()
|
|
640
639
|
|
|
641
640
|
try:
|
|
@@ -707,7 +706,7 @@ class TheTool:
|
|
|
707
706
|
Returns:
|
|
708
707
|
ToolOutput
|
|
709
708
|
"""
|
|
710
|
-
tool_name =
|
|
709
|
+
tool_name = "summarize"
|
|
711
710
|
start = perf_counter()
|
|
712
711
|
|
|
713
712
|
try:
|
|
@@ -782,7 +781,7 @@ class TheTool:
|
|
|
782
781
|
Returns:
|
|
783
782
|
ToolOutput
|
|
784
783
|
"""
|
|
785
|
-
tool_name =
|
|
784
|
+
tool_name = "translate"
|
|
786
785
|
start = perf_counter()
|
|
787
786
|
|
|
788
787
|
try:
|
|
@@ -900,7 +899,7 @@ class TheTool:
|
|
|
900
899
|
Returns:
|
|
901
900
|
ToolOutput
|
|
902
901
|
"""
|
|
903
|
-
tool_name =
|
|
902
|
+
tool_name = "propositionize"
|
|
904
903
|
start = perf_counter()
|
|
905
904
|
|
|
906
905
|
try:
|
|
@@ -961,7 +960,7 @@ class TheTool:
|
|
|
961
960
|
|
|
962
961
|
Arguments:
|
|
963
962
|
text: The input text
|
|
964
|
-
source_text:
|
|
963
|
+
source_text: The source text that we want to check relation of text to it
|
|
965
964
|
with_analysis: Whether to include detailed reasoning analysis
|
|
966
965
|
output_lang: Language for the output
|
|
967
966
|
user_prompt: Additional instructions
|
|
@@ -975,13 +974,14 @@ class TheTool:
|
|
|
975
974
|
Returns:
|
|
976
975
|
ToolOutput
|
|
977
976
|
"""
|
|
978
|
-
tool_name =
|
|
977
|
+
tool_name = "check_fact"
|
|
979
978
|
start = perf_counter()
|
|
980
979
|
|
|
981
980
|
try:
|
|
982
981
|
operator_output = self._operator.run(
|
|
983
982
|
# User parameters
|
|
984
983
|
text=text,
|
|
984
|
+
source_text=source_text,
|
|
985
985
|
with_analysis=with_analysis,
|
|
986
986
|
output_lang=output_lang,
|
|
987
987
|
user_prompt=user_prompt,
|
|
@@ -995,7 +995,6 @@ class TheTool:
|
|
|
995
995
|
tool_name=tool_name,
|
|
996
996
|
output_model=Bool,
|
|
997
997
|
mode=None,
|
|
998
|
-
source_text=source_text,
|
|
999
998
|
)
|
|
1000
999
|
|
|
1001
1000
|
metadata = ToolOutputMetadata(
|
|
@@ -1049,7 +1048,7 @@ class TheTool:
|
|
|
1049
1048
|
Returns:
|
|
1050
1049
|
ToolOutput
|
|
1051
1050
|
"""
|
|
1052
|
-
tool_name =
|
|
1051
|
+
tool_name = "run_custom"
|
|
1053
1052
|
start = perf_counter()
|
|
1054
1053
|
|
|
1055
1054
|
try:
|
texttools/batch/__init__.py
DELETED
|
File without changes
|
texttools/batch/config.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def export_data(data) -> list[dict[str, str]]:
|
|
7
|
-
"""
|
|
8
|
-
Produces a structure of the following form from an initial data structure:
|
|
9
|
-
[{"id": str, "text": str},...]
|
|
10
|
-
"""
|
|
11
|
-
return data
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def import_data(data) -> Any:
|
|
15
|
-
"""
|
|
16
|
-
Takes the output and adds and aggregates it to the original structure.
|
|
17
|
-
"""
|
|
18
|
-
return data
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@dataclass
|
|
22
|
-
class BatchConfig:
|
|
23
|
-
"""
|
|
24
|
-
Configuration for batch job runner.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
system_prompt: str = ""
|
|
28
|
-
job_name: str = ""
|
|
29
|
-
input_data_path: str = ""
|
|
30
|
-
output_data_filename: str = ""
|
|
31
|
-
model: str = "gpt-4.1-mini"
|
|
32
|
-
MAX_BATCH_SIZE: int = 100
|
|
33
|
-
MAX_TOTAL_TOKENS: int = 2_000_000
|
|
34
|
-
CHARS_PER_TOKEN: float = 2.7
|
|
35
|
-
PROMPT_TOKEN_MULTIPLIER: int = 1_000
|
|
36
|
-
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
37
|
-
import_function: Callable = import_data
|
|
38
|
-
export_function: Callable = export_data
|
|
39
|
-
poll_interval_seconds: int = 30
|
|
40
|
-
max_retries: int = 3
|
texttools/batch/manager.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import uuid
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Any, Type, TypeVar
|
|
6
|
-
|
|
7
|
-
from openai import OpenAI
|
|
8
|
-
from openai.lib._pydantic import to_strict_json_schema
|
|
9
|
-
from pydantic import BaseModel
|
|
10
|
-
|
|
11
|
-
# Base Model type for output models
|
|
12
|
-
T = TypeVar("T", bound=BaseModel)
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger("texttools.batch_manager")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class BatchManager:
|
|
18
|
-
"""
|
|
19
|
-
Manages batch processing jobs for OpenAI's chat completions with structured outputs.
|
|
20
|
-
|
|
21
|
-
Handles the full lifecycle of a batch job: creating tasks from input texts,
|
|
22
|
-
starting the job, monitoring status, and fetching results. Results are automatically
|
|
23
|
-
parsed into the specified Pydantic output model. Job state is persisted to disk.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
client: OpenAI,
|
|
29
|
-
model: str,
|
|
30
|
-
output_model: Type[T],
|
|
31
|
-
prompt_template: str,
|
|
32
|
-
state_dir: Path = Path(".batch_jobs"),
|
|
33
|
-
custom_json_schema_obj_str: dict | None = None,
|
|
34
|
-
**client_kwargs: Any,
|
|
35
|
-
):
|
|
36
|
-
self._client = client
|
|
37
|
-
self._model = model
|
|
38
|
-
self._output_model = output_model
|
|
39
|
-
self._prompt_template = prompt_template
|
|
40
|
-
self._state_dir = state_dir
|
|
41
|
-
self._custom_json_schema_obj_str = custom_json_schema_obj_str
|
|
42
|
-
self._client_kwargs = client_kwargs
|
|
43
|
-
self._dict_input = False
|
|
44
|
-
self._state_dir.mkdir(parents=True, exist_ok=True)
|
|
45
|
-
|
|
46
|
-
if custom_json_schema_obj_str and not isinstance(
|
|
47
|
-
custom_json_schema_obj_str, dict
|
|
48
|
-
):
|
|
49
|
-
raise ValueError("Schema should be a dict")
|
|
50
|
-
|
|
51
|
-
def _state_file(self, job_name: str) -> Path:
|
|
52
|
-
return self._state_dir / f"{job_name}.json"
|
|
53
|
-
|
|
54
|
-
def _load_state(self, job_name: str) -> list[dict[str, Any]]:
|
|
55
|
-
"""
|
|
56
|
-
Loads the state (job information) from the state file for the given job name.
|
|
57
|
-
Returns an empty list if the state file does not exist.
|
|
58
|
-
"""
|
|
59
|
-
path = self._state_file(job_name)
|
|
60
|
-
if path.exists():
|
|
61
|
-
with open(path, "r", encoding="utf-8") as f:
|
|
62
|
-
return json.load(f)
|
|
63
|
-
return []
|
|
64
|
-
|
|
65
|
-
def _save_state(self, job_name: str, jobs: list[dict[str, Any]]) -> None:
|
|
66
|
-
"""
|
|
67
|
-
Saves the job state to the state file for the given job name.
|
|
68
|
-
"""
|
|
69
|
-
with open(self._state_file(job_name), "w", encoding="utf-8") as f:
|
|
70
|
-
json.dump(jobs, f)
|
|
71
|
-
|
|
72
|
-
def _clear_state(self, job_name: str) -> None:
|
|
73
|
-
"""
|
|
74
|
-
Deletes the state file for the given job name if it exists.
|
|
75
|
-
"""
|
|
76
|
-
path = self._state_file(job_name)
|
|
77
|
-
if path.exists():
|
|
78
|
-
path.unlink()
|
|
79
|
-
|
|
80
|
-
def _build_task(self, text: str, idx: str) -> dict[str, Any]:
|
|
81
|
-
"""
|
|
82
|
-
Builds a single task dictionary for the batch job, including the prompt, model, and response format configuration.
|
|
83
|
-
"""
|
|
84
|
-
response_format_config: dict[str, Any]
|
|
85
|
-
|
|
86
|
-
if self._custom_json_schema_obj_str:
|
|
87
|
-
response_format_config = {
|
|
88
|
-
"type": "json_schema",
|
|
89
|
-
"json_schema": self._custom_json_schema_obj_str,
|
|
90
|
-
}
|
|
91
|
-
else:
|
|
92
|
-
raw_schema = to_strict_json_schema(self._output_model)
|
|
93
|
-
response_format_config = {
|
|
94
|
-
"type": "json_schema",
|
|
95
|
-
"json_schema": {
|
|
96
|
-
"name": self._output_model.__name__,
|
|
97
|
-
"schema": raw_schema,
|
|
98
|
-
},
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
return {
|
|
102
|
-
"custom_id": str(idx),
|
|
103
|
-
"method": "POST",
|
|
104
|
-
"url": "/v1/chat/completions",
|
|
105
|
-
"body": {
|
|
106
|
-
"model": self.model,
|
|
107
|
-
"messages": [
|
|
108
|
-
{"role": "system", "content": self._prompt_template},
|
|
109
|
-
{"role": "user", "content": text},
|
|
110
|
-
],
|
|
111
|
-
"response_format": response_format_config,
|
|
112
|
-
**self._client_kwargs,
|
|
113
|
-
},
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
def _prepare_file(self, payload: list[str] | list[dict[str, str]]) -> Path:
|
|
117
|
-
"""
|
|
118
|
-
Prepares a JSONL file containing all tasks for the batch job, based on the input payload.
|
|
119
|
-
Returns the path to the created file.
|
|
120
|
-
"""
|
|
121
|
-
if not payload:
|
|
122
|
-
raise ValueError("Payload must not be empty")
|
|
123
|
-
if isinstance(payload[0], str):
|
|
124
|
-
tasks = [self._build_task(text, uuid.uuid4().hex) for text in payload]
|
|
125
|
-
elif isinstance(payload[0], dict):
|
|
126
|
-
tasks = [self._build_task(dic["text"], dic["id"]) for dic in payload]
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
raise TypeError(
|
|
130
|
-
"The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}"
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
file_path = self._state_dir / f"batch_{uuid.uuid4().hex}.jsonl"
|
|
134
|
-
with open(file_path, "w", encoding="utf-8") as f:
|
|
135
|
-
for task in tasks:
|
|
136
|
-
f.write(json.dumps(task) + "\n")
|
|
137
|
-
return file_path
|
|
138
|
-
|
|
139
|
-
def start(self, payload: list[str | dict[str, str]], job_name: str):
|
|
140
|
-
"""
|
|
141
|
-
Starts a new batch job by uploading the prepared file and creating a batch job on the server.
|
|
142
|
-
If a job with the same name already exists, it does nothing.
|
|
143
|
-
"""
|
|
144
|
-
if self._load_state(job_name):
|
|
145
|
-
return
|
|
146
|
-
|
|
147
|
-
path = self._prepare_file(payload)
|
|
148
|
-
upload = self._client.files.create(file=open(path, "rb"), purpose="batch")
|
|
149
|
-
job = self._client.batches.create(
|
|
150
|
-
input_file_id=upload.id,
|
|
151
|
-
endpoint="/v1/chat/completions",
|
|
152
|
-
completion_window="24h",
|
|
153
|
-
).to_dict()
|
|
154
|
-
self._save_state(job_name, [job])
|
|
155
|
-
|
|
156
|
-
def check_status(self, job_name: str) -> str:
|
|
157
|
-
"""
|
|
158
|
-
Checks and returns the current status of the batch job with the given job name.
|
|
159
|
-
Updates the job state with the latest information from the server.
|
|
160
|
-
"""
|
|
161
|
-
job = self._load_state(job_name)[0]
|
|
162
|
-
if not job:
|
|
163
|
-
return "completed"
|
|
164
|
-
|
|
165
|
-
info = self._client.batches.retrieve(job["id"])
|
|
166
|
-
job = info.to_dict()
|
|
167
|
-
self._save_state(job_name, [job])
|
|
168
|
-
logger.info("Batch job status: %s", job)
|
|
169
|
-
return job["status"]
|
|
170
|
-
|
|
171
|
-
def fetch_results(
|
|
172
|
-
self, job_name: str, remove_cache: bool = True
|
|
173
|
-
) -> tuple[dict[str, str], list]:
|
|
174
|
-
"""
|
|
175
|
-
Fetches the results of a completed batch job. Optionally saves the results to a file and/or removes the job cache.
|
|
176
|
-
Returns a tuple containing the parsed results and a log of errors (if any).
|
|
177
|
-
"""
|
|
178
|
-
job = self._load_state(job_name)[0]
|
|
179
|
-
if not job:
|
|
180
|
-
return {}
|
|
181
|
-
batch_id = job["id"]
|
|
182
|
-
|
|
183
|
-
info = self._client.batches.retrieve(batch_id)
|
|
184
|
-
out_file_id = info.output_file_id
|
|
185
|
-
if not out_file_id:
|
|
186
|
-
error_file_id = info.error_file_id
|
|
187
|
-
if error_file_id:
|
|
188
|
-
err_content = (
|
|
189
|
-
self._client.files.content(error_file_id).read().decode("utf-8")
|
|
190
|
-
)
|
|
191
|
-
logger.error("Error file content:", err_content)
|
|
192
|
-
return {}
|
|
193
|
-
|
|
194
|
-
content = self._client.files.content(out_file_id).read().decode("utf-8")
|
|
195
|
-
lines = content.splitlines()
|
|
196
|
-
results = {}
|
|
197
|
-
log = []
|
|
198
|
-
for line in lines:
|
|
199
|
-
result = json.loads(line)
|
|
200
|
-
custom_id = result["custom_id"]
|
|
201
|
-
if result["response"]["status_code"] == 200:
|
|
202
|
-
content = result["response"]["body"]["choices"][0]["message"]["content"]
|
|
203
|
-
try:
|
|
204
|
-
parsed_content = json.loads(content)
|
|
205
|
-
model_instance = self._output_model(**parsed_content)
|
|
206
|
-
results[custom_id] = model_instance.model_dump(mode="json")
|
|
207
|
-
except json.JSONDecodeError:
|
|
208
|
-
results[custom_id] = {"error": "Failed to parse content as JSON"}
|
|
209
|
-
error_d = {custom_id: results[custom_id]}
|
|
210
|
-
log.append(error_d)
|
|
211
|
-
except Exception as e:
|
|
212
|
-
results[custom_id] = {"error": str(e)}
|
|
213
|
-
error_d = {custom_id: results[custom_id]}
|
|
214
|
-
log.append(error_d)
|
|
215
|
-
else:
|
|
216
|
-
error_message = (
|
|
217
|
-
result["response"]["body"]
|
|
218
|
-
.get("error", {})
|
|
219
|
-
.get("message", "Unknown error")
|
|
220
|
-
)
|
|
221
|
-
results[custom_id] = {"error": error_message}
|
|
222
|
-
error_d = {custom_id: results[custom_id]}
|
|
223
|
-
log.append(error_d)
|
|
224
|
-
|
|
225
|
-
if remove_cache:
|
|
226
|
-
self._clear_state(job_name)
|
|
227
|
-
|
|
228
|
-
return results, log
|
texttools/batch/runner.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import os
|
|
4
|
-
import time
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any, Type, TypeVar
|
|
7
|
-
|
|
8
|
-
from dotenv import load_dotenv
|
|
9
|
-
from openai import OpenAI
|
|
10
|
-
from pydantic import BaseModel
|
|
11
|
-
|
|
12
|
-
from ..core.exceptions import TextToolsError
|
|
13
|
-
from ..core.internal_models import Str
|
|
14
|
-
from .config import BatchConfig
|
|
15
|
-
from .manager import BatchManager
|
|
16
|
-
|
|
17
|
-
# Base Model type for output models
|
|
18
|
-
T = TypeVar("T", bound=BaseModel)
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger("texttools.batch_runner")
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class BatchRunner:
|
|
24
|
-
"""
|
|
25
|
-
Handles running batch jobs using a batch manager and configuration.
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
self, config: BatchConfig = BatchConfig(), output_model: Type[T] = Str
|
|
30
|
-
):
|
|
31
|
-
try:
|
|
32
|
-
self._config = config
|
|
33
|
-
self._system_prompt = config.system_prompt
|
|
34
|
-
self._job_name = config.job_name
|
|
35
|
-
self._input_data_path = config.input_data_path
|
|
36
|
-
self._output_data_filename = config.output_data_filename
|
|
37
|
-
self._model = config.model
|
|
38
|
-
self._output_model = output_model
|
|
39
|
-
self._manager = self._init_manager()
|
|
40
|
-
self._data = self._load_data()
|
|
41
|
-
self._parts: list[list[dict[str, Any]]] = []
|
|
42
|
-
# Map part index to job name
|
|
43
|
-
self._part_idx_to_job_name: dict[int, str] = {}
|
|
44
|
-
# Track retry attempts per part
|
|
45
|
-
self._part_attempts: dict[int, int] = {}
|
|
46
|
-
self._partition_data()
|
|
47
|
-
Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
48
|
-
|
|
49
|
-
except Exception as e:
|
|
50
|
-
raise TextToolsError(f"Batch runner initialization failed: {e}")
|
|
51
|
-
|
|
52
|
-
def _init_manager(self) -> BatchManager:
|
|
53
|
-
load_dotenv()
|
|
54
|
-
api_key = os.getenv("OPENAI_API_KEY")
|
|
55
|
-
client = OpenAI(api_key=api_key)
|
|
56
|
-
return BatchManager(
|
|
57
|
-
client=client,
|
|
58
|
-
model=self._model,
|
|
59
|
-
prompt_template=self._system_prompt,
|
|
60
|
-
output_model=self._output_model,
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
def _load_data(self):
|
|
64
|
-
with open(self._input_data_path, "r", encoding="utf-8") as f:
|
|
65
|
-
data = json.load(f)
|
|
66
|
-
data = self._config.export_function(data)
|
|
67
|
-
|
|
68
|
-
# Ensure data is a list of dicts with 'id' and 'content' as strings
|
|
69
|
-
if not isinstance(data, list):
|
|
70
|
-
raise ValueError(
|
|
71
|
-
"Exported data must be a list of dicts with 'id' and 'content' keys"
|
|
72
|
-
)
|
|
73
|
-
for item in data:
|
|
74
|
-
if not (isinstance(item, dict) and "id" in item and "content" in item):
|
|
75
|
-
raise ValueError(
|
|
76
|
-
f"Item must be a dict with 'id' and 'content' keys. Got: {type(item)}"
|
|
77
|
-
)
|
|
78
|
-
if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
|
|
79
|
-
raise ValueError("'id' and 'content' must be strings.")
|
|
80
|
-
return data
|
|
81
|
-
|
|
82
|
-
def _partition_data(self):
|
|
83
|
-
total_length = sum(len(item["content"]) for item in self._data)
|
|
84
|
-
prompt_length = len(self._system_prompt)
|
|
85
|
-
total = total_length + (prompt_length * len(self._data))
|
|
86
|
-
calculation = total / self._config.CHARS_PER_TOKEN
|
|
87
|
-
logger.info(
|
|
88
|
-
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
89
|
-
)
|
|
90
|
-
if calculation < self._config.MAX_TOTAL_TOKENS:
|
|
91
|
-
self._parts = [self._data]
|
|
92
|
-
else:
|
|
93
|
-
# Partition into chunks of MAX_BATCH_SIZE
|
|
94
|
-
self._parts = [
|
|
95
|
-
self._data[i : i + self._config.MAX_BATCH_SIZE]
|
|
96
|
-
for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
|
|
97
|
-
]
|
|
98
|
-
logger.info(f"Data split into {len(self._parts)} part(s)")
|
|
99
|
-
|
|
100
|
-
def _submit_all_jobs(self) -> None:
|
|
101
|
-
for idx, part in enumerate(self._parts):
|
|
102
|
-
if self._result_exists(idx):
|
|
103
|
-
logger.info(f"Skipping part {idx + 1}: result already exists.")
|
|
104
|
-
continue
|
|
105
|
-
part_job_name = (
|
|
106
|
-
f"{self._job_name}_part_{idx + 1}"
|
|
107
|
-
if len(self._parts) > 1
|
|
108
|
-
else self._job_name
|
|
109
|
-
)
|
|
110
|
-
# If a job with this name already exists, register and skip submitting
|
|
111
|
-
existing_job = self._manager._load_state(part_job_name)
|
|
112
|
-
if existing_job:
|
|
113
|
-
logger.info(
|
|
114
|
-
f"Skipping part {idx + 1}: job already exists ({part_job_name})."
|
|
115
|
-
)
|
|
116
|
-
self._part_idx_to_job_name[idx] = part_job_name
|
|
117
|
-
self._part_attempts.setdefault(idx, 0)
|
|
118
|
-
continue
|
|
119
|
-
|
|
120
|
-
payload = part
|
|
121
|
-
logger.info(
|
|
122
|
-
f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
|
|
123
|
-
)
|
|
124
|
-
self._manager.start(payload, job_name=part_job_name)
|
|
125
|
-
self._part_idx_to_job_name[idx] = part_job_name
|
|
126
|
-
self._part_attempts.setdefault(idx, 0)
|
|
127
|
-
# This is added for letting file get uploaded, before starting the next part.
|
|
128
|
-
logger.info("Uploading...")
|
|
129
|
-
time.sleep(30)
|
|
130
|
-
|
|
131
|
-
def _save_results(
|
|
132
|
-
self,
|
|
133
|
-
output_data: list[dict[str, Any]] | dict[str, Any],
|
|
134
|
-
log: list[Any],
|
|
135
|
-
part_idx: int,
|
|
136
|
-
):
|
|
137
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
138
|
-
result_path = (
|
|
139
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
140
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
141
|
-
)
|
|
142
|
-
if not output_data:
|
|
143
|
-
logger.info("No output data to save. Skipping this part.")
|
|
144
|
-
return
|
|
145
|
-
else:
|
|
146
|
-
with open(result_path, "w", encoding="utf-8") as f:
|
|
147
|
-
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
148
|
-
if log:
|
|
149
|
-
log_path = (
|
|
150
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
151
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
|
|
152
|
-
)
|
|
153
|
-
with open(log_path, "w", encoding="utf-8") as f:
|
|
154
|
-
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
155
|
-
|
|
156
|
-
def _result_exists(self, part_idx: int) -> bool:
|
|
157
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
158
|
-
result_path = (
|
|
159
|
-
Path(self._config.BASE_OUTPUT_DIR)
|
|
160
|
-
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
161
|
-
)
|
|
162
|
-
return result_path.exists()
|
|
163
|
-
|
|
164
|
-
def run(self):
|
|
165
|
-
"""
|
|
166
|
-
Execute the batch job processing pipeline.
|
|
167
|
-
|
|
168
|
-
Submits jobs, monitors progress, handles retries, and saves results.
|
|
169
|
-
"""
|
|
170
|
-
try:
|
|
171
|
-
# Submit all jobs up-front for concurrent execution
|
|
172
|
-
self._submit_all_jobs()
|
|
173
|
-
pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
|
|
174
|
-
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
175
|
-
# Polling loop
|
|
176
|
-
while pending_parts:
|
|
177
|
-
finished_this_round: list[int] = []
|
|
178
|
-
for part_idx in list(pending_parts):
|
|
179
|
-
job_name = self._part_idx_to_job_name[part_idx]
|
|
180
|
-
status = self._manager.check_status(job_name=job_name)
|
|
181
|
-
logger.info(f"Status for {job_name}: {status}")
|
|
182
|
-
if status == "completed":
|
|
183
|
-
logger.info(
|
|
184
|
-
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
185
|
-
)
|
|
186
|
-
output_data, log = self._manager.fetch_results(
|
|
187
|
-
job_name=job_name, remove_cache=False
|
|
188
|
-
)
|
|
189
|
-
output_data = self._config.import_function(output_data)
|
|
190
|
-
self._save_results(output_data, log, part_idx)
|
|
191
|
-
logger.info(
|
|
192
|
-
f"Fetched and saved results for part {part_idx + 1}."
|
|
193
|
-
)
|
|
194
|
-
finished_this_round.append(part_idx)
|
|
195
|
-
elif status == "failed":
|
|
196
|
-
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
197
|
-
self._part_attempts[part_idx] = attempt
|
|
198
|
-
if attempt <= self._config.max_retries:
|
|
199
|
-
logger.info(
|
|
200
|
-
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
201
|
-
)
|
|
202
|
-
self._manager._clear_state(job_name)
|
|
203
|
-
time.sleep(10)
|
|
204
|
-
payload = self._to_manager_payload(self._parts[part_idx])
|
|
205
|
-
new_job_name = (
|
|
206
|
-
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
207
|
-
)
|
|
208
|
-
self._manager.start(payload, job_name=new_job_name)
|
|
209
|
-
self._part_idx_to_job_name[part_idx] = new_job_name
|
|
210
|
-
else:
|
|
211
|
-
logger.info(
|
|
212
|
-
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
213
|
-
)
|
|
214
|
-
finished_this_round.append(part_idx)
|
|
215
|
-
else:
|
|
216
|
-
# Still running or queued
|
|
217
|
-
continue
|
|
218
|
-
# Remove finished parts
|
|
219
|
-
for part_idx in finished_this_round:
|
|
220
|
-
pending_parts.discard(part_idx)
|
|
221
|
-
if pending_parts:
|
|
222
|
-
logger.info(
|
|
223
|
-
f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
224
|
-
)
|
|
225
|
-
time.sleep(self._config.poll_interval_seconds)
|
|
226
|
-
|
|
227
|
-
except Exception as e:
|
|
228
|
-
raise TextToolsError(f"Batch job execution failed: {e}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|