tlmtc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tlmtc/__init__.py ADDED
@@ -0,0 +1,54 @@
1
+ """Public package interface and lazy-loaded API exports."""
2
+
3
+ import importlib
4
+ import logging
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ __version__ = "0.1.0"
8
+
9
+ __all__ = [
10
+ "predict_tlmtc",
11
+ "train_tlmtc",
12
+ "__version__",
13
+ ]
14
+
15
+ logging.getLogger("tlmtc").addHandler(logging.NullHandler())
16
+
17
+ _LAZY: dict[str, tuple[str, str]] = {
18
+ "predict_tlmtc": ("tlmtc.api", "predict_tlmtc"),
19
+ "train_tlmtc": ("tlmtc.api", "train_tlmtc"),
20
+ }
21
+
22
+
23
+ def __getattr__(
24
+ name: str,
25
+ ) -> Any:
26
+ try:
27
+ module_path, attr = _LAZY[name]
28
+ except KeyError as exc:
29
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc
30
+
31
+ try:
32
+ value = getattr(importlib.import_module(module_path), attr)
33
+ except ModuleNotFoundError as exc:
34
+ missing = getattr(exc, "name", None)
35
+
36
+ if missing in {"torch", "peft", "accelerate"}:
37
+ raise ImportError(
38
+ f"`torch`, `peft`, and `accelerate` are required for `tlmtc.{name}`. "
39
+ "Install them with: `pip install 'tlmtc[full]'`."
40
+ ) from exc
41
+
42
+ raise
43
+
44
+ globals()[name] = value
45
+ return value
46
+
47
+
48
+ def __dir__() -> list[str]:
49
+ return sorted(__all__)
50
+
51
+
52
+ if TYPE_CHECKING:
53
+ from tlmtc.api import predict_tlmtc as predict_tlmtc
54
+ from tlmtc.api import train_tlmtc as train_tlmtc
tlmtc/__main__.py ADDED
@@ -0,0 +1,6 @@
1
+ """Module execution entrypoint for the tlmtc CLI."""
2
+
3
+ from tlmtc.cli import app
4
+
5
+ if __name__ == "__main__":
6
+ app()
tlmtc/api.py ADDED
@@ -0,0 +1,455 @@
1
+ """Public Python API for running tlmtc training and prediction workflows."""
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from tlmtc.data_pipeline import DataPipeline
8
+ from tlmtc.data_preparation import create_prediction_dataset, read_prediction_csv, tokenize_prediction_dataset
9
+ from tlmtc.evaluation_pipeline import EvaluationPipeline
10
+ from tlmtc.finetune_pipeline import FinetunePipeline
11
+ from tlmtc.meta import TrainRunMeta, read_run_meta, write_run_meta
12
+ from tlmtc.paths import PredictionPaths, RunPaths, resolve_paths, resolve_prediction_paths
13
+ from tlmtc.prediction import (
14
+ apply_thresholds,
15
+ load_prediction_model,
16
+ make_prediction_frame,
17
+ predict_probabilities,
18
+ )
19
+ from tlmtc.runtime_output import configure_runtime_output, emit_progress
20
+ from tlmtc.settings import UNSET, PredictionSettings, RunSettings, Unset, load_config_file
21
+
22
+
23
+ @dataclass(frozen=True, slots=True)
24
+ class TrainResult:
25
+ """Result metadata for a completed tlmtc training run.
26
+
27
+ Attributes:
28
+ paths: Resolved filesystem layout containing input paths and generated run artifacts.
29
+ """
30
+
31
+ paths: RunPaths
32
+
33
+
34
+ @dataclass(frozen=True, slots=True)
35
+ class PredictResult:
36
+ """Result metadata for a completed tlmtc prediction run.
37
+
38
+ Attributes:
39
+ paths: Resolved filesystem layout containing prediction inputs and generated artifacts.
40
+ """
41
+
42
+ paths: PredictionPaths
43
+
44
+
45
+ def train_tlmtc(
46
+ raw_csv: str | Path,
47
+ *,
48
+ raw_test_csv: str | Path | Unset = UNSET,
49
+ work_dir: str | Path | Unset = UNSET,
50
+ config_path: str | Path | Unset = UNSET,
51
+ run_id: str | None | Unset = UNSET,
52
+ target_name: str | Unset = UNSET,
53
+ validation_size: float | Unset = UNSET,
54
+ test_size: float | Unset = UNSET,
55
+ random_seed: int | Unset = UNSET,
56
+ transfer_learning: bool | Unset = UNSET,
57
+ hyperparameter_tuning: bool | Unset = UNSET,
58
+ threshold_optimization: bool | Unset = UNSET,
59
+ threshold_type: str | Unset = UNSET,
60
+ scale_learning_rate: bool | Unset = UNSET,
61
+ wrap_peft: bool | Unset = UNSET,
62
+ proxy_checkpoint: str | Unset = UNSET,
63
+ checkpoint: str | Unset = UNSET,
64
+ sequence_length: int | Unset = UNSET,
65
+ best_model_metric: str | Unset = UNSET,
66
+ batch_size: int | Unset = UNSET,
67
+ train_epochs: int | Unset = UNSET,
68
+ learning_rate: float | Unset = UNSET,
69
+ weight_decay: float | Unset = UNSET,
70
+ lr_scheduler: str | Unset = UNSET,
71
+ best_threshold_metric: str | Unset = UNSET,
72
+ tuning_trials: int | Unset = UNSET,
73
+ optuna_space: dict[str, Any] | Unset = UNSET,
74
+ lora_r: int | Unset = UNSET,
75
+ lora_alpha: int | Unset = UNSET,
76
+ lora_dropout: float | Unset = UNSET,
77
+ lora_bias: str | Unset = UNSET,
78
+ early_stopping_patience: int | Unset = UNSET,
79
+ use_cpu: bool | Unset = UNSET,
80
+ verbosity: str | Unset = UNSET,
81
+ ) -> TrainResult:
82
+ """Run the full multi-label text classification training workflow.
83
+
84
+ The workflow can perform data preparation, hyperparameter tuning, model fine-tuning,
85
+ threshold optimization, evaluation, and reporting end-to-end according to the selected
86
+ workflow flags.
87
+
88
+ Args:
89
+ raw_csv: Path to the raw multi-label training CSV. The file must contain a `text` column,
90
+ at least two binary `label_*` columns, and optionally a `text_pair` column.
91
+ raw_test_csv: Path to a separate raw test CSV. If omitted, a test split is created
92
+ from `raw_csv` using `test_size`. Defaults to no separate test CSV.
93
+ work_dir: Base directory for resolving inputs and writing run artifacts. Defaults to the
94
+ current working directory.
95
+ config_path: Path to a YAML configuration file. Defaults to no configuration file.
96
+ run_id: Run identifier used to name the run directory. If omitted, a random
97
+ identifier is generated.
98
+ target_name: Display name for the classification target in logs and reports. Defaults to
99
+ `"Target"`.
100
+ validation_size: Fraction reserved for validation splitting. Defaults to `0.15`.
101
+ test_size: Fraction reserved for test splitting when `raw_test_csv` is omitted. Defaults to
102
+ `0.15`.
103
+ random_seed: Random seed used for reproducible splitting and shuffling. Defaults to `2469`.
104
+ transfer_learning: Whether to fine-tune the target checkpoint and produce model/evaluation
105
+ artifacts. If `False`, data preparation still runs; with `hyperparameter_tuning=True`,
106
+ tlmtc runs proxy-checkpoint hyperparameter tuning only. Defaults to `True`.
107
+ hyperparameter_tuning: Whether to evaluate candidate hyperparameter configurations with
108
+ Optuna before final fine-tuning. If `True` and `transfer_learning=False`, only the
109
+ proxy-checkpoint tuning stage is run after data preparation. If both are `False`,
110
+ the workflow stops after data preparation. Defaults to `True`.
111
+ threshold_optimization: Whether to tune decision thresholds on validation-set predictions
112
+ after fine-tuning. If `False`, evaluation uses the default threshold `0.5`. Ignored
113
+ when `transfer_learning=False`. Defaults to `True`.
114
+ threshold_type: Thresholding mode. Supported values are `"global"` and `"label"`. Defaults to
115
+ `"label"`.
116
+ scale_learning_rate: Whether to scale a proxy-tuned learning rate for the target checkpoint.
117
+ Defaults to `False`.
118
+ wrap_peft: Whether to use parameter-efficient fine-tuning with LoRA adapters. Defaults to `True`.
119
+ proxy_checkpoint: Compatible encoder-only Hugging Face checkpoint identifier used during
120
+ hyperparameter tuning. Defaults to `"EuroBERT/EuroBERT-210m"`. If `checkpoint`
121
+ is supplied and `proxy_checkpoint` is omitted, the proxy checkpoint defaults to the
122
+ selected `checkpoint`. Loaded with `trust_remote_code=False`; checkpoints that require
123
+ custom remote code are not supported. Only use checkpoints you trust.
124
+ checkpoint: Compatible encoder-only Hugging Face checkpoint identifier or local path used for
125
+ final fine-tuning. Defaults to `"EuroBERT/EuroBERT-610m"`. Loaded with `trust_remote_code=False`;
126
+ checkpoints that require custom remote code are not supported. Only use checkpoints and local model
127
+ directories you trust.
128
+ sequence_length: Maximum tokenized sequence length. Defaults to `128`.
129
+ best_model_metric: Metric used to select the best model checkpoint. Supported values are
130
+ `"f1_micro"`, `"f1_macro"`, `"roc_auc_micro"`, and `"roc_auc_macro"`. Defaults to
131
+ `"roc_auc_macro"`.
132
+ batch_size: Initial training and evaluation batch size. Used directly when hyperparameter tuning is
133
+ disabled, otherwise replaced by the tuned value. Defaults to `16`.
134
+ train_epochs: Initial number of training epochs. Used directly when hyperparameter tuning is
135
+ disabled, otherwise replaced by the tuned value. Defaults to `20`.
136
+ learning_rate: Initial optimizer learning rate. Used directly when hyperparameter tuning is
137
+ disabled, otherwise replaced by the tuned value. Defaults to `2e-5`.
138
+ weight_decay: Initial weight decay for training. Used directly when hyperparameter tuning is
139
+ disabled, otherwise replaced by the tuned value. Defaults to `0.01`.
140
+ lr_scheduler: Initial learning-rate scheduler name. Used directly when hyperparameter tuning is
141
+ disabled, otherwise replaced by the tuned value. Defaults to `"linear"`.
142
+ best_threshold_metric: Metric used to select decision thresholds. Supported values are
143
+ `"f1_micro"` and `"f1_macro"`. Defaults to `"f1_macro"`.
144
+ tuning_trials: Number of hyperparameter configurations to evaluate during Optuna tuning. Higher
145
+ values may improve the selected configuration but increase runtime. Defaults to `10`.
146
+ optuna_space: Optional partial override for the hyperparameter tuning ranges and candidate
147
+ values. Supported keys are `lr_low`, `lr_high`, `batch_sizes`, `wd_low`, `wd_high`,
148
+ `schedulers`, `epoch_low`, `epoch_high`. Missing keys are filled from the default tuning space
149
+ selected by `wrap_peft`.
150
+
151
+ Defaults to the PEFT search space when `wrap_peft=True`:
152
+
153
+ {
154
+ "lr_low": 5e-5,
155
+ "lr_high": 4e-4,
156
+ "batch_sizes": [8, 16, 32],
157
+ "wd_low": 0.0,
158
+ "wd_high": 0.01,
159
+ "schedulers": ["linear", "cosine"],
160
+ "epoch_low": 5,
161
+ "epoch_high": 20,
162
+ "lr_reference_batch_size": 32,
163
+ }
164
+
165
+ Defaults to the full fine-tuning search space when `wrap_peft=False`:
166
+
167
+ {
168
+ "lr_low": 1e-5,
169
+ "lr_high": 8e-5,
170
+ "batch_sizes": [8, 16, 32],
171
+ "wd_low": 0.0,
172
+ "wd_high": 0.1,
173
+ "schedulers": ["linear", "cosine", "polynomial"],
174
+ "epoch_low": 5,
175
+ "epoch_high": 30,
176
+ "lr_reference_batch_size": 32,
177
+ }
178
+ lora_r: LoRA rank. Defaults to `8`.
179
+ lora_alpha: LoRA scaling factor. Defaults to `32`.
180
+ lora_dropout: LoRA dropout probability. Defaults to `0.1`.
181
+ lora_bias: LoRA bias handling mode. Supported values are `"none"`, `"all"`, and `"lora_only"`.
182
+ Defaults to `"none"`.
183
+ early_stopping_patience: Early stopping patience in epochs without improvement. Defaults to
184
+ `10`.
185
+ use_cpu: Whether to force CPU execution. Defaults to `False`.
186
+ verbosity: Runtime output mode. Supported values are `"progress"` and `"quiet"`. Defaults to
187
+ `"progress"`.
188
+
189
+ Returns:
190
+ Result metadata containing the resolved input and artifact paths.
191
+ """
192
+ settings = RunSettings.resolve(
193
+ config=load_config_file(config_path) if isinstance(config_path, (str, Path)) else None,
194
+ env=None,
195
+ overrides={
196
+ "raw_csv": raw_csv,
197
+ "raw_test_csv": raw_test_csv,
198
+ "work_dir": work_dir,
199
+ "run_id": run_id,
200
+ "model": {
201
+ "target_name": target_name,
202
+ "proxy_checkpoint": proxy_checkpoint,
203
+ "checkpoint": checkpoint,
204
+ "sequence_length": sequence_length,
205
+ },
206
+ "split": {
207
+ "validation_size": validation_size,
208
+ "test_size": test_size,
209
+ "random_seed": random_seed,
210
+ },
211
+ "workflow": {
212
+ "hyperparameter_tuning": hyperparameter_tuning,
213
+ "threshold_optimization": threshold_optimization,
214
+ "transfer_learning": transfer_learning,
215
+ "scale_learning_rate": scale_learning_rate,
216
+ "wrap_peft": wrap_peft,
217
+ },
218
+ "training": {
219
+ "batch_size": batch_size,
220
+ "train_epochs": train_epochs,
221
+ "weight_decay": weight_decay,
222
+ "learning_rate": learning_rate,
223
+ "lr_scheduler": lr_scheduler,
224
+ "best_model_metric": best_model_metric,
225
+ "early_stopping_patience": early_stopping_patience,
226
+ },
227
+ "threshold": {
228
+ "threshold_type": threshold_type,
229
+ "best_threshold_metric": best_threshold_metric,
230
+ },
231
+ "hpo": {
232
+ "tuning_trials": tuning_trials,
233
+ "optuna_space": optuna_space,
234
+ },
235
+ "peft": {
236
+ "lora_r": lora_r,
237
+ "lora_alpha": lora_alpha,
238
+ "lora_dropout": lora_dropout,
239
+ "lora_bias": lora_bias,
240
+ },
241
+ "hardware": {
242
+ "use_cpu": use_cpu,
243
+ },
244
+ "runtime": {
245
+ "verbosity": verbosity,
246
+ },
247
+ },
248
+ )
249
+
250
+ configure_runtime_output(settings.runtime.verbosity)
251
+ emit_progress("Starting training run")
252
+
253
+ paths = resolve_paths(
254
+ raw_csv=settings.raw_csv,
255
+ raw_test_csv=settings.raw_test_csv,
256
+ work_dir=settings.work_dir,
257
+ run_id=settings.run_id,
258
+ ).ensure_dirs()
259
+
260
+ data_pipeline = DataPipeline(
261
+ paths=paths,
262
+ split=settings.split,
263
+ model=settings.model,
264
+ )
265
+ data_pipeline.split_data()
266
+ data_pipeline.get_multi_hot_vectors()
267
+ data_pipeline.create_hf_dataset()
268
+ data_pipeline.tokenize_data()
269
+
270
+ finetune_pipeline = FinetunePipeline(
271
+ tokenized_dataset=data_pipeline.tokenized_dataset,
272
+ paths=paths,
273
+ model=settings.model,
274
+ workflow=settings.workflow,
275
+ peft=settings.peft,
276
+ training=settings.training,
277
+ hpo=settings.hpo,
278
+ threshold=settings.threshold,
279
+ hardware=settings.hardware,
280
+ )
281
+ finetune_pipeline.load_pretrained()
282
+ finetune_pipeline.tune_hyperparameters()
283
+ finetune_pipeline.fine_tune_pretrained()
284
+ finetune_pipeline.tune_thresholds()
285
+ finetune_pipeline.save_pretrained()
286
+
287
+ evaluation_pipeline = EvaluationPipeline(
288
+ tokenized_dataset=data_pipeline.tokenized_dataset,
289
+ updated_trainer=finetune_pipeline.updated_trainer,
290
+ paths=paths,
291
+ model=settings.model,
292
+ workflow=settings.workflow,
293
+ training=settings.training,
294
+ tuned_threshold=finetune_pipeline.tuned_threshold,
295
+ input_mode=data_pipeline.input_mode,
296
+ )
297
+ evaluation_pipeline.run_evaluation()
298
+ evaluation_pipeline.save_metrics()
299
+ evaluation_pipeline.render_tables()
300
+ evaluation_pipeline.render_figures()
301
+
302
+ write_run_meta(
303
+ meta=TrainRunMeta(
304
+ run_id=settings.run_id,
305
+ target_name=settings.model.target_name,
306
+ checkpoint=settings.model.checkpoint,
307
+ proxy_checkpoint=settings.model.proxy_checkpoint,
308
+ sequence_length=settings.model.sequence_length,
309
+ input_mode=data_pipeline.input_mode,
310
+ label_names=evaluation_pipeline.label_names,
311
+ threshold_type=settings.threshold.threshold_type,
312
+ thresholds=finetune_pipeline.tuned_threshold.tolist(),
313
+ transfer_learning=settings.workflow.transfer_learning,
314
+ hyperparameter_tuning=settings.workflow.hyperparameter_tuning,
315
+ threshold_optimization=settings.workflow.threshold_optimization,
316
+ scale_learning_rate=settings.workflow.scale_learning_rate,
317
+ wrap_peft=settings.workflow.wrap_peft,
318
+ ),
319
+ path=paths.train_run_meta_path,
320
+ )
321
+
322
+ emit_progress("Training run complete")
323
+ return TrainResult(paths=paths)
324
+
325
+
326
+ def predict_tlmtc(
327
+ prediction_csv: str | Path,
328
+ *,
329
+ work_dir: str | Path | Unset = UNSET,
330
+ config_path: str | Path | Unset = UNSET,
331
+ run_id: str | None | Unset = UNSET,
332
+ batch_size: int | Unset = UNSET,
333
+ use_cpu: bool | Unset = UNSET,
334
+ verbosity: str | Unset = UNSET,
335
+ ) -> PredictResult:
336
+ """Run the multi-label text classification prediction workflow.
337
+
338
+ Prediction consumes persisted metadata and model artifacts from a completed
339
+ training run, applies the persisted decision thresholds, and writes probability
340
+ and binary prediction artifacts.
341
+
342
+ Args:
343
+ prediction_csv: Path to the unlabeled prediction CSV. The file must contain a `text`
344
+ column and, for models trained with paired-text inputs, a `text_pair` column.
345
+ Prediction artifacts preserve input text columns unchanged.
346
+ work_dir: Base directory for resolving inputs, reading training artifacts, and writing
347
+ prediction artifacts. Defaults to the current working directory.
348
+ config_path: Path to a YAML configuration file. Defaults to no configuration file.
349
+ run_id: Run identifier used to select the completed training run. If omitted, the latest
350
+ completed training run is selected from persisted training metadata. Prediction reloads
351
+ the trained model or adapter artifacts for this run with `trust_remote_code=False`;
352
+ artifacts that require custom remote code are not supported. Only use saved model
353
+ artifacts and adapters you trust.
354
+ batch_size: Prediction batch size used for batched inference. Defaults to `32`.
355
+ use_cpu: Whether to force CPU execution. Defaults to `False`.
356
+ verbosity: Runtime output mode. Supported values are `"progress"` and `"quiet"`. Defaults to
357
+ `"progress"`.
358
+
359
+ Returns:
360
+ Result metadata containing the resolved input and artifact paths.
361
+ """
362
+ settings = PredictionSettings.resolve(
363
+ config=load_config_file(config_path) if isinstance(config_path, (str, Path)) else None,
364
+ env=None,
365
+ overrides={
366
+ "prediction_csv": prediction_csv,
367
+ "work_dir": work_dir,
368
+ "run_id": run_id,
369
+ "batch_size": batch_size,
370
+ "hardware": {
371
+ "use_cpu": use_cpu,
372
+ },
373
+ "runtime": {
374
+ "verbosity": verbosity,
375
+ },
376
+ },
377
+ )
378
+
379
+ configure_runtime_output(settings.runtime.verbosity)
380
+ emit_progress("Starting prediction run")
381
+
382
+ paths = resolve_prediction_paths(
383
+ input_csv=settings.prediction_csv,
384
+ work_dir=settings.work_dir,
385
+ run_id=settings.run_id,
386
+ ).ensure_dirs()
387
+
388
+ emit_progress("Reading training metadata")
389
+
390
+ meta = read_run_meta(paths.train_run_meta_path)
391
+
392
+ if not meta.transfer_learning:
393
+ raise RuntimeError(
394
+ "Prediction requires a training run with transfer_learning=True. "
395
+ f"Run '{meta.run_id}' did not persist a fine-tuned prediction model."
396
+ )
397
+
398
+ input_mode = meta.input_mode
399
+ label_names = meta.label_names
400
+
401
+ assert input_mode is not None
402
+ assert label_names is not None
403
+
404
+ emit_progress("Reading prediction inputs")
405
+
406
+ input_df = read_prediction_csv(
407
+ df_path=paths.input_data_path,
408
+ expected_input_mode=input_mode,
409
+ )
410
+ prediction_dataset = create_prediction_dataset(
411
+ df=input_df,
412
+ input_mode=input_mode,
413
+ )
414
+ emit_progress("Tokenizing prediction inputs")
415
+ tokenized_dataset = tokenize_prediction_dataset(
416
+ dataset=prediction_dataset,
417
+ checkpoint=meta.checkpoint,
418
+ input_mode=input_mode,
419
+ sequence_length=meta.sequence_length,
420
+ )
421
+ emit_progress("Loading fine-tuned prediction model")
422
+ model = load_prediction_model(
423
+ model_dir=paths.train_run_model_dir,
424
+ checkpoint=meta.checkpoint,
425
+ num_labels=len(label_names),
426
+ wrap_peft=meta.wrap_peft,
427
+ )
428
+ emit_progress("Running prediction")
429
+ probabilities = predict_probabilities(
430
+ model=model,
431
+ dataset=tokenized_dataset,
432
+ batch_size=settings.batch_size,
433
+ use_cpu=settings.hardware.use_cpu,
434
+ )
435
+ probability_df = make_prediction_frame(
436
+ input_df=input_df,
437
+ values=probabilities,
438
+ label_names=label_names,
439
+ )
440
+ predictions = apply_thresholds(
441
+ probabilities=probabilities,
442
+ thresholds=meta.thresholds,
443
+ )
444
+ prediction_df = make_prediction_frame(
445
+ input_df=input_df,
446
+ values=predictions,
447
+ label_names=label_names,
448
+ )
449
+
450
+ emit_progress("Writing prediction artifacts")
451
+ probability_df.to_csv(paths.probabilities_path, index=False)
452
+ prediction_df.to_csv(paths.predictions_path, index=False)
453
+
454
+ emit_progress("Prediction run complete")
455
+ return PredictResult(paths=paths)