hamtaa-texttools 0.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hamtaa-texttools might be problematic. Click here for more details.
- hamtaa_texttools-0.1.43.dist-info/METADATA +60 -0
- hamtaa_texttools-0.1.43.dist-info/RECORD +60 -0
- hamtaa_texttools-0.1.43.dist-info/WHEEL +5 -0
- hamtaa_texttools-0.1.43.dist-info/top_level.txt +1 -0
- texttools/__init__.py +26 -0
- texttools/base/__init__.py +3 -0
- texttools/base/base_categorizer.py +40 -0
- texttools/base/base_keyword_extractor.py +35 -0
- texttools/base/base_ner_extractor.py +61 -0
- texttools/base/base_question_detector.py +35 -0
- texttools/base/base_question_generator.py +99 -0
- texttools/base/base_question_merger.py +59 -0
- texttools/base/base_question_rewriter.py +61 -0
- texttools/base/base_router.py +33 -0
- texttools/base/base_summarizer.py +55 -0
- texttools/base/base_task_performer.py +53 -0
- texttools/base/base_translator.py +38 -0
- texttools/batch_manager/__init__.py +2 -0
- texttools/batch_manager/batch_manager.py +241 -0
- texttools/batch_manager/batch_runner.py +207 -0
- texttools/formatter/__init__.py +1 -0
- texttools/formatter/base.py +26 -0
- texttools/formatter/gemma3_formatter.py +51 -0
- texttools/handlers/__init__.py +6 -0
- texttools/handlers/categorizer/__init__.py +6 -0
- texttools/handlers/categorizer/categorizer.py +61 -0
- texttools/handlers/handlers.py +88 -0
- texttools/tools/__init__.py +33 -0
- texttools/tools/categorizer/__init__.py +2 -0
- texttools/tools/categorizer/encoder_model/__init__.py +1 -0
- texttools/tools/categorizer/encoder_model/encoder_vectorizer.py +51 -0
- texttools/tools/categorizer/llm/__init__.py +2 -0
- texttools/tools/categorizer/llm/gemma_categorizer.py +169 -0
- texttools/tools/categorizer/llm/openai_categorizer.py +80 -0
- texttools/tools/keyword_extractor/__init__.py +1 -0
- texttools/tools/keyword_extractor/gemma_extractor.py +138 -0
- texttools/tools/merger/__init__.py +2 -0
- texttools/tools/merger/gemma_question_merger.py +214 -0
- texttools/tools/ner/__init__.py +1 -0
- texttools/tools/ner/gemma_ner_extractor.py +157 -0
- texttools/tools/question_detector/__init__.py +2 -0
- texttools/tools/question_detector/gemma_detector.py +130 -0
- texttools/tools/question_detector/llm_detector.py +112 -0
- texttools/tools/question_generator/__init__.py +1 -0
- texttools/tools/question_generator/gemma_question_generator.py +198 -0
- texttools/tools/reranker/__init__.py +3 -0
- texttools/tools/reranker/reranker.py +137 -0
- texttools/tools/reranker/scorer.py +216 -0
- texttools/tools/reranker/sorter.py +278 -0
- texttools/tools/rewriter/__init__.py +2 -0
- texttools/tools/rewriter/gemma_question_rewriter.py +213 -0
- texttools/tools/router/__init__.py +0 -0
- texttools/tools/router/gemma_router.py +169 -0
- texttools/tools/subject_to_question/__init__.py +1 -0
- texttools/tools/subject_to_question/gemma_question_generator.py +224 -0
- texttools/tools/summarizer/__init__.py +2 -0
- texttools/tools/summarizer/gemma_summarizer.py +140 -0
- texttools/tools/summarizer/llm_summerizer.py +108 -0
- texttools/tools/translator/__init__.py +1 -0
- texttools/tools/translator/gemma_translator.py +202 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseTaskPerformer(ABC):
|
|
7
|
+
"""
|
|
8
|
+
Base class for common functionalities of LLM-based task performers.
|
|
9
|
+
This includes features like text preprocessing and dispatching results
|
|
10
|
+
to registered handlers.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, handlers: Optional[list[Any]] = None):
|
|
14
|
+
"""
|
|
15
|
+
Initializes the BaseTaskPerformer with optional result handlers.
|
|
16
|
+
|
|
17
|
+
:param handlers: An optional list of handlers to process the component's results.
|
|
18
|
+
"""
|
|
19
|
+
self.handlers = handlers or []
|
|
20
|
+
|
|
21
|
+
def _preprocess(self, text: str) -> str:
|
|
22
|
+
"""
|
|
23
|
+
Preprocesses input text by stripping leading/trailing whitespace.
|
|
24
|
+
This can be extended for more complex preprocessing if needed.
|
|
25
|
+
|
|
26
|
+
:param text: The raw input text.
|
|
27
|
+
:return: The preprocessed text.
|
|
28
|
+
"""
|
|
29
|
+
return text.strip()
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def perform(self, *args, **kwargs) -> Any:
|
|
33
|
+
"""
|
|
34
|
+
Abstract method to be implemented by concrete task performers.
|
|
35
|
+
This method will execute the primary task of the class (e.g., scoring, sorting).
|
|
36
|
+
The signature of args and kwargs will vary based on the specific task.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
def _dispatch(self, result_data: dict[str, Any]) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Dispatches the component's results to any registered result handlers.
|
|
43
|
+
Each handler receives a dictionary of result data.
|
|
44
|
+
|
|
45
|
+
:param result_data: A dictionary containing the results specific to the component.
|
|
46
|
+
"""
|
|
47
|
+
for handler in self.handlers:
|
|
48
|
+
try:
|
|
49
|
+
handler.handle(result_data)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logging.error(
|
|
52
|
+
f"Handler {handler.__class__.__name__} failed: {e}", exc_info=True
|
|
53
|
+
)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseTranslator(ABC):
|
|
6
|
+
"""
|
|
7
|
+
Base class for all translators that output a translated string.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
handlers: Optional[list[Any]] = None,
|
|
13
|
+
):
|
|
14
|
+
self.handlers = handlers or []
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def translate(
|
|
18
|
+
self, text: str, target_language: str, source_language: Optional[str] = None
|
|
19
|
+
) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Translate the input text from the source language to the target language.
|
|
22
|
+
Should return the translated string.
|
|
23
|
+
The source_language can be optional if the LLM can detect it automatically.
|
|
24
|
+
"""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
def preprocess(self, text: str) -> str:
|
|
28
|
+
"""
|
|
29
|
+
Optional text preprocessing step.
|
|
30
|
+
"""
|
|
31
|
+
return text.strip()
|
|
32
|
+
|
|
33
|
+
def _dispatch(self, result: dict) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Dispatch the result to handlers.
|
|
36
|
+
"""
|
|
37
|
+
for handler in self.handlers:
|
|
38
|
+
handler.handle(result)
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import uuid
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Optional, Type
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from openai import OpenAI
|
|
8
|
+
from openai.lib._pydantic import to_strict_json_schema
|
|
9
|
+
# from openai.lib._parsing._completions import type_to_response_format_param
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SimpleBatchManager:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
client: OpenAI,
|
|
16
|
+
model: str,
|
|
17
|
+
output_model: Type[BaseModel],
|
|
18
|
+
prompt_template: str,
|
|
19
|
+
handlers: Optional[list[Any]] = None,
|
|
20
|
+
state_dir: Path = Path(".batch_jobs"),
|
|
21
|
+
custom_json_schema_obj_str: Optional[dict] = None,
|
|
22
|
+
**client_kwargs: Any,
|
|
23
|
+
):
|
|
24
|
+
self.client = client
|
|
25
|
+
self.model = model
|
|
26
|
+
self.output_model = output_model
|
|
27
|
+
self.prompt_template = prompt_template
|
|
28
|
+
self.handlers = handlers or []
|
|
29
|
+
self.state_dir = state_dir
|
|
30
|
+
self.state_dir.mkdir(parents=True, exist_ok=True)
|
|
31
|
+
self.custom_json_schema_obj_str = custom_json_schema_obj_str
|
|
32
|
+
self.client_kwargs = client_kwargs
|
|
33
|
+
self.dict_input = False
|
|
34
|
+
|
|
35
|
+
if self.custom_json_schema_obj_str:
|
|
36
|
+
if self.custom_json_schema_obj_str is not dict:
|
|
37
|
+
raise ValueError("schema should be a dict")
|
|
38
|
+
|
|
39
|
+
def _state_file(self, job_name: str) -> Path:
|
|
40
|
+
return self.state_dir / f"{job_name}.json"
|
|
41
|
+
|
|
42
|
+
def _load_state(self, job_name: str) -> list[dict[str, Any]]:
|
|
43
|
+
"""
|
|
44
|
+
Loads the state (job information) from the state file for the given job name.
|
|
45
|
+
Returns an empty list if the state file does not exist.
|
|
46
|
+
"""
|
|
47
|
+
path = self._state_file(job_name)
|
|
48
|
+
if path.exists():
|
|
49
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
50
|
+
return json.load(f)
|
|
51
|
+
return []
|
|
52
|
+
|
|
53
|
+
def _save_state(self, job_name: str, jobs: list[dict[str, Any]]) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Saves the job state to the state file for the given job name.
|
|
56
|
+
"""
|
|
57
|
+
with open(self._state_file(job_name), "w", encoding="utf-8") as f:
|
|
58
|
+
json.dump(jobs, f)
|
|
59
|
+
|
|
60
|
+
def _clear_state(self, job_name: str) -> None:
|
|
61
|
+
"""
|
|
62
|
+
Deletes the state file for the given job name if it exists.
|
|
63
|
+
"""
|
|
64
|
+
path = self._state_file(job_name)
|
|
65
|
+
if path.exists():
|
|
66
|
+
path.unlink()
|
|
67
|
+
|
|
68
|
+
def _build_task(self, text: str, idx: str) -> dict[str, Any]:
|
|
69
|
+
"""
|
|
70
|
+
Builds a single task dictionary for the batch job, including the prompt, model, and response format configuration.
|
|
71
|
+
"""
|
|
72
|
+
response_format_config: dict[str, Any]
|
|
73
|
+
if self.custom_json_schema_obj_str:
|
|
74
|
+
# try:
|
|
75
|
+
# parsed_custom_schema = json.loads(self.custom_json_schema_obj_str)
|
|
76
|
+
response_format_config = {
|
|
77
|
+
"type": "json_schema",
|
|
78
|
+
"json_schema": self.custom_json_schema_obj_str,
|
|
79
|
+
}
|
|
80
|
+
# except json.JSONDecodeError as e:
|
|
81
|
+
# raise ValueError(
|
|
82
|
+
# "Failed to parse custom_json_schema_obj_str. "
|
|
83
|
+
# "Please ensure it's a valid JSON string."
|
|
84
|
+
# ) from e
|
|
85
|
+
else:
|
|
86
|
+
raw_schema = to_strict_json_schema(self.output_model)
|
|
87
|
+
response_format_config = {
|
|
88
|
+
"type": "json_schema",
|
|
89
|
+
"json_schema": {
|
|
90
|
+
"name": self.output_model.__name__,
|
|
91
|
+
"schema": raw_schema,
|
|
92
|
+
},
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return {
|
|
96
|
+
"custom_id": str(idx),
|
|
97
|
+
"method": "POST",
|
|
98
|
+
"url": "/v1/chat/completions",
|
|
99
|
+
"body": {
|
|
100
|
+
"model": self.model,
|
|
101
|
+
"messages": [
|
|
102
|
+
{"role": "system", "content": self.prompt_template},
|
|
103
|
+
{"role": "user", "content": text},
|
|
104
|
+
],
|
|
105
|
+
"response_format": response_format_config,
|
|
106
|
+
**self.client_kwargs,
|
|
107
|
+
},
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
def _prepare_file(self, payload: list[str] | list[dict[str, str]]) -> Path:
|
|
111
|
+
"""
|
|
112
|
+
Prepares a JSONL file containing all tasks for the batch job, based on the input payload.
|
|
113
|
+
Returns the path to the created file.
|
|
114
|
+
"""
|
|
115
|
+
if not payload:
|
|
116
|
+
raise ValueError("Payload must not be empty")
|
|
117
|
+
if isinstance(payload[0], str):
|
|
118
|
+
tasks = [self._build_task(text, uuid.uuid4().hex) for text in payload]
|
|
119
|
+
elif isinstance(payload[0], dict):
|
|
120
|
+
tasks = [self._build_task(dic["text"], dic["id"]) for dic in payload]
|
|
121
|
+
|
|
122
|
+
else:
|
|
123
|
+
raise TypeError(
|
|
124
|
+
"The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
file_path = self.state_dir / f"batch_{uuid.uuid4().hex}.jsonl"
|
|
128
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
129
|
+
for task in tasks:
|
|
130
|
+
f.write(json.dumps(task) + "\n")
|
|
131
|
+
return file_path
|
|
132
|
+
|
|
133
|
+
def start(self, payload: list[str | dict[str, str]], job_name: str):
|
|
134
|
+
"""
|
|
135
|
+
Starts a new batch job by uploading the prepared file and creating a batch job on the server.
|
|
136
|
+
If a job with the same name already exists, it does nothing.
|
|
137
|
+
"""
|
|
138
|
+
if self._load_state(job_name):
|
|
139
|
+
return
|
|
140
|
+
path = self._prepare_file(payload)
|
|
141
|
+
upload = self.client.files.create(file=open(path, "rb"), purpose="batch")
|
|
142
|
+
job = self.client.batches.create(
|
|
143
|
+
input_file_id=upload.id,
|
|
144
|
+
endpoint="/v1/chat/completions",
|
|
145
|
+
completion_window="24h",
|
|
146
|
+
).to_dict()
|
|
147
|
+
self._save_state(job_name, [job])
|
|
148
|
+
|
|
149
|
+
def check_status(self, job_name: str) -> str:
|
|
150
|
+
"""
|
|
151
|
+
Checks and returns the current status of the batch job with the given job name.
|
|
152
|
+
Updates the job state with the latest information from the server.
|
|
153
|
+
"""
|
|
154
|
+
job = self._load_state(job_name)[0]
|
|
155
|
+
if not job:
|
|
156
|
+
return "completed"
|
|
157
|
+
|
|
158
|
+
info = self.client.batches.retrieve(job["id"])
|
|
159
|
+
job = info.to_dict()
|
|
160
|
+
self._save_state(job_name, [job])
|
|
161
|
+
print("HERE is the job", job)
|
|
162
|
+
return job["status"]
|
|
163
|
+
|
|
164
|
+
def _parsed(self, result: dict) -> list:
|
|
165
|
+
"""
|
|
166
|
+
Parses the result dictionary, extracting the desired output or error for each item.
|
|
167
|
+
Returns a list of dictionaries with 'id' and 'output' keys.
|
|
168
|
+
"""
|
|
169
|
+
modified_result = []
|
|
170
|
+
# errors = []
|
|
171
|
+
for key, d in result.items():
|
|
172
|
+
if "desired_output" in d:
|
|
173
|
+
new_dict = {"id": key, "output": d["desired_output"]}
|
|
174
|
+
modified_result.append(new_dict)
|
|
175
|
+
else:
|
|
176
|
+
new_dict = {"id": key, "output": d["error"]}
|
|
177
|
+
modified_result.append(new_dict)
|
|
178
|
+
return modified_result
|
|
179
|
+
# return modified_result , errors
|
|
180
|
+
|
|
181
|
+
def fetch_results(
|
|
182
|
+
self, job_name: str, remove_cache: bool = True
|
|
183
|
+
) -> tuple[dict[str, str], list]:
|
|
184
|
+
"""
|
|
185
|
+
Fetches the results of a completed batch job. Optionally saves the results to a file and/or removes the job cache.
|
|
186
|
+
Returns a tuple containing the parsed results and a log of errors (if any).
|
|
187
|
+
"""
|
|
188
|
+
job = self._load_state(job_name)[0]
|
|
189
|
+
if not job:
|
|
190
|
+
return {}
|
|
191
|
+
batch_id = job["id"]
|
|
192
|
+
|
|
193
|
+
info = self.client.batches.retrieve(batch_id)
|
|
194
|
+
out_file_id = info.output_file_id
|
|
195
|
+
if not out_file_id:
|
|
196
|
+
error_file_id = info.error_file_id
|
|
197
|
+
if error_file_id:
|
|
198
|
+
err_content = (
|
|
199
|
+
self.client.files.content(error_file_id).read().decode("utf-8")
|
|
200
|
+
)
|
|
201
|
+
print("Error file content:", err_content)
|
|
202
|
+
return {}
|
|
203
|
+
|
|
204
|
+
content = self.client.files.content(out_file_id).read().decode("utf-8")
|
|
205
|
+
lines = content.splitlines()
|
|
206
|
+
results = {}
|
|
207
|
+
log = []
|
|
208
|
+
for line in lines:
|
|
209
|
+
result = json.loads(line)
|
|
210
|
+
custom_id = result["custom_id"]
|
|
211
|
+
if result["response"]["status_code"] == 200:
|
|
212
|
+
content = result["response"]["body"]["choices"][0]["message"]["content"]
|
|
213
|
+
try:
|
|
214
|
+
parsed_content = json.loads(content)
|
|
215
|
+
model_instance = self.output_model(**parsed_content)
|
|
216
|
+
results[custom_id] = model_instance.model_dump(mode="json")
|
|
217
|
+
except json.JSONDecodeError:
|
|
218
|
+
results[custom_id] = {"error": "Failed to parse content as JSON"}
|
|
219
|
+
error_d = {custom_id: results[custom_id]}
|
|
220
|
+
log.append(error_d)
|
|
221
|
+
except Exception as e:
|
|
222
|
+
results[custom_id] = {"error": str(e)}
|
|
223
|
+
error_d = {custom_id: results[custom_id]}
|
|
224
|
+
log.append(error_d)
|
|
225
|
+
else:
|
|
226
|
+
error_message = (
|
|
227
|
+
result["response"]["body"]
|
|
228
|
+
.get("error", {})
|
|
229
|
+
.get("message", "Unknown error")
|
|
230
|
+
)
|
|
231
|
+
results[custom_id] = {"error": error_message}
|
|
232
|
+
error_d = {custom_id: results[custom_id]}
|
|
233
|
+
log.append(error_d)
|
|
234
|
+
|
|
235
|
+
for handler in self.handlers:
|
|
236
|
+
handler.handle(results)
|
|
237
|
+
if remove_cache:
|
|
238
|
+
self._clear_state(job_name)
|
|
239
|
+
# results = {"results": results, "log": log}
|
|
240
|
+
# return results
|
|
241
|
+
return results, log
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
|
|
8
|
+
# from dotenv import load_dotenv
|
|
9
|
+
from openai import OpenAI
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from texttools.batch_manager import SimpleBatchManager
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OutputModel(BaseModel):
|
|
16
|
+
desired_output: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def exporting_data(data):
|
|
20
|
+
"""
|
|
21
|
+
Produces a structure of the following form from an initial data structure:
|
|
22
|
+
[
|
|
23
|
+
{"id": str, "content": str},...
|
|
24
|
+
]
|
|
25
|
+
"""
|
|
26
|
+
return data
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def importing_data(data):
|
|
30
|
+
"""
|
|
31
|
+
Takes the output and adds and aggregates it to the original structure.
|
|
32
|
+
"""
|
|
33
|
+
return data
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class BatchConfig:
|
|
38
|
+
"""
|
|
39
|
+
Configuration for batch job runner.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
system_prompt: str = ""
|
|
43
|
+
job_name: str = ""
|
|
44
|
+
input_data_path: str = ""
|
|
45
|
+
output_data_filename: str = ""
|
|
46
|
+
model: str = "gpt-4.1-mini"
|
|
47
|
+
MAX_BATCH_SIZE: int = 100
|
|
48
|
+
MAX_TOTAL_TOKENS: int = 2000000
|
|
49
|
+
CHARS_PER_TOKEN: float = 2.7
|
|
50
|
+
PROMPT_TOKEN_MULTIPLIER: int = 1000
|
|
51
|
+
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
52
|
+
import_function: Callable = importing_data
|
|
53
|
+
export_function: Callable = exporting_data
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class BatchJobRunner:
|
|
57
|
+
"""
|
|
58
|
+
Handles running batch jobs using a batch manager and configuration.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self, config: BatchConfig = BatchConfig(), output_model: type = OutputModel
|
|
63
|
+
):
|
|
64
|
+
self.config = config
|
|
65
|
+
self.system_prompt = config.system_prompt
|
|
66
|
+
self.job_name = config.job_name
|
|
67
|
+
self.input_data_path = config.input_data_path
|
|
68
|
+
self.output_data_filename = config.output_data_filename
|
|
69
|
+
self.model = config.model
|
|
70
|
+
self.output_model = output_model
|
|
71
|
+
self.manager = self._init_manager()
|
|
72
|
+
self.data = self._load_data()
|
|
73
|
+
self.parts: list[list[dict[str, Any]]] = []
|
|
74
|
+
self._partition_data()
|
|
75
|
+
Path(self.config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
76
|
+
|
|
77
|
+
def _init_manager(self) -> SimpleBatchManager:
|
|
78
|
+
# load_dotenv()
|
|
79
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
80
|
+
client = OpenAI(api_key=api_key)
|
|
81
|
+
return SimpleBatchManager(
|
|
82
|
+
client=client,
|
|
83
|
+
model=self.model,
|
|
84
|
+
prompt_template=self.system_prompt,
|
|
85
|
+
output_model=self.output_model,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _load_data(self):
|
|
89
|
+
with open(self.input_data_path, "r", encoding="utf-8") as f:
|
|
90
|
+
data = json.load(f)
|
|
91
|
+
data = self.config.export_function(data)
|
|
92
|
+
|
|
93
|
+
# Validation: ensure data is a list of dicts with 'id' and 'content' as strings
|
|
94
|
+
if not isinstance(data, list):
|
|
95
|
+
raise ValueError(
|
|
96
|
+
'Exported data must be a list in this form: [ {"id": str, "content": str},...]'
|
|
97
|
+
)
|
|
98
|
+
for item in data:
|
|
99
|
+
if not (isinstance(item, dict) and "id" in item and "content" in item):
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"Each item must be a dict with 'id' and 'content' keys."
|
|
102
|
+
)
|
|
103
|
+
if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
|
|
104
|
+
raise ValueError("'id' and 'content' must be strings.")
|
|
105
|
+
return data
|
|
106
|
+
|
|
107
|
+
def _partition_data(self):
|
|
108
|
+
total_length = sum(len(item["content"]) for item in self.data)
|
|
109
|
+
prompt_length = len(self.system_prompt)
|
|
110
|
+
total = total_length + (prompt_length * len(self.data))
|
|
111
|
+
calculation = total / self.config.CHARS_PER_TOKEN
|
|
112
|
+
print(
|
|
113
|
+
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
114
|
+
)
|
|
115
|
+
if calculation < self.config.MAX_TOTAL_TOKENS:
|
|
116
|
+
self.parts = [self.data]
|
|
117
|
+
else:
|
|
118
|
+
# Partition into chunks of MAX_BATCH_SIZE
|
|
119
|
+
self.parts = [
|
|
120
|
+
self.data[i : i + self.config.MAX_BATCH_SIZE]
|
|
121
|
+
for i in range(0, len(self.data), self.config.MAX_BATCH_SIZE)
|
|
122
|
+
]
|
|
123
|
+
print(f"Data split into {len(self.parts)} part(s)")
|
|
124
|
+
|
|
125
|
+
def run(self):
|
|
126
|
+
for idx, part in enumerate(self.parts):
|
|
127
|
+
if self._result_exists(idx):
|
|
128
|
+
print(f"Skipping part {idx + 1}: result already exists.")
|
|
129
|
+
continue
|
|
130
|
+
part_job_name = (
|
|
131
|
+
f"{self.job_name}_part_{idx + 1}"
|
|
132
|
+
if len(self.parts) > 1
|
|
133
|
+
else self.job_name
|
|
134
|
+
)
|
|
135
|
+
print(
|
|
136
|
+
f"\n--- Processing part {idx + 1}/{len(self.parts)}: {part_job_name} ---"
|
|
137
|
+
)
|
|
138
|
+
self._process_part(part, part_job_name, idx)
|
|
139
|
+
|
|
140
|
+
def _process_part(
|
|
141
|
+
self, part: list[dict[str, Any]], part_job_name: str, part_idx: int
|
|
142
|
+
):
|
|
143
|
+
while True:
|
|
144
|
+
print(f"Starting job for part: {part_job_name}")
|
|
145
|
+
self.manager.start(part, job_name=part_job_name)
|
|
146
|
+
print("Started batch job. Checking status...")
|
|
147
|
+
while True:
|
|
148
|
+
status = self.manager.check_status(job_name=part_job_name)
|
|
149
|
+
print(f"Status: {status}")
|
|
150
|
+
if status == "completed":
|
|
151
|
+
print("Job completed. Fetching results...")
|
|
152
|
+
output_data, log = self.manager.fetch_results(
|
|
153
|
+
job_name=part_job_name, remove_cache=False
|
|
154
|
+
)
|
|
155
|
+
output_data = self.config.import_function(output_data)
|
|
156
|
+
self._save_results(output_data, log, part_idx)
|
|
157
|
+
print("Fetched and saved results for this part.")
|
|
158
|
+
return
|
|
159
|
+
elif status == "failed":
|
|
160
|
+
print("Job failed. Clearing state, waiting, and retrying...")
|
|
161
|
+
self.manager._clear_state(part_job_name)
|
|
162
|
+
time.sleep(10) # Wait before retrying
|
|
163
|
+
break # Break inner loop to restart the job
|
|
164
|
+
else:
|
|
165
|
+
time.sleep(5) # Wait before checking again
|
|
166
|
+
|
|
167
|
+
def _save_results(
|
|
168
|
+
self, output_data: list[dict[str, Any]], log: list[Any], part_idx: int
|
|
169
|
+
):
|
|
170
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self.parts) > 1 else ""
|
|
171
|
+
result_path = (
|
|
172
|
+
Path(self.config.BASE_OUTPUT_DIR)
|
|
173
|
+
/ f"{Path(self.output_data_filename).stem}{part_suffix}.json"
|
|
174
|
+
)
|
|
175
|
+
if not output_data:
|
|
176
|
+
print("No output data to save. Skipping this part.")
|
|
177
|
+
return
|
|
178
|
+
else:
|
|
179
|
+
with open(result_path, "w", encoding="utf-8") as f:
|
|
180
|
+
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
181
|
+
if log:
|
|
182
|
+
log_path = (
|
|
183
|
+
Path(self.config.BASE_OUTPUT_DIR)
|
|
184
|
+
/ f"{Path(self.output_data_filename).stem}{part_suffix}_log.json"
|
|
185
|
+
)
|
|
186
|
+
with open(log_path, "w", encoding="utf-8") as f:
|
|
187
|
+
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
188
|
+
|
|
189
|
+
def _result_exists(self, part_idx: int) -> bool:
|
|
190
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self.parts) > 1 else ""
|
|
191
|
+
result_path = (
|
|
192
|
+
Path(self.config.BASE_OUTPUT_DIR)
|
|
193
|
+
/ f"{Path(self.output_data_path).stem}{part_suffix}.json"
|
|
194
|
+
)
|
|
195
|
+
return result_path.exists()
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
if __name__ == "__main__":
|
|
199
|
+
print("=== Batch Job Runner ===")
|
|
200
|
+
config = BatchConfig(
|
|
201
|
+
system_prompt="",
|
|
202
|
+
job_name="job_name",
|
|
203
|
+
input_data_path="Data.json",
|
|
204
|
+
output_data_filename="output",
|
|
205
|
+
)
|
|
206
|
+
runner = BatchJobRunner(config)
|
|
207
|
+
runner.run()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .gemma3_formatter import Gemma3Formatter
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ChatFormatter(ABC):
|
|
6
|
+
"""
|
|
7
|
+
Given (raw_text, reason, maybe other hints), produce whatever payload
|
|
8
|
+
A) single string prompt (for providers that don t support multiple messages), or
|
|
9
|
+
B) list of {role, content} dicts, or
|
|
10
|
+
C) whatever shape the provider needs.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def format(
|
|
15
|
+
self,
|
|
16
|
+
text: str,
|
|
17
|
+
reason: Optional[str],
|
|
18
|
+
schema_instr: str,
|
|
19
|
+
prompt_template: Optional[str],
|
|
20
|
+
) -> Any:
|
|
21
|
+
"""
|
|
22
|
+
- For an OpenAI style API, this might return list[{"role": "user"/"assistant", "content": "…"}].
|
|
23
|
+
- For a one shot “text only” API, this might return a single string combining everything.
|
|
24
|
+
- For some niche service, it might return JSON: {"inputs": […], "parameters": {…}}.
|
|
25
|
+
"""
|
|
26
|
+
pass
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from texttools.formatter.base import ChatFormatter
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Gemma3Formatter(ChatFormatter):
|
|
7
|
+
"""
|
|
8
|
+
Formatter that merges consecutive user messages (strings) with '\n'
|
|
9
|
+
and leaves assistant messages alone. No image‐handling, no extra tokens.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
ROLE = "role"
|
|
13
|
+
USER_ROLE = "user"
|
|
14
|
+
ASSISTANT_ROLE = "assistant"
|
|
15
|
+
CONTENT = "content"
|
|
16
|
+
VALID_ROLES = {USER_ROLE, ASSISTANT_ROLE}
|
|
17
|
+
|
|
18
|
+
def format(
|
|
19
|
+
self, messages: list[dict[Literal["role", "content"], str]]
|
|
20
|
+
) -> list[dict[str, str]]:
|
|
21
|
+
"""
|
|
22
|
+
:param messages: list of {"role": ..., "content": ...}, where role is "user", "assistant", or "system"
|
|
23
|
+
:return: a new list where consecutive "user" messages are merged into single entries
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
merged: list[dict[str, str]] = []
|
|
27
|
+
|
|
28
|
+
for msg in messages:
|
|
29
|
+
role, content = msg[self.ROLE], msg[self.CONTENT].strip()
|
|
30
|
+
|
|
31
|
+
# Replace "system" role with "user" role
|
|
32
|
+
if role == "system":
|
|
33
|
+
role = self.USER_ROLE
|
|
34
|
+
|
|
35
|
+
# Raise value error if msg["role"] wan't a valid role
|
|
36
|
+
if role not in self.VALID_ROLES:
|
|
37
|
+
raise ValueError(f"Unexpected role: {role}")
|
|
38
|
+
|
|
39
|
+
# Merge with previous user turn
|
|
40
|
+
if (
|
|
41
|
+
merged
|
|
42
|
+
and role == self.USER_ROLE
|
|
43
|
+
and merged[-1][self.ROLE] == self.USER_ROLE
|
|
44
|
+
):
|
|
45
|
+
merged[-1][self.CONTENT] += "\n" + content
|
|
46
|
+
|
|
47
|
+
# Otherwise, start a new turn
|
|
48
|
+
else:
|
|
49
|
+
merged.append({self.ROLE: role, self.CONTENT: content})
|
|
50
|
+
|
|
51
|
+
return merged
|