hamtaa-texttools 1.1.9__py3-none-any.whl → 1.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.1.9.dist-info → hamtaa_texttools-1.1.10.dist-info}/METADATA +22 -20
- hamtaa_texttools-1.1.10.dist-info/RECORD +30 -0
- texttools/__init__.py +4 -2
- texttools/batch/batch_config.py +26 -0
- texttools/batch/batch_runner.py +66 -103
- texttools/batch/{batch_manager.py → internals/batch_manager.py} +24 -24
- texttools/batch/internals/utils.py +16 -0
- texttools/tools/async_tools.py +34 -45
- texttools/tools/internals/async_operator.py +16 -65
- texttools/tools/internals/base_operator.py +3 -32
- texttools/tools/internals/operator.py +16 -65
- texttools/tools/internals/prompt_loader.py +8 -7
- texttools/tools/sync_tools.py +34 -45
- hamtaa_texttools-1.1.9.dist-info/RECORD +0 -30
- texttools/batch/__init__.py +0 -3
- texttools/tools/__init__.py +0 -4
- {hamtaa_texttools-1.1.9.dist-info → hamtaa_texttools-1.1.10.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.9.dist-info → hamtaa_texttools-1.1.10.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.9.dist-info → hamtaa_texttools-1.1.10.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hamtaa-texttools
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.10
|
|
4
4
|
Summary: A high-level NLP toolkit built on top of modern LLMs.
|
|
5
5
|
Author-email: Tohidi <the.mohammad.tohidi@gmail.com>, Montazer <montazerh82@gmail.com>, Givechi <mohamad.m.givechi@gmail.com>, MoosaviNejad <erfanmoosavi84@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -29,6 +29,7 @@ Requires-Python: >=3.8
|
|
|
29
29
|
Description-Content-Type: text/markdown
|
|
30
30
|
License-File: LICENSE
|
|
31
31
|
Requires-Dist: openai==1.97.1
|
|
32
|
+
Requires-Dist: pydantic>=2.0.0
|
|
32
33
|
Requires-Dist: pyyaml>=6.0
|
|
33
34
|
Dynamic: license-file
|
|
34
35
|
|
|
@@ -40,26 +41,26 @@ Dynamic: license-file
|
|
|
40
41
|
|
|
41
42
|
It provides both **sync (`TheTool`)** and **async (`AsyncTheTool`)** APIs for maximum flexibility.
|
|
42
43
|
|
|
43
|
-
It provides ready-to-use utilities for **translation, question detection, keyword extraction, categorization, NER extraction, and more**
|
|
44
|
+
It provides ready-to-use utilities for **translation, question detection, keyword extraction, categorization, NER extraction, and more** - designed to help you integrate AI-powered text processing into your applications with minimal effort.
|
|
44
45
|
|
|
45
46
|
---
|
|
46
47
|
|
|
47
48
|
## ✨ Features
|
|
48
49
|
|
|
49
|
-
TextTools provides a rich collection of high-level NLP utilities
|
|
50
|
+
TextTools provides a rich collection of high-level NLP utilities,
|
|
50
51
|
Each tool is designed to work with structured outputs (JSON / Pydantic).
|
|
51
52
|
|
|
52
53
|
- **`categorize()`** - Classifies text into Islamic studies categories
|
|
53
|
-
- **`is_question()`** - Binary detection of whether input is a question
|
|
54
54
|
- **`extract_keywords()`** - Extracts keywords from text
|
|
55
55
|
- **`extract_entities()`** - Named Entity Recognition (NER) system
|
|
56
|
-
- **`
|
|
56
|
+
- **`is_question()`** - Binary detection of whether input is a question
|
|
57
57
|
- **`text_to_question()`** - Generates questions from text
|
|
58
58
|
- **`merge_questions()`** - Merges multiple questions with different modes
|
|
59
59
|
- **`rewrite()`** - Rewrites text with different wording/meaning
|
|
60
60
|
- **`subject_to_question()`** - Generates questions about a specific subject
|
|
61
|
+
- **`summarize()`** - Text summarization
|
|
61
62
|
- **`translate()`** - Text translation between languages
|
|
62
|
-
- **`run_custom()`** - Allows users to define a custom tool with arbitrary BaseModel
|
|
63
|
+
- **`run_custom()`** - Allows users to define a custom tool with an arbitrary BaseModel
|
|
63
64
|
|
|
64
65
|
---
|
|
65
66
|
|
|
@@ -67,18 +68,18 @@ Each tool is designed to work with structured outputs (JSON / Pydantic).
|
|
|
67
68
|
|
|
68
69
|
TextTools provides several optional flags to customize LLM behavior:
|
|
69
70
|
|
|
70
|
-
- **`with_analysis
|
|
71
|
-
Note
|
|
71
|
+
- **`with_analysis (bool)`** → Adds a reasoning step before generating the final output.
|
|
72
|
+
**Note:** This doubles token usage per call because it triggers an additional LLM request.
|
|
72
73
|
|
|
73
|
-
- **`logprobs
|
|
74
|
+
- **`logprobs (bool)`** → Returns token-level probabilities for the generated output. You can also specify `top_logprobs=<N>` to get the top N alternative tokens and their probabilities.
|
|
74
75
|
|
|
75
|
-
- **`output_lang
|
|
76
|
+
- **`output_lang (str)`** → Forces the model to respond in a specific language. The model will ignore other instructions about language and respond strictly in the requested language.
|
|
76
77
|
|
|
77
|
-
- **`user_prompt
|
|
78
|
+
- **`user_prompt (str)`** → Allows you to inject a custom instruction or prompt into the model alongside the main template. This gives you fine-grained control over how the model interprets or modifies the input text.
|
|
78
79
|
|
|
79
|
-
- **`temperature
|
|
80
|
+
- **`temperature (float)`** → Determines how creative the model should respond. Takes a float number from `0.0` to `2.0`.
|
|
80
81
|
|
|
81
|
-
- **`validator
|
|
82
|
+
- **`validator (Callable)`** → Forces TheTool to validate the output result based on your custom validator. Validator should return bool (True if there were no problem, False if the validation failed.) If validator failed, TheTool will retry to get another output by modifying `temperature`.
|
|
82
83
|
|
|
83
84
|
All these parameters can be used individually or together to tailor the behavior of any tool in **TextTools**.
|
|
84
85
|
|
|
@@ -89,10 +90,10 @@ All these parameters can be used individually or together to tailor the behavior
|
|
|
89
90
|
## 🧩 ToolOutput
|
|
90
91
|
|
|
91
92
|
Every tool of `TextTools` returns a `ToolOutput` object which is a BaseModel with attributes:
|
|
92
|
-
- **`result`** → The output of LLM
|
|
93
|
-
- **`analysis`** → The reasoning step before generating the final output
|
|
94
|
-
- **`logprobs`** → Token-level probabilities for the generated output
|
|
95
|
-
- **`errors`** → Any error that have occured during calling LLM
|
|
93
|
+
- **`result (Any)`** → The output of LLM
|
|
94
|
+
- **`analysis (str)`** → The reasoning step before generating the final output
|
|
95
|
+
- **`logprobs (list)`** → Token-level probabilities for the generated output
|
|
96
|
+
- **`errors (list[str])`** → Any error that have occured during calling LLM
|
|
96
97
|
|
|
97
98
|
**None:** You can use `repr(ToolOutput)` to see details of an output.
|
|
98
99
|
|
|
@@ -108,7 +109,7 @@ pip install -U hamtaa-texttools
|
|
|
108
109
|
|
|
109
110
|
---
|
|
110
111
|
|
|
111
|
-
## Sync vs Async
|
|
112
|
+
## 🧨 Sync vs Async
|
|
112
113
|
| Tool | Style | Use case |
|
|
113
114
|
|--------------|---------|---------------------------------------------|
|
|
114
115
|
| `TheTool` | Sync | Simple scripts, sequential workflows |
|
|
@@ -211,9 +212,10 @@ logging.basicConfig(level=logging.CRITICAL)
|
|
|
211
212
|
|
|
212
213
|
Process large datasets efficiently using OpenAI's batch API.
|
|
213
214
|
|
|
214
|
-
## Quick Start
|
|
215
|
+
## ⚡ Quick Start (Batch)
|
|
215
216
|
|
|
216
217
|
```python
|
|
218
|
+
from pydantic import BaseModel
|
|
217
219
|
from texttools import BatchJobRunner, BatchConfig
|
|
218
220
|
|
|
219
221
|
# Configure your batch job
|
|
@@ -243,6 +245,6 @@ Feel free to **open issues, suggest new features, or submit pull requests**.
|
|
|
243
245
|
|
|
244
246
|
---
|
|
245
247
|
|
|
246
|
-
## License
|
|
248
|
+
## 🌿 License
|
|
247
249
|
|
|
248
250
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
hamtaa_texttools-1.1.10.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
|
|
2
|
+
texttools/__init__.py,sha256=EZPPNPafVGvBaxjG9anP0piqH3gAC0DdjdAckQeAgNU,251
|
|
3
|
+
texttools/batch/batch_config.py,sha256=FCDXy9TfH7xjd1PHvn_CtdwEQSq-YO5sktiaMZEId58,740
|
|
4
|
+
texttools/batch/batch_runner.py,sha256=zzzVIXedmaq-8fqsFtGRR64F7CtYRLlhQeBu8uMwJQg,9385
|
|
5
|
+
texttools/batch/internals/batch_manager.py,sha256=UoBe76vmFG72qrSaGKDZf4HzkykFBkkkbL9TLfV8TuQ,8730
|
|
6
|
+
texttools/batch/internals/utils.py,sha256=F1_7YlVFKhjUROAFX4m0SaP8KiZVZyHRMIIB87VUGQc,373
|
|
7
|
+
texttools/prompts/README.md,sha256=-5YO93CN93QLifqZpUeUnCOCBbDiOTV-cFQeJ7Gg0I4,1377
|
|
8
|
+
texttools/prompts/categorizer.yaml,sha256=GMqIIzQFhgnlpkgU1qi3FAD3mD4A2jiWD5TilQ2XnnE,1204
|
|
9
|
+
texttools/prompts/extract_entities.yaml,sha256=KiKjeDpHaeh3JVtZ6q1pa3k4DYucUIU9WnEcRTCA-SE,651
|
|
10
|
+
texttools/prompts/extract_keywords.yaml,sha256=0O7ypL_OsEOxtvlQ2CZjnsv9637DJwAKprZsf9Vo2_s,769
|
|
11
|
+
texttools/prompts/is_question.yaml,sha256=d0-vKRbXWkxvO64ikvxRjEmpAXGpCYIPGhgexvPPjws,471
|
|
12
|
+
texttools/prompts/merge_questions.yaml,sha256=0J85GvTirZB4ELwH3sk8ub_WcqqpYf6PrMKr3djlZeo,1792
|
|
13
|
+
texttools/prompts/rewrite.yaml,sha256=LO7He_IA3MZKz8a-LxH9DHJpOjpYwaYN1pbjp1Y0tFo,5392
|
|
14
|
+
texttools/prompts/run_custom.yaml,sha256=38OkCoVITbuuS9c08UZSP1jZW4WjSmRIi8fR0RAiPu4,108
|
|
15
|
+
texttools/prompts/subject_to_question.yaml,sha256=C7x7rNNm6U_ZG9HOn6zuzYOtvJUZ2skuWbL1-aYdd3E,1147
|
|
16
|
+
texttools/prompts/summarize.yaml,sha256=o6rxGPfWtZd61Duvm8NVvCJqfq73b-wAuMSKR6UYUqY,459
|
|
17
|
+
texttools/prompts/text_to_question.yaml,sha256=UheKYpDn6iyKI8NxunHZtFpNyfCLZZe5cvkuXpurUJY,783
|
|
18
|
+
texttools/prompts/translate.yaml,sha256=mGT2uBCei6uucWqVbs4silk-UV060v3G0jnt0P6sr50,634
|
|
19
|
+
texttools/tools/async_tools.py,sha256=yEj4dM2bdW_12hxvimhxPOGfGhl1PqFsHM3Z4toCTaM,14813
|
|
20
|
+
texttools/tools/sync_tools.py,sha256=wHY0O8R9HipUz0P268zk1w3SlFxEIffm5EjX4tcWxNM,14579
|
|
21
|
+
texttools/tools/internals/async_operator.py,sha256=5sVc5K5-Vuulsxly0IfrLmzd8W7ySI4cY09myyOGL0I,7022
|
|
22
|
+
texttools/tools/internals/base_operator.py,sha256=l2Mg59MGIf396yPx1CAgcplKclOptQWeTR0UIz9VTdk,2255
|
|
23
|
+
texttools/tools/internals/formatters.py,sha256=tACNLP6PeoqaRpNudVxBaHA25zyWqWYPZQuYysIu88g,941
|
|
24
|
+
texttools/tools/internals/operator.py,sha256=W0DxTGB3cbtDfzLqwMCM8x5xiVWgN0vZWX8PzJwAQKE,6795
|
|
25
|
+
texttools/tools/internals/output_models.py,sha256=ekpbyocmXj_dee7ieOT1zOkMo9cPHT7xcUFCZoUaXA0,1886
|
|
26
|
+
texttools/tools/internals/prompt_loader.py,sha256=4g6-U8kqrGN7VpNaRcrBcnF-h03PXjUDBP0lL0_4EZY,1953
|
|
27
|
+
hamtaa_texttools-1.1.10.dist-info/METADATA,sha256=6X7Qd8nOAIn7reZ5l-66CyK5fczMADtlIfdIbypJIUE,9101
|
|
28
|
+
hamtaa_texttools-1.1.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
29
|
+
hamtaa_texttools-1.1.10.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
|
|
30
|
+
hamtaa_texttools-1.1.10.dist-info/RECORD,,
|
texttools/__init__.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
from .batch import BatchJobRunner
|
|
2
|
-
from .
|
|
1
|
+
from .batch.batch_runner import BatchJobRunner
|
|
2
|
+
from .batch.batch_config import BatchConfig
|
|
3
|
+
from .tools.sync_tools import TheTool
|
|
4
|
+
from .tools.async_tools import AsyncTheTool
|
|
3
5
|
|
|
4
6
|
__all__ = ["TheTool", "AsyncTheTool", "BatchJobRunner", "BatchConfig"]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from texttools.batch.internals.utils import import_data, export_data
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class BatchConfig:
|
|
9
|
+
"""
|
|
10
|
+
Configuration for batch job runner.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
system_prompt: str = ""
|
|
14
|
+
job_name: str = ""
|
|
15
|
+
input_data_path: str = ""
|
|
16
|
+
output_data_filename: str = ""
|
|
17
|
+
model: str = "gpt-4.1-mini"
|
|
18
|
+
MAX_BATCH_SIZE: int = 100
|
|
19
|
+
MAX_TOTAL_TOKENS: int = 2_000_000
|
|
20
|
+
CHARS_PER_TOKEN: float = 2.7
|
|
21
|
+
PROMPT_TOKEN_MULTIPLIER: int = 1_000
|
|
22
|
+
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
23
|
+
import_function: Callable = import_data
|
|
24
|
+
export_function: Callable = export_data
|
|
25
|
+
poll_interval_seconds: int = 30
|
|
26
|
+
max_retries: int = 3
|
texttools/batch/batch_runner.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
import time
|
|
4
|
-
from dataclasses import dataclass
|
|
5
4
|
from pathlib import Path
|
|
6
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Type, TypeVar
|
|
7
6
|
import logging
|
|
8
7
|
|
|
9
8
|
from dotenv import load_dotenv
|
|
10
9
|
from openai import OpenAI
|
|
11
10
|
from pydantic import BaseModel
|
|
12
11
|
|
|
13
|
-
from texttools.batch.batch_manager import BatchManager
|
|
12
|
+
from texttools.batch.internals.batch_manager import BatchManager
|
|
13
|
+
from texttools.batch.batch_config import BatchConfig
|
|
14
14
|
from texttools.tools.internals.output_models import StrOutput
|
|
15
15
|
|
|
16
16
|
# Base Model type for output models
|
|
@@ -19,43 +19,6 @@ T = TypeVar("T", bound=BaseModel)
|
|
|
19
19
|
logger = logging.getLogger("texttools.batch_runner")
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def export_data(data) -> list[dict[str, str]]:
|
|
23
|
-
"""
|
|
24
|
-
Produces a structure of the following form from an initial data structure:
|
|
25
|
-
[{"id": str, "text": str},...]
|
|
26
|
-
"""
|
|
27
|
-
return data
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def import_data(data) -> Any:
|
|
31
|
-
"""
|
|
32
|
-
Takes the output and adds and aggregates it to the original structure.
|
|
33
|
-
"""
|
|
34
|
-
return data
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@dataclass
|
|
38
|
-
class BatchConfig:
|
|
39
|
-
"""
|
|
40
|
-
Configuration for batch job runner.
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
system_prompt: str = ""
|
|
44
|
-
job_name: str = ""
|
|
45
|
-
input_data_path: str = ""
|
|
46
|
-
output_data_filename: str = ""
|
|
47
|
-
model: str = "gpt-4.1-mini"
|
|
48
|
-
MAX_BATCH_SIZE: int = 100
|
|
49
|
-
MAX_TOTAL_TOKENS: int = 2_000_000
|
|
50
|
-
CHARS_PER_TOKEN: float = 2.7
|
|
51
|
-
PROMPT_TOKEN_MULTIPLIER: int = 1_000
|
|
52
|
-
BASE_OUTPUT_DIR: str = "Data/batch_entity_result"
|
|
53
|
-
import_function: Callable = import_data
|
|
54
|
-
export_function: Callable = export_data
|
|
55
|
-
poll_interval_seconds: int = 30
|
|
56
|
-
max_retries: int = 3
|
|
57
|
-
|
|
58
|
-
|
|
59
22
|
class BatchJobRunner:
|
|
60
23
|
"""
|
|
61
24
|
Handles running batch jobs using a batch manager and configuration.
|
|
@@ -64,22 +27,22 @@ class BatchJobRunner:
|
|
|
64
27
|
def __init__(
|
|
65
28
|
self, config: BatchConfig = BatchConfig(), output_model: Type[T] = StrOutput
|
|
66
29
|
):
|
|
67
|
-
self.
|
|
68
|
-
self.
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
76
|
-
self.
|
|
77
|
-
self._partition_data()
|
|
78
|
-
Path(self.config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
30
|
+
self._config = config
|
|
31
|
+
self._system_prompt = config.system_prompt
|
|
32
|
+
self._job_name = config.job_name
|
|
33
|
+
self._input_data_path = config.input_data_path
|
|
34
|
+
self._output_data_filename = config.output_data_filename
|
|
35
|
+
self._model = config.model
|
|
36
|
+
self._output_model = output_model
|
|
37
|
+
self._manager = self._init_manager()
|
|
38
|
+
self._data = self._load_data()
|
|
39
|
+
self._parts: list[list[dict[str, Any]]] = []
|
|
79
40
|
# Map part index to job name
|
|
80
|
-
self.
|
|
41
|
+
self._part_idx_to_job_name: dict[int, str] = {}
|
|
81
42
|
# Track retry attempts per part
|
|
82
|
-
self.
|
|
43
|
+
self._part_attempts: dict[int, int] = {}
|
|
44
|
+
self._partition_data()
|
|
45
|
+
Path(self._config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
83
46
|
|
|
84
47
|
def _init_manager(self) -> BatchManager:
|
|
85
48
|
load_dotenv()
|
|
@@ -87,15 +50,15 @@ class BatchJobRunner:
|
|
|
87
50
|
client = OpenAI(api_key=api_key)
|
|
88
51
|
return BatchManager(
|
|
89
52
|
client=client,
|
|
90
|
-
model=self.
|
|
91
|
-
prompt_template=self.
|
|
92
|
-
output_model=self.
|
|
53
|
+
model=self._model,
|
|
54
|
+
prompt_template=self._system_prompt,
|
|
55
|
+
output_model=self._output_model,
|
|
93
56
|
)
|
|
94
57
|
|
|
95
58
|
def _load_data(self):
|
|
96
|
-
with open(self.
|
|
59
|
+
with open(self._input_data_path, "r", encoding="utf-8") as f:
|
|
97
60
|
data = json.load(f)
|
|
98
|
-
data = self.
|
|
61
|
+
data = self._config.export_function(data)
|
|
99
62
|
|
|
100
63
|
# Ensure data is a list of dicts with 'id' and 'content' as strings
|
|
101
64
|
if not isinstance(data, list):
|
|
@@ -112,50 +75,50 @@ class BatchJobRunner:
|
|
|
112
75
|
return data
|
|
113
76
|
|
|
114
77
|
def _partition_data(self):
|
|
115
|
-
total_length = sum(len(item["content"]) for item in self.
|
|
116
|
-
prompt_length = len(self.
|
|
117
|
-
total = total_length + (prompt_length * len(self.
|
|
118
|
-
calculation = total / self.
|
|
78
|
+
total_length = sum(len(item["content"]) for item in self._data)
|
|
79
|
+
prompt_length = len(self._system_prompt)
|
|
80
|
+
total = total_length + (prompt_length * len(self._data))
|
|
81
|
+
calculation = total / self._config.CHARS_PER_TOKEN
|
|
119
82
|
logger.info(
|
|
120
83
|
f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}"
|
|
121
84
|
)
|
|
122
|
-
if calculation < self.
|
|
123
|
-
self.
|
|
85
|
+
if calculation < self._config.MAX_TOTAL_TOKENS:
|
|
86
|
+
self._parts = [self._data]
|
|
124
87
|
else:
|
|
125
88
|
# Partition into chunks of MAX_BATCH_SIZE
|
|
126
|
-
self.
|
|
127
|
-
self.
|
|
128
|
-
for i in range(0, len(self.
|
|
89
|
+
self._parts = [
|
|
90
|
+
self._data[i : i + self._config.MAX_BATCH_SIZE]
|
|
91
|
+
for i in range(0, len(self._data), self._config.MAX_BATCH_SIZE)
|
|
129
92
|
]
|
|
130
|
-
logger.info(f"Data split into {len(self.
|
|
93
|
+
logger.info(f"Data split into {len(self._parts)} part(s)")
|
|
131
94
|
|
|
132
95
|
def _submit_all_jobs(self) -> None:
|
|
133
|
-
for idx, part in enumerate(self.
|
|
96
|
+
for idx, part in enumerate(self._parts):
|
|
134
97
|
if self._result_exists(idx):
|
|
135
98
|
logger.info(f"Skipping part {idx + 1}: result already exists.")
|
|
136
99
|
continue
|
|
137
100
|
part_job_name = (
|
|
138
|
-
f"{self.
|
|
139
|
-
if len(self.
|
|
140
|
-
else self.
|
|
101
|
+
f"{self._job_name}_part_{idx + 1}"
|
|
102
|
+
if len(self._parts) > 1
|
|
103
|
+
else self._job_name
|
|
141
104
|
)
|
|
142
105
|
# If a job with this name already exists, register and skip submitting
|
|
143
|
-
existing_job = self.
|
|
106
|
+
existing_job = self._manager._load_state(part_job_name)
|
|
144
107
|
if existing_job:
|
|
145
108
|
logger.info(
|
|
146
109
|
f"Skipping part {idx + 1}: job already exists ({part_job_name})."
|
|
147
110
|
)
|
|
148
|
-
self.
|
|
149
|
-
self.
|
|
111
|
+
self._part_idx_to_job_name[idx] = part_job_name
|
|
112
|
+
self._part_attempts.setdefault(idx, 0)
|
|
150
113
|
continue
|
|
151
114
|
|
|
152
115
|
payload = part
|
|
153
116
|
logger.info(
|
|
154
|
-
f"Submitting job for part {idx + 1}/{len(self.
|
|
117
|
+
f"Submitting job for part {idx + 1}/{len(self._parts)}: {part_job_name}"
|
|
155
118
|
)
|
|
156
|
-
self.
|
|
157
|
-
self.
|
|
158
|
-
self.
|
|
119
|
+
self._manager.start(payload, job_name=part_job_name)
|
|
120
|
+
self._part_idx_to_job_name[idx] = part_job_name
|
|
121
|
+
self._part_attempts.setdefault(idx, 0)
|
|
159
122
|
# This is added for letting file get uploaded, before starting the next part.
|
|
160
123
|
logger.info("Uploading...")
|
|
161
124
|
time.sleep(30)
|
|
@@ -166,10 +129,10 @@ class BatchJobRunner:
|
|
|
166
129
|
log: list[Any],
|
|
167
130
|
part_idx: int,
|
|
168
131
|
):
|
|
169
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self.
|
|
132
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
170
133
|
result_path = (
|
|
171
|
-
Path(self.
|
|
172
|
-
/ f"{Path(self.
|
|
134
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
135
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
173
136
|
)
|
|
174
137
|
if not output_data:
|
|
175
138
|
logger.info("No output data to save. Skipping this part.")
|
|
@@ -179,17 +142,17 @@ class BatchJobRunner:
|
|
|
179
142
|
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
|
180
143
|
if log:
|
|
181
144
|
log_path = (
|
|
182
|
-
Path(self.
|
|
183
|
-
/ f"{Path(self.
|
|
145
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
146
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}_log.json"
|
|
184
147
|
)
|
|
185
148
|
with open(log_path, "w", encoding="utf-8") as f:
|
|
186
149
|
json.dump(log, f, ensure_ascii=False, indent=4)
|
|
187
150
|
|
|
188
151
|
def _result_exists(self, part_idx: int) -> bool:
|
|
189
|
-
part_suffix = f"_part_{part_idx + 1}" if len(self.
|
|
152
|
+
part_suffix = f"_part_{part_idx + 1}" if len(self._parts) > 1 else ""
|
|
190
153
|
result_path = (
|
|
191
|
-
Path(self.
|
|
192
|
-
/ f"{Path(self.
|
|
154
|
+
Path(self._config.BASE_OUTPUT_DIR)
|
|
155
|
+
/ f"{Path(self._output_data_filename).stem}{part_suffix}.json"
|
|
193
156
|
)
|
|
194
157
|
return result_path.exists()
|
|
195
158
|
|
|
@@ -201,41 +164,41 @@ class BatchJobRunner:
|
|
|
201
164
|
"""
|
|
202
165
|
# Submit all jobs up-front for concurrent execution
|
|
203
166
|
self._submit_all_jobs()
|
|
204
|
-
pending_parts: set[int] = set(self.
|
|
167
|
+
pending_parts: set[int] = set(self._part_idx_to_job_name.keys())
|
|
205
168
|
logger.info(f"Pending parts: {sorted(pending_parts)}")
|
|
206
169
|
# Polling loop
|
|
207
170
|
while pending_parts:
|
|
208
171
|
finished_this_round: list[int] = []
|
|
209
172
|
for part_idx in list(pending_parts):
|
|
210
|
-
job_name = self.
|
|
211
|
-
status = self.
|
|
173
|
+
job_name = self._part_idx_to_job_name[part_idx]
|
|
174
|
+
status = self._manager.check_status(job_name=job_name)
|
|
212
175
|
logger.info(f"Status for {job_name}: {status}")
|
|
213
176
|
if status == "completed":
|
|
214
177
|
logger.info(
|
|
215
178
|
f"Job completed. Fetching results for part {part_idx + 1}..."
|
|
216
179
|
)
|
|
217
|
-
output_data, log = self.
|
|
180
|
+
output_data, log = self._manager.fetch_results(
|
|
218
181
|
job_name=job_name, remove_cache=False
|
|
219
182
|
)
|
|
220
|
-
output_data = self.
|
|
183
|
+
output_data = self._config.import_function(output_data)
|
|
221
184
|
self._save_results(output_data, log, part_idx)
|
|
222
185
|
logger.info(f"Fetched and saved results for part {part_idx + 1}.")
|
|
223
186
|
finished_this_round.append(part_idx)
|
|
224
187
|
elif status == "failed":
|
|
225
|
-
attempt = self.
|
|
226
|
-
self.
|
|
227
|
-
if attempt <= self.
|
|
188
|
+
attempt = self._part_attempts.get(part_idx, 0) + 1
|
|
189
|
+
self._part_attempts[part_idx] = attempt
|
|
190
|
+
if attempt <= self._config.max_retries:
|
|
228
191
|
logger.info(
|
|
229
192
|
f"Job {job_name} failed (attempt {attempt}). Retrying after short backoff..."
|
|
230
193
|
)
|
|
231
|
-
self.
|
|
194
|
+
self._manager._clear_state(job_name)
|
|
232
195
|
time.sleep(10)
|
|
233
|
-
payload = self._to_manager_payload(self.
|
|
196
|
+
payload = self._to_manager_payload(self._parts[part_idx])
|
|
234
197
|
new_job_name = (
|
|
235
|
-
f"{self.
|
|
198
|
+
f"{self._job_name}_part_{part_idx + 1}_retry_{attempt}"
|
|
236
199
|
)
|
|
237
|
-
self.
|
|
238
|
-
self.
|
|
200
|
+
self._manager.start(payload, job_name=new_job_name)
|
|
201
|
+
self._part_idx_to_job_name[part_idx] = new_job_name
|
|
239
202
|
else:
|
|
240
203
|
logger.info(
|
|
241
204
|
f"Job {job_name} failed after {attempt - 1} retries. Marking as failed."
|
|
@@ -249,6 +212,6 @@ class BatchJobRunner:
|
|
|
249
212
|
pending_parts.discard(part_idx)
|
|
250
213
|
if pending_parts:
|
|
251
214
|
logger.info(
|
|
252
|
-
f"Waiting {self.
|
|
215
|
+
f"Waiting {self._config.poll_interval_seconds}s before next status check for parts: {sorted(pending_parts)}"
|
|
253
216
|
)
|
|
254
|
-
time.sleep(self.
|
|
217
|
+
time.sleep(self._config.poll_interval_seconds)
|
|
@@ -33,15 +33,15 @@ class BatchManager:
|
|
|
33
33
|
custom_json_schema_obj_str: dict | None = None,
|
|
34
34
|
**client_kwargs: Any,
|
|
35
35
|
):
|
|
36
|
-
self.
|
|
37
|
-
self.
|
|
38
|
-
self.
|
|
39
|
-
self.
|
|
40
|
-
self.
|
|
41
|
-
self.
|
|
42
|
-
self.
|
|
43
|
-
self.
|
|
44
|
-
self.
|
|
36
|
+
self._client = client
|
|
37
|
+
self._model = model
|
|
38
|
+
self._output_model = output_model
|
|
39
|
+
self._prompt_template = prompt_template
|
|
40
|
+
self._state_dir = state_dir
|
|
41
|
+
self._custom_json_schema_obj_str = custom_json_schema_obj_str
|
|
42
|
+
self._client_kwargs = client_kwargs
|
|
43
|
+
self._dict_input = False
|
|
44
|
+
self._state_dir.mkdir(parents=True, exist_ok=True)
|
|
45
45
|
|
|
46
46
|
if custom_json_schema_obj_str and not isinstance(
|
|
47
47
|
custom_json_schema_obj_str, dict
|
|
@@ -49,7 +49,7 @@ class BatchManager:
|
|
|
49
49
|
raise ValueError("Schema should be a dict")
|
|
50
50
|
|
|
51
51
|
def _state_file(self, job_name: str) -> Path:
|
|
52
|
-
return self.
|
|
52
|
+
return self._state_dir / f"{job_name}.json"
|
|
53
53
|
|
|
54
54
|
def _load_state(self, job_name: str) -> list[dict[str, Any]]:
|
|
55
55
|
"""
|
|
@@ -83,17 +83,17 @@ class BatchManager:
|
|
|
83
83
|
"""
|
|
84
84
|
response_format_config: dict[str, Any]
|
|
85
85
|
|
|
86
|
-
if self.
|
|
86
|
+
if self._custom_json_schema_obj_str:
|
|
87
87
|
response_format_config = {
|
|
88
88
|
"type": "json_schema",
|
|
89
|
-
"json_schema": self.
|
|
89
|
+
"json_schema": self._custom_json_schema_obj_str,
|
|
90
90
|
}
|
|
91
91
|
else:
|
|
92
|
-
raw_schema = to_strict_json_schema(self.
|
|
92
|
+
raw_schema = to_strict_json_schema(self._output_model)
|
|
93
93
|
response_format_config = {
|
|
94
94
|
"type": "json_schema",
|
|
95
95
|
"json_schema": {
|
|
96
|
-
"name": self.
|
|
96
|
+
"name": self._output_model.__name__,
|
|
97
97
|
"schema": raw_schema,
|
|
98
98
|
},
|
|
99
99
|
}
|
|
@@ -105,11 +105,11 @@ class BatchManager:
|
|
|
105
105
|
"body": {
|
|
106
106
|
"model": self.model,
|
|
107
107
|
"messages": [
|
|
108
|
-
{"role": "system", "content": self.
|
|
108
|
+
{"role": "system", "content": self._prompt_template},
|
|
109
109
|
{"role": "user", "content": text},
|
|
110
110
|
],
|
|
111
111
|
"response_format": response_format_config,
|
|
112
|
-
**self.
|
|
112
|
+
**self._client_kwargs,
|
|
113
113
|
},
|
|
114
114
|
}
|
|
115
115
|
|
|
@@ -130,7 +130,7 @@ class BatchManager:
|
|
|
130
130
|
"The input must be either a list of texts or a dictionary in the form {'id': str, 'text': str}"
|
|
131
131
|
)
|
|
132
132
|
|
|
133
|
-
file_path = self.
|
|
133
|
+
file_path = self._state_dir / f"batch_{uuid.uuid4().hex}.jsonl"
|
|
134
134
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
135
135
|
for task in tasks:
|
|
136
136
|
f.write(json.dumps(task) + "\n")
|
|
@@ -145,8 +145,8 @@ class BatchManager:
|
|
|
145
145
|
return
|
|
146
146
|
|
|
147
147
|
path = self._prepare_file(payload)
|
|
148
|
-
upload = self.
|
|
149
|
-
job = self.
|
|
148
|
+
upload = self._client.files.create(file=open(path, "rb"), purpose="batch")
|
|
149
|
+
job = self._client.batches.create(
|
|
150
150
|
input_file_id=upload.id,
|
|
151
151
|
endpoint="/v1/chat/completions",
|
|
152
152
|
completion_window="24h",
|
|
@@ -162,7 +162,7 @@ class BatchManager:
|
|
|
162
162
|
if not job:
|
|
163
163
|
return "completed"
|
|
164
164
|
|
|
165
|
-
info = self.
|
|
165
|
+
info = self._client.batches.retrieve(job["id"])
|
|
166
166
|
job = info.to_dict()
|
|
167
167
|
self._save_state(job_name, [job])
|
|
168
168
|
logger.info("Batch job status: %s", job)
|
|
@@ -180,18 +180,18 @@ class BatchManager:
|
|
|
180
180
|
return {}
|
|
181
181
|
batch_id = job["id"]
|
|
182
182
|
|
|
183
|
-
info = self.
|
|
183
|
+
info = self._client.batches.retrieve(batch_id)
|
|
184
184
|
out_file_id = info.output_file_id
|
|
185
185
|
if not out_file_id:
|
|
186
186
|
error_file_id = info.error_file_id
|
|
187
187
|
if error_file_id:
|
|
188
188
|
err_content = (
|
|
189
|
-
self.
|
|
189
|
+
self._client.files.content(error_file_id).read().decode("utf-8")
|
|
190
190
|
)
|
|
191
191
|
logger.error("Error file content:", err_content)
|
|
192
192
|
return {}
|
|
193
193
|
|
|
194
|
-
content = self.
|
|
194
|
+
content = self._client.files.content(out_file_id).read().decode("utf-8")
|
|
195
195
|
lines = content.splitlines()
|
|
196
196
|
results = {}
|
|
197
197
|
log = []
|
|
@@ -202,7 +202,7 @@ class BatchManager:
|
|
|
202
202
|
content = result["response"]["body"]["choices"][0]["message"]["content"]
|
|
203
203
|
try:
|
|
204
204
|
parsed_content = json.loads(content)
|
|
205
|
-
model_instance = self.
|
|
205
|
+
model_instance = self._output_model(**parsed_content)
|
|
206
206
|
results[custom_id] = model_instance.model_dump(mode="json")
|
|
207
207
|
except json.JSONDecodeError:
|
|
208
208
|
results[custom_id] = {"error": "Failed to parse content as JSON"}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def export_data(data) -> list[dict[str, str]]:
|
|
5
|
+
"""
|
|
6
|
+
Produces a structure of the following form from an initial data structure:
|
|
7
|
+
[{"id": str, "text": str},...]
|
|
8
|
+
"""
|
|
9
|
+
return data
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def import_data(data) -> Any:
|
|
13
|
+
"""
|
|
14
|
+
Takes the output and adds and aggregates it to the original structure.
|
|
15
|
+
"""
|
|
16
|
+
return data
|