caskade 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,99 @@
1
+ name: CD
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ push:
6
+ branches:
7
+ - main
8
+ - dev
9
+ release:
10
+ types:
11
+ - published
12
+
13
+ concurrency:
14
+ group: ${{ github.workflow }}-${{ github.ref }}
15
+ cancel-in-progress: true
16
+
17
+ env:
18
+ FORCE_COLOR: 3
19
+
20
+ jobs:
21
+ dist:
22
+ name: Distribution build
23
+ runs-on: ubuntu-latest
24
+
25
+ steps:
26
+ - uses: actions/checkout@v4
27
+ with:
28
+ fetch-depth: 0
29
+
30
+ - name: Build sdist and wheel
31
+ run: pipx run build
32
+
33
+ - uses: actions/upload-artifact@v4
34
+ with:
35
+ path: dist
36
+
37
+ - name: Check products
38
+ run: pipx run twine check dist/*
39
+
40
+ test-built-dist:
41
+ needs: [dist]
42
+ name: Test built distribution
43
+ runs-on: ubuntu-latest
44
+ permissions:
45
+ id-token: write
46
+ steps:
47
+ - uses: actions/setup-python@v5
48
+ name: Install Python
49
+ with:
50
+ python-version: "3.10"
51
+ - uses: actions/download-artifact@v4
52
+ with:
53
+ name: artifact
54
+ path: dist
55
+ - name: List contents of built dist
56
+ run: |
57
+ ls -ltrh
58
+ ls -ltrh dist
59
+ - name: Publish to Test PyPI
60
+ uses: pypa/gh-action-pypi-publish@v1.10.3
61
+ with:
62
+ repository-url: https://test.pypi.org/legacy/
63
+ verbose: true
64
+ skip-existing: true
65
+ - name: Check pypi packages
66
+ run: |
67
+ sleep 3
68
+ python -m pip install --upgrade pip
69
+
70
+ echo "=== Testing wheel file ==="
71
+ # Install wheel to get dependencies and check import
72
+ python -m pip install --extra-index-url https://test.pypi.org/simple --upgrade --pre caskade
73
+ python -c "import caskade; print(caskade.__version__); caskade.test()"
74
+ echo "=== Done testing wheel file ==="
75
+
76
+ echo "=== Testing source tar file ==="
77
+ # Install tar gz and check import
78
+ python -m pip uninstall --yes caskade
79
+ python -m pip install --extra-index-url https://test.pypi.org/simple --upgrade --pre --no-binary=:all: caskade
80
+ python -c "import caskade; print(caskade.__version__); caskade.test()"
81
+ echo "=== Done testing source tar file ==="
82
+
83
+ publish:
84
+ needs: [dist, test-built-dist]
85
+ name: Publish to PyPI
86
+ environment: pypi
87
+ permissions:
88
+ id-token: write
89
+ runs-on: ubuntu-latest
90
+ if: github.event_name == 'release' && github.event.action == 'published'
91
+
92
+ steps:
93
+ - uses: actions/download-artifact@v4
94
+ with:
95
+ name: artifact
96
+ path: dist
97
+
98
+ - uses: pypa/gh-action-pypi-publish@v1.10.3
99
+ if: startsWith(github.ref, 'refs/tags')
@@ -0,0 +1,83 @@
1
+ # This workflow will install Python dependencies, run tests and lint with a single version of Python
2
+ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3
+
4
+ name: CI
5
+
6
+ on:
7
+ workflow_dispatch:
8
+ pull_request:
9
+ push:
10
+ branches:
11
+ - main
12
+ - dev
13
+
14
+ concurrency:
15
+ group: ${{ github.workflow }}-${{ github.ref }}
16
+ cancel-in-progress: true
17
+
18
+ env:
19
+ FORCE_COLOR: 3
20
+ PROJECT_NAME: "caskade"
21
+
22
+ jobs:
23
+ build:
24
+ runs-on: ${{matrix.os}}
25
+ strategy:
26
+ fail-fast: false
27
+ matrix:
28
+ python-version: ["3.9", "3.10", "3.11"]
29
+ os: [ubuntu-latest, windows-latest, macOS-latest]
30
+
31
+ steps:
32
+ - name: Checkout caskade
33
+ uses: actions/checkout@v4
34
+ with:
35
+ fetch-depth: 0
36
+
37
+ - name: Set up Python ${{ matrix.python-version }}
38
+ uses: actions/setup-python@v5
39
+ with:
40
+ python-version: ${{ matrix.python-version }}
41
+ allow-prereleases: true
42
+
43
+ - name: Record State
44
+ run: |
45
+ pwd
46
+ echo github.ref is: ${{ github.ref }}
47
+ echo GITHUB_SHA is: $GITHUB_SHA
48
+ echo github.event_name is: ${{ github.event_name }}
49
+ echo github workspace: ${{ github.workspace }}
50
+ pip --version
51
+
52
+ - name: Install dependencies
53
+ run: |
54
+ python -m pip install --upgrade pip
55
+ pip install pytest pytest-cov torch wheel pydantic
56
+
57
+ # We only want to install this on one run, because otherwise we'll have
58
+ # duplicate annotations.
59
+ - name: Install error reporter
60
+ if: ${{ matrix.python-version == '3.10' }}
61
+ run: |
62
+ python -m pip install pytest-github-actions-annotate-failures
63
+
64
+ - name: Install caskade
65
+ run: |
66
+ pip install -e ".[dev]"
67
+ pip show ${{ env.PROJECT_NAME }}
68
+
69
+ - name: Test with pytest
70
+ run: |
71
+ pytest -vvv --cov=${{ env.PROJECT_NAME }} --cov-report=xml --cov-report=term tests/
72
+
73
+ - name: Upload coverage reports to Codecov with GitHub Action
74
+ if:
75
+ ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest'}}
76
+ uses: codecov/codecov-action@v4
77
+ env:
78
+ CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
79
+ with:
80
+ files: ./coverage.xml
81
+ flags: unittests
82
+ name: codecov-umbrella
83
+ fail_ci_if_error: true
@@ -0,0 +1,162 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
caskade-0.0.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Connor Stone, PhD
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
caskade-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,99 @@
1
+ Metadata-Version: 2.3
2
+ Name: caskade
3
+ Version: 0.0.1
4
+ Summary: Package for building scientific simulators, with dynamic arguments arranged in a directed acyclic graph.
5
+ Project-URL: Homepage, https://github.com/ConnorStoneAstro/caskade
6
+ Project-URL: Documentation, https://github.com/ConnorStoneAstro/caskade
7
+ Project-URL: Repository, https://github.com/ConnorStoneAstro/caskade
8
+ Project-URL: Issues, https://github.com/ConnorStoneAstro/caskade/issues
9
+ Author-email: Connor Stone <connorstone628@gmail.com>, Alexandre Adam <alexandre.adam@mila.quebec>
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Connor Stone, PhD
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ The above copyright notice and this permission notice shall be included in all
22
+ copies or substantial portions of the Software.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ SOFTWARE.
31
+ License-File: LICENSE
32
+ Keywords: DAG,caskade,differentiable programming,pytorch,scientific python
33
+ Classifier: Development Status :: 1 - Planning
34
+ Classifier: Intended Audience :: Science/Research
35
+ Classifier: License :: OSI Approved :: MIT License
36
+ Classifier: Operating System :: OS Independent
37
+ Classifier: Programming Language :: Python :: 3
38
+ Requires-Python: >=3.9
39
+ Requires-Dist: torch
40
+ Provides-Extra: dev
41
+ Requires-Dist: pre-commit<4,>=3.6; extra == 'dev'
42
+ Requires-Dist: pytest-cov<5,>=4.1; extra == 'dev'
43
+ Requires-Dist: pytest-mock<4,>=3.12; extra == 'dev'
44
+ Requires-Dist: pytest<9,>=8.0; extra == 'dev'
45
+ Description-Content-Type: text/markdown
46
+
47
+ # caskade
48
+
49
+ Build scientific simulators, treating them as a directed acyclic graph. Handles
50
+ argument passing for complex nested simulators.
51
+
52
+ ## Install
53
+
54
+ ``` bash
55
+ pip install caskade
56
+ ```
57
+
58
+ ## Usage
59
+
60
+ Make a `Module` object which may have some `Param`s. Define a `forward` method
61
+ using the decorator.
62
+
63
+ ``` python
64
+ from caskade import Module, Param, forward
65
+
66
+ class MySim(Module):
67
+ def __init__(self, a, b=None):
68
+ super().__init__()
69
+ self.a = a
70
+ self.b = Param("b", b)
71
+
72
+ @forward
73
+ def myfun(self, x, b=None):
74
+ return x + self.a + b
75
+ ```
76
+
77
+ We may now create instances of the simulator and pass the dynamic parameters.
78
+
79
+ ``` python
80
+ import torch
81
+
82
+ sim = MySim(1.0)
83
+
84
+ params = [torch.tensor(2.0)]
85
+
86
+ print(sim.myfun(3.0, params=params))
87
+ ```
88
+
89
+ Which will print `6` by automatically filling `b` with the value from `params`.
90
+
91
+ ### Why do this?
92
+
93
+ The above example is not very impressive, the real power comes from the fact
94
+ that `Module` objects can be nested arbitrarily making a much more complicated
95
+ analysis graph. Further, the `Param` objects can be linked or have other complex
96
+ relationships. All of the complexity of the nested structure and argument
97
+ passing is abstracted away so that at the top one need only pass a list of
98
+ tensors for each parameter, a single large 1d tensor, or a dictionary with the
99
+ same structure as the graph.
@@ -0,0 +1,53 @@
1
+ # caskade
2
+
3
+ Build scientific simulators, treating them as a directed acyclic graph. Handles
4
+ argument passing for complex nested simulators.
5
+
6
+ ## Install
7
+
8
+ ``` bash
9
+ pip install caskade
10
+ ```
11
+
12
+ ## Usage
13
+
14
+ Make a `Module` object which may have some `Param`s. Define a `forward` method
15
+ using the decorator.
16
+
17
+ ``` python
18
+ from caskade import Module, Param, forward
19
+
20
+ class MySim(Module):
21
+ def __init__(self, a, b=None):
22
+ super().__init__()
23
+ self.a = a
24
+ self.b = Param("b", b)
25
+
26
+ @forward
27
+ def myfun(self, x, b=None):
28
+ return x + self.a + b
29
+ ```
30
+
31
+ We may now create instances of the simulator and pass the dynamic parameters.
32
+
33
+ ``` python
34
+ import torch
35
+
36
+ sim = MySim(1.0)
37
+
38
+ params = [torch.tensor(2.0)]
39
+
40
+ print(sim.myfun(3.0, params=params))
41
+ ```
42
+
43
+ Which will print `6` by automatically filling `b` with the value from `params`.
44
+
45
+ ### Why do this?
46
+
47
+ The above example is not very impressive, the real power comes from the fact
48
+ that `Module` objects can be nested arbitrarily making a much more complicated
49
+ analysis graph. Further, the `Param` objects can be linked or have other complex
50
+ relationships. All of the complexity of the nested structure and argument
51
+ passing is abstracted away so that at the top one need only pass a list of
52
+ tensors for each parameter, a single large 1d tensor, or a dictionary with the
53
+ same structure as the graph.
@@ -0,0 +1,64 @@
1
+ [build-system]
2
+ requires = ["hatchling", "hatch-requirements-txt", "hatch-vcs"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "caskade"
7
+ dynamic = [
8
+ "dependencies",
9
+ "version"
10
+ ]
11
+ authors = [
12
+ { name="Connor Stone", email="connorstone628@gmail.com" },
13
+ { name="Alexandre Adam", email="alexandre.adam@mila.quebec" },
14
+ ]
15
+ description = "Package for building scientific simulators, with dynamic arguments arranged in a directed acyclic graph."
16
+ readme = "README.md"
17
+ requires-python = ">=3.9"
18
+ license = {file = "LICENSE"}
19
+ keywords = [
20
+ "caskade",
21
+ "DAG",
22
+ "scientific python",
23
+ "differentiable programming",
24
+ "pytorch"
25
+ ]
26
+ classifiers=[
27
+ "Development Status :: 1 - Planning",
28
+ "Intended Audience :: Science/Research",
29
+ "License :: OSI Approved :: MIT License",
30
+ "Operating System :: OS Independent",
31
+ "Programming Language :: Python :: 3"
32
+ ]
33
+
34
+ [project.urls]
35
+ Homepage = "https://github.com/ConnorStoneAstro/caskade"
36
+ Documentation = "https://github.com/ConnorStoneAstro/caskade"
37
+ Repository = "https://github.com/ConnorStoneAstro/caskade"
38
+ Issues = "https://github.com/ConnorStoneAstro/caskade/issues"
39
+
40
+ [project.optional-dependencies]
41
+ dev = [
42
+ "pytest>=8.0,<9",
43
+ "pytest-cov>=4.1,<5",
44
+ "pytest-mock>=3.12,<4",
45
+ "pre-commit>=3.6,<4",
46
+ ]
47
+
48
+ [tool.hatch.metadata.hooks.requirements_txt]
49
+ files = ["requirements.txt"]
50
+
51
+ [tool.hatch.version]
52
+ source = "vcs"
53
+
54
+ [tool.hatch.build.hooks.vcs]
55
+ version-file = "src/caskade/_version.py"
56
+
57
+ [tool.hatch.version.raw-options]
58
+ local_scheme = "no-local-version"
59
+
60
+ [tool.ruff]
61
+ line-length = 100
62
+
63
+ [tool.pytest.ini_options]
64
+ norecursedirs = "tests/utils"
@@ -0,0 +1 @@
1
+ torch
@@ -0,0 +1,14 @@
1
+ from ._version import version as VERSION # noqa
2
+
3
+ from .base import Node
4
+ from .context import ActiveContext
5
+ from .decorators import forward
6
+ from .module import Module
7
+ from .param import Param, LiveParam
8
+ from .tests import test
9
+
10
+
11
+ __version__ = VERSION
12
+ __author__ = "Connor and Alexandre"
13
+
14
+ __all__ = ("Node", "Module", "Param", "LiveParam", "ActiveContext", "forward")
@@ -0,0 +1,16 @@
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.0.1'
16
+ __version_tuple__ = version_tuple = (0, 0, 1)
@@ -0,0 +1,132 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+
6
+ class Node:
7
+ """
8
+ Base graph node class for caskade objects.
9
+
10
+ The `Node` object is the base class for all caskade objects. It is used to
11
+ construct the directed acyclic graph (DAG). The primary function of the
12
+ `Node` object is to manage the parent-child relationships between nodes in
13
+ the graph. There is limited functionality for the `Node` object, though it
14
+ implements the base versions of the `active` state and `to` /
15
+ `update_dynamic_params` methods. The `active` state is used to communicate
16
+ through the graph that the simulator is currently running. The `to` method
17
+ is used to move and/or cast the values of the parameter. The
18
+ `update_dynamic_params` method is used by `Module` objects to keep track of
19
+ all dynamic `Param` objects below them in the graph.
20
+
21
+ Examples
22
+ --------
23
+ ``` python
24
+ n1 = Node("node1")
25
+ n2 = Node("node2")
26
+ n1.link("subnode", n2) # link n2 as a child of n1, may use any str as the key
27
+ n1.unlink("subnode") # alternately n1.unlink(n2) to unlink by object
28
+ """
29
+
30
+ def __init__(self, name):
31
+ assert isinstance(name, str), f"{self.__class__.__name__} name must be a string"
32
+ assert "|" not in name, f"{self.__class__.__name__} cannot contain '|'"
33
+ self._name = name
34
+ self._children = {}
35
+ self._parents = set()
36
+ self._active = False
37
+ self._type = "node"
38
+
39
+ @property
40
+ def name(self) -> str:
41
+ return self._name
42
+
43
+ @property
44
+ def children(self) -> dict:
45
+ return self._children
46
+
47
+ @property
48
+ def parents(self) -> set:
49
+ return self._parents
50
+
51
+ def link(self, key, child):
52
+ # Avoid double linking to the same object
53
+ if key in self.children:
54
+ raise ValueError(f"Child key {key} already linked to parent {self.name}")
55
+ for ownchild in self.children.values():
56
+ if ownchild == child:
57
+ raise ValueError(f"Child {child.name} already linked to parent {self.name}")
58
+
59
+ self._children[key] = child
60
+ child._parents.add(self)
61
+ self.update_dynamic_params()
62
+
63
+ def unlink(self, key):
64
+ if isinstance(key, Node):
65
+ for node in self.children:
66
+ if self.children[node] == key:
67
+ key = node
68
+ break
69
+ self._children[key]._parents.remove(self)
70
+ self._children[key].update_dynamic_params()
71
+ del self._children[key]
72
+ self.update_dynamic_params()
73
+
74
+ def topological_ordering(self, with_type=None) -> tuple:
75
+ ordering = [self]
76
+ for node in self.children.values():
77
+ for subnode in node.topological_ordering():
78
+ if subnode not in ordering:
79
+ ordering.append(subnode)
80
+ if with_type is None:
81
+ return tuple(ordering)
82
+ return tuple(filter(lambda n: n._type == with_type, ordering))
83
+
84
+ def update_dynamic_params(self):
85
+ for parent in self.parents:
86
+ parent.update_dynamic_params()
87
+
88
+ @property
89
+ def active(self) -> bool:
90
+ return self._active
91
+
92
+ @active.setter
93
+ def active(self, value):
94
+ # Avoid unnecessary updates
95
+ if self._active == value:
96
+ return
97
+
98
+ # Set self active level
99
+ self._active = value
100
+
101
+ # Propagate active level to children
102
+ for child in self._children.values():
103
+ child.active = value
104
+
105
+ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
106
+ """
107
+ Moves and/or casts the PyTorch values of the Node.
108
+
109
+ Parameters
110
+ ----------
111
+ device: (Optional[torch.device], optional)
112
+ The device to move the values to. Defaults to None.
113
+ dtype: (Optional[torch.dtype], optional)
114
+ The desired data type. Defaults to None.
115
+ """
116
+
117
+ for child in self.children.values():
118
+ child.to(device=device, dtype=dtype)
119
+
120
+ def graph_dict(self) -> dict:
121
+ graph = {
122
+ f"{self.name}|{self._type}": {},
123
+ }
124
+ for node in self.children.values():
125
+ graph[f"{self.name}|{self._type}"].update(node.graph_dict())
126
+ return graph
127
+
128
+ def __str__(self) -> str:
129
+ return str(self.graph_dict())
130
+
131
+ def __repr__(self) -> str:
132
+ return f"{self.__class__.__name__}({self.name})"
@@ -0,0 +1,21 @@
1
+ from typing import Union, Mapping, Sequence
2
+
3
+ from torch import Tensor
4
+
5
+ from .module import Module
6
+
7
+
8
+ class ActiveContext:
9
+ def __init__(
10
+ self, module: Module, params: Union[Sequence[Tensor], Mapping[str, Tensor], Tensor]
11
+ ):
12
+ self.module = module
13
+ self.params = params
14
+
15
+ def __enter__(self):
16
+ self.module.active = True
17
+ self.module.fill_params(self.params)
18
+
19
+ def __exit__(self, exc_type, exc_value, traceback):
20
+ self.module.clear_params()
21
+ self.module.active = False
@@ -0,0 +1,50 @@
1
+ import inspect
2
+ import functools
3
+
4
+ from .context import ActiveContext
5
+
6
+
7
+ def forward(method):
8
+ """
9
+ Decorator to define a forward method for a module.
10
+
11
+ Parameters
12
+ ----------
13
+ method: (Callable)
14
+ The forward method to be decorated.
15
+
16
+ Returns
17
+ -------
18
+ Callable
19
+ The decorated forward method.
20
+ """
21
+
22
+ # Get kwargs from function signature
23
+ method_kwargs = []
24
+ for arg in inspect.signature(method).parameters.values():
25
+ if arg.default is not arg.empty:
26
+ method_kwargs.append(arg.name)
27
+
28
+ @functools.wraps(method)
29
+ def wrapped(self, *args, **kwargs):
30
+ if self.active:
31
+ kwargs.update(self.fill_kwargs(method_kwargs))
32
+ return method(self, *args, **kwargs)
33
+
34
+ # Extract params from the arguments
35
+ if len(self.dynamic_params) == 0:
36
+ params = {}
37
+ elif "params" in kwargs:
38
+ params = kwargs.pop("params")
39
+ elif args:
40
+ params = args.pop(0)
41
+ else:
42
+ raise ValueError(
43
+ f"Params must be provided for dynamic modules. Expected {len(self.dynamic_params)} params."
44
+ )
45
+
46
+ with ActiveContext(self, params):
47
+ kwargs.update(self.fill_kwargs(method_kwargs))
48
+ return method(self, *args, **kwargs)
49
+
50
+ return wrapped
@@ -0,0 +1,99 @@
1
+ from typing import Sequence, Mapping
2
+
3
+ from torch import Tensor
4
+
5
+ from .base import Node
6
+ from .param import Param, LiveParam
7
+
8
+
9
+ class Module(Node):
10
+
11
+ def __init__(self, name):
12
+ super().__init__(name=name)
13
+ self.dynamic_params = ()
14
+ self.live_params = ()
15
+ self._type = "module"
16
+ self._batch = False
17
+
18
+ @property
19
+ def batch(self) -> bool:
20
+ return self._batch
21
+
22
+ @batch.setter
23
+ def batch(self, value):
24
+ assert isinstance(value, bool)
25
+ self._batch = value
26
+
27
+ def update_dynamic_params(self):
28
+ super().update_dynamic_params()
29
+ self.dynamic_params = tuple(self.topological_ordering("dynamic"))
30
+ self.live_params = tuple(self.topological_ordering("live"))
31
+
32
+ def fill_params(self, params):
33
+ assert self.active, "Module must be active to fill params"
34
+
35
+ if isinstance(params, Tensor):
36
+ if self.batch:
37
+ B = params.shape[0]
38
+ pos = 0
39
+ for param in self.dynamic_params:
40
+ try:
41
+ size = param.shape.numel()
42
+ except AttributeError:
43
+ raise ValueError(
44
+ f"Param {param.name} has no shape. dynamic parameters must have a shape to use Tensor input."
45
+ )
46
+ if self.batch:
47
+ param.value = params[:, pos : pos + size].view((B,) + param.shape)
48
+ pos += size * B
49
+ else:
50
+ param.value = params[pos : pos + size].view(param.shape)
51
+ pos += size
52
+ elif isinstance(params, Sequence):
53
+ if len(params) == len(self.dynamic_params):
54
+ for param, value in zip(self.dynamic_params, params):
55
+ param.value = value
56
+ else:
57
+ raise ValueError(
58
+ f"Input params length ({len(params)}) does not match dynamic params length ({len(self.dynamic_params)})"
59
+ )
60
+ elif isinstance(params, Mapping):
61
+ for key in params:
62
+ if key in self.children:
63
+ if isinstance(self.children[key], Param):
64
+ self.children[key].value = params[key]
65
+ elif isinstance(self.children[key], Module):
66
+ self.children[key].fill_params(params[key])
67
+ else:
68
+ raise ValueError(f"Key {key} type {type(self.children[key])} not supported")
69
+ else:
70
+ raise ValueError(f"Key {key} not found in {self.name} children")
71
+ else:
72
+ raise ValueError(
73
+ f"Input params type {type(params)} not supported. Should be Tensor, Sequence or Mapping."
74
+ )
75
+
76
+ def clear_params(self):
77
+ assert self.active, "Module must be active to clear params"
78
+
79
+ for param in self.dynamic_params:
80
+ param.value = None
81
+
82
+ for param in self.live_params:
83
+ param.value = LiveParam
84
+
85
+ def fill_kwargs(self, keys) -> dict[str, Tensor]:
86
+ return {key: getattr(self, key).value for key in keys}
87
+
88
+ def __setattr__(self, key, value):
89
+ try:
90
+ if key in self.children and isinstance(self.children[key], Param):
91
+ self.children[key].value = value
92
+ return
93
+ if isinstance(value, Node):
94
+ self.link(key, value)
95
+ self.update_dynamic_params()
96
+
97
+ super().__setattr__(key, value)
98
+ except AttributeError:
99
+ super().__setattr__(key, value)
@@ -0,0 +1,147 @@
1
+ from typing import Optional, Union, Callable
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from .base import Node
7
+
8
+ __all__ = ("Param", "LiveParam")
9
+
10
+
11
+ class LiveParamBase:
12
+ """Placeholder to identify a parameter as live updating. Like `None` there
13
+ exists only one instance of this class."""
14
+
15
+ pass
16
+
17
+
18
+ LiveParam = LiveParamBase()
19
+
20
+
21
+ class Param(Node):
22
+ """
23
+ Node to represent a parameter in the graph.
24
+
25
+ The `Param` object is used to represent a parameter in the graph. During
26
+ runtime this will represent a tensor value which can be used in various
27
+ calculations. The `Param` object can be set to a constant value (`value`);
28
+ `None` meaning the value is to be provided at runtime (`dynamic`);
29
+ `LiveParam` meaning the value will be computed internally in the simulator
30
+ during runtime (`live`); another `Param` object meaning it will take on that
31
+ value at runtime (`pointer`); or a function of other `Param` objects to be
32
+ computed at runtime (`function`). These options allow users to flexibly set
33
+ the behavior of the simulator.
34
+
35
+ Examples
36
+ --------
37
+ ``` python
38
+ p1 = Param("test", (1.0, 2.0)) # constant value, length 2 vector
39
+ p2 = Param("test", None, (2,2)) # dynamic 2x2 matrix value
40
+ p3 = Param("test", LiveParam) # live updating value
41
+ p4 = Param("test", p1) # pointer to another parameter
42
+ p5 = Param("test", lambda p: p.children["other"].value * 2) # function of another parameter
43
+ p5.link("other", p2) # link the other parameter needed for the function
44
+ ```
45
+
46
+ Parameters
47
+ ----------
48
+ name: (str)
49
+ The name of the parameter.
50
+ value: (Optional[Union[Tensor, float, int]], optional)
51
+ The value of the parameter. Defaults to None meaning dynamic.
52
+ shape: (Optional[tuple[int, ...]], optional)
53
+ The shape of the parameter. Defaults to () meaning scalar.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ name,
59
+ value: Optional[Union[Tensor, float, int]] = None,
60
+ shape: Optional[tuple[int, ...]] = (),
61
+ ):
62
+ super().__init__(name=name)
63
+ if value is None:
64
+ if shape is None:
65
+ raise ValueError("Either value or shape must be provided")
66
+ if not isinstance(shape, tuple):
67
+ raise ValueError("Shape must be a tuple")
68
+ self.shape = shape
69
+ elif not isinstance(value, (Param, Callable, LiveParamBase)):
70
+ value = torch.as_tensor(value)
71
+ assert (
72
+ shape == () or shape == value.shape
73
+ ), f"Shape {shape} does not match value shape {value.shape}"
74
+ self.value = value
75
+
76
+ @property
77
+ def dynamic(self) -> bool:
78
+ return self._type == "dynamic"
79
+
80
+ @property
81
+ def live(self) -> bool:
82
+ return self._type == "live"
83
+
84
+ @property
85
+ def shape(self) -> tuple:
86
+ return self._shape
87
+
88
+ @shape.setter
89
+ def shape(self, shape):
90
+ if self._type in ["pointer", "function"]:
91
+ raise RuntimeError("Cannot set shape of parameter with type 'pointer' or 'function'")
92
+ self._shape = shape
93
+
94
+ @property
95
+ def value(self) -> Union[Tensor, None]:
96
+ if self._type == "pointer":
97
+ return self._value.value
98
+ if self._type == "function":
99
+ return self._value(self)
100
+ return self._value
101
+
102
+ @value.setter
103
+ def value(self, value):
104
+ # While active, update silently
105
+ if self.active:
106
+ if self.dynamic or self.live:
107
+ self._value = value
108
+ return
109
+ raise RuntimeError(f"Cannot set value of non-live parameter {self.name} while active")
110
+
111
+ # unlink if pointer to avoid floating references
112
+ if self._type == "pointer":
113
+ self.unlink(self._value)
114
+
115
+ if value is None:
116
+ self._type = "dynamic"
117
+ elif isinstance(value, LiveParamBase):
118
+ self._type = "live"
119
+ elif isinstance(value, Param):
120
+ self._type = "pointer"
121
+ self.link(value.name, value)
122
+ self._shape = None
123
+ elif callable(value):
124
+ self._type = "function"
125
+ self._shape = None
126
+ else:
127
+ self._type = "value"
128
+ value = torch.as_tensor(value)
129
+ self.shape = value.shape
130
+
131
+ self._value = value
132
+ self.update_dynamic_params()
133
+
134
+ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
135
+ """
136
+ Moves and/or casts the values of the parameter.
137
+
138
+ Parameters
139
+ ----------
140
+ device: (Optional[torch.device], optional)
141
+ The device to move the values to. Defaults to None.
142
+ dtype: (Optional[torch.dtype], optional)
143
+ The desired data type. Defaults to None.
144
+ """
145
+ super().to(device=device, dtype=dtype)
146
+ if self._type == "value":
147
+ self._value = self._value.to(device=device, dtype=dtype)
@@ -0,0 +1,47 @@
1
+ import torch
2
+
3
+ from caskade import Module, Param, forward, LiveParam
4
+
5
+ __all__ = ("test",)
6
+
7
+
8
+ def _test_full_integration():
9
+
10
+ class TestSim(Module):
11
+ def __init__(self, a, b, c, c_shape, m1):
12
+ super().__init__("test_sim")
13
+ self.a = a
14
+ self.b = Param("b", b)
15
+ self.c = Param("c", c, c_shape)
16
+ self.m1 = m1
17
+
18
+ @forward
19
+ def testfun(self, x, b=None):
20
+ self.c.value = b + x
21
+ y = self.m1()
22
+ return x + self.a + b + y
23
+
24
+ class TestSubSim(Module):
25
+ def __init__(self, d, e, f):
26
+ super().__init__("test_sub_sim")
27
+ self.d = Param("d", d)
28
+ self.e = Param("e", e)
29
+ self.f = Param("f", f)
30
+
31
+ @forward
32
+ def __call__(self, d=None, e=None, f=None):
33
+ return d + e + f
34
+
35
+ sub1 = TestSubSim(d=1.0, e=lambda s: s.children["flink"].value, f=None)
36
+ sub1.e.link("flink", sub1.f)
37
+ main1 = TestSim(a=2.0, b=None, c=LiveParam, c_shape=(), m1=sub1)
38
+ sub1.f = main1.c
39
+
40
+ b_value = torch.tensor(3.0)
41
+ res = main1.testfun(1.0, params=[b_value])
42
+ assert res.item() == 15.0
43
+
44
+
45
+ def test():
46
+ _test_full_integration()
47
+ print("Success!")
@@ -0,0 +1,105 @@
1
+ from caskade import Node
2
+
3
+ import pytest
4
+
5
+
6
+ def test_creation():
7
+ node = Node("test")
8
+ assert node._name == "test"
9
+ assert node._children == {}
10
+ assert node._parents == set()
11
+ assert node._active == False
12
+ assert node._type == "node"
13
+
14
+ with pytest.raises(AttributeError):
15
+ node.name = "newname"
16
+
17
+
18
+ def test_link():
19
+ node1 = Node("node1")
20
+ node2 = Node("node2")
21
+ node1.link("subnode", node2)
22
+
23
+ assert "subnode" in node1._children
24
+ assert node1._children["subnode"] == node2
25
+ assert node1._parents == set()
26
+ assert node2._parents == set([node1])
27
+
28
+ str(node1)
29
+ repr(node1)
30
+
31
+ node1.unlink(node2)
32
+ assert "subnode" not in node1._children
33
+ assert node2._parents == set()
34
+ assert node1._parents == set()
35
+
36
+
37
+ def test_topological_ordering():
38
+ node1 = Node("node1")
39
+ node2 = Node("node2")
40
+ node3 = Node("node3")
41
+ node4 = Node("node4")
42
+ node5 = Node("node5")
43
+ node6 = Node("node6")
44
+
45
+ node1.link("subnode1", node2)
46
+ node1.link("subnode2", node3)
47
+ node2.link("subnode3", node4)
48
+ node2.link("subnode4", node5)
49
+ node3.link("subnode5", node6)
50
+
51
+ ordering = node1.topological_ordering()
52
+ assert ordering == (node1, node2, node4, node5, node3, node6)
53
+
54
+ ordering = node1.topological_ordering(with_type="node")
55
+ assert ordering == (node1, node2, node4, node5, node3, node6)
56
+
57
+ ordering = node1.topological_ordering(with_type="dynamic")
58
+ assert ordering == ()
59
+
60
+ node1.unlink("subnode1")
61
+ ordering = node1.topological_ordering()
62
+ assert ordering == (node1, node3, node6)
63
+
64
+ node1.unlink("subnode2")
65
+ ordering = node1.topological_ordering()
66
+ assert ordering == (node1,)
67
+
68
+
69
+ def test_active():
70
+ node1 = Node("node1")
71
+ node2 = Node("node2")
72
+ node3 = Node("node3")
73
+ node4 = Node("node4")
74
+ node5 = Node("node5")
75
+ node6 = Node("node6")
76
+
77
+ node1.link("subnode1", node2)
78
+ node1.link("subnode2", node3)
79
+ node2.link("subnode3", node4)
80
+ node2.link("subnode4", node5)
81
+ node3.link("subnode5", node6)
82
+
83
+ node1.active = True
84
+ assert node1.active == True
85
+ assert node2.active == True
86
+ assert node3.active == True
87
+ assert node4.active == True
88
+ assert node5.active == True
89
+ assert node6.active == True
90
+
91
+ node2.active = False
92
+ assert node1.active == True
93
+ assert node2.active == False
94
+ assert node3.active == True
95
+ assert node4.active == False
96
+ assert node5.active == False
97
+ assert node6.active == True
98
+
99
+ node1.active = False
100
+ assert node1.active == False
101
+ assert node2.active == False
102
+ assert node3.active == False
103
+ assert node4.active == False
104
+ assert node5.active == False
105
+ assert node6.active == False
@@ -0,0 +1,40 @@
1
+ import torch
2
+
3
+ from caskade import Module, Param, forward, LiveParam
4
+
5
+
6
+ def test_full_integration():
7
+
8
+ class TestSim(Module):
9
+ def __init__(self, a, b, c, c_shape, m1):
10
+ super().__init__("test_sim")
11
+ self.a = a
12
+ self.b = Param("b", b)
13
+ self.c = Param("c", c, c_shape)
14
+ self.m1 = m1
15
+
16
+ @forward
17
+ def testfun(self, x, b=None):
18
+ self.c.value = b + x
19
+ y = self.m1()
20
+ return x + self.a + b + y
21
+
22
+ class TestSubSim(Module):
23
+ def __init__(self, d, e, f):
24
+ super().__init__("test_sub_sim")
25
+ self.d = Param("d", d)
26
+ self.e = Param("e", e)
27
+ self.f = Param("f", f)
28
+
29
+ @forward
30
+ def __call__(self, d=None, e=None, f=None):
31
+ return d + e + f
32
+
33
+ sub1 = TestSubSim(d=1.0, e=lambda s: s.children["flink"].value, f=None)
34
+ sub1.e.link("flink", sub1.f)
35
+ main1 = TestSim(a=2.0, b=None, c=LiveParam, c_shape=(), m1=sub1)
36
+ sub1.f = main1.c
37
+
38
+ b_value = torch.tensor(3.0)
39
+ res = main1.testfun(1.0, params=[b_value])
40
+ assert res.item() == 15.0
@@ -0,0 +1 @@
1
+ from caskade import Module
@@ -0,0 +1,71 @@
1
+ import pytest
2
+ import torch
3
+
4
+ from caskade import Param, LiveParam
5
+
6
+
7
+ def test_live_param():
8
+ lp1 = LiveParam
9
+ lp2 = LiveParam
10
+
11
+ assert lp1 is lp2, "LiveParam is not a singleton"
12
+
13
+
14
+ def test_param_creation():
15
+
16
+ p1 = Param("test")
17
+ assert p1.name == "test"
18
+ assert p1.dynamic
19
+ assert not p1.live
20
+ assert p1.value is None
21
+
22
+ p2 = Param("test", 1.0)
23
+ assert p2.name == "test"
24
+ assert p2.value.item() == 1.0
25
+ p3 = Param("test", torch.ones((1, 2, 3)))
26
+ with pytest.raises(RuntimeError):
27
+ p3.active = True
28
+ p3.value = 1.0
29
+
30
+ with pytest.raises(AssertionError):
31
+ p4 = Param("test", 1.0, shape=(1, 2, 3))
32
+
33
+ p5 = Param("test", p3)
34
+ with pytest.raises(RuntimeError):
35
+ p5.shape = (1, 2, 3)
36
+
37
+ p6 = Param("test", lambda p: p.children["other"].value * 2)
38
+ p6.link("other", p2)
39
+ with pytest.raises(RuntimeError):
40
+ p6.shape = (1, 2, 3)
41
+
42
+
43
+ def test_value_setter():
44
+
45
+ # dynamic
46
+ p = Param("test")
47
+ assert p._type == "dynamic"
48
+
49
+ # value
50
+ p.value = 1.0
51
+ assert p._type == "value"
52
+ assert p.value.item() == 1.0
53
+
54
+ p = Param("testshape", shape=(2,))
55
+ p.value = [1.0, 2.0]
56
+
57
+ # pointer
58
+ other = Param("testother", 2.0)
59
+ p.value = other
60
+ assert p._type == "pointer"
61
+ assert p.shape is None
62
+
63
+ # function
64
+ p.value = lambda p: p.children["other"].value * 2
65
+ p.link("other", other)
66
+ assert p._type == "function"
67
+ assert p.value.item() == 4.0
68
+
69
+ # live
70
+ p.value = LiveParam
71
+ assert p._type == "live"