hamtaa-texttools 1.2.0__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
- import sys
2
1
  from collections.abc import Callable
3
2
  from time import perf_counter
4
3
  from typing import Any, Literal
@@ -66,7 +65,7 @@ class TheTool:
66
65
  ToolOutput
67
66
 
68
67
  """
69
- tool_name = sys._getframe().f_code.co_name
68
+ tool_name = "categorize"
70
69
  start = perf_counter()
71
70
 
72
71
  try:
@@ -198,7 +197,7 @@ class TheTool:
198
197
  Returns:
199
198
  ToolOutput
200
199
  """
201
- tool_name = sys._getframe().f_code.co_name
200
+ tool_name = "extract_keywords"
202
201
  start = perf_counter()
203
202
 
204
203
  try:
@@ -211,7 +210,6 @@ class TheTool:
211
210
  temperature=temperature,
212
211
  logprobs=logprobs,
213
212
  top_logprobs=top_logprobs,
214
- mode=mode,
215
213
  number_of_keywords=number_of_keywords,
216
214
  validator=validator,
217
215
  max_validation_retries=max_validation_retries,
@@ -219,6 +217,7 @@ class TheTool:
219
217
  # Internal parameters
220
218
  tool_name=tool_name,
221
219
  output_model=ListStr,
220
+ mode=mode,
222
221
  )
223
222
 
