hamtaa-texttools 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hamtaa-texttools
3
- Version: 1.3.0
3
+ Version: 1.3.2
4
4
  Summary: A high-level NLP toolkit built on top of modern LLMs.
5
5
  Author-email: Tohidi <the.mohammad.tohidi@gmail.com>, Erfan Moosavi <erfanmoosavi84@gmail.com>, Montazer <montazerh82@gmail.com>, Givechi <mohamad.m.givechi@gmail.com>, Zareshahi <a.zareshahi1377@gmail.com>
6
6
  Maintainer-email: Erfan Moosavi <erfanmoosavi84@gmail.com>, Tohidi <the.mohammad.tohidi@gmail.com>
@@ -21,6 +21,9 @@ Dynamic: license-file
21
21
 
22
22
  # TextTools
23
23
 
24
+ ![PyPI](https://img.shields.io/pypi/v/hamtaa-texttools)
25
+ ![License](https://img.shields.io/pypi/l/hamtaa-texttools)
26
+
24
27
  ## 📌 Overview
25
28
 
26
29
  **TextTools** is a high-level **NLP toolkit** built on top of **LLMs**.
@@ -44,11 +47,11 @@ Each tool is designed to work with structured outputs.
44
47
  - **`is_question()`** - Binary question detection
45
48
  - **`text_to_question()`** - Generates questions from text
46
49
  - **`merge_questions()`** - Merges multiple questions into one
47
- - **`rewrite()`** - Rewrites text in a diffrent way
48
- - **`subject_to_question()`** - Generates questions about a specific subject
50
+ - **`rewrite()`** - Rewrites text in a different way
51
+ - **`subject_to_question()`** - Generates questions about a given subject
49
52
  - **`summarize()`** - Text summarization
50
53
  - **`translate()`** - Text translation
51
- - **`propositionize()`** - Convert text to atomic independence meaningful sentences
54
+ - **`propositionize()`** - Convert text to atomic independent meaningful sentences
52
55
  - **`check_fact()`** - Check whether a statement is relevant to the source text
53
56
  - **`run_custom()`** - Allows users to define a custom tool with an arbitrary BaseModel
54
57
 
@@ -66,7 +69,7 @@ pip install -U hamtaa-texttools
66
69
 
67
70
  ## 📊 Tool Quality Tiers
68
71
 
69
- | Status | Meaning | Tools | Use in Production? |
72
+ | Status | Meaning | Tools | Safe for Production? |
70
73
  |--------|---------|----------|-------------------|
71
74
  | **✅ Production** | Evaluated, tested, stable. | `categorize()` (list mode), `extract_keywords()`, `extract_entities()`, `is_question()`, `text_to_question()`, `merge_questions()`, `rewrite()`, `subject_to_question()`, `summarize()`, `run_custom()` | **Yes** - ready for reliable use. |
72
75
  | **🧪 Experimental** | Added to the package but **not fully evaluated**. Functional, but quality may vary. | `categorize()` (tree mode), `translate()`, `propositionize()`, `check_fact()` | **Use with caution** - outputs not yet validated. |
@@ -177,40 +180,7 @@ Use **TextTools** when you need to:
177
180
 
178
181
  ---
179
182
 
180
- ## 📚 Batch Processing
181
-
182
- Process large datasets efficiently using OpenAI's batch API.
183
-
184
- ## ⚡ Quick Start (Batch Runner)
185
-
186
- ```python
187
- from pydantic import BaseModel
188
- from texttools import BatchRunner, BatchConfig
189
-
190
- config = BatchConfig(
191
- system_prompt="Extract entities from the text",
192
- job_name="entity_extraction",
193
- input_data_path="data.json",
194
- output_data_filename="results.json",
195
- model="gpt-4o-mini"
196
- )
197
-
198
- class Output(BaseModel):
199
- entities: list[str]
200
-
201
- runner = BatchRunner(config, output_model=Output)
202
- runner.run()
203
- ```
204
-
205
- ---
206
-
207
183
  ## 🤝 Contributing
208
184
 
209
185
  Contributions are welcome!
210
186
  Feel free to **open issues, suggest new features, or submit pull requests**.
211
-
212
- ---
213
-
214
- ## 🌿 License
215
-
216
- This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
@@ -1,17 +1,14 @@
1
- hamtaa_texttools-1.3.0.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
2
- texttools/__init__.py,sha256=4z7wInlrgbGSlWlXHQNeZMCGQH1sN2xtARsbgLHOLd8,283
1
+ hamtaa_texttools-1.3.2.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
2
+ texttools/__init__.py,sha256=RK1GAU6pq2lGwFtHdrCX5JkPRHmOLGcmGH67hd_7VAQ,175
3
3
  texttools/models.py,sha256=5eT2cSrFq8Xa38kANznV7gbi7lwB2PoDxciLKTpsd6c,2516
4
4
  texttools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- texttools/batch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- texttools/batch/config.py,sha256=GDDXuhRZ_bOGVwSIlU4tWP247tx1_A7qzLJn7VqDyLU,1050
7
- texttools/batch/manager.py,sha256=XZtf8UkdClfQlnRKne4nWEcFvdSKE67EamEePKy7jwI,8730
8
- texttools/batch/runner.py,sha256=9qxXIMfYRXW5SXDqqKtRr61rnQdYZkbCGqKImhSrY6I,9923
9
5
  texttools/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- texttools/core/engine.py,sha256=iRHdlIOPuUwIN6_72HNyTQQE7h_7xUZhC-WO-fDA5k8,9597
6
+ texttools/core/engine.py,sha256=AjifrcJl6PeRu1W6nu9zcxySn-1439Ef2La4d7GpNKY,9481
11
7
  texttools/core/exceptions.py,sha256=6SDjUL1rmd3ngzD3ytF4LyTRj3bQMSFR9ECrLoqXXHw,395
12
- texttools/core/internal_models.py,sha256=aExdLvhXhSev8NY1kuAJckeXdFBEisQtKZPxybd3rW8,1703
13
- texttools/core/operators/async_operator.py,sha256=wFs7eZ9QJrL0jBOu00YffgfPnIrCSavNjecSorXh-mE,6452
14
- texttools/core/operators/sync_operator.py,sha256=NaUS-aLh3y0QNMiKut4qtcSZKYXbuPbw0o2jvPsYKdY,6357
8
+ texttools/core/internal_models.py,sha256=J1qGEO8V0OoX6_-1yxbSmZSR79tJF0ExAIG1QuvH0L0,1734
9
+ texttools/core/operators/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ texttools/core/operators/async_operator.py,sha256=-72YQEGFkbk2uYW6PHkLT4wGxhj2p6Uqy3sJtVa9-rk,6386
11
+ texttools/core/operators/sync_operator.py,sha256=mfXtEOlIAhHo4SHaHRKjGb0Z1T894clv-toUzUcbfpo,6291
15
12
  texttools/prompts/categorize.yaml,sha256=42Rp3SgVHaDLKrJ27_uK788LiQud0pOXJthz4r0a40Y,1214
16
13
  texttools/prompts/check_fact.yaml,sha256=zWFQDRhEE1ij9wSeeenS9YSTM-bY5zzUaG390zUgmcs,714
17
14
  texttools/prompts/extract_entities.yaml,sha256=_zYKHNJDIzVDI_-TnwFCKyMs-XLM5igvmWhvSTc3INQ,637
@@ -28,7 +25,7 @@ texttools/prompts/translate.yaml,sha256=Dd5bs3O8SI-FlVSwHMYGeEjMmdOWeRlcfBHkhixC
28
25
  texttools/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
26
  texttools/tools/async_tools.py,sha256=2suwx8N0aRnowaSOpV6C57AqPlmQe5Z0Yx4E5QIMkmU,46939
30
27
  texttools/tools/sync_tools.py,sha256=mEuL-nlbxVW30dPE3hGkAUnYXbul-3gN2Le4CMVFCgU,42528
31
- hamtaa_texttools-1.3.0.dist-info/METADATA,sha256=_CXrOjvT2jwWcs1LHID0vVyo9eKlSIK_BzU8YUeNypo,8024
32
- hamtaa_texttools-1.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
33
- hamtaa_texttools-1.3.0.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
34
- hamtaa_texttools-1.3.0.dist-info/RECORD,,
28
+ hamtaa_texttools-1.3.2.dist-info/METADATA,sha256=LjhXLwovneW5Ii1DvAYhFT4JR64ar23UyptCvCO6Hpc,7448
29
+ hamtaa_texttools-1.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ hamtaa_texttools-1.3.2.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
31
+ hamtaa_texttools-1.3.2.dist-info/RECORD,,
texttools/__init__.py CHANGED
@@ -1,7 +1,5 @@
1
- from .batch.config import BatchConfig
2
- from .batch.runner import BatchRunner
3
1
  from .models import CategoryTree
4
2
  from .tools.async_tools import AsyncTheTool
5
3
  from .tools.sync_tools import TheTool
6
4
 
7
- __all__ = ["TheTool", "AsyncTheTool", "CategoryTree", "BatchRunner", "BatchConfig"]
5
+ __all__ = ["TheTool", "AsyncTheTool", "CategoryTree"]
texttools/core/engine.py CHANGED
@@ -4,6 +4,7 @@ import random
4
4
  import re
5
5
  from functools import lru_cache
6
6
  from pathlib import Path
7
+ from typing import Any
7
8
 
8
9
  import yaml
9
10
 
@@ -20,9 +21,6 @@ class PromptLoader:
20
21
 
21
22
  @lru_cache(maxsize=32)
22
23
  def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
23
- """
24
- Loads prompt templates from YAML file with optional mode selection.
25
- """
26
24
  try:
27
25
  base_dir = Path(__file__).parent.parent / Path("prompts")
28
26
  prompt_path = base_dir / prompt_file
@@ -73,13 +71,12 @@ class PromptLoader:
73
71
  self, prompt_file: str, text: str, mode: str, **extra_kwargs
74
72
  ) -> dict[str, str]:
75
73
  try:
76
- template_configs = self._load_templates(prompt_file, mode)
77
74
  format_args = {"text": text}
78
75
  format_args.update(extra_kwargs)
79
76
 
80
- # Inject variables inside each template
81
- for key in template_configs.keys():
82
- template_configs[key] = template_configs[key].format(**format_args)
77
+ template_configs = self._load_templates(prompt_file, mode)
78
+ for key, value in template_configs.items():
79
+ template_configs[key] = value.format(**format_args)
83
80
 
84
81
  return template_configs
85
82
 
@@ -97,30 +94,27 @@ class OperatorUtils:
97
94
  output_lang: str | None,
98
95
  user_prompt: str | None,
99
96
  ) -> str:
