hamtaa-texttools 1.0.5__py3-none-any.whl → 1.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hamtaa_texttools-1.1.16.dist-info/METADATA +255 -0
- hamtaa_texttools-1.1.16.dist-info/RECORD +31 -0
- texttools/__init__.py +6 -8
- texttools/batch/batch_config.py +26 -0
- texttools/batch/batch_runner.py +144 -139
- texttools/batch/{batch_manager.py → internals/batch_manager.py} +42 -54
- texttools/batch/internals/utils.py +16 -0
- texttools/prompts/README.md +8 -4
- texttools/prompts/categorize.yaml +77 -0
- texttools/prompts/detect_entity.yaml +22 -0
- texttools/prompts/extract_keywords.yaml +68 -0
- texttools/prompts/{question_merger.yaml → merge_questions.yaml} +5 -5
- texttools/tools/async_tools.py +804 -0
- texttools/tools/internals/async_operator.py +139 -236
- texttools/tools/internals/formatters.py +24 -0
- texttools/tools/internals/models.py +183 -0
- texttools/tools/internals/operator_utils.py +54 -0
- texttools/tools/internals/prompt_loader.py +23 -43
- texttools/tools/internals/sync_operator.py +201 -0
- texttools/tools/sync_tools.py +804 -0
- hamtaa_texttools-1.0.5.dist-info/METADATA +0 -192
- hamtaa_texttools-1.0.5.dist-info/RECORD +0 -30
- texttools/batch/__init__.py +0 -4
- texttools/formatters/base_formatter.py +0 -33
- texttools/formatters/user_merge_formatter.py +0 -30
- texttools/prompts/categorizer.yaml +0 -28
- texttools/prompts/keyword_extractor.yaml +0 -18
- texttools/tools/__init__.py +0 -4
- texttools/tools/async_the_tool.py +0 -277
- texttools/tools/internals/operator.py +0 -295
- texttools/tools/internals/output_models.py +0 -52
- texttools/tools/the_tool.py +0 -501
- {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/top_level.txt +0 -0
- /texttools/prompts/{ner_extractor.yaml → extract_entities.yaml} +0 -0
- /texttools/prompts/{question_detector.yaml → is_question.yaml} +0 -0
- /texttools/prompts/{rewriter.yaml → rewrite.yaml} +0 -0
- /texttools/prompts/{custom_tool.yaml → run_custom.yaml} +0 -0
- /texttools/prompts/{subject_question_generator.yaml → subject_to_question.yaml} +0 -0
- /texttools/prompts/{summarizer.yaml → summarize.yaml} +0 -0
- /texttools/prompts/{question_generator.yaml → text_to_question.yaml} +0 -0
- /texttools/prompts/{translator.yaml → translate.yaml} +0 -0
texttools/batch/batch_runner.py
CHANGED
|
@@ -1,212 +1,217 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
import time
|
|
4
|
-
from dataclasses import dataclass
|
|
5
4
|
from pathlib import Path
|
|
6
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Type, TypeVar
|
|
6
|
+
import logging
|
|
7
7
|
|
|
8
|
+
from dotenv import load_dotenv
|
|
8
9
|
from openai import OpenAI
|
|
9
10
|
from pydantic import BaseModel
|
|
10
11
|
|
|
11
|
-
from texttools.batch.batch_manager import
|
|
12
|
+
from texttools.batch.internals.batch_manager import BatchManager
|
|
13
|
+
from texttools.batch.batch_config import BatchConfig
|
|
14
|
+
from texttools.tools.internals.models import StrOutput
|
|
12
15
|
|
|
16
|
+
# Base Model type for output models
|
|
17
|
+
T = TypeVar("T", bound=BaseModel)
|
|
13
18
|
|
|
14
|
-
|
|
15
|
-
output: str
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def export_data(data):
|
|
19
|
-
"""
|
|
20
|
-
Produces a structure of the following form from an initial data structure:
|
|
21
|
-
[
|
|
22
|
-
{"id": str, "content": str},...
|
|
23
|
-
]
|
|
24
|
-
"""
|
|
25
|
-
return data
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def import_data(data):
|
|
29
|
-
"""
|
|
30
|
-
Takes the output and adds and aggregates it to the original structure.
|
|
31
|
-
"""
|
|
32
|
-
return data
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@dataclass
|
|
36
|
-
class BatchConfig:
|
|
37
|
-
"""
|
|
38
|
-
Configuration for batch job runner.
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
system_prompt: str = ""
|
|
42
|
-
job_name: str = ""
|
|
43
|
-
input_data_path: str = ""
|
|
44
|
-
output_data_filename: str = ""
|
|
45
|
-
model: str = "gpt-4.1-mini"
|
|
46
|
-
MAX_BATCH_SIZE: int = 100
|
|
47
|
-
MAX_TOTAL_TOKENS: int = 2000000
|
|
48
|
-
CHARS_PER_TOKEN: float = 2.7
|
|
49
|
-
PROMPT_TOKEN_MULTIPLIER: int = 1000
|
|
50
|
-
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
51
|
-
import_function: Callable = import_data
|
|
52
|
-
export_function: Callable = export_data
|
|
19
|
+
logger = logging.getLogger("texttools.batch_runner")
|
|
53
20
|
|
|
54
21
|
|
|
55
22
|
class BatchJobRunner:
|
|
56
23
|
"""
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
Handles data loading, partitioning, job execution via SimpleBatchManager,
|
|
60
|
-
and result saving. Manages the complete workflow from input data to processed outputs,
|
|
61
|
-
including retries and progress tracking across multiple batch parts.
|
|
24
|
+
Handles running batch jobs using a batch manager and configuration.
|
|
62
25
|
"""
|
|
63
26
|
|
|
64
27
|
def __init__(
|
|
65
|
-
self, config: BatchConfig = BatchConfig(), output_model:
|
|
28
|
+
self, config: BatchConfig = BatchConfig(), output_model: Type[T] = StrOutput
|
|
66
29
|
):
|
|
67
|
-
self.
|
|
68
|
-
self.
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
30
|
+
self._config = config
|
|
31
|
+
self._system_prompt = config.system_prompt
|
|
32
|
+
self._job_name = config.job_name
|
|
33
|
+
self._input_data_path = config.input_data_path
|
|
34
|
+
self._output_data_filename = config.output_data_filename
|
|
35
|
+
self._model = config.model
|
|
36
|
+
self._output_model = output_model
|
|
37
|
+
self._manager = self._init_manager()
|
|
38
|
+
self._data = self._load_data()
|
|
39
|
+
self._parts: list[list[dict[str, Any]]] = []
|
|
40
|
+
# Map part index to job name
|
|
41
|
+
self._part_idx_to_job_name: dict[int, str] = {}
|
|
42
|
+
# Track retry attempts per part
|
|
43
|
+
self._part_attempts: dict[int, int] = {}
|
|
77
44
|
self._partition_data()
|
|
78
|
-
Path(self.
|
|
45
|
+
Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
79
46
|
|
|
80
|
-
def _init_manager(self) ->
|
|
47
|
+
def _init_manager(self) -> BatchManager:
|
|
48
|
+
load_dotenv()
|
|
81
49
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
82
50
|
client = OpenAI(api_key=api_key)
|
|
83
|
-
return
|
|
51
|
+
return BatchManager(
|
|
84
52
|
client=client,
|
|
85
|
-
model=self.
|
|
86
|
-
prompt_template=self.
|
|
87
|
-
output_model=self.
|
|
53
|
+
model=self._model,
|
|
54
|
+
prompt_template=self._system_prompt,
|
|
55
|
+
output_model=self._output_model,
|
|
88
56
|
)
|
|
89
57
|
|
|
90
58
|
def _load_data(self):
|
|
91
|
-
with open(self.
|
|
59
|
+
with open(self._input_data_path, "r", encoding="utf-8") as f:
|
|
92
60
|
data = json.load(f)
|
|
93
|
-
data = self.
|
|
61
|
+
data = self._config.export_function(data)
|
|
94
62
|
|
|
95
63
|
# Ensure data is a list of dicts with 'id' and 'content' as strings
|
|
96
64
|
if not isinstance(data, list):
|
|
97
65
|
raise ValueError(
|
|
98
|
-
|
|
66
|
+
"Exported data must be a list of dicts with 'id' and 'content' keys"
|
|
99
67
|
)
|
|
100
68
|
for item in data:
|
|
101
69
|
if not (isinstance(item, dict) and "id" in item and "content" in item):
|
|
102
70
|
raise ValueError(
|
|
103
|
-
"
|
|
71
|
+
f"Item must be a dict with 'id' and 'content' keys. Got: {type(item)}"
|
|
104
72
|
)
|
|
105
73
|
if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
|
|
106
74
|
raise ValueError("'id' and 'content' must be strings.")
|
|
107
75
|
return data
|
|
108
76
|
|
|
109
77
|
def _partition_data(self):
|
|
110
|
-
total_length = sum(len(item["content"]) for item in self.
|
|
111
|
-
prompt_length = len(self.
|
|
112
|
-
total = total_length + (prompt_length * len(self.
|
|
113
|
-
calculation = total / self.
|
|
114
|
-
|
|
78
|
+
total_length = sum(len(item["content"]) for item in self._data)
|
|
79
|
+
prompt_length = len(self._system_prompt)
|
|
80
|
+
total = total_length + (prompt_length * len(self._data))
|
|
81
|
+
calculation = total / self._config.CHARS_PER_TOKEN
|
|
82
|
+
logger.info(
|
|
115
83
|
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
116
84
|
)
|
|
117
|
-
if calculation < self.
|
|
118
|
-
self.
|
|
85
|
+
if calculation < self._config.MAX_TOTAL_TOKENS:
|
|
86
|
+
self._parts = [self._data]
|
|
119
87
|
else:
|
|
120
88
|
# Partition into chunks of MAX_BATCH_SIZE
|
|
121
|
-
self.
|
|
122
|
-
self.
|
|
123
|
-
for i in range(0, len(self.
|
|
89
|
+
self._parts = [
|
|
90
|
+
self._data[i : i + self._config.MAX_BATCH_SIZE]
|
|
91
|
+
for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
|
|
124
92
|
]
|
|
125
|
-
|
|
93
|
+
logger.info(f"Data split into {len(self._parts)} part(s)")
|
|
126
94
|
|
|
127
|
-
def
|
|
128
|
-
for idx, part in enumerate(self.
|
|
95
|
+
def _submit_all_jobs(self) -> None:
|
|
96
|
+
for idx, part in enumerate(self._parts):
|
|
129
97
|
if self._result_exists(idx):
|
|
130
|
-
|
|
98
|
+
logger.info(f"Skipping part {idx + 1}: result already exists.")
|
|
131
99
|
continue
|
|
132
100
|
part_job_name = (
|
|
133
|
-
f"{self.
|
|
134
|
-
if len(self.
|
|
135
|
-
else self.
|
|
101
|
+
f"{self._job_name}_part_{idx + 1}"
|
|
102
|
+
if len(self._parts) > 1
|
|
103
|
+
else self._job_name
|
|
136
104
|
)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
105
|
+
# If a job with this name already exists, register and skip submitting
|
|
106
|
+
existing_job = self._manager._load_state(part_job_name)
|
|
107
|
+
if existing_job:
|
|
108
|
+
logger.info(
|
|
109
|
+
f"Skipping part {idx + 1}: job already exists ({part_job_name})."
|
|
110
|
+
)
|
|
111
|
+
self._part_idx_to_job_name[idx] = part_job_name
|
|
112
|
+
self._part_attempts.setdefault(idx, 0)
|
|
113
|
+
continue
|
|
141
114
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
self.
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
if status == "completed":
|
|
153
|
-
print("Job completed. Fetching results...")
|
|
154
|
-
output_data, log = self.manager.fetch_results(
|
|
155
|
-
job_name=part_job_name, remove_cache=False
|
|
156
|
-
)
|
|
157
|
-
output_data = self.config.import_function(output_data)
|
|
158
|
-
self._save_results(output_data, log, part_idx)
|
|
159
|
-
print("Fetched and saved results for this part.")
|
|
160
|
-
return
|
|
161
|
-
elif status == "failed":
|
|
162
|
-
print("Job failed. Clearing state, waiting, and retrying...")
|
|
163
|
-
self.manager._clear_state(part_job_name)
|
|
164
|
-
# Wait before retrying
|
|
165
|
-
time.sleep(10)
|
|
166
|
-
# Break inner loop to restart the job
|
|
167
|
-
break
|
|
168
|
-
else:
|
|
169
|
-
# Wait before checking again
|
|
170
|
-
time.sleep(5)
|
|
115
|
+
payload = part
|
|
116
|
+
logger.info(
|
|
117
|
+
f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
|
|
118
|
+
)
|
|
119
|
+
self._manager.start(payload, job_name=part_job_name)
|
|
120
|
+
self._part_idx_to_job_name[idx] = part_job_name
|
|
121
|
+
self._part_attempts.setdefault(idx, 0)
|
|
122
|
+
# This is added for letting file get uploaded, before starting the next part.
|
|
123
|
+
logger.info("Uploading...")
|
|
124
|
+
time.sleep(30)
|
|
171
125
|
|
|
172
126
|
def _save_results(
|
|
173
|
-
self,
|
|
127
|
+
self,
|
|
128
|
+
output_data: list[dict[str, Any]] | dict[str, Any],
|
|
129
|
+
log: list[Any],
|
|
130
|
+
part_idx: int,
|
|
174
131
|
):
|
|
175
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self.
|
|
132
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
176
133
|
result_path = (
|
|
177
|
-
Path(self.
|
|
178
|
-
/ f"{Path(self.
|
|
134
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
135
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
179
136
|
)
|
|
180
137
|
if not output_data:
|
|
181
|
-
|
|
138
|
+
logger.info("No output data to save. Skipping this part.")
|
|
182
139
|
return
|
|
183
140
|
else:
|
|
184
141
|
with open(result_path, "w", encoding="utf-8") as f:
|
|
185
142
|
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
186
143
|
if log:
|
|
187
144
|
log_path = (
|
|
188
|
-
Path(self.
|
|
189
|
-
/ f"{Path(self.
|
|
145
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
146
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
|
|
190
147
|
)
|
|
191
148
|
with open(log_path, "w", encoding="utf-8") as f:
|
|
192
149
|
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
193
150
|
|
|
194
151
|
def _result_exists(self, part_idx: int) -> bool:
|
|
195
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self.
|
|
152
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
196
153
|
result_path = (
|
|
197
|
-
Path(self.
|
|
198
|
-
/ f"{Path(self.
|
|
154
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
155
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
199
156
|
)
|
|
200
157
|
return result_path.exists()
|
|
201
158
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
159
|
+
def run(self):
|
|
160
|
+
"""
|
|
161
|
+
Execute the batch job processing pipeline.
|
|
162
|
+
|
|
163
|
+
Submits jobs, monitors progress, handles retries, and saves results.
|
|
164
|
+
"""
|
|
165
|
+
# Submit all jobs up-front for concurrent execution
|
|
166
|
+
self._submit_all_jobs()
|
|
167
|
+
pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
|
|
168
|
+
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
169
|
+
# Polling loop
|
|
170
|
+
while pending_parts:
|
|
171
|
+
finished_this_round: list[int] = []
|
|
172
|
+
for part_idx in list(pending_parts):
|
|
173
|
+
job_name = self._part_idx_to_job_name[part_idx]
|
|
174
|
+
status = self._manager.check_status(job_name=job_name)
|
|
175
|
+
logger.info(f"Status for {job_name}: {status}")
|
|
176
|
+
if status == "completed":
|
|
177
|
+
logger.info(
|
|
178
|
+
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
179
|
+
)
|
|
180
|
+
output_data, log = self._manager.fetch_results(
|
|
181
|
+
job_name=job_name, remove_cache=False
|
|
182
|
+
)
|
|
183
|
+
output_data = self._config.import_function(output_data)
|
|
184
|
+
self._save_results(output_data, log, part_idx)
|
|
185
|
+
logger.info(f"Fetched and saved results for part {part_idx + 1}.")
|
|
186
|
+
finished_this_round.append(part_idx)
|
|
187
|
+
elif status == "failed":
|
|
188
|
+
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
189
|
+
self._part_attempts[part_idx] = attempt
|
|
190
|
+
if attempt <= self._config.max_retries:
|
|
191
|
+
logger.info(
|
|
192
|
+
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
193
|
+
)
|
|
194
|
+
self._manager._clear_state(job_name)
|
|
195
|
+
time.sleep(10)
|
|
196
|
+
payload = self._to_manager_payload(self._parts[part_idx])
|
|
197
|
+
new_job_name = (
|
|
198
|
+
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
199
|
+
)
|
|
200
|
+
self._manager.start(payload, job_name=new_job_name)
|
|
201
|
+
self._part_idx_to_job_name[part_idx] = new_job_name
|
|
202
|
+
else:
|
|
203
|
+
logger.info(
|
|
204
|
+
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
205
|
+
)
|
|
206
|
+
finished_this_round.append(part_idx)
|
|
207
|
+
else:
|
|
208
|
+
# Still running or queued
|
|
209
|
+
continue
|
|
210
|
+
# Remove finished parts
|
|
211
|
+
for part_idx in finished_this_round:
|
|
212
|
+
pending_parts.discard(part_idx)
|
|
213
|
+
if pending_parts:
|
|
214
|
+
logger.info(
|
|
215
|
+
f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
216
|
+
)
|
|
217
|
+
time.sleep(self._config.poll_interval_seconds)
|
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import uuid
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Type
|
|
4
|
+
from typing import Any, Type, TypeVar
|
|
5
|
+
import logging
|
|
5
6
|
|
|
6
7
|
from pydantic import BaseModel
|
|
7
8
|
from openai import OpenAI
|
|
8
9
|
from openai.lib._pydantic import to_strict_json_schema
|
|
9
10
|
|
|
11
|
+
# Base Model type for output models
|
|
12
|
+
T = TypeVar("T", bound=BaseModel)
|
|
10
13
|
|
|
11
|
-
|
|
14
|
+
logger = logging.getLogger("texttools.batch_manager")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BatchManager:
|
|
12
18
|
"""
|
|
13
19
|
Manages batch processing jobs for OpenAI's chat completions with structured outputs.
|
|
14
20
|
|
|
@@ -21,30 +27,29 @@ class SimpleBatchManager:
|
|
|
21
27
|
self,
|
|
22
28
|
client: OpenAI,
|
|
23
29
|
model: str,
|
|
24
|
-
output_model: Type[
|
|
30
|
+
output_model: Type[T],
|
|
25
31
|
prompt_template: str,
|
|
26
|
-
handlers: list[Any] | None = None,
|
|
27
32
|
state_dir: Path = Path(".batch_jobs"),
|
|
28
33
|
custom_json_schema_obj_str: dict | None = None,
|
|
29
34
|
**client_kwargs: Any,
|
|
30
35
|
):
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
33
|
-
self.
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
36
|
-
self.
|
|
37
|
-
self.
|
|
38
|
-
self.
|
|
39
|
-
self.
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
36
|
+
self._client = client
|
|
37
|
+
self._model = model
|
|
38
|
+
self._output_model = output_model
|
|
39
|
+
self._prompt_template = prompt_template
|
|
40
|
+
self._state_dir = state_dir
|
|
41
|
+
self._custom_json_schema_obj_str = custom_json_schema_obj_str
|
|
42
|
+
self._client_kwargs = client_kwargs
|
|
43
|
+
self._dict_input = False
|
|
44
|
+
self._state_dir.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
|
|
46
|
+
if custom_json_schema_obj_str and not isinstance(
|
|
47
|
+
custom_json_schema_obj_str, dict
|
|
48
|
+
):
|
|
49
|
+
raise ValueError("Schema should be a dict")
|
|
45
50
|
|
|
46
51
|
def _state_file(self, job_name: str) -> Path:
|
|
47
|
-
return self.
|
|
52
|
+
return self._state_dir / f"{job_name}.json"
|
|
48
53
|
|
|
49
54
|
def _load_state(self, job_name: str) -> list[dict[str, Any]]:
|
|
50
55
|
"""
|
|
@@ -78,17 +83,17 @@ class SimpleBatchManager:
|
|
|
78
83
|
"""
|
|
79
84
|
response_format_config: dict[str, Any]
|
|
80
85
|
|
|
81
|
-
if self.
|
|
86
|
+
if self._custom_json_schema_obj_str:
|
|
82
87
|
response_format_config = {
|
|
83
88
|
"type": "json_schema",
|
|
84
|
-
"json_schema": self.
|
|
89
|
+
"json_schema": self._custom_json_schema_obj_str,
|
|
85
90
|
}
|
|
86
91
|
else:
|
|
87
|
-
raw_schema = to_strict_json_schema(self.
|
|
92
|
+
raw_schema = to_strict_json_schema(self._output_model)
|
|
88
93
|
response_format_config = {
|
|
89
94
|
"type": "json_schema",
|
|
90
95
|
"json_schema": {
|
|
91
|
-
"name": self.
|
|
96
|
+
"name": self._output_model.__name__,
|
|
92
97
|
"schema": raw_schema,
|
|
93
98
|
},
|
|
94
99
|
}
|
|
@@ -100,11 +105,11 @@ class SimpleBatchManager:
|
|
|
100
105
|
"body": {
|
|
101
106
|
"model": self.model,
|
|
102
107
|
"messages": [
|
|
103
|
-
{"role": "system", "content": self.
|
|
108
|
+
{"role": "system", "content": self._prompt_template},
|
|
104
109
|
{"role": "user", "content": text},
|
|
105
110
|
],
|
|
106
111
|
"response_format": response_format_config,
|
|
107
|
-
**self.
|
|
112
|
+
**self._client_kwargs,
|
|
108
113
|
},
|
|
109
114
|
}
|
|
110
115
|
|
|
@@ -122,10 +127,10 @@ class SimpleBatchManager:
|
|
|
122
127
|
|
|
123
128
|
else:
|
|
124
129
|
raise TypeError(
|
|
125
|
-
"The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}
|
|
130
|
+
"The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}"
|
|
126
131
|
)
|
|
127
132
|
|
|
128
|
-
file_path = self.
|
|
133
|
+
file_path = self._state_dir / f"batch_{uuid.uuid4().hex}.jsonl"
|
|
129
134
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
130
135
|
for task in tasks:
|
|
131
136
|
f.write(json.dumps(task) + "\n")
|
|
@@ -138,9 +143,10 @@ class SimpleBatchManager:
|
|
|
138
143
|
"""
|
|
139
144
|
if self._load_state(job_name):
|
|
140
145
|
return
|
|
146
|
+
|
|
141
147
|
path = self._prepare_file(payload)
|
|
142
|
-
upload = self.
|
|
143
|
-
job = self.
|
|
148
|
+
upload = self._client.files.create(file=open(path, "rb"), purpose="batch")
|
|
149
|
+
job = self._client.batches.create(
|
|
144
150
|
input_file_id=upload.id,
|
|
145
151
|
endpoint="/v1/chat/completions",
|
|
146
152
|
completion_window="24h",
|
|
@@ -156,28 +162,12 @@ class SimpleBatchManager:
|
|
|
156
162
|
if not job:
|
|
157
163
|
return "completed"
|
|
158
164
|
|
|
159
|
-
info = self.
|
|
165
|
+
info = self._client.batches.retrieve(job["id"])
|
|
160
166
|
job = info.to_dict()
|
|
161
167
|
self._save_state(job_name, [job])
|
|
162
|
-
|
|
168
|
+
logger.info("Batch job status: %s", job)
|
|
163
169
|
return job["status"]
|
|
164
170
|
|
|
165
|
-
def _parsed(self, result: dict) -> list:
|
|
166
|
-
"""
|
|
167
|
-
Parses the result dictionary, extracting the desired output or error for each item.
|
|
168
|
-
Returns a list of dictionaries with 'id' and 'output' keys.
|
|
169
|
-
"""
|
|
170
|
-
modified_result = []
|
|
171
|
-
|
|
172
|
-
for key, d in result.items():
|
|
173
|
-
if "desired_output" in d:
|
|
174
|
-
new_dict = {"id": key, "output": d["desired_output"]}
|
|
175
|
-
modified_result.append(new_dict)
|
|
176
|
-
else:
|
|
177
|
-
new_dict = {"id": key, "output": d["error"]}
|
|
178
|
-
modified_result.append(new_dict)
|
|
179
|
-
return modified_result
|
|
180
|
-
|
|
181
171
|
def fetch_results(
|
|
182
172
|
self, job_name: str, remove_cache: bool = True
|
|
183
173
|
) -> tuple[dict[str, str], list]:
|
|
@@ -190,18 +180,18 @@ class SimpleBatchManager:
|
|
|
190
180
|
return {}
|
|
191
181
|
batch_id = job["id"]
|
|
192
182
|
|
|
193
|
-
info = self.
|
|
183
|
+
info = self._client.batches.retrieve(batch_id)
|
|
194
184
|
out_file_id = info.output_file_id
|
|
195
185
|
if not out_file_id:
|
|
196
186
|
error_file_id = info.error_file_id
|
|
197
187
|
if error_file_id:
|
|
198
188
|
err_content = (
|
|
199
|
-
self.
|
|
189
|
+
self._client.files.content(error_file_id).read().decode("utf-8")
|
|
200
190
|
)
|
|
201
|
-
|
|
191
|
+
logger.error("Error file content:", err_content)
|
|
202
192
|
return {}
|
|
203
193
|
|
|
204
|
-
content = self.
|
|
194
|
+
content = self._client.files.content(out_file_id).read().decode("utf-8")
|
|
205
195
|
lines = content.splitlines()
|
|
206
196
|
results = {}
|
|
207
197
|
log = []
|
|
@@ -212,7 +202,7 @@ class SimpleBatchManager:
|
|
|
212
202
|
content = result["response"]["body"]["choices"][0]["message"]["content"]
|
|
213
203
|
try:
|
|
214
204
|
parsed_content = json.loads(content)
|
|
215
|
-
model_instance = self.
|
|
205
|
+
model_instance = self._output_model(**parsed_content)
|
|
216
206
|
results[custom_id] = model_instance.model_dump(mode="json")
|
|
217
207
|
except json.JSONDecodeError:
|
|
218
208
|
results[custom_id] = {"error": "Failed to parse content as JSON"}
|
|
@@ -232,8 +222,6 @@ class SimpleBatchManager:
|
|
|
232
222
|
error_d = {custom_id: results[custom_id]}
|
|
233
223
|
log.append(error_d)
|
|
234
224
|
|
|
235
|
-
for handler in self.handlers:
|
|
236
|
-
handler.handle(results)
|
|
237
225
|
if remove_cache:
|
|
238
226
|
self._clear_state(job_name)
|
|
239
227
|
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def export_data(data) -> list[dict[str, str]]:
|
|
5
|
+
"""
|
|
6
|
+
Produces a structure of the following form from an initial data structure:
|
|
7
|
+
[{"id": str, "text": str},...]
|
|
8
|
+
"""
|
|
9
|
+
return data
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def import_data(data) -> Any:
|
|
13
|
+
"""
|
|
14
|
+
Takes the output and adds and aggregates it to the original structure.
|
|
15
|
+
"""
|
|
16
|
+
return data
|
texttools/prompts/README.md
CHANGED
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
## Overview
|
|
4
4
|
This folder contains YAML files for all prompts used in the project. Each file represents a separate prompt template, which can be loaded by tools or scripts that require structured prompts for AI models.
|
|
5
5
|
|
|
6
|
+
---
|
|
7
|
+
|
|
6
8
|
## Structure
|
|
7
9
|
- **prompt_file.yaml**: Each YAML file represents a single prompt template.
|
|
8
10
|
- **main_template**: The main instruction template for the model.
|
|
@@ -12,18 +14,20 @@ This folder contains YAML files for all prompts used in the project. Each file r
|
|
|
12
14
|
### Example YAML Structure
|
|
13
15
|
```yaml
|
|
14
16
|
main_template:
|
|
15
|
-
|
|
17
|
+
mode_1: |
|
|
16
18
|
Your main instructions here with placeholders like {input}.
|
|
17
|
-
|
|
19
|
+
mode_2: |
|
|
18
20
|
Optional reasoning instructions here.
|
|
19
21
|
|
|
20
22
|
analyze_template:
|
|
21
|
-
|
|
23
|
+
mode_1: |
|
|
22
24
|
Analyze and summarize the input.
|
|
23
|
-
|
|
25
|
+
mode_2: |
|
|
24
26
|
Optional detailed analysis template.
|
|
25
27
|
```
|
|
26
28
|
|
|
29
|
+
---
|
|
30
|
+
|
|
27
31
|
## Guidelines
|
|
28
32
|
1. **Naming**: Use descriptive names for each YAML file corresponding to the tool or task it serves.
|
|
29
33
|
2. **Placeholders**: Use `{input}` or other relevant placeholders to dynamically inject data.
|