tlmtc 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tlmtc/__init__.py +54 -0
- tlmtc/__main__.py +6 -0
- tlmtc/api.py +455 -0
- tlmtc/cli.py +345 -0
- tlmtc/data_contracts.py +160 -0
- tlmtc/data_pipeline.py +257 -0
- tlmtc/data_preparation.py +221 -0
- tlmtc/evaluation.py +291 -0
- tlmtc/evaluation_pipeline.py +309 -0
- tlmtc/finetune_pipeline.py +355 -0
- tlmtc/hpo.py +157 -0
- tlmtc/meta.py +86 -0
- tlmtc/paths.py +371 -0
- tlmtc/prediction.py +154 -0
- tlmtc/reporting.py +605 -0
- tlmtc/runtime_output.py +100 -0
- tlmtc/settings.py +456 -0
- tlmtc/training.py +339 -0
- tlmtc-0.1.0.dist-info/METADATA +223 -0
- tlmtc-0.1.0.dist-info/RECORD +23 -0
- tlmtc-0.1.0.dist-info/WHEEL +4 -0
- tlmtc-0.1.0.dist-info/entry_points.txt +2 -0
- tlmtc-0.1.0.dist-info/licenses/LICENSE.md +21 -0
tlmtc/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Public package interface and lazy-loaded API exports."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import logging
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"predict_tlmtc",
|
|
11
|
+
"train_tlmtc",
|
|
12
|
+
"__version__",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
logging.getLogger("tlmtc").addHandler(logging.NullHandler())
|
|
16
|
+
|
|
17
|
+
_LAZY: dict[str, tuple[str, str]] = {
|
|
18
|
+
"predict_tlmtc": ("tlmtc.api", "predict_tlmtc"),
|
|
19
|
+
"train_tlmtc": ("tlmtc.api", "train_tlmtc"),
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def __getattr__(
|
|
24
|
+
name: str,
|
|
25
|
+
) -> Any:
|
|
26
|
+
try:
|
|
27
|
+
module_path, attr = _LAZY[name]
|
|
28
|
+
except KeyError as exc:
|
|
29
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
value = getattr(importlib.import_module(module_path), attr)
|
|
33
|
+
except ModuleNotFoundError as exc:
|
|
34
|
+
missing = getattr(exc, "name", None)
|
|
35
|
+
|
|
36
|
+
if missing in {"torch", "peft", "accelerate"}:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
f"`torch`, `peft`, and `accelerate` are required for `tlmtc.{name}`. "
|
|
39
|
+
"Install them with: `pip install 'tlmtc[full]'`."
|
|
40
|
+
) from exc
|
|
41
|
+
|
|
42
|
+
raise
|
|
43
|
+
|
|
44
|
+
globals()[name] = value
|
|
45
|
+
return value
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def __dir__() -> list[str]:
|
|
49
|
+
return sorted(__all__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
if TYPE_CHECKING:
|
|
53
|
+
from tlmtc.api import predict_tlmtc as predict_tlmtc
|
|
54
|
+
from tlmtc.api import train_tlmtc as train_tlmtc
|
tlmtc/__main__.py
ADDED
tlmtc/api.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
"""Public Python API for running tlmtc training and prediction workflows."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from tlmtc.data_pipeline import DataPipeline
|
|
8
|
+
from tlmtc.data_preparation import create_prediction_dataset, read_prediction_csv, tokenize_prediction_dataset
|
|
9
|
+
from tlmtc.evaluation_pipeline import EvaluationPipeline
|
|
10
|
+
from tlmtc.finetune_pipeline import FinetunePipeline
|
|
11
|
+
from tlmtc.meta import TrainRunMeta, read_run_meta, write_run_meta
|
|
12
|
+
from tlmtc.paths import PredictionPaths, RunPaths, resolve_paths, resolve_prediction_paths
|
|
13
|
+
from tlmtc.prediction import (
|
|
14
|
+
apply_thresholds,
|
|
15
|
+
load_prediction_model,
|
|
16
|
+
make_prediction_frame,
|
|
17
|
+
predict_probabilities,
|
|
18
|
+
)
|
|
19
|
+
from tlmtc.runtime_output import configure_runtime_output, emit_progress
|
|
20
|
+
from tlmtc.settings import UNSET, PredictionSettings, RunSettings, Unset, load_config_file
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True, slots=True)
|
|
24
|
+
class TrainResult:
|
|
25
|
+
"""Result metadata for a completed tlmtc training run.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
paths: Resolved filesystem layout containing input paths and generated run artifacts.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
paths: RunPaths
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True, slots=True)
|
|
35
|
+
class PredictResult:
|
|
36
|
+
"""Result metadata for a completed tlmtc prediction run.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
paths: Resolved filesystem layout containing prediction inputs and generated artifacts.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
paths: PredictionPaths
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def train_tlmtc(
|
|
46
|
+
raw_csv: str | Path,
|
|
47
|
+
*,
|
|
48
|
+
raw_test_csv: str | Path | Unset = UNSET,
|
|
49
|
+
work_dir: str | Path | Unset = UNSET,
|
|
50
|
+
config_path: str | Path | Unset = UNSET,
|
|
51
|
+
run_id: str | None | Unset = UNSET,
|
|
52
|
+
target_name: str | Unset = UNSET,
|
|
53
|
+
validation_size: float | Unset = UNSET,
|
|
54
|
+
test_size: float | Unset = UNSET,
|
|
55
|
+
random_seed: int | Unset = UNSET,
|
|
56
|
+
transfer_learning: bool | Unset = UNSET,
|
|
57
|
+
hyperparameter_tuning: bool | Unset = UNSET,
|
|
58
|
+
threshold_optimization: bool | Unset = UNSET,
|
|
59
|
+
threshold_type: str | Unset = UNSET,
|
|
60
|
+
scale_learning_rate: bool | Unset = UNSET,
|
|
61
|
+
wrap_peft: bool | Unset = UNSET,
|
|
62
|
+
proxy_checkpoint: str | Unset = UNSET,
|
|
63
|
+
checkpoint: str | Unset = UNSET,
|
|
64
|
+
sequence_length: int | Unset = UNSET,
|
|
65
|
+
best_model_metric: str | Unset = UNSET,
|
|
66
|
+
batch_size: int | Unset = UNSET,
|
|
67
|
+
train_epochs: int | Unset = UNSET,
|
|
68
|
+
learning_rate: float | Unset = UNSET,
|
|
69
|
+
weight_decay: float | Unset = UNSET,
|
|
70
|
+
lr_scheduler: str | Unset = UNSET,
|
|
71
|
+
best_threshold_metric: str | Unset = UNSET,
|
|
72
|
+
tuning_trials: int | Unset = UNSET,
|
|
73
|
+
optuna_space: dict[str, Any] | Unset = UNSET,
|
|
74
|
+
lora_r: int | Unset = UNSET,
|
|
75
|
+
lora_alpha: int | Unset = UNSET,
|
|
76
|
+
lora_dropout: float | Unset = UNSET,
|
|
77
|
+
lora_bias: str | Unset = UNSET,
|
|
78
|
+
early_stopping_patience: int | Unset = UNSET,
|
|
79
|
+
use_cpu: bool | Unset = UNSET,
|
|
80
|
+
verbosity: str | Unset = UNSET,
|
|
81
|
+
) -> TrainResult:
|
|
82
|
+
"""Run the full multi-label text classification training workflow.
|
|
83
|
+
|
|
84
|
+
The workflow can perform data preparation, hyperparameter tuning, model fine-tuning,
|
|
85
|
+
threshold optimization, evaluation, and reporting end-to-end according to the selected
|
|
86
|
+
workflow flags.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
raw_csv: Path to the raw multi-label training CSV. The file must contain a `text` column,
|
|
90
|
+
at least two binary `label_*` columns, and optionally a `text_pair` column.
|
|
91
|
+
raw_test_csv: Path to a separate raw test CSV. If omitted, a test split is created
|
|
92
|
+
from `raw_csv` using `test_size`. Defaults to no separate test CSV.
|
|
93
|
+
work_dir: Base directory for resolving inputs and writing run artifacts. Defaults to the
|
|
94
|
+
current working directory.
|
|
95
|
+
config_path: Path to a YAML configuration file. Defaults to no configuration file.
|
|
96
|
+
run_id: Run identifier used to name the run directory. If omitted, a random
|
|
97
|
+
identifier is generated.
|
|
98
|
+
target_name: Display name for the classification target in logs and reports. Defaults to
|
|
99
|
+
`"Target"`.
|
|
100
|
+
validation_size: Fraction reserved for validation splitting. Defaults to `0.15`.
|
|
101
|
+
test_size: Fraction reserved for test splitting when `raw_test_csv` is omitted. Defaults to
|
|
102
|
+
`0.15`.
|
|
103
|
+
random_seed: Random seed used for reproducible splitting and shuffling. Defaults to `2469`.
|
|
104
|
+
transfer_learning: Whether to fine-tune the target checkpoint and produce model/evaluation
|
|
105
|
+
artifacts. If `False`, data preparation still runs; with `hyperparameter_tuning=True`,
|
|
106
|
+
tlmtc runs proxy-checkpoint hyperparameter tuning only. Defaults to `True`.
|
|
107
|
+
hyperparameter_tuning: Whether to evaluate candidate hyperparameter configurations with
|
|
108
|
+
Optuna before final fine-tuning. If `True` and `transfer_learning=False`, only the
|
|
109
|
+
proxy-checkpoint tuning stage is run after data preparation. If both are `False`,
|
|
110
|
+
the workflow stops after data preparation. Defaults to `True`.
|
|
111
|
+
threshold_optimization: Whether to tune decision thresholds on validation-set predictions
|
|
112
|
+
after fine-tuning. If `False`, evaluation uses the default threshold `0.5`. Ignored
|
|
113
|
+
when `transfer_learning=False`. Defaults to `True`.
|
|
114
|
+
threshold_type: Thresholding mode. Supported values are `"global"` and `"label"`. Defaults to
|
|
115
|
+
`"label"`.
|
|
116
|
+
scale_learning_rate: Whether to scale a proxy-tuned learning rate for the target checkpoint.
|
|
117
|
+
Defaults to `False`.
|
|
118
|
+
wrap_peft: Whether to use parameter-efficient fine-tuning with LoRA adapters. Defaults to `True`.
|
|
119
|
+
proxy_checkpoint: Compatible encoder-only Hugging Face checkpoint identifier used during
|
|
120
|
+
hyperparameter tuning. Defaults to `"EuroBERT/EuroBERT-210m"`. If `checkpoint`
|
|
121
|
+
is supplied and `proxy_checkpoint` is omitted, the proxy checkpoint defaults to the
|
|
122
|
+
selected `checkpoint`. Loaded with `trust_remote_code=False`; checkpoints that require
|
|
123
|
+
custom remote code are not supported. Only use checkpoints you trust.
|
|
124
|
+
checkpoint: Compatible encoder-only Hugging Face checkpoint identifier or local path used for
|
|
125
|
+
final fine-tuning. Defaults to `"EuroBERT/EuroBERT-610m"`. Loaded with `trust_remote_code=False`;
|
|
126
|
+
checkpoints that require custom remote code are not supported. Only use checkpoints and local model
|
|
127
|
+
directories you trust.
|
|
128
|
+
sequence_length: Maximum tokenized sequence length. Defaults to `128`.
|
|
129
|
+
best_model_metric: Metric used to select the best model checkpoint. Supported values are
|
|
130
|
+
`"f1_micro"`, `"f1_macro"`, `"roc_auc_micro"`, and `"roc_auc_macro"`. Defaults to
|
|
131
|
+
`"roc_auc_macro"`.
|
|
132
|
+
batch_size: Initial training and evaluation batch size. Used directly when hyperparameter tuning is
|
|
133
|
+
disabled, otherwise replaced by the tuned value. Defaults to `16`.
|
|
134
|
+
train_epochs: Initial number of training epochs. Used directly when hyperparameter tuning is
|
|
135
|
+
disabled, otherwise replaced by the tuned value. Defaults to `20`.
|
|
136
|
+
learning_rate: Initial optimizer learning rate. Used directly when hyperparameter tuning is
|
|
137
|
+
disabled, otherwise replaced by the tuned value. Defaults to `2e-5`.
|
|
138
|
+
weight_decay: Initial weight decay for training. Used directly when hyperparameter tuning is
|
|
139
|
+
disabled, otherwise replaced by the tuned value. Defaults to `0.01`.
|
|
140
|
+
lr_scheduler: Initial learning-rate scheduler name. Used directly when hyperparameter tuning is
|
|
141
|
+
disabled, otherwise replaced by the tuned value. Defaults to `"linear"`.
|
|
142
|
+
best_threshold_metric: Metric used to select decision thresholds. Supported values are
|
|
143
|
+
`"f1_micro"` and `"f1_macro"`. Defaults to `"f1_macro"`.
|
|
144
|
+
tuning_trials: Number of hyperparameter configurations to evaluate during Optuna tuning. Higher
|
|
145
|
+
values may improve the selected configuration but increase runtime. Defaults to `10`.
|
|
146
|
+
optuna_space: Optional partial override for the hyperparameter tuning ranges and candidate
|
|
147
|
+
values. Supported keys are `lr_low`, `lr_high`, `batch_sizes`, `wd_low`, `wd_high`,
|
|
148
|
+
`schedulers`, `epoch_low`, `epoch_high`. Missing keys are filled from the default tuning space
|
|
149
|
+
selected by `wrap_peft`.
|
|
150
|
+
|
|
151
|
+
Defaults to the PEFT search space when `wrap_peft=True`:
|
|
152
|
+
|
|
153
|
+
{
|
|
154
|
+
"lr_low": 5e-5,
|
|
155
|
+
"lr_high": 4e-4,
|
|
156
|
+
"batch_sizes": [8, 16, 32],
|
|
157
|
+
"wd_low": 0.0,
|
|
158
|
+
"wd_high": 0.01,
|
|
159
|
+
"schedulers": ["linear", "cosine"],
|
|
160
|
+
"epoch_low": 5,
|
|
161
|
+
"epoch_high": 20,
|
|
162
|
+
"lr_reference_batch_size": 32,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
Defaults to the full fine-tuning search space when `wrap_peft=False`:
|
|
166
|
+
|
|
167
|
+
{
|
|
168
|
+
"lr_low": 1e-5,
|
|
169
|
+
"lr_high": 8e-5,
|
|
170
|
+
"batch_sizes": [8, 16, 32],
|
|
171
|
+
"wd_low": 0.0,
|
|
172
|
+
"wd_high": 0.1,
|
|
173
|
+
"schedulers": ["linear", "cosine", "polynomial"],
|
|
174
|
+
"epoch_low": 5,
|
|
175
|
+
"epoch_high": 30,
|
|
176
|
+
"lr_reference_batch_size": 32,
|
|
177
|
+
}
|
|
178
|
+
lora_r: LoRA rank. Defaults to `8`.
|
|
179
|
+
lora_alpha: LoRA scaling factor. Defaults to `32`.
|
|
180
|
+
lora_dropout: LoRA dropout probability. Defaults to `0.1`.
|
|
181
|
+
lora_bias: LoRA bias handling mode. Supported values are `"none"`, `"all"`, and `"lora_only"`.
|
|
182
|
+
Defaults to `"none"`.
|
|
183
|
+
early_stopping_patience: Early stopping patience in epochs without improvement. Defaults to
|
|
184
|
+
`10`.
|
|
185
|
+
use_cpu: Whether to force CPU execution. Defaults to `False`.
|
|
186
|
+
verbosity: Runtime output mode. Supported values are `"progress"` and `"quiet"`. Defaults to
|
|
187
|
+
`"progress"`.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Result metadata containing the resolved input and artifact paths.
|
|
191
|
+
"""
|
|
192
|
+
settings = RunSettings.resolve(
|
|
193
|
+
config=load_config_file(config_path) if isinstance(config_path, (str, Path)) else None,
|
|
194
|
+
env=None,
|
|
195
|
+
overrides={
|
|
196
|
+
"raw_csv": raw_csv,
|
|
197
|
+
"raw_test_csv": raw_test_csv,
|
|
198
|
+
"work_dir": work_dir,
|
|
199
|
+
"run_id": run_id,
|
|
200
|
+
"model": {
|
|
201
|
+
"target_name": target_name,
|
|
202
|
+
"proxy_checkpoint": proxy_checkpoint,
|
|
203
|
+
"checkpoint": checkpoint,
|
|
204
|
+
"sequence_length": sequence_length,
|
|
205
|
+
},
|
|
206
|
+
"split": {
|
|
207
|
+
"validation_size": validation_size,
|
|
208
|
+
"test_size": test_size,
|
|
209
|
+
"random_seed": random_seed,
|
|
210
|
+
},
|
|
211
|
+
"workflow": {
|
|
212
|
+
"hyperparameter_tuning": hyperparameter_tuning,
|
|
213
|
+
"threshold_optimization": threshold_optimization,
|
|
214
|
+
"transfer_learning": transfer_learning,
|
|
215
|
+
"scale_learning_rate": scale_learning_rate,
|
|
216
|
+
"wrap_peft": wrap_peft,
|
|
217
|
+
},
|
|
218
|
+
"training": {
|
|
219
|
+
"batch_size": batch_size,
|
|
220
|
+
"train_epochs": train_epochs,
|
|
221
|
+
"weight_decay": weight_decay,
|
|
222
|
+
"learning_rate": learning_rate,
|
|
223
|
+
"lr_scheduler": lr_scheduler,
|
|
224
|
+
"best_model_metric": best_model_metric,
|
|
225
|
+
"early_stopping_patience": early_stopping_patience,
|
|
226
|
+
},
|
|
227
|
+
"threshold": {
|
|
228
|
+
"threshold_type": threshold_type,
|
|
229
|
+
"best_threshold_metric": best_threshold_metric,
|
|
230
|
+
},
|
|
231
|
+
"hpo": {
|
|
232
|
+
"tuning_trials": tuning_trials,
|
|
233
|
+
"optuna_space": optuna_space,
|
|
234
|
+
},
|
|
235
|
+
"peft": {
|
|
236
|
+
"lora_r": lora_r,
|
|
237
|
+
"lora_alpha": lora_alpha,
|
|
238
|
+
"lora_dropout": lora_dropout,
|
|
239
|
+
"lora_bias": lora_bias,
|
|
240
|
+
},
|
|
241
|
+
"hardware": {
|
|
242
|
+
"use_cpu": use_cpu,
|
|
243
|
+
},
|
|
244
|
+
"runtime": {
|
|
245
|
+
"verbosity": verbosity,
|
|
246
|
+
},
|
|
247
|
+
},
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
configure_runtime_output(settings.runtime.verbosity)
|
|
251
|
+
emit_progress("Starting training run")
|
|
252
|
+
|
|
253
|
+
paths = resolve_paths(
|
|
254
|
+
raw_csv=settings.raw_csv,
|
|
255
|
+
raw_test_csv=settings.raw_test_csv,
|
|
256
|
+
work_dir=settings.work_dir,
|
|
257
|
+
run_id=settings.run_id,
|
|
258
|
+
).ensure_dirs()
|
|
259
|
+
|
|
260
|
+
data_pipeline = DataPipeline(
|
|
261
|
+
paths=paths,
|
|
262
|
+
split=settings.split,
|
|
263
|
+
model=settings.model,
|
|
264
|
+
)
|
|
265
|
+
data_pipeline.split_data()
|
|
266
|
+
data_pipeline.get_multi_hot_vectors()
|
|
267
|
+
data_pipeline.create_hf_dataset()
|
|
268
|
+
data_pipeline.tokenize_data()
|
|
269
|
+
|
|
270
|
+
finetune_pipeline = FinetunePipeline(
|
|
271
|
+
tokenized_dataset=data_pipeline.tokenized_dataset,
|
|
272
|
+
paths=paths,
|
|
273
|
+
model=settings.model,
|
|
274
|
+
workflow=settings.workflow,
|
|
275
|
+
peft=settings.peft,
|
|
276
|
+
training=settings.training,
|
|
277
|
+
hpo=settings.hpo,
|
|
278
|
+
threshold=settings.threshold,
|
|
279
|
+
hardware=settings.hardware,
|
|
280
|
+
)
|
|
281
|
+
finetune_pipeline.load_pretrained()
|
|
282
|
+
finetune_pipeline.tune_hyperparameters()
|
|
283
|
+
finetune_pipeline.fine_tune_pretrained()
|
|
284
|
+
finetune_pipeline.tune_thresholds()
|
|
285
|
+
finetune_pipeline.save_pretrained()
|
|
286
|
+
|
|
287
|
+
evaluation_pipeline = EvaluationPipeline(
|
|
288
|
+
tokenized_dataset=data_pipeline.tokenized_dataset,
|
|
289
|
+
updated_trainer=finetune_pipeline.updated_trainer,
|
|
290
|
+
paths=paths,
|
|
291
|
+
model=settings.model,
|
|
292
|
+
workflow=settings.workflow,
|
|
293
|
+
training=settings.training,
|
|
294
|
+
tuned_threshold=finetune_pipeline.tuned_threshold,
|
|
295
|
+
input_mode=data_pipeline.input_mode,
|
|
296
|
+
)
|
|
297
|
+
evaluation_pipeline.run_evaluation()
|
|
298
|
+
evaluation_pipeline.save_metrics()
|
|
299
|
+
evaluation_pipeline.render_tables()
|
|
300
|
+
evaluation_pipeline.render_figures()
|
|
301
|
+
|
|
302
|
+
write_run_meta(
|
|
303
|
+
meta=TrainRunMeta(
|
|
304
|
+
run_id=settings.run_id,
|
|
305
|
+
target_name=settings.model.target_name,
|
|
306
|
+
checkpoint=settings.model.checkpoint,
|
|
307
|
+
proxy_checkpoint=settings.model.proxy_checkpoint,
|
|
308
|
+
sequence_length=settings.model.sequence_length,
|
|
309
|
+
input_mode=data_pipeline.input_mode,
|
|
310
|
+
label_names=evaluation_pipeline.label_names,
|
|
311
|
+
threshold_type=settings.threshold.threshold_type,
|
|
312
|
+
thresholds=finetune_pipeline.tuned_threshold.tolist(),
|
|
313
|
+
transfer_learning=settings.workflow.transfer_learning,
|
|
314
|
+
hyperparameter_tuning=settings.workflow.hyperparameter_tuning,
|
|
315
|
+
threshold_optimization=settings.workflow.threshold_optimization,
|
|
316
|
+
scale_learning_rate=settings.workflow.scale_learning_rate,
|
|
317
|
+
wrap_peft=settings.workflow.wrap_peft,
|
|
318
|
+
),
|
|
319
|
+
path=paths.train_run_meta_path,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
emit_progress("Training run complete")
|
|
323
|
+
return TrainResult(paths=paths)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def predict_tlmtc(
|
|
327
|
+
prediction_csv: str | Path,
|
|
328
|
+
*,
|
|
329
|
+
work_dir: str | Path | Unset = UNSET,
|
|
330
|
+
config_path: str | Path | Unset = UNSET,
|
|
331
|
+
run_id: str | None | Unset = UNSET,
|
|
332
|
+
batch_size: int | Unset = UNSET,
|
|
333
|
+
use_cpu: bool | Unset = UNSET,
|
|
334
|
+
verbosity: str | Unset = UNSET,
|
|
335
|
+
) -> PredictResult:
|
|
336
|
+
"""Run the multi-label text classification prediction workflow.
|
|
337
|
+
|
|
338
|
+
Prediction consumes persisted metadata and model artifacts from a completed
|
|
339
|
+
training run, applies the persisted decision thresholds, and writes probability
|
|
340
|
+
and binary prediction artifacts.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
prediction_csv: Path to the unlabeled prediction CSV. The file must contain a `text`
|
|
344
|
+
column and, for models trained with paired-text inputs, a `text_pair` column.
|
|
345
|
+
Prediction artifacts preserve input text columns unchanged.
|
|
346
|
+
work_dir: Base directory for resolving inputs, reading training artifacts, and writing
|
|
347
|
+
prediction artifacts. Defaults to the current working directory.
|
|
348
|
+
config_path: Path to a YAML configuration file. Defaults to no configuration file.
|
|
349
|
+
run_id: Run identifier used to select the completed training run. If omitted, the latest
|
|
350
|
+
completed training run is selected from persisted training metadata. Prediction reloads
|
|
351
|
+
the trained model or adapter artifacts for this run with `trust_remote_code=False`;
|
|
352
|
+
artifacts that require custom remote code are not supported. Only use saved model
|
|
353
|
+
artifacts and adapters you trust.
|
|
354
|
+
batch_size: Prediction batch size used for batched inference. Defaults to `32`.
|
|
355
|
+
use_cpu: Whether to force CPU execution. Defaults to `False`.
|
|
356
|
+
verbosity: Runtime output mode. Supported values are `"progress"` and `"quiet"`. Defaults to
|
|
357
|
+
`"progress"`.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Result metadata containing the resolved input and artifact paths.
|
|
361
|
+
"""
|
|
362
|
+
settings = PredictionSettings.resolve(
|
|
363
|
+
config=load_config_file(config_path) if isinstance(config_path, (str, Path)) else None,
|
|
364
|
+
env=None,
|
|
365
|
+
overrides={
|
|
366
|
+
"prediction_csv": prediction_csv,
|
|
367
|
+
"work_dir": work_dir,
|
|
368
|
+
"run_id": run_id,
|
|
369
|
+
"batch_size": batch_size,
|
|
370
|
+
"hardware": {
|
|
371
|
+
"use_cpu": use_cpu,
|
|
372
|
+
},
|
|
373
|
+
"runtime": {
|
|
374
|
+
"verbosity": verbosity,
|
|
375
|
+
},
|
|
376
|
+
},
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
configure_runtime_output(settings.runtime.verbosity)
|
|
380
|
+
emit_progress("Starting prediction run")
|
|
381
|
+
|
|
382
|
+
paths = resolve_prediction_paths(
|
|
383
|
+
input_csv=settings.prediction_csv,
|
|
384
|
+
work_dir=settings.work_dir,
|
|
385
|
+
run_id=settings.run_id,
|
|
386
|
+
).ensure_dirs()
|
|
387
|
+
|
|
388
|
+
emit_progress("Reading training metadata")
|
|
389
|
+
|
|
390
|
+
meta = read_run_meta(paths.train_run_meta_path)
|
|
391
|
+
|
|
392
|
+
if not meta.transfer_learning:
|
|
393
|
+
raise RuntimeError(
|
|
394
|
+
"Prediction requires a training run with transfer_learning=True. "
|
|
395
|
+
f"Run '{meta.run_id}' did not persist a fine-tuned prediction model."
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
input_mode = meta.input_mode
|
|
399
|
+
label_names = meta.label_names
|
|
400
|
+
|
|
401
|
+
assert input_mode is not None
|
|
402
|
+
assert label_names is not None
|
|
403
|
+
|
|
404
|
+
emit_progress("Reading prediction inputs")
|
|
405
|
+
|
|
406
|
+
input_df = read_prediction_csv(
|
|
407
|
+
df_path=paths.input_data_path,
|
|
408
|
+
expected_input_mode=input_mode,
|
|
409
|
+
)
|
|
410
|
+
prediction_dataset = create_prediction_dataset(
|
|
411
|
+
df=input_df,
|
|
412
|
+
input_mode=input_mode,
|
|
413
|
+
)
|
|
414
|
+
emit_progress("Tokenizing prediction inputs")
|
|
415
|
+
tokenized_dataset = tokenize_prediction_dataset(
|
|
416
|
+
dataset=prediction_dataset,
|
|
417
|
+
checkpoint=meta.checkpoint,
|
|
418
|
+
input_mode=input_mode,
|
|
419
|
+
sequence_length=meta.sequence_length,
|
|
420
|
+
)
|
|
421
|
+
emit_progress("Loading fine-tuned prediction model")
|
|
422
|
+
model = load_prediction_model(
|
|
423
|
+
model_dir=paths.train_run_model_dir,
|
|
424
|
+
checkpoint=meta.checkpoint,
|
|
425
|
+
num_labels=len(label_names),
|
|
426
|
+
wrap_peft=meta.wrap_peft,
|
|
427
|
+
)
|
|
428
|
+
emit_progress("Running prediction")
|
|
429
|
+
probabilities = predict_probabilities(
|
|
430
|
+
model=model,
|
|
431
|
+
dataset=tokenized_dataset,
|
|
432
|
+
batch_size=settings.batch_size,
|
|
433
|
+
use_cpu=settings.hardware.use_cpu,
|
|
434
|
+
)
|
|
435
|
+
probability_df = make_prediction_frame(
|
|
436
|
+
input_df=input_df,
|
|
437
|
+
values=probabilities,
|
|
438
|
+
label_names=label_names,
|
|
439
|
+
)
|
|
440
|
+
predictions = apply_thresholds(
|
|
441
|
+
probabilities=probabilities,
|
|
442
|
+
thresholds=meta.thresholds,
|
|
443
|
+
)
|
|
444
|
+
prediction_df = make_prediction_frame(
|
|
445
|
+
input_df=input_df,
|
|
446
|
+
values=predictions,
|
|
447
|
+
label_names=label_names,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
emit_progress("Writing prediction artifacts")
|
|
451
|
+
probability_df.to_csv(paths.probabilities_path, index=False)
|
|
452
|
+
prediction_df.to_csv(paths.predictions_path, index=False)
|
|
453
|
+
|
|
454
|
+
emit_progress("Prediction run complete")
|
|
455
|
+
return PredictResult(paths=paths)
|