100
- main_prompt = ""
97
+ parts = []
101
98
 
102
99
  if analysis:
103
- main_prompt += f"Based on this analysis:\n{analysis}\n"
104
-
100
+ parts.append(f"Based on this analysis: {analysis}")
105
101
  if output_lang:
106
- main_prompt += f"Respond only in the {output_lang} language.\n"
107
-
102
+ parts.append(f"Respond only in the {output_lang} language.")
108
103
  if user_prompt:
109
- main_prompt += f"Consider this instruction {user_prompt}\n"
104
+ parts.append(f"Consider this instruction: {user_prompt}")
110
105
 
111
- main_prompt += main_template
112
-
113
- return main_prompt
106
+ parts.append(main_template)
107
+ return "\n".join(parts)
114
108
 
115
109
  @staticmethod
116
110
  def build_message(prompt: str) -> list[dict[str, str]]:
117
111
  return [{"role": "user", "content": prompt}]
118
112
 
119
113
  @staticmethod
120
- def extract_logprobs(completion: dict) -> list[dict]:
114
+ def extract_logprobs(completion: Any) -> list[dict]:
121
115
  """
122
- Extracts and filters token probabilities from completion logprobs.
123
- Skips punctuation and structural tokens, returns cleaned probability data.
116
+ Extracts and filters logprobs from completion.
117
+ Skips punctuation and structural tokens.
124
118
  """