224
223
  metadata = ToolOutputMetadata(
@@ -272,7 +271,7 @@ class TheTool:
272
271
  Returns:
273
272
  ToolOutput
274
273
  """
275
- tool_name = sys._getframe().f_code.co_name
274
+ tool_name = "extract_entities"
276
275
  start = perf_counter()
277
276
 
278
277
  try:
@@ -343,7 +342,7 @@ class TheTool:
343
342
  Returns:
344
343
  ToolOutput
345
344
  """
346
- tool_name = sys._getframe().f_code.co_name
345
+ tool_name = "is_question"
347
346
  start = perf_counter()
348
347
 
349
348
  try:
@@ -416,7 +415,7 @@ class TheTool:
416
415
  Returns:
417
416
  ToolOutput
418
417
  """
419
- tool_name = sys._getframe().f_code.co_name
418
+ tool_name = "text_to_question"
420
419
  start = perf_counter()
421
420
 
422
421
  try:
@@ -489,7 +488,7 @@ class TheTool:
489
488
  Returns:
490
489
  ToolOutput
491
490
  """
492
- tool_name = sys._getframe().f_code.co_name
491
+ tool_name = "merge_questions"
493
492
  start = perf_counter()
494
493
 
495
494
  try:
@@ -562,7 +561,7 @@ class TheTool:
562
561
  Returns:
563
562
  ToolOutput
564
563
  """
565
- tool_name = sys._getframe().f_code.co_name
564
+ tool_name = "rewrite"
566
565
  start = perf_counter()
567
566
 
568
567
  try:
@@ -635,7 +634,7 @@ class TheTool:
635
634
  Returns:
636
635
  ToolOutput
637
636
  """
638
- tool_name = sys._getframe().f_code.co_name
637
+ tool_name = "subject_to_question"
639
638
  start = perf_counter()
640
639
 
641
640
  try:
@@ -707,7 +706,7 @@ class TheTool:
707
706
  Returns:
708
707
  ToolOutput
709
708
  """
710
- tool_name = sys._getframe().f_code.co_name
709
+ tool_name = "summarize"
711
710
  start = perf_counter()
712
711
 
713
712
  try:
@@ -782,7 +781,7 @@ class TheTool:
782
781
  Returns:
783
782
  ToolOutput
784
783
  """
785
- tool_name = sys._getframe().f_code.co_name
784
+ tool_name = "translate"
786
785
  start = perf_counter()
787
786
 
788
787
  try:
@@ -900,7 +899,7 @@ class TheTool:
900
899
  Returns:
901
900
  ToolOutput
902
901
  """
903
- tool_name = sys._getframe().f_code.co_name
902
+ tool_name = "propositionize"
904
903
  start = perf_counter()
905
904
 
906
905
  try:
@@ -961,7 +960,7 @@ class TheTool:
961
960
 
962
961
  Arguments:
963
962
  text: The input text
964
- source_text: the source text that we want to check relation of text to it
963
+ source_text: The source text that we want to check relation of text to it
965
964
  with_analysis: Whether to include detailed reasoning analysis
966
965
  output_lang: Language for the output
967
966
  user_prompt: Additional instructions
@@ -975,13 +974,14 @@ class TheTool:
975
974
  Returns:
976
975
  ToolOutput
977
976
  """
978
- tool_name = sys._getframe().f_code.co_name
977
+ tool_name = "check_fact"
979
978
  start = perf_counter()
980
979
 
981
980
  try:
982
981
  operator_output = self._operator.run(
983
982
  # User parameters
984
983
  text=text,
984
+ source_text=source_text,
985
985
  with_analysis=with_analysis,
986
986
  output_lang=output_lang,
987
987
  user_prompt=user_prompt,
@@ -995,7 +995,6 @@ class TheTool:
995
995
  tool_name=tool_name,
996
996
  output_model=Bool,
997
997
  mode=None,
998
- source_text=source_text,
999
998
  )
1000
999
 
1001
1000
  metadata = ToolOutputMetadata(
@@ -1049,7 +1048,7 @@ class TheTool:
1049
1048
  Returns:
1050
1049
  ToolOutput
1051
1050
  """
1052
- tool_name = sys._getframe().f_code.co_name
1051
+ tool_name = "run_custom"
1053
1052
  start = perf_counter()
1054
1053
 
1055
1054
  try:
File without changes
texttools/batch/config.py DELETED
@@ -1,40 +0,0 @@
1
- from collections.abc import Callable
2
- from dataclasses import dataclass
3
- from typing import Any
4
-
5
-
6
- def export_data(data) -> list[dict[str, str]]:
7
- """
8
- Produces a structure of the following form from an initial data structure:
9
- [{"id": str, "text": str},...]
10
- """
11
- return data
12
-
13
-
14
- def import_data(data) -> Any:
15
- """
16
- Takes the output and adds and aggregates it to the original structure.
17
- """
18
- return data
19
-
20
-
21
- @dataclass
22
- class BatchConfig:
23
- """
24
- Configuration for batch job runner.
25
- """
26
-
27
- system_prompt: str = ""
28
- job_name: str = ""
29
- input_data_path: str = ""
30
- output_data_filename: str = ""
31
- model: str = "gpt-4.1-mini"
32
- MAX_BATCH_SIZE: int = 100
33
- MAX_TOTAL_TOKENS: int = 2_000_000
34
- CHARS_PER_TOKEN: float = 2.7
35
- PROMPT_TOKEN_MULTIPLIER: int = 1_000
36
- BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
37
- import_function: Callable = import_data
38
- export_function: Callable = export_data
39
- poll_interval_seconds: int = 30
40
- max_retries: int = 3
@@ -1,228 +0,0 @@
1
- import json
2
- import logging
3
- import uuid
4
- from pathlib import Path
5
- from typing import Any, Type, TypeVar
6
-
7
- from openai import OpenAI
8
- from openai.lib._pydantic import to_strict_json_schema
9
- from pydantic import BaseModel
10
-
11
- # Base Model type for output models
12
- T = TypeVar("T", bound=BaseModel)
13
-
14
- logger = logging.getLogger("texttools.batch_manager")
15
-
16
-
17
- class BatchManager:
18
- """
19
- Manages batch processing jobs for OpenAI's chat completions with structured outputs.
20
-
21
- Handles the full lifecycle of a batch job: creating tasks from input texts,
22
- starting the job, monitoring status, and fetching results. Results are automatically
23
- parsed into the specified Pydantic output model. Job state is persisted to disk.
24
- """
25
-
26
- def __init__(
27
- self,
28
- client: OpenAI,
29
- model: str,
30
- output_model: Type[T],
31
- prompt_template: str,
32
- state_dir: Path = Path(".batch_jobs"),
33
- custom_json_schema_obj_str: dict | None = None,
34
- **client_kwargs: Any,
35
- ):
36
- self._client = client
37
- self._model = model
38
- self._output_model = output_model
39
- self._prompt_template = prompt_template
40
- self._state_dir = state_dir
41
- self._custom_json_schema_obj_str = custom_json_schema_obj_str
42
- self._client_kwargs = client_kwargs
43
- self._dict_input = False
44
- self._state_dir.mkdir(parents=True, exist_ok=True)
45
-
46
- if custom_json_schema_obj_str and not isinstance(
47
- custom_json_schema_obj_str, dict
48
- ):
49
- raise ValueError("Schema should be a dict")
50
-
51
- def _state_file(self, job_name: str) -> Path:
52
- return self._state_dir / f"{job_name}.json"
53
-
54
- def _load_state(self, job_name: str) -> list[dict[str, Any]]:
55
- """
56
- Loads the state (job information) from the state file for the given job name.
57
- Returns an empty list if the state file does not exist.
58
- """
59
- path = self._state_file(job_name)
60
- if path.exists():
61
- with open(path, "r", encoding="utf-8") as f:
62
- return json.load(f)
63
- return []
64
-
65
- def _save_state(self, job_name: str, jobs: list[dict[str, Any]]) -> None:
66
- """
67
- Saves the job state to the state file for the given job name.
68
- """
69
- with open(self._state_file(job_name), "w", encoding="utf-8") as f:
70
- json.dump(jobs, f)
71
-
72
- def _clear_state(self, job_name: str) -> None:
73
- """
74
- Deletes the state file for the given job name if it exists.
75
- """
76
- path = self._state_file(job_name)
77
- if path.exists():
78
- path.unlink()
79
-
80
- def _build_task(self, text: str, idx: str) -> dict[str, Any]:
81
- """
82
- Builds a single task dictionary for the batch job, including the prompt, model, and response format configuration.
83
- """
84
- response_format_config: dict[str, Any]
85
-
86
- if self._custom_json_schema_obj_str:
87
- response_format_config = {
88
- "type": "json_schema",
89
- "json_schema": self._custom_json_schema_obj_str,
90
- }
91
- else:
92
- raw_schema = to_strict_json_schema(self._output_model)
93
- response_format_config = {
94
- "type": "json_schema",
95
- "json_schema": {
96
- "name": self._output_model.__name__,
97
- "schema": raw_schema,
98
- },
99
- }
100
-
101
- return {
102
- "custom_id": str(idx),
103
- "method": "POST",
104
- "url": "/v1/chat/completions",
105
- "body": {
106
- "model": self.model,
107
- "messages": [
108
- {"role": "system", "content": self._prompt_template},
109
- {"role": "user", "content": text},
110
- ],
111
- "response_format": response_format_config,
112
- **self._client_kwargs,
113
- },
114
- }
115
-
116
- def _prepare_file(self, payload: list[str] | list[dict[str, str]]) -> Path:
117
- """
118
- Prepares a JSONL file containing all tasks for the batch job, based on the input payload.
119
- Returns the path to the created file.
120
- """
121
- if not payload:
122
- raise ValueError("Payload must not be empty")
123
- if isinstance(payload[0], str):
124
- tasks = [self._build_task(text, uuid.uuid4().hex) for text in payload]
125
- elif isinstance(payload[0], dict):
126
- tasks = [self._build_task(dic["text"], dic["id"]) for dic in payload]
127
-
128
- else:
129
- raise TypeError(
130
- "The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}"
131
- )
132
-
133
- file_path = self._state_dir / f"batch_{uuid.uuid4().hex}.jsonl"
134
- with open(file_path, "w", encoding="utf-8") as f:
135
- for task in tasks:
136
- f.write(json.dumps(task) + "\n")
137
- return file_path
138
-
139
- def start(self, payload: list[str | dict[str, str]], job_name: str):
140
- """
141
- Starts a new batch job by uploading the prepared file and creating a batch job on the server.
142
- If a job with the same name already exists, it does nothing.
143
- """
144
- if self._load_state(job_name):
145
- return
146
-
147
- path = self._prepare_file(payload)
148
- upload = self._client.files.create(file=open(path, "rb"), purpose="batch")
149
- job = self._client.batches.create(
150
- input_file_id=upload.id,
151
- endpoint="/v1/chat/completions",
152
- completion_window="24h",
153
- ).to_dict()
154
- self._save_state(job_name, [job])
155
-
156
- def check_status(self, job_name: str) -> str:
157
- """
158
- Checks and returns the current status of the batch job with the given job name.
159
- Updates the job state with the latest information from the server.
160
- """
161
- job = self._load_state(job_name)[0]
162
- if not job:
163
- return "completed"
164
-
165
- info = self._client.batches.retrieve(job["id"])
166
- job = info.to_dict()
167
- self._save_state(job_name, [job])
168
- logger.info("Batch job status: %s", job)
169
- return job["status"]
170
-
171
- def fetch_results(
172
- self, job_name: str, remove_cache: bool = True
173
- ) -> tuple[dict[str, str], list]:
174
- """
175
- Fetches the results of a completed batch job. Optionally saves the results to a file and/or removes the job cache.
176
- Returns a tuple containing the parsed results and a log of errors (if any).
177
- """
178
- job = self._load_state(job_name)[0]
179
- if not job:
180
- return {}
181
- batch_id = job["id"]
182
-
183
- info = self._client.batches.retrieve(batch_id)
184
- out_file_id = info.output_file_id
185
- if not out_file_id:
186
- error_file_id = info.error_file_id
187
- if error_file_id:
188
- err_content = (
189
- self._client.files.content(error_file_id).read().decode("utf-8")
190
- )
191
- logger.error("Error file content:", err_content)
192
- return {}
193
-
194
- content = self._client.files.content(out_file_id).read().decode("utf-8")
195
- lines = content.splitlines()
196
- results = {}
197
- log = []
198
- for line in lines:
199
- result = json.loads(line)
200
- custom_id = result["custom_id"]
201
- if result["response"]["status_code"] == 200:
202
- content = result["response"]["body"]["choices"][0]["message"]["content"]
203
- try:
204
- parsed_content = json.loads(content)
205
- model_instance = self._output_model(**parsed_content)
206
- results[custom_id] = model_instance.model_dump(mode="json")
207
- except json.JSONDecodeError:
208
- results[custom_id] = {"error": "Failed to parse content as JSON"}
209
- error_d = {custom_id: results[custom_id]}
210
- log.append(error_d)
211
- except Exception as e:
212
- results[custom_id] = {"error": str(e)}
213
- error_d = {custom_id: results[custom_id]}
214
- log.append(error_d)
215
- else:
216
- error_message = (
217
- result["response"]["body"]
218
- .get("error", {})
219
- .get("message", "Unknown error")
220
- )
221
- results[custom_id] = {"error": error_message}
222
- error_d = {custom_id: results[custom_id]}
223
- log.append(error_d)
224
-
225
- if remove_cache:
226
- self._clear_state(job_name)
227
-
228
- return results, log
texttools/batch/runner.py DELETED
@@ -1,228 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import time
5
- from pathlib import Path
6
- from typing import Any, Type, TypeVar
7
-
8
- from dotenv import load_dotenv
9
- from openai import OpenAI
10
- from pydantic import BaseModel
11
-
12
- from ..core.exceptions import TextToolsError
13
- from ..core.internal_models import Str
14
- from .config import BatchConfig
15
- from .manager import BatchManager
16
-
17
- # Base Model type for output models
18
- T = TypeVar("T", bound=BaseModel)
19
-
20
- logger = logging.getLogger("texttools.batch_runner")
21
-
22
-
23
- class BatchRunner:
24
- """
25
- Handles running batch jobs using a batch manager and configuration.
26
- """
27
-
28
- def __init__(
29
- self, config: BatchConfig = BatchConfig(), output_model: Type[T] = Str
30
- ):
31
- try:
32
- self._config = config
33
- self._system_prompt = config.system_prompt
34
- self._job_name = config.job_name
35
- self._input_data_path = config.input_data_path
36
- self._output_data_filename = config.output_data_filename
37
- self._model = config.model
38
- self._output_model = output_model
39
- self._manager = self._init_manager()
40
- self._data = self._load_data()
41
- self._parts: list[list[dict[str, Any]]] = []
42
- # Map part index to job name
43
- self._part_idx_to_job_name: dict[int, str] = {}
44
- # Track retry attempts per part
45
- self._part_attempts: dict[int, int] = {}
46
- self._partition_data()
47
- Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
48
-
49
- except Exception as e:
50
- raise TextToolsError(f"Batch runner initialization failed: {e}")
51
-
52
- def _init_manager(self) -> BatchManager:
53
- load_dotenv()
54
- api_key = os.getenv("OPENAI_API_KEY")
55
- client = OpenAI(api_key=api_key)
56
- return BatchManager(
57
- client=client,
58
- model=self._model,
59
- prompt_template=self._system_prompt,
60
- output_model=self._output_model,
61
- )
62
-
63
- def _load_data(self):
64
- with open(self._input_data_path, "r", encoding="utf-8") as f:
65
- data = json.load(f)
66
- data = self._config.export_function(data)
67
-
68
- # Ensure data is a list of dicts with 'id' and 'content' as strings
69
- if not isinstance(data, list):
70
- raise ValueError(
71
- "Exported data must be a list of dicts with 'id' and 'content' keys"
72
- )
73
- for item in data:
74
- if not (isinstance(item, dict) and "id" in item and "content" in item):
75
- raise ValueError(
76
- f"Item must be a dict with 'id' and 'content' keys. Got: {type(item)}"
77
- )
78
- if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
79
- raise ValueError("'id' and 'content' must be strings.")
80
- return data
81
-
82
- def _partition_data(self):
83
- total_length = sum(len(item["content"]) for item in self._data)
84
- prompt_length = len(self._system_prompt)
85
- total = total_length + (prompt_length * len(self._data))
86
- calculation = total / self._config.CHARS_PER_TOKEN
87
- logger.info(
88
- f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
89
- )
90
- if calculation < self._config.MAX_TOTAL_TOKENS:
91
- self._parts = [self._data]
92
- else:
93
- # Partition into chunks of MAX_BATCH_SIZE
94
- self._parts = [
95
- self._data[i : i + self._config.MAX_BATCH_SIZE]
96
- for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
97
- ]
98
- logger.info(f"Data split into {len(self._parts)} part(s)")
99
-
100
- def _submit_all_jobs(self) -> None:
101
- for idx, part in enumerate(self._parts):
102
- if self._result_exists(idx):
103
- logger.info(f"Skipping part {idx + 1}: result already exists.")
104
- continue
105
- part_job_name = (
106
- f"{self._job_name}_part_{idx + 1}"
107
- if len(self._parts) > 1
108
- else self._job_name
109
- )
110
- # If a job with this name already exists, register and skip submitting
111
- existing_job = self._manager._load_state(part_job_name)
112
- if existing_job:
113
- logger.info(
114
- f"Skipping part {idx + 1}: job already exists ({part_job_name})."
115
- )
116
- self._part_idx_to_job_name[idx] = part_job_name
117
- self._part_attempts.setdefault(idx, 0)
118
- continue
119
-
120
- payload = part
121
- logger.info(
122
- f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
123
- )
124
- self._manager.start(payload, job_name=part_job_name)
125
- self._part_idx_to_job_name[idx] = part_job_name
126
- self._part_attempts.setdefault(idx, 0)
127
- # This is added for letting file get uploaded, before starting the next part.
128
- logger.info("Uploading...")
129
- time.sleep(30)
130
-
131
- def _save_results(
132
- self,
133
- output_data: list[dict[str, Any]] | dict[str, Any],
134
- log: list[Any],
135
- part_idx: int,
136
- ):
137
- part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
138
- result_path = (
139
- Path(self._config.BASE_OUTPUT_DIR)
140
- / f"{Path(self._output_data_filename).stem}{part_suffix}.json"
141
- )
142
- if not output_data:
143
- logger.info("No output data to save. Skipping this part.")
144
- return
145
- else:
146
- with open(result_path, "w", encoding="utf-8") as f:
147
- json.dump(output_data, f, ensure_ascii=False, indent=4)
148
- if log:
149
- log_path = (
150
- Path(self._config.BASE_OUTPUT_DIR)
151
- / f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
152
- )
153
- with open(log_path, "w", encoding="utf-8") as f:
154
- json.dump(log, f, ensure_ascii=False, indent=4)
155
-
156
- def _result_exists(self, part_idx: int) -> bool:
157
- part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
158
- result_path = (
159
- Path(self._config.BASE_OUTPUT_DIR)
160
- / f"{Path(self._output_data_filename).stem}{part_suffix}.json"
161
- )
162
- return result_path.exists()
163
-
164
- def run(self):
165
- """
166
- Execute the batch job processing pipeline.
167
-
168
- Submits jobs, monitors progress, handles retries, and saves results.
169
- """
170
- try:
171
- # Submit all jobs up-front for concurrent execution
172
- self._submit_all_jobs()
173
- pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
174
- logger.info(f"Pending parts: {sorted(pending_parts)}")
175
- # Polling loop
176
- while pending_parts:
177
- finished_this_round: list[int] = []
178
- for part_idx in list(pending_parts):
179
- job_name = self._part_idx_to_job_name[part_idx]
180
- status = self._manager.check_status(job_name=job_name)
181
- logger.info(f"Status for {job_name}: {status}")
182
- if status == "completed":
183
- logger.info(
184
- f"Job completed. Fetching results for part {part_idx + 1}..."
185
- )
186
- output_data, log = self._manager.fetch_results(
187
- job_name=job_name, remove_cache=False
188
- )
189
- output_data = self._config.import_function(output_data)
190
- self._save_results(output_data, log, part_idx)
191
- logger.info(
192
- f"Fetched and saved results for part {part_idx + 1}."
193
- )
194
- finished_this_round.append(part_idx)
195
- elif status == "failed":
196
- attempt = self._part_attempts.get(part_idx, 0) + 1
197
- self._part_attempts[part_idx] = attempt
198
- if attempt <= self._config.max_retries:
199
- logger.info(
200
- f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
201
- )
202
- self._manager._clear_state(job_name)
203
- time.sleep(10)
204
- payload = self._to_manager_payload(self._parts[part_idx])
205
- new_job_name = (
206
- f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
207
- )
208
- self._manager.start(payload, job_name=new_job_name)
209
- self._part_idx_to_job_name[part_idx] = new_job_name
210
- else:
211
- logger.info(
212
- f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
213
- )
214
- finished_this_round.append(part_idx)
215
- else:
216
- # Still running or queued
217
- continue
218
- # Remove finished parts
219
- for part_idx in finished_this_round:
220
- pending_parts.discard(part_idx)
221
- if pending_parts:
222
- logger.info(
223
- f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
224
- )
225
- time.sleep(self._config.poll_interval_seconds)
226
-
227
- except Exception as e:
228
- raise TextToolsError(f"Batch job execution failed: {e}")