hamtaa-texttools 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hamtaa-texttools might be problematic. Click here for more details.
- {hamtaa_texttools-1.0.4.dist-info → hamtaa_texttools-1.0.6.dist-info}/METADATA +192 -141
- hamtaa_texttools-1.0.6.dist-info/RECORD +30 -0
- {hamtaa_texttools-1.0.4.dist-info → hamtaa_texttools-1.0.6.dist-info}/licenses/LICENSE +20 -20
- {hamtaa_texttools-1.0.4.dist-info → hamtaa_texttools-1.0.6.dist-info}/top_level.txt +0 -0
- texttools/__init__.py +9 -9
- texttools/batch/__init__.py +4 -4
- texttools/batch/batch_manager.py +229 -240
- texttools/batch/batch_runner.py +263 -212
- texttools/formatters/base_formatter.py +33 -33
- texttools/formatters/{user_merge_formatter/user_merge_formatter.py → user_merge_formatter.py} +30 -30
- texttools/prompts/README.md +35 -31
- texttools/prompts/categorizer.yaml +28 -31
- texttools/prompts/{question_detector.yaml → is_question.yaml} +13 -14
- texttools/prompts/keyword_extractor.yaml +18 -14
- texttools/prompts/ner_extractor.yaml +20 -21
- texttools/prompts/question_merger.yaml +45 -48
- texttools/prompts/rewriter.yaml +111 -0
- texttools/prompts/run_custom.yaml +7 -0
- texttools/prompts/{subject_question_generator.yaml → subject_to_question.yaml} +22 -26
- texttools/prompts/summarizer.yaml +13 -11
- texttools/prompts/{question_generator.yaml → text_to_question.yaml} +19 -22
- texttools/prompts/translator.yaml +14 -14
- texttools/tools/__init__.py +4 -4
- texttools/tools/async_the_tool.py +277 -263
- texttools/tools/internals/async_operator.py +308 -288
- texttools/tools/internals/operator.py +295 -306
- texttools/tools/internals/output_models.py +52 -62
- texttools/tools/internals/prompt_loader.py +66 -82
- texttools/tools/the_tool.py +501 -400
- hamtaa_texttools-1.0.4.dist-info/RECORD +0 -29
- texttools/prompts/question_rewriter.yaml +0 -46
- {hamtaa_texttools-1.0.4.dist-info → hamtaa_texttools-1.0.6.dist-info}/WHEEL +0 -0
texttools/batch/batch_runner.py
CHANGED
|
@@ -1,212 +1,263 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import os
|
|
3
|
-
import time
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
from
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
"""
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
self
|
|
68
|
-
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
77
|
-
self.
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
from openai import OpenAI
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
from texttools.batch import SimpleBatchManager
|
|
14
|
+
|
|
15
|
+
# Configure logger
|
|
16
|
+
logger = logging.getLogger("batch_runner")
|
|
17
|
+
logger.setLevel(logging.INFO)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OutputModel(BaseModel):
|
|
21
|
+
desired_output: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def export_data(data):
|
|
25
|
+
"""
|
|
26
|
+
Produces a structure of the following form from an initial data structure:
|
|
27
|
+
[{"id": str, "text": str},...]
|
|
28
|
+
"""
|
|
29
|
+
return data
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def import_data(data):
|
|
33
|
+
"""
|
|
34
|
+
Takes the output and adds and aggregates it to the original structure.
|
|
35
|
+
"""
|
|
36
|
+
return data
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class BatchConfig:
|
|
41
|
+
"""
|
|
42
|
+
Configuration for batch job runner.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
system_prompt: str = ""
|
|
46
|
+
job_name: str = ""
|
|
47
|
+
input_data_path: str = ""
|
|
48
|
+
output_data_filename: str = ""
|
|
49
|
+
model: str = "gpt-4.1-mini"
|
|
50
|
+
MAX_BATCH_SIZE: int = 100
|
|
51
|
+
MAX_TOTAL_TOKENS: int = 2000000
|
|
52
|
+
CHARS_PER_TOKEN: float = 2.7
|
|
53
|
+
PROMPT_TOKEN_MULTIPLIER: int = 1000
|
|
54
|
+
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
55
|
+
import_function: Callable = import_data
|
|
56
|
+
export_function: Callable = export_data
|
|
57
|
+
poll_interval_seconds: int = 30
|
|
58
|
+
max_retries: int = 3
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class BatchJobRunner:
|
|
62
|
+
"""
|
|
63
|
+
Handles running batch jobs using a batch manager and configuration.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self, config: BatchConfig = BatchConfig(), output_model: type = OutputModel
|
|
68
|
+
):
|
|
69
|
+
self.config = config
|
|
70
|
+
self.system_prompt = config.system_prompt
|
|
71
|
+
self.job_name = config.job_name
|
|
72
|
+
self.input_data_path = config.input_data_path
|
|
73
|
+
self.output_data_filename = config.output_data_filename
|
|
74
|
+
self.model = config.model
|
|
75
|
+
self.output_model = output_model
|
|
76
|
+
self.manager = self._init_manager()
|
|
77
|
+
self.data = self._load_data()
|
|
78
|
+
self.parts: list[list[dict[str, Any]]] = []
|
|
79
|
+
self._partition_data()
|
|
80
|
+
Path(self.config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
81
|
+
# Map part index to job name
|
|
82
|
+
self.part_idx_to_job_name: dict[int, str] = {}
|
|
83
|
+
# Track retry attempts per part
|
|
84
|
+
self.part_attempts: dict[int, int] = {}
|
|
85
|
+
|
|
86
|
+
def _init_manager(self) -> SimpleBatchManager:
|
|
87
|
+
load_dotenv()
|
|
88
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
89
|
+
client = OpenAI(api_key=api_key)
|
|
90
|
+
return SimpleBatchManager(
|
|
91
|
+
client=client,
|
|
92
|
+
model=self.model,
|
|
93
|
+
prompt_template=self.system_prompt,
|
|
94
|
+
output_model=self.output_model,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def _load_data(self):
|
|
98
|
+
with open(self.input_data_path, "r", encoding="utf-8") as f:
|
|
99
|
+
data = json.load(f)
|
|
100
|
+
data = self.config.export_function(data)
|
|
101
|
+
|
|
102
|
+
# Ensure data is a list of dicts with 'id' and 'content' as strings
|
|
103
|
+
if not isinstance(data, list):
|
|
104
|
+
raise ValueError(
|
|
105
|
+
'Exported data must be a list in this form: [ {"id": str, "content": str},...]'
|
|
106
|
+
)
|
|
107
|
+
for item in data:
|
|
108
|
+
if not (isinstance(item, dict) and "id" in item and "content" in item):
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Each item must be a dict with 'id' and 'content' keys."
|
|
111
|
+
)
|
|
112
|
+
if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
|
|
113
|
+
raise ValueError("'id' and 'content' must be strings.")
|
|
114
|
+
return data
|
|
115
|
+
|
|
116
|
+
def _partition_data(self):
|
|
117
|
+
total_length = sum(len(item["content"]) for item in self.data)
|
|
118
|
+
prompt_length = len(self.system_prompt)
|
|
119
|
+
total = total_length + (prompt_length * len(self.data))
|
|
120
|
+
calculation = total / self.config.CHARS_PER_TOKEN
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
123
|
+
)
|
|
124
|
+
if calculation < self.config.MAX_TOTAL_TOKENS:
|
|
125
|
+
self.parts = [self.data]
|
|
126
|
+
else:
|
|
127
|
+
# Partition into chunks of MAX_BATCH_SIZE
|
|
128
|
+
self.parts = [
|
|
129
|
+
self.data[i : i + self.config.MAX_BATCH_SIZE]
|
|
130
|
+
for i in range(0, len(self.data), self.config.MAX_BATCH_SIZE)
|
|
131
|
+
]
|
|
132
|
+
logger.info(f"Data split into {len(self.parts)} part(s)")
|
|
133
|
+
|
|
134
|
+
def _submit_all_jobs(self) -> None:
|
|
135
|
+
for idx, part in enumerate(self.parts):
|
|
136
|
+
if self._result_exists(idx):
|
|
137
|
+
logger.info(f"Skipping part {idx + 1}: result already exists.")
|
|
138
|
+
continue
|
|
139
|
+
part_job_name = (
|
|
140
|
+
f"{self.job_name}_part_{idx + 1}"
|
|
141
|
+
if len(self.parts) > 1
|
|
142
|
+
else self.job_name
|
|
143
|
+
)
|
|
144
|
+
# If a job with this name already exists, register and skip submitting
|
|
145
|
+
existing_job = self.manager._load_state(part_job_name)
|
|
146
|
+
if existing_job:
|
|
147
|
+
logger.info(
|
|
148
|
+
f"Skipping part {idx + 1}: job already exists ({part_job_name})."
|
|
149
|
+
)
|
|
150
|
+
self.part_idx_to_job_name[idx] = part_job_name
|
|
151
|
+
self.part_attempts.setdefault(idx, 0)
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
payload = part
|
|
155
|
+
logger.info(
|
|
156
|
+
f"Submitting job for part {idx + 1}/{len(self.parts)}: {part_job_name}"
|
|
157
|
+
)
|
|
158
|
+
self.manager.start(payload, job_name=part_job_name)
|
|
159
|
+
self.part_idx_to_job_name[idx] = part_job_name
|
|
160
|
+
self.part_attempts.setdefault(idx, 0)
|
|
161
|
+
# This is added for letting file get uploaded, before starting the next part.
|
|
162
|
+
logger.info("Uploading...")
|
|
163
|
+
time.sleep(30)
|
|
164
|
+
|
|
165
|
+
def run(self):
|
|
166
|
+
# Submit all jobs up-front for concurrent execution
|
|
167
|
+
self._submit_all_jobs()
|
|
168
|
+
pending_parts: set[int] = set(self.part_idx_to_job_name.keys())
|
|
169
|
+
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
170
|
+
# Polling loop
|
|
171
|
+
while pending_parts:
|
|
172
|
+
finished_this_round: list[int] = []
|
|
173
|
+
for part_idx in list(pending_parts):
|
|
174
|
+
job_name = self.part_idx_to_job_name[part_idx]
|
|
175
|
+
status = self.manager.check_status(job_name=job_name)
|
|
176
|
+
logger.info(f"Status for {job_name}: {status}")
|
|
177
|
+
if status == "completed":
|
|
178
|
+
logger.info(
|
|
179
|
+
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
180
|
+
)
|
|
181
|
+
output_data, log = self.manager.fetch_results(
|
|
182
|
+
job_name=job_name, remove_cache=False
|
|
183
|
+
)
|
|
184
|
+
output_data = self.config.import_function(output_data)
|
|
185
|
+
self._save_results(output_data, log, part_idx)
|
|
186
|
+
logger.info(f"Fetched and saved results for part {part_idx + 1}.")
|
|
187
|
+
finished_this_round.append(part_idx)
|
|
188
|
+
elif status == "failed":
|
|
189
|
+
attempt = self.part_attempts.get(part_idx, 0) + 1
|
|
190
|
+
self.part_attempts[part_idx] = attempt
|
|
191
|
+
if attempt <= self.config.max_retries:
|
|
192
|
+
logger.info(
|
|
193
|
+
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
194
|
+
)
|
|
195
|
+
self.manager._clear_state(job_name)
|
|
196
|
+
time.sleep(10)
|
|
197
|
+
payload = self._to_manager_payload(self.parts[part_idx])
|
|
198
|
+
new_job_name = (
|
|
199
|
+
f"{self.job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
200
|
+
)
|
|
201
|
+
self.manager.start(payload, job_name=new_job_name)
|
|
202
|
+
self.part_idx_to_job_name[part_idx] = new_job_name
|
|
203
|
+
else:
|
|
204
|
+
logger.info(
|
|
205
|
+
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
206
|
+
)
|
|
207
|
+
finished_this_round.append(part_idx)
|
|
208
|
+
else:
|
|
209
|
+
# Still running or queued
|
|
210
|
+
continue
|
|
211
|
+
# Remove finished parts
|
|
212
|
+
for part_idx in finished_this_round:
|
|
213
|
+
pending_parts.discard(part_idx)
|
|
214
|
+
if pending_parts:
|
|
215
|
+
logger.info(
|
|
216
|
+
f"Waiting {self.config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
217
|
+
)
|
|
218
|
+
time.sleep(self.config.poll_interval_seconds)
|
|
219
|
+
|
|
220
|
+
def _save_results(
|
|
221
|
+
self,
|
|
222
|
+
output_data: list[dict[str, Any]] | dict[str, Any],
|
|
223
|
+
log: list[Any],
|
|
224
|
+
part_idx: int,
|
|
225
|
+
):
|
|
226
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self.parts) > 1 else ""
|
|
227
|
+
result_path = (
|
|
228
|
+
Path(self.config.BASE_OUTPUT_DIR)
|
|
229
|
+
/ f"{Path(self.output_data_filename).stem}{part_suffix}.json"
|
|
230
|
+
)
|
|
231
|
+
if not output_data:
|
|
232
|
+
logger.info("No output data to save. Skipping this part.")
|
|
233
|
+
return
|
|
234
|
+
else:
|
|
235
|
+
with open(result_path, "w", encoding="utf-8") as f:
|
|
236
|
+
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
237
|
+
if log:
|
|
238
|
+
log_path = (
|
|
239
|
+
Path(self.config.BASE_OUTPUT_DIR)
|
|
240
|
+
/ f"{Path(self.output_data_filename).stem}{part_suffix}_log.json"
|
|
241
|
+
)
|
|
242
|
+
with open(log_path, "w", encoding="utf-8") as f:
|
|
243
|
+
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
244
|
+
|
|
245
|
+
def _result_exists(self, part_idx: int) -> bool:
|
|
246
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self.parts) > 1 else ""
|
|
247
|
+
result_path = (
|
|
248
|
+
Path(self.config.BASE_OUTPUT_DIR)
|
|
249
|
+
/ f"{Path(self.output_data_filename).stem}{part_suffix}.json"
|
|
250
|
+
)
|
|
251
|
+
return result_path.exists()
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
if __name__ == "__main__":
|
|
255
|
+
logger.info("=== Batch Job Runner ===")
|
|
256
|
+
config = BatchConfig(
|
|
257
|
+
system_prompt="",
|
|
258
|
+
job_name="job_name",
|
|
259
|
+
input_data_path="Data.json",
|
|
260
|
+
output_data_filename="output",
|
|
261
|
+
)
|
|
262
|
+
runner = BatchJobRunner(config)
|
|
263
|
+
runner.run()
|
|
@@ -1,33 +1,33 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Any
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
class BaseFormatter(ABC):
|
|
6
|
-
"""
|
|
7
|
-
Adapter to convert a conversation into a specific LLM API's input format.
|
|
8
|
-
|
|
9
|
-
Concrete implementations transform standardized messages (e.g., list[dict]) into the
|
|
10
|
-
exact payload required by a provider (e.g., OpenAI's message list, a single string, etc.).
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
@abstractmethod
|
|
14
|
-
def format(
|
|
15
|
-
self,
|
|
16
|
-
messages: Any,
|
|
17
|
-
) -> Any:
|
|
18
|
-
"""
|
|
19
|
-
Transform the input messages into a provider-specific payload.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
messages: The input conversation. While often a list of dicts with
|
|
23
|
-
'role' and 'content' keys, the exact type and structure may vary
|
|
24
|
-
by implementation.
|
|
25
|
-
|
|
26
|
-
Returns:
|
|
27
|
-
A payload in the format expected by the target LLM API. This could be:
|
|
28
|
-
- A list of role-content dictionaries (e.g., for OpenAI)
|
|
29
|
-
- A single formatted string (e.g., for completion-style APIs)
|
|
30
|
-
- A complex dictionary with additional parameters
|
|
31
|
-
- Any other provider-specific data structure
|
|
32
|
-
"""
|
|
33
|
-
pass
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseFormatter(ABC):
|
|
6
|
+
"""
|
|
7
|
+
Adapter to convert a conversation into a specific LLM API's input format.
|
|
8
|
+
|
|
9
|
+
Concrete implementations transform standardized messages (e.g., list[dict]) into the
|
|
10
|
+
exact payload required by a provider (e.g., OpenAI's message list, a single string, etc.).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def format(
|
|
15
|
+
self,
|
|
16
|
+
messages: Any,
|
|
17
|
+
) -> Any:
|
|
18
|
+
"""
|
|
19
|
+
Transform the input messages into a provider-specific payload.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
messages: The input conversation. While often a list of dicts with
|
|
23
|
+
'role' and 'content' keys, the exact type and structure may vary
|
|
24
|
+
by implementation.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
A payload in the format expected by the target LLM API. This could be:
|
|
28
|
+
- A list of role-content dictionaries (e.g., for OpenAI)
|
|
29
|
+
- A single formatted string (e.g., for completion-style APIs)
|
|
30
|
+
- A complex dictionary with additional parameters
|
|
31
|
+
- Any other provider-specific data structure
|
|
32
|
+
"""
|
|
33
|
+
pass
|
texttools/formatters/{user_merge_formatter/user_merge_formatter.py → user_merge_formatter.py}
RENAMED
|
@@ -1,30 +1,30 @@
|
|
|
1
|
-
from texttools.formatters.base_formatter import BaseFormatter
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
class UserMergeFormatter(BaseFormatter):
|
|
5
|
-
"""
|
|
6
|
-
Merges consecutive user messages into a single message, separated by newlines.
|
|
7
|
-
|
|
8
|
-
This is useful for condensing a multi-turn user input into a single coherent
|
|
9
|
-
message for the LLM. Assistant and system messages are left unchanged and
|
|
10
|
-
act as separators between user message groups.
|
|
11
|
-
|
|
12
|
-
Raises:
|
|
13
|
-
ValueError: If the input messages have invalid structure or roles.
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
def format(self, messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
|
17
|
-
merged: list[dict[str, str]] = []
|
|
18
|
-
|
|
19
|
-
for message in messages:
|
|
20
|
-
role, content = message["role"], message["content"].strip()
|
|
21
|
-
|
|
22
|
-
# Merge with previous user turn
|
|
23
|
-
if merged and role == "user" and merged[-1]["role"] == "user":
|
|
24
|
-
merged[-1]["content"] += "\n" + content
|
|
25
|
-
|
|
26
|
-
# Otherwise, start a new turn
|
|
27
|
-
else:
|
|
28
|
-
merged.append({"role": role, "content": content})
|
|
29
|
-
|
|
30
|
-
return merged
|
|
1
|
+
from texttools.formatters.base_formatter import BaseFormatter
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class UserMergeFormatter(BaseFormatter):
|
|
5
|
+
"""
|
|
6
|
+
Merges consecutive user messages into a single message, separated by newlines.
|
|
7
|
+
|
|
8
|
+
This is useful for condensing a multi-turn user input into a single coherent
|
|
9
|
+
message for the LLM. Assistant and system messages are left unchanged and
|
|
10
|
+
act as separators between user message groups.
|
|
11
|
+
|
|
12
|
+
Raises:
|
|
13
|
+
ValueError: If the input messages have invalid structure or roles.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def format(self, messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
|
17
|
+
merged: list[dict[str, str]] = []
|
|
18
|
+
|
|
19
|
+
for message in messages:
|
|
20
|
+
role, content = message["role"], message["content"].strip()
|
|
21
|
+
|
|
22
|
+
# Merge with previous user turn
|
|
23
|
+
if merged and role == "user" and merged[-1]["role"] == "user":
|
|
24
|
+
merged[-1]["content"] += "\n" + content
|
|
25
|
+
|
|
26
|
+
# Otherwise, start a new turn
|
|
27
|
+
else:
|
|
28
|
+
merged.append({"role": role, "content": content})
|
|
29
|
+
|
|
30
|
+
return merged
|