125
119
  logprobs_data = []
126
120
 
@@ -153,16 +147,17 @@ class OperatorUtils:
153
147
 
154
148
  @staticmethod
155
149
  def get_retry_temp(base_temp: float) -> float:
156
- delta_temp = random.choice([-1, 1]) * random.uniform(0.1, 0.9)
157
- new_temp = base_temp + delta_temp
158
-
150
+ new_temp = base_temp + random.choice([-1, 1]) * random.uniform(0.1, 0.9)
159
151
  return max(0.0, min(new_temp, 1.5))
160
152
 
161
153
 
162
154
  def text_to_chunks(text: str, size: int, overlap: int) -> list[str]:
155
+ """
156
+ Utility for chunking large texts. Used for translation tool
157
+ """
163
158
  separators = ["\n\n", "\n", " ", ""]
164
159
  is_separator_regex = False
165
- keep_separator = True # Equivalent to 'start'
160
+ keep_separator = True
166
161
  length_function = len
167
162
  strip_whitespace = True
168
163
  chunk_size = size
@@ -256,6 +251,9 @@ def text_to_chunks(text: str, size: int, overlap: int) -> list[str]:
256
251
 
257
252
 
258
253
  async def run_with_timeout(coro, timeout: float | None):
254
+ """
255
+ Utility for timeout logic defined in AsyncTheTool
256
+ """
259
257
  if timeout is None:
