fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,143 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import logging
4
+ import lightning as L
5
+ from typing_extensions import override
6
+
7
+ from fkat.pytorch.schedule import Schedule
8
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
9
+
10
+ logger: logging.Logger = logging.getLogger(__name__)
11
+
12
+
13
+ class HasTag(Schedule):
14
+ """
15
+ A schedule that activates only when a specific MLflow tag is present AND the trigger schedule is satisfied.
16
+
17
+ This schedule combines another schedule (the trigger schedule) with MLflow tag validation.
18
+ It allows callbacks to be dynamically enabled or disabled through experiment configuration rather than
19
+ code changes, which is particularly useful for performance-intensive callbacks like FLOP measurement
20
+ or detailed logging that should only run conditionally.
21
+
22
+ The schedule checks two conditions:
23
+ 1. If the trigger schedule is satisfied
24
+ 2. If the specified MLflow tag exists in the current MLflow run
25
+
26
+ Both conditions must be true for the schedule to activate. If the trigger schedule doesn't activate or
27
+ the trainer is not provided, the schedule will never activate.
28
+
29
+ Note:
30
+ - The trainer can be optionally provided to the ``check`` method for MLflow tag validation. If trainer is None,
31
+ the schedule will never activate.
32
+ - Tag checking occurs only when the trigger schedule condition is already satisfied,
33
+ minimizing MLflow API calls.
34
+ - If an exception occurs during tag checking, it will be logged and the schedule will not activate.
35
+
36
+ Example:
37
+ Python code example::
38
+
39
+ # Create a schedule that checks every 5 batches if the 'enable_flops' tag exists
40
+ from fkat.pytorch.schedule import Every
41
+
42
+ trigger = Every(n_batches=5)
43
+ flops_schedule = HasTag(tag="enable_flops", schedule=trigger)
44
+ flops_callback = Flops(schedule=flops_schedule)
45
+ trainer = L.Trainer(callbacks=[flops_callback])
46
+
47
+ Hydra configuration example:
48
+
49
+ .. code-block:: yaml
50
+
51
+ # In your config.yaml file
52
+ callbacks:
53
+ - _target_: fkat.pytorch.callbacks.profiling.Flops
54
+ schedule:
55
+ _target_: fkat.pytorch.schedule.mlflow.HasTag
56
+ tag: ENABLE_FLOPS
57
+ schedule:
58
+ _target_: fkat.pytorch.schedule.Every
59
+ n_steps: 20
60
+
61
+ # Another example using Fixed schedule
62
+ callbacks:
63
+ - _target_: fkat.pytorch.callbacks.heartbeat.Heartbeat
64
+ schedule:
65
+ _target_: fkat.pytorch.schedule.mlflow.HasTag
66
+ tag: ENABLE_HEARTBEAT
67
+ schedule:
68
+ _target_: fkat.pytorch.schedule.Fixed
69
+ warmup_steps: 100
70
+ active_steps: 1000
71
+
72
+ # Example with Elapsed time-based schedule
73
+ callbacks:
74
+ - _target_: fkat.pytorch.callbacks.custom_logging.DetailedMetrics
75
+ schedule:
76
+ _target_: fkat.pytorch.schedule.mlflow.HasTag
77
+ tag: DETAILED_LOGGING
78
+ schedule:
79
+ _target_: fkat.pytorch.schedule.Elapsed
80
+ interval: ${timedelta:minutes=15}
81
+ """
82
+
83
+ def __init__(self, tag: str, schedule: Schedule) -> None:
84
+ """
85
+ Initialize a new MLflow HasTag schedule.
86
+
87
+ Args:
88
+ tag (str): The name of the tag that must be present in the MLflow run.
89
+ schedule (Schedule): The schedule that determines when to check for the tag.
90
+ This can be any implementation of the Schedule protocol (e.g., Every, Fixed, Elapsed).
91
+ """
92
+ self._tag: str = tag
93
+ self._schedule: Schedule = schedule
94
+ self._cb_logger: CallbackLogger | None = None
95
+
96
+ @override
97
+ def check(
98
+ self,
99
+ *,
100
+ stage: str | None = None,
101
+ batch_idx: int | None = None,
102
+ step: int | None = None,
103
+ trainer: L.Trainer | None = None,
104
+ ) -> bool:
105
+ """
106
+ Check if the schedule should activate based on the trigger schedule and MLflow tag presence.
107
+
108
+ This method first checks if the trigger schedule is satisfied.
109
+ If this condition is met, it then checks if the specified MLflow tag exists in the current run.
110
+ Both conditions must be true for the method to return True.
111
+
112
+ Args:
113
+ stage (str, optional): Current training stage (e.g., "train", "validate", "test").
114
+ Passed to the trigger schedule.
115
+ batch_idx (int, optional): Current batch index within the epoch.
116
+ Passed to the trigger schedule.
117
+ step (int, optional): Current global step (cumulative across epochs).
118
+ Passed to the trigger schedule.
119
+ trainer (L.Trainer, optional): The Lightning Trainer instance.
120
+ Required for MLflow tag validation.
121
+
122
+ Returns:
123
+ bool: True if both the trigger schedule is satisfied AND the specified tag is present,
124
+ False otherwise.
125
+
126
+ Note:
127
+ - The trainer must be provided for MLflow tag validation.
128
+ - Tag checking occurs only when the trigger schedule is already satisfied.
129
+ - If an exception occurs during tag checking, it will be logged as a warning
130
+ and the method will return False.
131
+ """
132
+ triggered = self._schedule.check(stage=stage, batch_idx=batch_idx, step=step, trainer=trainer)
133
+ if not triggered or trainer is None:
134
+ return False
135
+ try:
136
+ if self._cb_logger is None:
137
+ self._cb_logger = CallbackLogger(trainer)
138
+
139
+ tags = self._cb_logger.tags()
140
+ return self._tag in tags
141
+ except Exception as e:
142
+ logger.warning(f"Error when checking if tag {self._tag} exists: {e}")
143
+ return False
@@ -0,0 +1,49 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ from functools import wraps
5
+ from typing import TypeVar, ParamSpec
6
+ from collections.abc import Callable
7
+ from typing_extensions import overload
8
+
9
+
10
+ def get_rank() -> int | None:
11
+ return int(os.getenv("RANK", "0"))
12
+
13
+
14
+ def get_local_rank() -> int | None:
15
+ return int(os.getenv("LOCAL_RANK", "0"))
16
+
17
+
18
+ T = TypeVar("T")
19
+ P = ParamSpec("P")
20
+
21
+
22
+ @overload
23
+ def local_rank_zero_only(fn: Callable[P, T]) -> Callable[P, T | None]: ...
24
+
25
+
26
+ @overload
27
+ def local_rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ...
28
+
29
+
30
+ def local_rank_zero_only(fn: Callable[P, T], default: T | None = None) -> Callable[P, T | None]:
31
+ """Wrap a function to call internal function only in rank zero.
32
+
33
+ Function that can be used as a decorator to enable a function/method being called only on global rank 0.
34
+
35
+ """
36
+
37
+ @wraps(fn)
38
+ def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> T | None:
39
+ local_rank = getattr(local_rank_zero_only, "local_rank", None)
40
+ if local_rank is None:
41
+ raise RuntimeError("The `local_rank_zero_only.local_rank` needs to be set before use")
42
+ if local_rank == 0:
43
+ return fn(*args, **kwargs)
44
+ return default
45
+
46
+ return wrapped_fn
47
+
48
+
49
+ local_rank_zero_only.local_rank = getattr(local_rank_zero_only, "local_rank", get_local_rank() or 0) # type: ignore[attr-defined]
fkat/test.py ADDED
@@ -0,0 +1,31 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #!/usr/bin/env python
4
+
5
+ """
6
+ The ``fkat.test`` entrypoint processes the provided config,
7
+ instatiates the ``trainer``, ``model`` and ``data`` sections and calls ``trainer.test()``.
8
+ """
9
+
10
+ import hydra
11
+ import lightning as L
12
+ from omegaconf import DictConfig
13
+
14
+ from fkat import initialize, run_main
15
+
16
+
17
+ @hydra.main(version_base="1.3")
18
+ def main(cfg: DictConfig) -> None:
19
+ s = initialize(cfg)
20
+ kwargs = {
21
+ "ckpt_path": s.ckpt_path,
22
+ }
23
+ if isinstance(s.data, L.LightningDataModule):
24
+ kwargs["datamodule"] = s.data
25
+ else:
26
+ kwargs["test_dataloaders"] = s.data.test_dataloader() if s.data else None
27
+ s.trainer.test(s.model, **kwargs)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ run_main(main)
fkat/train.py ADDED
@@ -0,0 +1,32 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #!/usr/bin/env python
4
+
5
+ """
6
+ The ``fkat.train`` entrypoint processes the provided config,
7
+ instatiates the ``trainer``, ``model`` and ``data`` sections and calls ``trainer.fit()``.
8
+ """
9
+
10
+ import hydra
11
+ import lightning as L
12
+ from omegaconf import DictConfig
13
+
14
+ from fkat import initialize, run_main
15
+
16
+
17
+ @hydra.main(version_base="1.3")
18
+ def main(cfg: DictConfig) -> None:
19
+ s = initialize(cfg)
20
+ kwargs = {
21
+ "ckpt_path": s.ckpt_path,
22
+ }
23
+ if isinstance(s.data, L.LightningDataModule):
24
+ kwargs["datamodule"] = s.data
25
+ else:
26
+ kwargs["train_dataloaders"] = s.data.train_dataloader() if s.data else None
27
+ kwargs["val_dataloaders"] = s.data.val_dataloader() if s.data else None
28
+ s.trainer.fit(s.model, **kwargs)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ run_main(main)
fkat/utils/__init__.py ADDED
@@ -0,0 +1,28 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from datetime import datetime, timezone
4
+ from typing import TypeVar
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ def assert_not_none(obj: T | None, name: str = "obj") -> T:
10
+ assert obj is not None, f"{name} cannot be None"
11
+ return obj
12
+
13
+
14
+ def safe_timestamp(dt: datetime | None = None) -> str:
15
+ """
16
+ Generate a filesystem-safe timestamp string.
17
+
18
+ Format: YYYY-MM-DD_HH-MM-SS-mmm (e.g., 2026-01-14_16-09-15-123)
19
+
20
+ Args:
21
+ dt: datetime object to format. If None, uses current UTC time.
22
+
23
+ Returns:
24
+ Formatted timestamp string safe for use in filenames
25
+ """
26
+ if dt is None:
27
+ dt = datetime.now(timezone.utc)
28
+ return dt.strftime("%Y-%m-%d_%H-%M-%S-") + f"{dt.microsecond // 1000:03d}"
@@ -0,0 +1,3 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
fkat/utils/aws/imds.py ADDED
@@ -0,0 +1,137 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """
4
+ EC2 Instance Metadata Service related methods
5
+ """
6
+
7
+ import logging
8
+ import socket
9
+ from dataclasses import dataclass
10
+ from functools import lru_cache
11
+
12
+ import requests
13
+ from requests import HTTPError
14
+
15
+ from fkat.utils.logging import rank0_logger
16
+
17
+ IMDS_URL = "http://169.254.169.254/latest"
18
+ IMDS_METADATA_URL = f"{IMDS_URL}/meta-data"
19
+ IMDS_V2_TOKEN_URL = f"{IMDS_URL}/api/token"
20
+ NULL = "_NULL_" # sentinel value used to mark null (not-available) values
21
+
22
+
23
+ log: logging.Logger = rank0_logger(__name__)
24
+
25
+ Token = str
26
+
27
+
28
+ @dataclass
29
+ class InstanceMetadata:
30
+ """
31
+ Struct representing the instance metadata as fetched from IMDS on the current host.
32
+ Use :py:func:`fkat.utils.aws.imds.instance_metadata` to get a filled-out
33
+ instance of this object.
34
+ """
35
+
36
+ instance_id: str
37
+ instance_type: str
38
+ hostname: str
39
+ public_hostname: str
40
+ local_hostname: str
41
+ local_ipv4: str
42
+ availability_zone: str
43
+ region: str
44
+ ami_id: str
45
+
46
+
47
+ @lru_cache
48
+ def fetch(metadata: str = "", token: Token | None = None) -> str | None:
49
+ """
50
+ Fetches the specified ``metadata`` from EC2's Instance MetaData Service (IMDS) running
51
+ on the current host, by sending an HTTP GET request to ``http://169.254.169.254/latest/meta-data/<metadata>``.
52
+
53
+ To get a list of all valid values of ``metadata`` run this method with no arguments then split
54
+ the return value by new-line.
55
+
56
+
57
+ Arguments:
58
+ metadata: Name of the instance metadata to query (e.g. ``instance-type``)
59
+ token: IMDS token
60
+
61
+ Returns:
62
+ the specified ``metadata`` or ``None`` if IMDS cannot be reached
63
+ """
64
+
65
+ try:
66
+ response = requests.get(f"{IMDS_METADATA_URL}/{metadata}", headers={"X-aws-ec2-metadata-token": token or ""})
67
+ except Exception as e:
68
+ log.warning("Error querying IMDSV2 instance metadata won't be available", exc_info=e)
69
+ return None
70
+
71
+ if response.ok:
72
+ return response.text
73
+ else: # response NOT ok
74
+ try:
75
+ response.raise_for_status()
76
+ except HTTPError:
77
+ return None
78
+
79
+ raise AssertionError("Unreachable code!")
80
+
81
+
82
+ def token(timeout: int = 60) -> Token | None:
83
+ """
84
+ Fetches IMDS ``token`` from EC2's Instance MetaData Service (IMDSV2) running
85
+ on the current host, by sending an HTTP GET request to ``http://169.254.169.254/latest/meta-data/<metadata>``.
86
+
87
+ Arguments:
88
+ timeout: request timeout
89
+
90
+ Returns:
91
+ the specified ``token`` or ``""`` if IMDSV2 cannot be reached
92
+ """
93
+
94
+ try:
95
+ response = requests.put(
96
+ IMDS_V2_TOKEN_URL, headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, timeout=timeout
97
+ )
98
+ except Exception as e:
99
+ log.warning("Error querying IMDSV2 instance token won't be available", exc_info=e)
100
+ return ""
101
+
102
+ if response.ok:
103
+ return response.text
104
+ else: # response NOT ok
105
+ error_code_msg = f"IMDS Token response is not ok with status code {response.status_code}"
106
+ log.warning("Error querying IMDSV2 instance token won't be available", exc_info=Exception(error_code_msg))
107
+ return ""
108
+
109
+ raise AssertionError("Unreachable code!")
110
+
111
+
112
+ @lru_cache
113
+ def instance_metadata() -> InstanceMetadata:
114
+ """
115
+ Fetches IMDS instance metadata for the current host from EC2's Instance Metadata Service (IMDS),
116
+ which typically runs on localhost at ``http://169.254.169.254``.
117
+ If IMDS cannot be reached for any reason returns an instance of :py:class:`InstanceMetadata`
118
+ where all the fields are empty strings.
119
+
120
+ .. note::
121
+ This method is memoized (value is cached) hence, only the first call
122
+ will actually hit IMDS, and subsequent calls will return the memoized
123
+ value. Therefore, it is ok to call this function multiple times.
124
+
125
+ """
126
+ tkn = token()
127
+ return InstanceMetadata(
128
+ instance_id=fetch("instance-id", tkn) or "localhost",
129
+ instance_type=fetch("instance-type", tkn) or NULL,
130
+ availability_zone=fetch("placement/availability-zone", tkn) or NULL,
131
+ region=fetch("placement/region", tkn) or NULL,
132
+ hostname=fetch("hostname", tkn) or socket.gethostname(),
133
+ local_ipv4=fetch("local-ipv4", tkn) or NULL,
134
+ public_hostname=fetch("public-hostname", tkn) or NULL,
135
+ local_hostname=fetch("local-hostname", tkn) or NULL,
136
+ ami_id=fetch("ami-id", tkn) or NULL,
137
+ )
fkat/utils/boto3.py ADDED
@@ -0,0 +1,24 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Literal
4
+
5
+ import boto3
6
+ from botocore.config import Config
7
+
8
+
9
+ def session(
10
+ max_attempts: int = 6,
11
+ mode: Literal["legacy", "standard", "adaptive"] = "standard",
12
+ clients: list[str] | None = None,
13
+ ) -> boto3.Session:
14
+ config = Config(
15
+ retries={
16
+ "max_attempts": max_attempts,
17
+ "mode": mode,
18
+ }
19
+ )
20
+ session = boto3.Session()
21
+ if clients:
22
+ for client in clients:
23
+ session.client(client, config=config) # type: ignore
24
+ return session
fkat/utils/config.py ADDED
@@ -0,0 +1,194 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ from typing import Any, Protocol
5
+ from tempfile import TemporaryDirectory
6
+
7
+ from lightning import Trainer as LightningTrainer
8
+ from lightning.pytorch.utilities import rank_zero_only
9
+ from omegaconf import OmegaConf
10
+
11
+ from fkat.utils.mlflow import mlflow_logger
12
+
13
+
14
+ class Trainer(Protocol):
15
+ """
16
+ Protocol defining the interface for training models.
17
+
18
+ This protocol establishes the required methods that any trainer implementation
19
+ must provide. It serves as a structural subtyping interface for objects that
20
+ can train, evaluate, and make predictions with machine learning models.
21
+
22
+ Implementations of this protocol should handle the complete training lifecycle,
23
+ including model fitting, prediction, testing, and validation.
24
+
25
+ Note:
26
+ As a Protocol class, this is not meant to be instantiated directly but
27
+ rather used for type checking and interface definition.
28
+ """
29
+
30
+ def fit(self, *args: Any, **kwargs: Any) -> Any:
31
+ """
32
+ Train a model on the provided data.
33
+
34
+ This method handles the model training process, including data loading,
35
+ optimization, and potentially checkpointing.
36
+
37
+ Args:
38
+ *args: Variable length argument list.
39
+ **kwargs: Arbitrary keyword arguments.
40
+
41
+ Returns:
42
+ Any: Training results, which may include training history,
43
+ trained model, or other relevant information.
44
+ """
45
+ ...
46
+
47
+ def predict(self, *args: Any, **kwargs: Any) -> Any:
48
+ """
49
+ Generate predictions using a trained model.
50
+
51
+ This method applies the trained model to new data to produce predictions.
52
+
53
+ Args:
54
+ *args: Variable length argument list.
55
+ **kwargs: Arbitrary keyword arguments.
56
+
57
+ Returns:
58
+ Any: Model predictions, which could be probabilities, class labels,
59
+ regression values, or other outputs depending on the model type.
60
+ """
61
+ ...
62
+
63
+ def test(self, *args: Any, **kwargs: Any) -> Any:
64
+ """
65
+ Evaluate a trained model on test data.
66
+
67
+ This method assesses model performance on a test dataset, typically
68
+ calculating metrics such as accuracy, loss, or other relevant measures.
69
+
70
+ Args:
71
+ *args: Variable length argument list.
72
+ **kwargs: Arbitrary keyword arguments.
73
+
74
+ Returns:
75
+ Any: Test results, which may include metrics, predictions,
76
+ or other evaluation information.
77
+ """
78
+ ...
79
+
80
+ def validate(self, *args: Any, **kwargs: Any) -> Any:
81
+ """
82
+ Evaluate a trained model on validation data.
83
+
84
+ This method assesses model performance on a validation dataset, which is
85
+ typically used during the training process for hyperparameter tuning
86
+ or early stopping.
87
+
88
+ Args:
89
+ *args: Variable length argument list.
90
+ **kwargs: Arbitrary keyword arguments.
91
+
92
+ Returns:
93
+ Any: Validation results, which may include metrics, predictions,
94
+ or other evaluation information.
95
+ """
96
+ ...
97
+
98
+
99
+ class SingletonResolver:
100
+ """
101
+ A singleton class that resolves and manages training components.
102
+
103
+ This class serves as a central registry for training-related objects such as
104
+ trainers, data, models, and checkpoint paths. It ensures that these components
105
+ are accessible throughout the application while maintaining a single instance.
106
+
107
+ Example:
108
+ >>> resolver = SingletonResolver()
109
+ >>> resolver.trainer = MyTrainer()
110
+ >>> resolver.model = MyModel()
111
+ >>> # Access the same instance elsewhere
112
+ >>> same_resolver = SingletonResolver()
113
+ >>> assert same_resolver.trainer is resolver.trainer
114
+ """
115
+
116
+ trainer: Trainer
117
+ """The trainer instance responsible for executing the training process."""
118
+
119
+ data: Any | None = None
120
+ """The dataset or data loader used for training and evaluation.
121
+ Defaults to None."""
122
+
123
+ model: Any | None = None
124
+ """The model architecture to be trained or evaluated.
125
+ Defaults to None."""
126
+
127
+ ckpt_path: Any | None = None
128
+ """Path to checkpoint files for model loading/saving.
129
+ Defaults to None."""
130
+
131
+ return_predictions: Any | None = None
132
+ """Flag or configuration for returning predictions.
133
+ Defaults to None."""
134
+
135
+ tuners: Any | None = None
136
+ """Hyperparameter tuners or optimization components.
137
+ Defaults to None."""
138
+
139
+
140
+ def register_singleton_resolver() -> Any:
141
+ resolver = SingletonResolver()
142
+
143
+ def resolve(key: str) -> Any:
144
+ res = resolver
145
+ for attr in key.split("."):
146
+ res = getattr(res, attr)
147
+ return res
148
+
149
+ OmegaConf.register_new_resolver("fkat", resolve)
150
+ return resolver
151
+
152
+
153
+ def to_str(cfg: Any) -> str:
154
+ """
155
+ Convert a configuration object to a formatted string representation.
156
+
157
+ This function takes a configuration object and converts it to a human-readable
158
+ YAML string. It's useful for logging, debugging, or displaying configuration settings.
159
+
160
+ Args:
161
+ cfg (Any): The configuration object to convert to string.
162
+
163
+ Returns:
164
+ str: A formatted string representation of the configuration.
165
+
166
+ Example:
167
+ >>> config = {"model": {"type": "resnet", "layers": 50}, "batch_size": 32}
168
+ >>> print(to_str(config))
169
+ Config:
170
+ model:
171
+ type: resnet
172
+ layers: 50
173
+ batch_size: 32
174
+ """
175
+ return "Config:\n" + OmegaConf.to_yaml(cfg)
176
+
177
+
178
+ def to_primitive_container(cfg: Any) -> Any:
179
+ if OmegaConf.is_config(cfg):
180
+ return OmegaConf.to_container(cfg)
181
+ return cfg
182
+
183
+
184
+ @rank_zero_only
185
+ def save(cfg: Any, trainer: LightningTrainer) -> None:
186
+ yaml_str = OmegaConf.to_yaml(cfg)
187
+ with TemporaryDirectory() as temp_dir:
188
+ yaml_path = os.path.join(temp_dir, "config.yaml")
189
+ os.makedirs(os.path.dirname(yaml_path), exist_ok=True)
190
+ with open(yaml_path, "w") as f:
191
+ f.write(yaml_str)
192
+ if mlflow := mlflow_logger(trainer):
193
+ if mlflow.run_id:
194
+ mlflow.experiment.log_artifact(mlflow.run_id, yaml_path)
@@ -0,0 +1,3 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
@@ -0,0 +1,3 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+