hamtaa-texttools 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hamtaa_texttools-1.2.0.dist-info/METADATA +212 -0
- hamtaa_texttools-1.2.0.dist-info/RECORD +34 -0
- texttools/__init__.py +6 -8
- texttools/batch/__init__.py +0 -4
- texttools/batch/config.py +40 -0
- texttools/batch/{batch_manager.py → manager.py} +41 -42
- texttools/batch/runner.py +228 -0
- texttools/core/__init__.py +0 -0
- texttools/core/engine.py +254 -0
- texttools/core/exceptions.py +22 -0
- texttools/core/internal_models.py +58 -0
- texttools/core/operators/async_operator.py +194 -0
- texttools/core/operators/sync_operator.py +192 -0
- texttools/models.py +88 -0
- texttools/prompts/categorize.yaml +36 -0
- texttools/prompts/check_fact.yaml +24 -0
- texttools/prompts/extract_entities.yaml +7 -3
- texttools/prompts/extract_keywords.yaml +80 -18
- texttools/prompts/is_question.yaml +6 -2
- texttools/prompts/merge_questions.yaml +12 -5
- texttools/prompts/propositionize.yaml +24 -0
- texttools/prompts/rewrite.yaml +9 -10
- texttools/prompts/run_custom.yaml +2 -2
- texttools/prompts/subject_to_question.yaml +7 -3
- texttools/prompts/summarize.yaml +6 -2
- texttools/prompts/text_to_question.yaml +12 -6
- texttools/prompts/translate.yaml +7 -2
- texttools/py.typed +0 -0
- texttools/tools/__init__.py +0 -4
- texttools/tools/async_tools.py +1093 -0
- texttools/tools/sync_tools.py +1092 -0
- hamtaa_texttools-1.1.1.dist-info/METADATA +0 -183
- hamtaa_texttools-1.1.1.dist-info/RECORD +0 -30
- texttools/batch/batch_runner.py +0 -263
- texttools/prompts/README.md +0 -35
- texttools/prompts/categorizer.yaml +0 -28
- texttools/tools/async_the_tool.py +0 -414
- texttools/tools/internals/async_operator.py +0 -179
- texttools/tools/internals/base_operator.py +0 -91
- texttools/tools/internals/formatters.py +0 -24
- texttools/tools/internals/operator.py +0 -179
- texttools/tools/internals/output_models.py +0 -59
- texttools/tools/internals/prompt_loader.py +0 -57
- texttools/tools/the_tool.py +0 -412
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.2.0.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Type, TypeVar
|
|
7
|
+
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
from openai import OpenAI
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from ..core.exceptions import TextToolsError
|
|
13
|
+
from ..core.internal_models import Str
|
|
14
|
+
from .config import BatchConfig
|
|
15
|
+
from .manager import BatchManager
|
|
16
|
+
|
|
17
|
+
# Base Model type for output models
|
|
18
|
+
T = TypeVar("T", bound=BaseModel)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger("texttools.batch_runner")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BatchRunner:
|
|
24
|
+
"""
|
|
25
|
+
Handles running batch jobs using a batch manager and configuration.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self, config: BatchConfig = BatchConfig(), output_model: Type[T] = Str
|
|
30
|
+
):
|
|
31
|
+
try:
|
|
32
|
+
self._config = config
|
|
33
|
+
self._system_prompt = config.system_prompt
|
|
34
|
+
self._job_name = config.job_name
|
|
35
|
+
self._input_data_path = config.input_data_path
|
|
36
|
+
self._output_data_filename = config.output_data_filename
|
|
37
|
+
self._model = config.model
|
|
38
|
+
self._output_model = output_model
|
|
39
|
+
self._manager = self._init_manager()
|
|
40
|
+
self._data = self._load_data()
|
|
41
|
+
self._parts: list[list[dict[str, Any]]] = []
|
|
42
|
+
# Map part index to job name
|
|
43
|
+
self._part_idx_to_job_name: dict[int, str] = {}
|
|
44
|
+
# Track retry attempts per part
|
|
45
|
+
self._part_attempts: dict[int, int] = {}
|
|
46
|
+
self._partition_data()
|
|
47
|
+
Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
48
|
+
|
|
49
|
+
except Exception as e:
|
|
50
|
+
raise TextToolsError(f"Batch runner initialization failed: {e}")
|
|
51
|
+
|
|
52
|
+
def _init_manager(self) -> BatchManager:
|
|
53
|
+
load_dotenv()
|
|
54
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
55
|
+
client = OpenAI(api_key=api_key)
|
|
56
|
+
return BatchManager(
|
|
57
|
+
client=client,
|
|
58
|
+
model=self._model,
|
|
59
|
+
prompt_template=self._system_prompt,
|
|
60
|
+
output_model=self._output_model,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def _load_data(self):
|
|
64
|
+
with open(self._input_data_path, "r", encoding="utf-8") as f:
|
|
65
|
+
data = json.load(f)
|
|
66
|
+
data = self._config.export_function(data)
|
|
67
|
+
|
|
68
|
+
# Ensure data is a list of dicts with 'id' and 'content' as strings
|
|
69
|
+
if not isinstance(data, list):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"Exported data must be a list of dicts with 'id' and 'content' keys"
|
|
72
|
+
)
|
|
73
|
+
for item in data:
|
|
74
|
+
if not (isinstance(item, dict) and "id" in item and "content" in item):
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Item must be a dict with 'id' and 'content' keys. Got: {type(item)}"
|
|
77
|
+
)
|
|
78
|
+
if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
|
|
79
|
+
raise ValueError("'id' and 'content' must be strings.")
|
|
80
|
+
return data
|
|
81
|
+
|
|
82
|
+
def _partition_data(self):
|
|
83
|
+
total_length = sum(len(item["content"]) for item in self._data)
|
|
84
|
+
prompt_length = len(self._system_prompt)
|
|
85
|
+
total = total_length + (prompt_length * len(self._data))
|
|
86
|
+
calculation = total / self._config.CHARS_PER_TOKEN
|
|
87
|
+
logger.info(
|
|
88
|
+
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
89
|
+
)
|
|
90
|
+
if calculation < self._config.MAX_TOTAL_TOKENS:
|
|
91
|
+
self._parts = [self._data]
|
|
92
|
+
else:
|
|
93
|
+
# Partition into chunks of MAX_BATCH_SIZE
|
|
94
|
+
self._parts = [
|
|
95
|
+
self._data[i : i + self._config.MAX_BATCH_SIZE]
|
|
96
|
+
for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
|
|
97
|
+
]
|
|
98
|
+
logger.info(f"Data split into {len(self._parts)} part(s)")
|
|
99
|
+
|
|
100
|
+
def _submit_all_jobs(self) -> None:
|
|
101
|
+
for idx, part in enumerate(self._parts):
|
|
102
|
+
if self._result_exists(idx):
|
|
103
|
+
logger.info(f"Skipping part {idx + 1}: result already exists.")
|
|
104
|
+
continue
|
|
105
|
+
part_job_name = (
|
|
106
|
+
f"{self._job_name}_part_{idx + 1}"
|
|
107
|
+
if len(self._parts) > 1
|
|
108
|
+
else self._job_name
|
|
109
|
+
)
|
|
110
|
+
# If a job with this name already exists, register and skip submitting
|
|
111
|
+
existing_job = self._manager._load_state(part_job_name)
|
|
112
|
+
if existing_job:
|
|
113
|
+
logger.info(
|
|
114
|
+
f"Skipping part {idx + 1}: job already exists ({part_job_name})."
|
|
115
|
+
)
|
|
116
|
+
self._part_idx_to_job_name[idx] = part_job_name
|
|
117
|
+
self._part_attempts.setdefault(idx, 0)
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
payload = part
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
|
|
123
|
+
)
|
|
124
|
+
self._manager.start(payload, job_name=part_job_name)
|
|
125
|
+
self._part_idx_to_job_name[idx] = part_job_name
|
|
126
|
+
self._part_attempts.setdefault(idx, 0)
|
|
127
|
+
# This is added for letting file get uploaded, before starting the next part.
|
|
128
|
+
logger.info("Uploading...")
|
|
129
|
+
time.sleep(30)
|
|
130
|
+
|
|
131
|
+
def _save_results(
|
|
132
|
+
self,
|
|
133
|
+
output_data: list[dict[str, Any]] | dict[str, Any],
|
|
134
|
+
log: list[Any],
|
|
135
|
+
part_idx: int,
|
|
136
|
+
):
|
|
137
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
138
|
+
result_path = (
|
|
139
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
140
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
141
|
+
)
|
|
142
|
+
if not output_data:
|
|
143
|
+
logger.info("No output data to save. Skipping this part.")
|
|
144
|
+
return
|
|
145
|
+
else:
|
|
146
|
+
with open(result_path, "w", encoding="utf-8") as f:
|
|
147
|
+
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
148
|
+
if log:
|
|
149
|
+
log_path = (
|
|
150
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
151
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
|
|
152
|
+
)
|
|
153
|
+
with open(log_path, "w", encoding="utf-8") as f:
|
|
154
|
+
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
155
|
+
|
|
156
|
+
def _result_exists(self, part_idx: int) -> bool:
|
|
157
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
158
|
+
result_path = (
|
|
159
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
160
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
161
|
+
)
|
|
162
|
+
return result_path.exists()
|
|
163
|
+
|
|
164
|
+
def run(self):
|
|
165
|
+
"""
|
|
166
|
+
Execute the batch job processing pipeline.
|
|
167
|
+
|
|
168
|
+
Submits jobs, monitors progress, handles retries, and saves results.
|
|
169
|
+
"""
|
|
170
|
+
try:
|
|
171
|
+
# Submit all jobs up-front for concurrent execution
|
|
172
|
+
self._submit_all_jobs()
|
|
173
|
+
pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
|
|
174
|
+
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
175
|
+
# Polling loop
|
|
176
|
+
while pending_parts:
|
|
177
|
+
finished_this_round: list[int] = []
|
|
178
|
+
for part_idx in list(pending_parts):
|
|
179
|
+
job_name = self._part_idx_to_job_name[part_idx]
|
|
180
|
+
status = self._manager.check_status(job_name=job_name)
|
|
181
|
+
logger.info(f"Status for {job_name}: {status}")
|
|
182
|
+
if status == "completed":
|
|
183
|
+
logger.info(
|
|
184
|
+
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
185
|
+
)
|
|
186
|
+
output_data, log = self._manager.fetch_results(
|
|
187
|
+
job_name=job_name, remove_cache=False
|
|
188
|
+
)
|
|
189
|
+
output_data = self._config.import_function(output_data)
|
|
190
|
+
self._save_results(output_data, log, part_idx)
|
|
191
|
+
logger.info(
|
|
192
|
+
f"Fetched and saved results for part {part_idx + 1}."
|
|
193
|
+
)
|
|
194
|
+
finished_this_round.append(part_idx)
|
|
195
|
+
elif status == "failed":
|
|
196
|
+
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
197
|
+
self._part_attempts[part_idx] = attempt
|
|
198
|
+
if attempt <= self._config.max_retries:
|
|
199
|
+
logger.info(
|
|
200
|
+
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
201
|
+
)
|
|
202
|
+
self._manager._clear_state(job_name)
|
|
203
|
+
time.sleep(10)
|
|
204
|
+
payload = self._to_manager_payload(self._parts[part_idx])
|
|
205
|
+
new_job_name = (
|
|
206
|
+
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
207
|
+
)
|
|
208
|
+
self._manager.start(payload, job_name=new_job_name)
|
|
209
|
+
self._part_idx_to_job_name[part_idx] = new_job_name
|
|
210
|
+
else:
|
|
211
|
+
logger.info(
|
|
212
|
+
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
213
|
+
)
|
|
214
|
+
finished_this_round.append(part_idx)
|
|
215
|
+
else:
|
|
216
|
+
# Still running or queued
|
|
217
|
+
continue
|
|
218
|
+
# Remove finished parts
|
|
219
|
+
for part_idx in finished_this_round:
|
|
220
|
+
pending_parts.discard(part_idx)
|
|
221
|
+
if pending_parts:
|
|
222
|
+
logger.info(
|
|
223
|
+
f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
224
|
+
)
|
|
225
|
+
time.sleep(self._config.poll_interval_seconds)
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
raise TextToolsError(f"Batch job execution failed: {e}")
|
|
File without changes
|
texttools/core/engine.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import random
|
|
3
|
+
import re
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
from .exceptions import PromptError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PromptLoader:
|
|
13
|
+
"""
|
|
14
|
+
Utility for loading and formatting YAML prompt templates.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
MAIN_TEMPLATE = "main_template"
|
|
18
|
+
ANALYZE_TEMPLATE = "analyze_template"
|
|
19
|
+
|
|
20
|
+
@lru_cache(maxsize=32)
|
|
21
|
+
def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
|
|
22
|
+
"""
|
|
23
|
+
Loads prompt templates from YAML file with optional mode selection.
|
|
24
|
+
"""
|
|
25
|
+
try:
|
|
26
|
+
base_dir = Path(__file__).parent.parent / Path("prompts")
|
|
27
|
+
prompt_path = base_dir / prompt_file
|
|
28
|
+
|
|
29
|
+
if not prompt_path.exists():
|
|
30
|
+
raise PromptError(f"Prompt file not found: {prompt_file}")
|
|
31
|
+
|
|
32
|
+
data = yaml.safe_load(prompt_path.read_text(encoding="utf-8"))
|
|
33
|
+
|
|
34
|
+
if self.MAIN_TEMPLATE not in data:
|
|
35
|
+
raise PromptError(f"Missing 'main_template' in {prompt_file}")
|
|
36
|
+
|
|
37
|
+
if self.ANALYZE_TEMPLATE not in data:
|
|
38
|
+
raise PromptError(f"Missing 'analyze_template' in {prompt_file}")
|
|
39
|
+
|
|
40
|
+
if mode and mode not in data.get(self.MAIN_TEMPLATE, {}):
|
|
41
|
+
raise PromptError(f"Mode '{mode}' not found in {prompt_file}")
|
|
42
|
+
|
|
43
|
+
main_template = (
|
|
44
|
+
data[self.MAIN_TEMPLATE][mode]
|
|
45
|
+
if mode and isinstance(data[self.MAIN_TEMPLATE], dict)
|
|
46
|
+
else data[self.MAIN_TEMPLATE]
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
analyze_template = (
|
|
50
|
+
data[self.ANALYZE_TEMPLATE][mode]
|
|
51
|
+
if mode and isinstance(data[self.ANALYZE_TEMPLATE], dict)
|
|
52
|
+
else data[self.ANALYZE_TEMPLATE]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if not main_template or not main_template.strip():
|
|
56
|
+
raise PromptError(
|
|
57
|
+
f"Empty main_template in {prompt_file}"
|
|
58
|
+
+ (f" for mode '{mode}'" if mode else "")
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return {
|
|
62
|
+
self.MAIN_TEMPLATE: main_template,
|
|
63
|
+
self.ANALYZE_TEMPLATE: analyze_template,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
except yaml.YAMLError as e:
|
|
67
|
+
raise PromptError(f"Invalid YAML in {prompt_file}: {e}")
|
|
68
|
+
except Exception as e:
|
|
69
|
+
raise PromptError(f"Failed to load prompt {prompt_file}: {e}")
|
|
70
|
+
|
|
71
|
+
def load(
|
|
72
|
+
self, prompt_file: str, text: str, mode: str, **extra_kwargs
|
|
73
|
+
) -> dict[str, str]:
|
|
74
|
+
try:
|
|
75
|
+
template_configs = self._load_templates(prompt_file, mode)
|
|
76
|
+
format_args = {"text": text}
|
|
77
|
+
format_args.update(extra_kwargs)
|
|
78
|
+
|
|
79
|
+
# Inject variables inside each template
|
|
80
|
+
for key in template_configs.keys():
|
|
81
|
+
template_configs[key] = template_configs[key].format(**format_args)
|
|
82
|
+
|
|
83
|
+
return template_configs
|
|
84
|
+
|
|
85
|
+
except KeyError as e:
|
|
86
|
+
raise PromptError(f"Missing template variable: {e}")
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise PromptError(f"Failed to format prompt: {e}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class OperatorUtils:
|
|
92
|
+
@staticmethod
|
|
93
|
+
def build_main_prompt(
|
|
94
|
+
main_template: str,
|
|
95
|
+
analysis: str | None,
|
|
96
|
+
output_lang: str | None,
|
|
97
|
+
user_prompt: str | None,
|
|
98
|
+
) -> str:
|
|
99
|
+
main_prompt = ""
|
|
100
|
+
|
|
101
|
+
if analysis:
|
|
102
|
+
main_prompt += f"Based on this analysis:\n{analysis}\n"
|
|
103
|
+
|
|
104
|
+
if output_lang:
|
|
105
|
+
main_prompt += f"Respond only in the {output_lang} language.\n"
|
|
106
|
+
|
|
107
|
+
if user_prompt:
|
|
108
|
+
main_prompt += f"Consider this instruction {user_prompt}\n"
|
|
109
|
+
|
|
110
|
+
main_prompt += main_template
|
|
111
|
+
|
|
112
|
+
return main_prompt
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def build_message(prompt: str) -> list[dict[str, str]]:
|
|
116
|
+
return [{"role": "user", "content": prompt}]
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def extract_logprobs(completion: dict) -> list[dict]:
|
|
120
|
+
"""
|
|
121
|
+
Extracts and filters token probabilities from completion logprobs.
|
|
122
|
+
Skips punctuation and structural tokens, returns cleaned probability data.
|
|
123
|
+
"""
|
|
124
|
+
logprobs_data = []
|
|
125
|
+
|
|
126
|
+
ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
|
|
127
|
+
|
|
128
|
+
for choice in completion.choices:
|
|
129
|
+
if not getattr(choice, "logprobs", None):
|
|
130
|
+
raise ValueError("Your model does not support logprobs")
|
|
131
|
+
|
|
132
|
+
for logprob_item in choice.logprobs.content:
|
|
133
|
+
if ignore_pattern.match(logprob_item.token):
|
|
134
|
+
continue
|
|
135
|
+
token_entry = {
|
|
136
|
+
"token": logprob_item.token,
|
|
137
|
+
"prob": round(math.exp(logprob_item.logprob), 8),
|
|
138
|
+
"top_alternatives": [],
|
|
139
|
+
}
|
|
140
|
+
for alt in logprob_item.top_logprobs:
|
|
141
|
+
if ignore_pattern.match(alt.token):
|
|
142
|
+
continue
|
|
143
|
+
token_entry["top_alternatives"].append(
|
|
144
|
+
{
|
|
145
|
+
"token": alt.token,
|
|
146
|
+
"prob": round(math.exp(alt.logprob), 8),
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
logprobs_data.append(token_entry)
|
|
150
|
+
|
|
151
|
+
return logprobs_data
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def get_retry_temp(base_temp: float) -> float:
|
|
155
|
+
delta_temp = random.choice([-1, 1]) * random.uniform(0.1, 0.9)
|
|
156
|
+
new_temp = base_temp + delta_temp
|
|
157
|
+
|
|
158
|
+
return max(0.0, min(new_temp, 1.5))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def text_to_chunks(text: str, size: int, overlap: int) -> list[str]:
|
|
162
|
+
separators = ["\n\n", "\n", " ", ""]
|
|
163
|
+
is_separator_regex = False
|
|
164
|
+
keep_separator = True # Equivalent to 'start'
|
|
165
|
+
length_function = len
|
|
166
|
+
strip_whitespace = True
|
|
167
|
+
chunk_size = size
|
|
168
|
+
chunk_overlap = overlap
|
|
169
|
+
|
|
170
|
+
def _split_text_with_regex(
|
|
171
|
+
text: str, separator: str, keep_separator: bool
|
|
172
|
+
) -> list[str]:
|
|
173
|
+
if not separator:
|
|
174
|
+
return [text]
|
|
175
|
+
if not keep_separator:
|
|
176
|
+
return re.split(separator, text)
|
|
177
|
+
_splits = re.split(f"({separator})", text)
|
|
178
|
+
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
|
179
|
+
if len(_splits) % 2 == 0:
|
|
180
|
+
splits += [_splits[-1]]
|
|
181
|
+
return [_splits[0]] + splits if _splits[0] else splits
|
|
182
|
+
|
|
183
|
+
def _join_docs(docs: list[str], separator: str) -> str | None:
|
|
184
|
+
text = separator.join(docs)
|
|
185
|
+
if strip_whitespace:
|
|
186
|
+
text = text.strip()
|
|
187
|
+
return text if text else None
|
|
188
|
+
|
|
189
|
+
def _merge_splits(splits: list[str], separator: str) -> list[str]:
|
|
190
|
+
separator_len = length_function(separator)
|
|
191
|
+
docs = []
|
|
192
|
+
current_doc = []
|
|
193
|
+
total = 0
|
|
194
|
+
for d in splits:
|
|
195
|
+
len_ = length_function(d)
|
|
196
|
+
if total + len_ + (separator_len if current_doc else 0) > chunk_size:
|
|
197
|
+
if total > chunk_size:
|
|
198
|
+
pass
|
|
199
|
+
if current_doc:
|
|
200
|
+
doc = _join_docs(current_doc, separator)
|
|
201
|
+
if doc is not None:
|
|
202
|
+
docs.append(doc)
|
|
203
|
+
while total > chunk_overlap or (
|
|
204
|
+
total + len_ + (separator_len if current_doc else 0)
|
|
205
|
+
> chunk_size
|
|
206
|
+
and total > 0
|
|
207
|
+
):
|
|
208
|
+
total -= length_function(current_doc[0]) + (
|
|
209
|
+
separator_len if len(current_doc) > 1 else 0
|
|
210
|
+
)
|
|
211
|
+
current_doc = current_doc[1:]
|
|
212
|
+
current_doc.append(d)
|
|
213
|
+
total += len_ + (separator_len if len(current_doc) > 1 else 0)
|
|
214
|
+
doc = _join_docs(current_doc, separator)
|
|
215
|
+
if doc is not None:
|
|
216
|
+
docs.append(doc)
|
|
217
|
+
return docs
|
|
218
|
+
|
|
219
|
+
def _split_text(text: str, separators: list[str]) -> list[str]:
|
|
220
|
+
final_chunks = []
|
|
221
|
+
separator = separators[-1]
|
|
222
|
+
new_separators = []
|
|
223
|
+
for i, _s in enumerate(separators):
|
|
224
|
+
separator_ = _s if is_separator_regex else re.escape(_s)
|
|
225
|
+
if not _s:
|
|
226
|
+
separator = _s
|
|
227
|
+
break
|
|
228
|
+
if re.search(separator_, text):
|
|
229
|
+
separator = _s
|
|
230
|
+
new_separators = separators[i + 1 :]
|
|
231
|
+
break
|
|
232
|
+
separator_ = separator if is_separator_regex else re.escape(separator)
|
|
233
|
+
splits = _split_text_with_regex(text, separator_, keep_separator)
|
|
234
|
+
_separator = "" if keep_separator else separator
|
|
235
|
+
good_splits = []
|
|
236
|
+
for s in splits:
|
|
237
|
+
if length_function(s) < chunk_size:
|
|
238
|
+
good_splits.append(s)
|
|
239
|
+
else:
|
|
240
|
+
if good_splits:
|
|
241
|
+
merged_text = _merge_splits(good_splits, _separator)
|
|
242
|
+
final_chunks.extend(merged_text)
|
|
243
|
+
good_splits = []
|
|
244
|
+
if not new_separators:
|
|
245
|
+
final_chunks.append(s)
|
|
246
|
+
else:
|
|
247
|
+
other_info = _split_text(s, new_separators)
|
|
248
|
+
final_chunks.extend(other_info)
|
|
249
|
+
if good_splits:
|
|
250
|
+
merged_text = _merge_splits(good_splits, _separator)
|
|
251
|
+
final_chunks.extend(merged_text)
|
|
252
|
+
return final_chunks
|
|
253
|
+
|
|
254
|
+
return _split_text(text, separators)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
class TextToolsError(Exception):
|
|
2
|
+
"""Base exception for all TextTools errors."""
|
|
3
|
+
|
|
4
|
+
pass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PromptError(TextToolsError):
|
|
8
|
+
"""Errors related to prompt loading and formatting."""
|
|
9
|
+
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LLMError(TextToolsError):
|
|
14
|
+
"""Errors from LLM API calls."""
|
|
15
|
+
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ValidationError(TextToolsError):
|
|
20
|
+
"""Errors from output validation."""
|
|
21
|
+
|
|
22
|
+
pass
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Any, Literal, Type
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, create_model
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OperatorOutput(BaseModel):
|
|
7
|
+
result: Any
|
|
8
|
+
analysis: str | None
|
|
9
|
+
logprobs: list[dict[str, Any]] | None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Str(BaseModel):
|
|
13
|
+
result: str = Field(..., description="The output string", example="text")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Bool(BaseModel):
|
|
17
|
+
result: bool = Field(
|
|
18
|
+
..., description="Boolean indicating the output state", example=True
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ListStr(BaseModel):
|
|
23
|
+
result: list[str] = Field(
|
|
24
|
+
..., description="The output list of strings", example=["text_1", "text_2"]
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ListDictStrStr(BaseModel):
|
|
29
|
+
result: list[dict[str, str]] = Field(
|
|
30
|
+
...,
|
|
31
|
+
description="List of dictionaries containing string key-value pairs",
|
|
32
|
+
example=[{"text": "Mohammad", "type": "PER"}, {"text": "Iran", "type": "LOC"}],
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ReasonListStr(BaseModel):
|
|
37
|
+
reason: str = Field(..., description="Thinking process that led to the output")
|
|
38
|
+
result: list[str] = Field(
|
|
39
|
+
..., description="The output list of strings", example=["text_1", "text_2"]
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# This function is needed to create CategorizerOutput with dynamic categories
|
|
44
|
+
def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
|
|
45
|
+
literal_type = Literal[*allowed_values]
|
|
46
|
+
|
|
47
|
+
CategorizerOutput = create_model(
|
|
48
|
+
"CategorizerOutput",
|
|
49
|
+
reason=(
|
|
50
|
+
str,
|
|
51
|
+
Field(
|
|
52
|
+
..., description="Explanation of why the input belongs to the category"
|
|
53
|
+
),
|
|
54
|
+
),
|
|
55
|
+
result=(literal_type, Field(..., description="Predicted category label")),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
return CategorizerOutput
|