260
258
  return await coro
261
259
  try:
@@ -21,7 +21,9 @@ class Bool(BaseModel):
21
21
 
22
22
  class ListStr(BaseModel):
23
23
  result: list[str] = Field(
24
- ..., description="The output list of strings", example=["text_1", "text_2"]
24
+ ...,
25
+ description="The output list of strings",
26
+ example=["text_1", "text_2", "text_3"],
25
27
  )
26
28
 
27
29
 
@@ -36,11 +38,13 @@ class ListDictStrStr(BaseModel):
36
38
  class ReasonListStr(BaseModel):
37
39
  reason: str = Field(..., description="Thinking process that led to the output")
38
40
  result: list[str] = Field(
39
- ..., description="The output list of strings", example=["text_1", "text_2"]
41
+ ...,
42
+ description="The output list of strings",
43
+ example=["text_1", "text_2", "text_3"],
40
44
  )
41
45
 
42
46
 
43
- # This function is needed to create CategorizerOutput with dynamic categories
47
+ # Create CategorizerOutput with dynamic categories
44
48
  def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
45
49
  literal_type = Literal[*allowed_values]
46
50
 
@@ -54,7 +54,7 @@ class AsyncOperator:
54
54
  ) -> tuple[T, Any]:
55
55
  """
56
56
  Parses a chat completion using OpenAI's structured output format.
57
- Returns both the parsed Any and the raw completion for logprobs.
57
+ Returns both the parsed and the completion for logprobs.
58
58
  """
59
59
  try:
