autosim 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
autosim-0.0.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 The Alan Turing Institute
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
autosim-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,80 @@
1
+ Metadata-Version: 2.3
2
+ Name: autosim
3
+ Version: 0.0.1
4
+ Summary: A package to generate simulation data easily
5
+ Author: AI for Physical Systems Team at The Alan Turing Institute
6
+ Author-email: AI for Physical Systems Team at The Alan Turing Institute <ai4physics@turing.ac.uk>
7
+ License: MIT License
8
+
9
+ Copyright (c) 2026 The Alan Turing Institute
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+ Requires-Dist: numpy>=1.24
29
+ Requires-Dist: scipy>=1.10
30
+ Requires-Dist: tqdm>=4.65
31
+ Requires-Dist: torch>=2.0
32
+ Requires-Dist: ipykernel>=7.1.0 ; extra == 'dev'
33
+ Requires-Dist: pytest>=9.0.1 ; extra == 'dev'
34
+ Requires-Dist: pytest-cov>=7.0.0 ; extra == 'dev'
35
+ Requires-Dist: ruff==0.12.11 ; extra == 'dev'
36
+ Requires-Dist: pyright==1.1.407 ; extra == 'dev'
37
+ Requires-Dist: pre-commit>=4.4.0 ; extra == 'dev'
38
+ Requires-Dist: matplotlib ; extra == 'dev'
39
+ Requires-Python: >=3.10, <3.13
40
+ Provides-Extra: dev
41
+ Description-Content-Type: text/markdown
42
+
43
+ # autosim
44
+ Lots of Simulations
45
+
46
+ ## Installation
47
+
48
+ This project uses [uv](https://docs.astral.sh/uv/) for dependency management.
49
+
50
+ ### Install uv
51
+
52
+ ```bash
53
+ curl -LsSf https://astral.sh/uv/install.sh | sh
54
+ ```
55
+
56
+ ### Install the package
57
+
58
+ ```bash
59
+ uv pip install -e .
60
+ ```
61
+
62
+ This installs `autosim` in editable mode along with its runtime dependencies:
63
+ - `numpy>=1.24`
64
+ - `scipy>=1.10`
65
+ - `tqdm>=4.65`
66
+ - `torch>=2.0`
67
+
68
+ ### Install development dependencies (includes pytest)
69
+
70
+ ```bash
71
+ uv sync --group dev
72
+ ```
73
+
74
+ ## Running tests
75
+
76
+ Once dev dependencies are installed:
77
+
78
+ ```bash
79
+ uv run pytest
80
+ ```
@@ -0,0 +1,38 @@
1
+ # autosim
2
+ Lots of Simulations
3
+
4
+ ## Installation
5
+
6
+ This project uses [uv](https://docs.astral.sh/uv/) for dependency management.
7
+
8
+ ### Install uv
9
+
10
+ ```bash
11
+ curl -LsSf https://astral.sh/uv/install.sh | sh
12
+ ```
13
+
14
+ ### Install the package
15
+
16
+ ```bash
17
+ uv pip install -e .
18
+ ```
19
+
20
+ This installs `autosim` in editable mode along with its runtime dependencies:
21
+ - `numpy>=1.24`
22
+ - `scipy>=1.10`
23
+ - `tqdm>=4.65`
24
+ - `torch>=2.0`
25
+
26
+ ### Install development dependencies (includes pytest)
27
+
28
+ ```bash
29
+ uv sync --group dev
30
+ ```
31
+
32
+ ## Running tests
33
+
34
+ Once dev dependencies are installed:
35
+
36
+ ```bash
37
+ uv run pytest
38
+ ```
@@ -0,0 +1,98 @@
1
+ [project]
2
+ name = "autosim"
3
+ version = "0.0.1"
4
+ description = "A package to generate simulation data easily"
5
+ readme = "README.md"
6
+ license = { file = "LICENSE" }
7
+ authors = [
8
+ { name = "AI for Physical Systems Team at The Alan Turing Institute", email = "ai4physics@turing.ac.uk" },
9
+ ]
10
+ requires-python = ">=3.10,<3.13"
11
+
12
+ dependencies = [
13
+ "numpy>=1.24",
14
+ "scipy>=1.10",
15
+ "tqdm>=4.65",
16
+ "torch>=2.0",
17
+ ]
18
+
19
+ [project.optional-dependencies]
20
+ dev = [
21
+ "ipykernel>=7.1.0",
22
+ "pytest>=9.0.1",
23
+ "pytest-cov>=7.0.0",
24
+ "ruff==0.12.11",
25
+ "pyright==1.1.407",
26
+ "pre-commit>=4.4.0",
27
+ "matplotlib"
28
+ ]
29
+
30
+ [build-system]
31
+ requires = ["uv_build"]
32
+ build-backend = "uv_build"
33
+
34
+ [tool.coverage.run]
35
+ relative_files = true
36
+ source = [".", "/tmp"]
37
+
38
+ [tool.pytest.ini_options]
39
+ addopts = "--ignore=v0"
40
+
41
+ [tool.pyright]
42
+ venvPath = "."
43
+ venv = ".venv"
44
+
45
+ [tool.ruff]
46
+ line-length = 88
47
+ target-version = "py310"
48
+
49
+ [tool.ruff.format]
50
+ docstring-code-format = true
51
+
52
+ [tool.ruff.lint]
53
+ select = [
54
+ "D", # docstring conventions
55
+ "E",
56
+ "F",
57
+ "W", # flake8
58
+ "B", # flake8-bugbear
59
+ "I", # isort
60
+ "ARG", # flake8-unused-arguments
61
+ "C4", # flake8-comprehensions
62
+ "EM", # flake8-errmsg
63
+ "ICN", # flake8-import-conventions
64
+ "ISC", # flake8-implicit-str-concat
65
+ "G", # flake8-logging-format
66
+ "PGH", # pygrep-hooks
67
+ "PIE", # flake8-pie
68
+ "PL", # pylint
69
+ "PT", # flake8-pytest-style
70
+ "RET", # flake8-return
71
+ "RUF", # Ruff-specific
72
+ "SIM", # flake8-simplify
73
+ "UP", # pyupgrade
74
+ "YTT", # flake8-2020
75
+ "EXE", # flake8-executable
76
+ ]
77
+
78
+ ignore = [
79
+ "PLR2004", # Magic value used in comparison
80
+ "EM102", # Exception must not use an f-string literal, assign to variable first
81
+ "ISC001", # Conflicts with formatter
82
+ # "D417", # Missing trailing new line in docstring
83
+ "D100", # Missing docstring in public module
84
+ "D104", # Missing docstring in public package
85
+ "PLR0913", # too many arguments
86
+ ]
87
+
88
+ unfixable = [
89
+ "F401", # Would remove unused imports
90
+ "F841", # Would remove unused variables
91
+ ]
92
+ flake8-unused-arguments.ignore-variadic-names = true # allow unused *args/**kwargs
93
+
94
+ [tool.ruff.lint.pydocstyle]
95
+ convention = "numpy"
96
+
97
+ [tool.ruff.lint.per-file-ignores]
98
+ "tests/*.py" = ["D"]
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
@@ -0,0 +1,191 @@
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from autosim.types import DeviceLike, TensorLike
8
+
9
+ SUPPORTED_DEVICES: list[str] = ["cpu", "mps", "cuda", "xpu"]
10
+
11
+ # Set this environment variable (to anything) to force mps to turn off.
12
+ # This is necessary because in github runners
13
+ # torch.backends.mps.is_available() returns true although it isn't.
14
+ if "TURN_OFF_MPS_IF_RUNNING_CI" in os.environ:
15
+ TURN_OFF_MPS_IF_RUNNING_CI = True
16
+ else:
17
+ TURN_OFF_MPS_IF_RUNNING_CI = False
18
+
19
+
20
+ class TorchDeviceError(NotImplementedError):
21
+ """Exception raised when the device is not implemented in torch."""
22
+
23
+ def __init__(self, device: str):
24
+ msg = f"Backend ({device}) not implemented."
25
+ super().__init__(msg)
26
+
27
+
28
+ def get_torch_device(device: DeviceLike | None) -> torch.device:
29
+ """
30
+ Get the device returning the torch default device if None.
31
+
32
+ Parameters
33
+ ----------
34
+ device: DeviceLike | None
35
+ The device to get. If None, the default torch device is returned.
36
+
37
+ Returns
38
+ -------
39
+ torch.device
40
+ The device.
41
+
42
+ Raises
43
+ ------
44
+ TorchDeviceError
45
+ If the device is not a valid torch device.
46
+ """
47
+ if isinstance(device, torch.device):
48
+ return device
49
+ if device is None:
50
+ return torch.get_default_device()
51
+ if device in SUPPORTED_DEVICES:
52
+ return torch.device(device)
53
+ raise TorchDeviceError(device)
54
+
55
+
56
+ def move_tensors_to_device(
57
+ *args: TensorLike, device: torch.device
58
+ ) -> tuple[TensorLike, ...]:
59
+ """
60
+ Move the given tensor to the device.
61
+
62
+ Parameters
63
+ ----------
64
+ *args: TensorLike
65
+ The tensors to move.
66
+ device: torch.device
67
+ The device to move the tensors to.
68
+
69
+ Returns
70
+ -------
71
+ tuple[TensorLike, ...]
72
+ The tensors on the device.
73
+ """
74
+ return tuple(tensor.to(device) for tensor in args)
75
+
76
+
77
+ # ruff: noqa: PLR0911
78
+ def check_torch_device_is_available(device: DeviceLike) -> bool:
79
+ """
80
+ Check if the given device type is available.
81
+
82
+ Parameters
83
+ ----------
84
+ device: DeviceLike
85
+ The device to check.
86
+
87
+ Returns
88
+ -------
89
+ bool
90
+ True if the device is available, False otherwise.
91
+
92
+ Raises
93
+ ------
94
+ TorchDeviceError
95
+ If the device is not a valid torch device.
96
+ """
97
+ if device == "cpu" or (
98
+ isinstance(device, torch.device) and device.type == torch.device("cpu").type
99
+ ):
100
+ return True
101
+ if device == "mps" or (
102
+ isinstance(device, torch.device) and device.type == torch.device("mps").type
103
+ ):
104
+ if TURN_OFF_MPS_IF_RUNNING_CI:
105
+ return False
106
+ return torch.backends.mps.is_available()
107
+ if device == "cuda":
108
+ return torch.cuda.is_available()
109
+ if isinstance(device, torch.device) and device.type == "cuda":
110
+ if device.index is not None:
111
+ return device.index < torch.cuda.device_count()
112
+ return True
113
+ if device == "xpu" or (
114
+ isinstance(device, torch.device) and device.type == torch.device("xpu").type
115
+ ):
116
+ return torch.xpu.is_available() if hasattr(torch, "xpu") else False
117
+ raise TorchDeviceError(str(device))
118
+
119
+
120
+ def check_model_device(model: nn.Module, expected_device: str) -> bool:
121
+ """
122
+ Check if the model is on the expected device.
123
+
124
+ Parameters
125
+ ----------
126
+ model: nn.Module
127
+ The model to check.
128
+ expected_device: str
129
+ The expected device.
130
+
131
+ Returns
132
+ -------
133
+ bool
134
+ True if the model is on the expected device (ignoring device index), False
135
+ otherwise.
136
+ """
137
+ return (
138
+ str(next(model.parameters()).device).split(":")[0]
139
+ == str(expected_device).split(":")[0]
140
+ )
141
+
142
+
143
+ class TorchDeviceMixin:
144
+ """
145
+ Mixin class to add device management to a PyTorch model.
146
+
147
+ Attributes
148
+ ----------
149
+ device: torch.device
150
+ The device to use. If None, the default torch device is used.
151
+
152
+ Raises
153
+ ------
154
+ TorchDeviceError
155
+ If the device is not a valid torch device.
156
+ """
157
+
158
+ def __init__(self, device: DeviceLike | None = None, cpu_only: bool = False):
159
+ # Warn if given device not CPU and cpu_only
160
+ # TODO: check handling
161
+ if cpu_only and (
162
+ (isinstance(device, str) and device != "cpu")
163
+ or (isinstance(device, torch.device) and torch.device("cpu") != device)
164
+ ):
165
+ msg = (
166
+ f"The device ({device}) must be CPU for given model. Setting device as "
167
+ "'cpu'."
168
+ )
169
+ # warnings.warn(msg, stacklevel=2)
170
+ logging.warning(msg)
171
+
172
+ self.device = get_torch_device(device)
173
+
174
+ if not check_torch_device_is_available(self.device):
175
+ raise TorchDeviceError(str(self.device))
176
+
177
+ def _move_tensors_to_device(self, *args: TensorLike) -> tuple[TensorLike, ...]:
178
+ """
179
+ Move the given tensor to the device.
180
+
181
+ Parameters
182
+ ----------
183
+ *args: TensorLike
184
+ The tensors to move.
185
+
186
+ Returns
187
+ -------
188
+ tuple[TensorLike, ...]
189
+ The tensors on the device.
190
+ """
191
+ return move_tensors_to_device(*args, device=self.device)
@@ -0,0 +1,125 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ from pathlib import Path
5
+
6
+
7
+ def configure_logging(log_to_file=False, level: str = "INFO"):
8
+ """
9
+ Configure the logging system.
10
+
11
+ Parameters
12
+ ----------
13
+ log_to_file: bool or string, optional
14
+ If True, logs will be written to a file.
15
+ If a string, logs will be written to the specified file.
16
+ verbose: str, optional
17
+ The verbosity level. Can be "critical", "error", "warning",
18
+ "info", or "debug". Defaults to "info".
19
+ """
20
+ logger = logging.getLogger("autosim")
21
+ logger.handlers = [] # Clear existing handlers
22
+
23
+ logger.setLevel(logging.DEBUG)
24
+
25
+ verbose_lower = level.lower()
26
+ match verbose_lower:
27
+ case "error":
28
+ console_log_level = logging.ERROR
29
+ case "warning":
30
+ console_log_level = logging.WARNING
31
+ case "info":
32
+ console_log_level = logging.INFO
33
+ case "debug":
34
+ console_log_level = logging.DEBUG
35
+ case "critical":
36
+ console_log_level = logging.CRITICAL
37
+ case _:
38
+ msg = 'verbose must be "critical", "error", "warning", "info", or "debug"'
39
+ raise ValueError(msg)
40
+
41
+ # Create console handler with a higher log level
42
+ ch = logging.StreamHandler(sys.stdout)
43
+ ch.setLevel(console_log_level)
44
+
45
+ # Create formatter and add it to the handler
46
+ # formatter = logging.Formatter("%(name)s - %(message)s")
47
+ formatter = logging.Formatter("%(levelname)-8s%(asctime)s - %(name)s - %(message)s")
48
+ ch.setFormatter(formatter)
49
+
50
+ # Add the handler to the logger
51
+ logger.addHandler(ch)
52
+
53
+ # Optionally log to a file
54
+ if log_to_file:
55
+ if isinstance(log_to_file, bool):
56
+ log_file_path = Path.cwd() / "autosim.log"
57
+ else:
58
+ log_file_path = Path(log_to_file)
59
+ # Create the directory if it doesn't exist
60
+ log_file_path.parent.mkdir(parents=True, exist_ok=True)
61
+
62
+ log_file_dir = os.path.dirname(log_file_path)
63
+ if log_file_dir and not os.path.exists(log_file_dir):
64
+ os.makedirs(log_file_dir)
65
+
66
+ try:
67
+ fh = logging.FileHandler(log_file_path)
68
+ fh.setLevel(logging.DEBUG)
69
+ fh.setFormatter(formatter)
70
+ logger.addHandler(fh)
71
+ except Exception:
72
+ logger.exception("Failed to create log file at %s", log_file_path)
73
+
74
+ # Capture (model) warnings and redirect them to the logging system
75
+ logging.captureWarnings(True)
76
+
77
+ warnings_logger = logging.getLogger("py.warnings")
78
+ for handler in logger.handlers:
79
+ warnings_logger.addHandler(handler)
80
+ warnings_logger.setLevel(logger.getEffectiveLevel())
81
+
82
+ return logger
83
+
84
+
85
+ def get_configured_logger(
86
+ log_level, progress_bar_attr="progress_bar"
87
+ ) -> tuple[logging.Logger, bool]:
88
+ """
89
+ Configure logger and progress bar flag consistently.
90
+
91
+ Parameters
92
+ ----------
93
+ log_level: str
94
+ The logging level to set. Can be "progress_bar", "debug", "info",
95
+ "warning", "error", or "critical".
96
+ progress_bar_attr: str
97
+ The attribute to check for progress bar. If log_level is set to this value,
98
+ the logger will be set to "error" level and progress_bar will be True. Defaults
99
+ to "progress_bar".
100
+
101
+ Returns
102
+ -------
103
+ tuple[logging.Logger, bool]
104
+ The configured logger and the progress bar flag.
105
+ """
106
+ valid_log_levels = [
107
+ "progress_bar",
108
+ "debug",
109
+ "info",
110
+ "warning",
111
+ "error",
112
+ "critical",
113
+ ]
114
+ log_level = log_level.lower()
115
+ if log_level not in valid_log_levels:
116
+ raise ValueError(
117
+ f"Invalid log level: {log_level}. Must be one of: {valid_log_levels}"
118
+ )
119
+ if log_level == progress_bar_attr:
120
+ log_level = "error"
121
+ progress_bar = True
122
+ else:
123
+ progress_bar = False
124
+ logger = configure_logging(level=log_level)
125
+ return logger, progress_bar
@@ -0,0 +1,22 @@
1
+ from .epidemic import Epidemic
2
+ from .flow_problem import FlowProblem
3
+ from .projectile import Projectile, ProjectileMultioutput
4
+ from .seir import SEIRSimulator
5
+
6
+ ALL_SIMULATORS = [
7
+ Epidemic,
8
+ SEIRSimulator,
9
+ FlowProblem,
10
+ Projectile,
11
+ ProjectileMultioutput,
12
+ ]
13
+
14
+ __all__ = [
15
+ "Epidemic",
16
+ "FlowProblem",
17
+ "Projectile",
18
+ "ProjectileMultioutput",
19
+ "SEIRSimulator",
20
+ ]
21
+
22
+ SIMULATOR_REGISTRY = dict(zip(__all__, ALL_SIMULATORS, strict=False))