60
60
  request_kwargs = {
@@ -92,7 +92,6 @@ class AsyncOperator:
92
92
 
93
93
  async def run(
94
94
  self,
95
- # User parameters
96
95
  text: str,
97
96
  with_analysis: bool,
98
97
  output_lang: str | None,
@@ -103,7 +102,6 @@ class AsyncOperator:
103
102
  validator: Callable[[Any], bool] | None,
104
103
  max_validation_retries: int | None,
105
104
  priority: int | None,
106
- # Internal parameters
107
105
  tool_name: str,
108
106
  output_model: Type[T],
109
107
  mode: str | None,
@@ -54,7 +54,7 @@ class Operator:
54
54
  ) -> tuple[T, Any]:
55
55
  """
56
56
  Parses a chat completion using OpenAI's structured output format.
57
- Returns both the parsed Any and the raw completion for logprobs.
57
+ Returns both the parsed and the completion for logprobs.
58
58
  """
59
59
  try:
60
60
  request_kwargs = {
@@ -90,7 +90,6 @@ class Operator:
90
90
 
91
91
  def run(
92
92
  self,
93
- # User parameters
94
93
  text: str,
95
94
  with_analysis: bool,
96
95
  output_lang: str | None,
@@ -101,7 +100,6 @@ class Operator:
101
100
  validator: Callable[[Any], bool] | None,
102
101
  max_validation_retries: int | None,
103
102
  priority: int | None,
104
- # Internal parameters
105
103
  tool_name: str,
106
104
  output_model: Type[T],
107
105
  mode: str | None,
texttools/batch/config.py DELETED
@@ -1,40 +0,0 @@
1
- from collections.abc import Callable
2
- from dataclasses import dataclass
3
- from typing import Any
4
-
5
-
6
- def export_data(data) -> list[dict[str, str]]:
7
- """
8
- Produces a structure of the following form from an initial data structure:
9
- [{"id": str, "text": str},...]
10
- """
11
- return data
12
-
13
-
14
- def import_data(data) -> Any:
15
- """
16
- Takes the output and adds and aggregates it to the original structure.
17
- """
18
- return data
19
-
20
-
21
- @dataclass
22
- class BatchConfig:
23
- """
24
- Configuration for batch job runner.
25
- """
26
-
27
- system_prompt: str = ""
28
- job_name: str = ""
29
- input_data_path: str = ""
30
- output_data_filename: str = ""
31
- model: str = "gpt-4.1-mini"
32
- MAX_BATCH_SIZE: int = 100
33
- MAX_TOTAL_TOKENS: int = 2_000_000
34
- CHARS_PER_TOKEN: float = 2.7
35
- PROMPT_TOKEN_MULTIPLIER: int = 1_000
36
- BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
37
- import_function: Callable = import_data
38
- export_function: Callable = export_data
39
- poll_interval_seconds: int = 30
40
- max_retries: int = 3
@@ -1,228 +0,0 @@
1
- import json
2
- import logging
3
- import uuid
4
- from pathlib import Path
5
- from typing import Any, Type, TypeVar
6
-
7
- from openai import OpenAI
8
- from openai.lib._pydantic import to_strict_json_schema
9
- from pydantic import BaseModel
10
-
11
- # Base Model type for output models
12
- T = TypeVar("T", bound=BaseModel)
13
-
14
- logger = logging.getLogger("texttools.batch_manager")
15
-
16
-
17
- class BatchManager:
18
- """
19
- Manages batch processing jobs for OpenAI's chat completions with structured outputs.
20
-
21
- Handles the full lifecycle of a batch job: creating tasks from input texts,
22
- starting the job, monitoring status, and fetching results. Results are automatically
23
- parsed into the specified Pydantic output model. Job state is persisted to disk.
24
- """
25
-
26
- def __init__(
27
- self,
28
- client: OpenAI,
29
- model: str,
30
- output_model: Type[T],
31
- prompt_template: str,
32
- state_dir: Path = Path(".batch_jobs"),
33
- custom_json_schema_obj_str: dict | None = None,
34
- **client_kwargs: Any,
35
- ):
36
- self._client = client
37
- self._model = model
38
- self._output_model = output_model
39
- self._prompt_template = prompt_template
40
- self._state_dir = state_dir
41
- self._custom_json_schema_obj_str = custom_json_schema_obj_str
42
- self._client_kwargs = client_kwargs
43
- self._dict_input = False
44
- self._state_dir.mkdir(parents=True, exist_ok=True)
45
-
46
- if custom_json_schema_obj_str and not isinstance(
47
- custom_json_schema_obj_str, dict
48
- ):
49
- raise ValueError("Schema should be a dict")
50
-
51
- def _state_file(self, job_name: str) -> Path:
52
- return self._state_dir / f"{job_name}.json"
53
-
54
- def _load_state(self, job_name: str) -> list[dict[str, Any]]:
55
- """
56
- Loads the state (job information) from the state file for the given job name.
57
- Returns an empty list if the state file does not exist.
58
- """
59
- path = self._state_file(job_name)
60
- if path.exists():
61
- with open(path, "r", encoding="utf-8") as f:
62
- return json.load(f)
63
- return []
64
-
65
- def _save_state(self, job_name: str, jobs: list[dict[str, Any]]) -> None:
66
- """
67
- Saves the job state to the state file for the given job name.
68
- """
69
- with open(self._state_file(job_name), "w", encoding="utf-8") as f:
70
- json.dump(jobs, f)
71
-
72
- def _clear_state(self, job_name: str) -> None:
73
- """
74
- Deletes the state file for the given job name if it exists.
75
- """
76
- path = self._state_file(job_name)
77
- if path.exists():
78
- path.unlink()
79
-
80
- def _build_task(self, text: str, idx: str) -> dict[str, Any]:
81
- """
82
- Builds a single task dictionary for the batch job, including the prompt, model, and response format configuration.
83
- """
84
- response_format_config: dict[str, Any]
85
-
86
- if self._custom_json_schema_obj_str:
87
- response_format_config = {
88
- "type": "json_schema",
89
- "json_schema": self._custom_json_schema_obj_str,
90
- }
91
- else:
92
- raw_schema = to_strict_json_schema(self._output_model)
93
- response_format_config = {
94
- "type": "json_schema",
95
- "json_schema": {
96
- "name": self._output_model.__name__,
97
- "schema": raw_schema,
98
- },
99
- }
100
-
101
- return {
102
- "custom_id": str(idx),
103
- "method": "POST",
104
- "url": "/v1/chat/completions",
105
- "body": {
106
- "model": self.model,
107
- "messages": [
108
- {"role": "system", "content": self._prompt_template},
109
- {"role": "user", "content": text},
110
- ],
111
- "response_format": response_format_config,
112
- **self._client_kwargs,
113
- },
114
- }
115
-
116
- def _prepare_file(self, payload: list[str] | list[dict[str, str]]) -> Path:
117
- """
118
- Prepares a JSONL file containing all tasks for the batch job, based on the input payload.
119
- Returns the path to the created file.
120
- """
121
- if not payload:
122
- raise ValueError("Payload must not be empty")
123
- if isinstance(payload[0], str):
124
- tasks = [self._build_task(text, uuid.uuid4().hex) for text in payload]
125
- elif isinstance(payload[0], dict):
126
- tasks = [self._build_task(dic["text"], dic["id"]) for dic in payload]
127
-
128
- else:
129
- raise TypeError(
130
- "The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}"
131
- )
132
-
133
- file_path = self._state_dir / f"batch_{uuid.uuid4().hex}.jsonl"
134
- with open(file_path, "w", encoding="utf-8") as f:
135
- for task in tasks:
136
- f.write(json.dumps(task) + "\n")
137
- return file_path
138
-
139
- def start(self, payload: list[str | dict[str, str]], job_name: str):
140
- """
141
- Starts a new batch job by uploading the prepared file and creating a batch job on the server.
142
- If a job with the same name already exists, it does nothing.
143
- """
144
- if self._load_state(job_name):
145
- return
146
-
147
- path = self._prepare_file(payload)
148
- upload = self._client.files.create(file=open(path, "rb"), purpose="batch")
149
- job = self._client.batches.create(
150
- input_file_id=upload.id,
151
- endpoint="/v1/chat/completions",
152
- completion_window="24h",
153
- ).to_dict()
154
- self._save_state(job_name, [job])
155
-
156
- def check_status(self, job_name: str) -> str:
157
- """
158
- Checks and returns the current status of the batch job with the given job name.
159
- Updates the job state with the latest information from the server.
160
- """
161
- job = self._load_state(job_name)[0]
162
- if not job:
163
- return "completed"
164
-
165
- info = self._client.batches.retrieve(job["id"])
166
- job = info.to_dict()
167
- self._save_state(job_name, [job])
168
- logger.info("Batch job status: %s", job)
169
- return job["status"]
170
-
171
- def fetch_results(
172
- self, job_name: str, remove_cache: bool = True
173
- ) -> tuple[dict[str, str], list]:
174
- """
175
- Fetches the results of a completed batch job. Optionally saves the results to a file and/or removes the job cache.
176
- Returns a tuple containing the parsed results and a log of errors (if any).
177
- """
178
- job = self._load_state(job_name)[0]
179
- if not job:
180
- return {}
181
- batch_id = job["id"]
182
-
183
- info = self._client.batches.retrieve(batch_id)
184
- out_file_id = info.output_file_id
185
- if not out_file_id:
186
- error_file_id = info.error_file_id
187
- if error_file_id:
188
- err_content = (
189
- self._client.files.content(error_file_id).read().decode("utf-8")
190
- )
191
- logger.error("Error file content:", err_content)
192
- return {}
193
-
194
- content = self._client.files.content(out_file_id).read().decode("utf-8")
195
- lines = content.splitlines()
196
- results = {}
197
- log = []
198
- for line in lines:
199
- result = json.loads(line)
200
- custom_id = result["custom_id"]
201
- if result["response"]["status_code"] == 200:
202
- content = result["response"]["body"]["choices"][0]["message"]["content"]
203
- try:
204
- parsed_content = json.loads(content)
205
- model_instance = self._output_model(**parsed_content)
206
- results[custom_id] = model_instance.model_dump(mode="json")
207
- except json.JSONDecodeError:
208
- results[custom_id] = {"error": "Failed to parse content as JSON"}
209
- error_d = {custom_id: results[custom_id]}
210
- log.append(error_d)
211
- except Exception as e:
212
- results[custom_id] = {"error": str(e)}
213
- error_d = {custom_id: results[custom_id]}
214
- log.append(error_d)
215
- else:
216
- error_message = (
217
- result["response"]["body"]
218
- .get("error", {})
219
- .get("message", "Unknown error")
220
- )
221
- results[custom_id] = {"error": error_message}
222
- error_d = {custom_id: results[custom_id]}
223
- log.append(error_d)
224
-
225
- if remove_cache:
226
- self._clear_state(job_name)
227
-
228
- return results, log
texttools/batch/runner.py DELETED
@@ -1,228 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import time
5
- from pathlib import Path
6
- from typing import Any, Type, TypeVar
7
-
8
- from dotenv import load_dotenv
9
- from openai import OpenAI
10
- from pydantic import BaseModel
11
-
12
- from ..core.exceptions import TextToolsError
13
- from ..core.internal_models import Str
14
- from .config import BatchConfig
15
- from .manager import BatchManager
16
-
17
- # Base Model type for output models
18
- T = TypeVar("T", bound=BaseModel)
19
-
20
- logger = logging.getLogger("texttools.batch_runner")
21
-
22
-
23
- class BatchRunner:
24
- """
25
- Handles running batch jobs using a batch manager and configuration.
26
- """
27
-
28
- def __init__(
29
- self, config: BatchConfig = BatchConfig(), output_model: Type[T] = Str
30
- ):
31
- try:
32
- self._config = config
33
- self._system_prompt = config.system_prompt
34
- self._job_name = config.job_name
35
- self._input_data_path = config.input_data_path
36
- self._output_data_filename = config.output_data_filename
37
- self._model = config.model
38
- self._output_model = output_model
39
- self._manager = self._init_manager()
40
- self._data = self._load_data()
41
- self._parts: list[list[dict[str, Any]]] = []
42
- # Map part index to job name
43
- self._part_idx_to_job_name: dict[int, str] = {}
44
- # Track retry attempts per part
45
- self._part_attempts: dict[int, int] = {}
46
- self._partition_data()
47
- Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
48
-
49
- except Exception as e:
50
- raise TextToolsError(f"Batch runner initialization failed: {e}")
51
-
52
- def _init_manager(self) -> BatchManager:
53
- load_dotenv()
54
- api_key = os.getenv("OPENAI_API_KEY")
55
- client = OpenAI(api_key=api_key)
56
- return BatchManager(
57
- client=client,
58
- model=self._model,
59
- prompt_template=self._system_prompt,
60
- output_model=self._output_model,
61
- )
62
-
63
- def _load_data(self):
64
- with open(self._input_data_path, "r", encoding="utf-8") as f:
65
- data = json.load(f)
66
- data = self._config.export_function(data)
67
-
68
- # Ensure data is a list of dicts with 'id' and 'content' as strings
69
- if not isinstance(data, list):
70
- raise ValueError(
71
- "Exported data must be a list of dicts with 'id' and 'content' keys"
72
- )
73
- for item in data:
74
- if not (isinstance(item, dict) and "id" in item and "content" in item):
75
- raise ValueError(
76
- f"Item must be a dict with 'id' and 'content' keys. Got: {type(item)}"
77
- )
78
- if not (isinstance(item["id"], str) and isinstance(item["content"], str)):
79
- raise ValueError("'id' and 'content' must be strings.")
80
- return data
81
-
82
- def _partition_data(self):
83
- total_length = sum(len(item["content"]) for item in self._data)
84
- prompt_length = len(self._system_prompt)
85
- total = total_length + (prompt_length * len(self._data))
86
- calculation = total / self._config.CHARS_PER_TOKEN
87
- logger.info(
88
- f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
89
- )
90
- if calculation < self._config.MAX_TOTAL_TOKENS:
91
- self._parts = [self._data]
92
- else:
93
- # Partition into chunks of MAX_BATCH_SIZE
94
- self._parts = [
95
- self._data[i : i + self._config.MAX_BATCH_SIZE]
96
- for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
97
- ]
98
- logger.info(f"Data split into {len(self._parts)} part(s)")
99
-
100
- def _submit_all_jobs(self) -> None:
101
- for idx, part in enumerate(self._parts):
102
- if self._result_exists(idx):
103
- logger.info(f"Skipping part {idx + 1}: result already exists.")
104
- continue
105
- part_job_name = (
106
- f"{self._job_name}_part_{idx + 1}"
107
- if len(self._parts) > 1
108
- else self._job_name
109
- )
110
- # If a job with this name already exists, register and skip submitting
111
- existing_job = self._manager._load_state(part_job_name)
112
- if existing_job:
113
- logger.info(
114
- f"Skipping part {idx + 1}: job already exists ({part_job_name})."
115
- )
116
- self._part_idx_to_job_name[idx] = part_job_name
117
- self._part_attempts.setdefault(idx, 0)
118
- continue
119
-
120
- payload = part
121
- logger.info(
122
- f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
123
- )
124
- self._manager.start(payload, job_name=part_job_name)
125
- self._part_idx_to_job_name[idx] = part_job_name
126
- self._part_attempts.setdefault(idx, 0)
127
- # This is added for letting file get uploaded, before starting the next part.
128
- logger.info("Uploading...")
129
- time.sleep(30)
130
-
131
- def _save_results(
132
- self,
133
- output_data: list[dict[str, Any]] | dict[str, Any],
134
- log: list[Any],
135
- part_idx: int,
136
- ):
137
- part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
138
- result_path = (
139
- Path(self._config.BASE_OUTPUT_DIR)
140
- / f"{Path(self._output_data_filename).stem}{part_suffix}.json"
141
- )
142
- if not output_data:
143
- logger.info("No output data to save. Skipping this part.")
144
- return
145
- else:
146
- with open(result_path, "w", encoding="utf-8") as f:
147
- json.dump(output_data, f, ensure_ascii=False, indent=4)
148
- if log:
149
- log_path = (
150
- Path(self._config.BASE_OUTPUT_DIR)
151
- / f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
152
- )
153
- with open(log_path, "w", encoding="utf-8") as f:
154
- json.dump(log, f, ensure_ascii=False, indent=4)
155
-
156
- def _result_exists(self, part_idx: int) -> bool:
157
- part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
158
- result_path = (
159
- Path(self._config.BASE_OUTPUT_DIR)
160
- / f"{Path(self._output_data_filename).stem}{part_suffix}.json"
161
- )
162
- return result_path.exists()
163
-
164
- def run(self):
165
- """
166
- Execute the batch job processing pipeline.
167
-
168
- Submits jobs, monitors progress, handles retries, and saves results.
169
- """
170
- try:
171
- # Submit all jobs up-front for concurrent execution
172
- self._submit_all_jobs()
173
- pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
174
- logger.info(f"Pending parts: {sorted(pending_parts)}")
175
- # Polling loop
176
- while pending_parts:
177
- finished_this_round: list[int] = []
178
- for part_idx in list(pending_parts):
179
- job_name = self._part_idx_to_job_name[part_idx]
180
- status = self._manager.check_status(job_name=job_name)
181
- logger.info(f"Status for {job_name}: {status}")
182
- if status == "completed":
183
- logger.info(
184
- f"Job completed. Fetching results for part {part_idx + 1}..."
185
- )
186
- output_data, log = self._manager.fetch_results(
187
- job_name=job_name, remove_cache=False
188
- )
189
- output_data = self._config.import_function(output_data)
190
- self._save_results(output_data, log, part_idx)
191
- logger.info(
192
- f"Fetched and saved results for part {part_idx + 1}."
193
- )
194
- finished_this_round.append(part_idx)
195
- elif status == "failed":
196
- attempt = self._part_attempts.get(part_idx, 0) + 1
197
- self._part_attempts[part_idx] = attempt
198
- if attempt <= self._config.max_retries:
199
- logger.info(
200
- f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
201
- )
202
- self._manager._clear_state(job_name)
203
- time.sleep(10)
204
- payload = self._to_manager_payload(self._parts[part_idx])
205
- new_job_name = (
206
- f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
207
- )
208
- self._manager.start(payload, job_name=new_job_name)
209
- self._part_idx_to_job_name[part_idx] = new_job_name
210
- else:
211
- logger.info(
212
- f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
213
- )
214
- finished_this_round.append(part_idx)
215
- else:
216
- # Still running or queued
217
- continue
218
- # Remove finished parts
219
- for part_idx in finished_this_round:
220
- pending_parts.discard(part_idx)
221
- if pending_parts:
222
- logger.info(
223
- f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
224
- )
225
- time.sleep(self._config.poll_interval_seconds)
226
-
227
- except Exception as e:
228
- raise TextToolsError(f"Batch job execution failed: {e}")
File without changes