caskade 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caskade-0.0.1/.github/workflows/cd.yml +99 -0
- caskade-0.0.1/.github/workflows/ci.yml +83 -0
- caskade-0.0.1/.gitignore +162 -0
- caskade-0.0.1/LICENSE +21 -0
- caskade-0.0.1/PKG-INFO +99 -0
- caskade-0.0.1/README.md +53 -0
- caskade-0.0.1/pyproject.toml +64 -0
- caskade-0.0.1/requirements.txt +1 -0
- caskade-0.0.1/src/caskade/__init__.py +14 -0
- caskade-0.0.1/src/caskade/_version.py +16 -0
- caskade-0.0.1/src/caskade/base.py +132 -0
- caskade-0.0.1/src/caskade/context.py +21 -0
- caskade-0.0.1/src/caskade/decorators.py +50 -0
- caskade-0.0.1/src/caskade/module.py +99 -0
- caskade-0.0.1/src/caskade/param.py +147 -0
- caskade-0.0.1/src/caskade/tests.py +47 -0
- caskade-0.0.1/tests/test_base.py +105 -0
- caskade-0.0.1/tests/test_integration.py +40 -0
- caskade-0.0.1/tests/test_module.py +1 -0
- caskade-0.0.1/tests/test_param.py +71 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
name: CD
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
workflow_dispatch:
|
|
5
|
+
push:
|
|
6
|
+
branches:
|
|
7
|
+
- main
|
|
8
|
+
- dev
|
|
9
|
+
release:
|
|
10
|
+
types:
|
|
11
|
+
- published
|
|
12
|
+
|
|
13
|
+
concurrency:
|
|
14
|
+
group: ${{ github.workflow }}-${{ github.ref }}
|
|
15
|
+
cancel-in-progress: true
|
|
16
|
+
|
|
17
|
+
env:
|
|
18
|
+
FORCE_COLOR: 3
|
|
19
|
+
|
|
20
|
+
jobs:
|
|
21
|
+
dist:
|
|
22
|
+
name: Distribution build
|
|
23
|
+
runs-on: ubuntu-latest
|
|
24
|
+
|
|
25
|
+
steps:
|
|
26
|
+
- uses: actions/checkout@v4
|
|
27
|
+
with:
|
|
28
|
+
fetch-depth: 0
|
|
29
|
+
|
|
30
|
+
- name: Build sdist and wheel
|
|
31
|
+
run: pipx run build
|
|
32
|
+
|
|
33
|
+
- uses: actions/upload-artifact@v4
|
|
34
|
+
with:
|
|
35
|
+
path: dist
|
|
36
|
+
|
|
37
|
+
- name: Check products
|
|
38
|
+
run: pipx run twine check dist/*
|
|
39
|
+
|
|
40
|
+
test-built-dist:
|
|
41
|
+
needs: [dist]
|
|
42
|
+
name: Test built distribution
|
|
43
|
+
runs-on: ubuntu-latest
|
|
44
|
+
permissions:
|
|
45
|
+
id-token: write
|
|
46
|
+
steps:
|
|
47
|
+
- uses: actions/setup-python@v5
|
|
48
|
+
name: Install Python
|
|
49
|
+
with:
|
|
50
|
+
python-version: "3.10"
|
|
51
|
+
- uses: actions/download-artifact@v4
|
|
52
|
+
with:
|
|
53
|
+
name: artifact
|
|
54
|
+
path: dist
|
|
55
|
+
- name: List contents of built dist
|
|
56
|
+
run: |
|
|
57
|
+
ls -ltrh
|
|
58
|
+
ls -ltrh dist
|
|
59
|
+
- name: Publish to Test PyPI
|
|
60
|
+
uses: pypa/gh-action-pypi-publish@v1.10.3
|
|
61
|
+
with:
|
|
62
|
+
repository-url: https://test.pypi.org/legacy/
|
|
63
|
+
verbose: true
|
|
64
|
+
skip-existing: true
|
|
65
|
+
- name: Check pypi packages
|
|
66
|
+
run: |
|
|
67
|
+
sleep 3
|
|
68
|
+
python -m pip install --upgrade pip
|
|
69
|
+
|
|
70
|
+
echo "=== Testing wheel file ==="
|
|
71
|
+
# Install wheel to get dependencies and check import
|
|
72
|
+
python -m pip install --extra-index-url https://test.pypi.org/simple --upgrade --pre caskade
|
|
73
|
+
python -c "import caskade; print(caskade.__version__); caskade.test()"
|
|
74
|
+
echo "=== Done testing wheel file ==="
|
|
75
|
+
|
|
76
|
+
echo "=== Testing source tar file ==="
|
|
77
|
+
# Install tar gz and check import
|
|
78
|
+
python -m pip uninstall --yes caskade
|
|
79
|
+
python -m pip install --extra-index-url https://test.pypi.org/simple --upgrade --pre --no-binary=:all: caskade
|
|
80
|
+
python -c "import caskade; print(caskade.__version__); caskade.test()"
|
|
81
|
+
echo "=== Done testing source tar file ==="
|
|
82
|
+
|
|
83
|
+
publish:
|
|
84
|
+
needs: [dist, test-built-dist]
|
|
85
|
+
name: Publish to PyPI
|
|
86
|
+
environment: pypi
|
|
87
|
+
permissions:
|
|
88
|
+
id-token: write
|
|
89
|
+
runs-on: ubuntu-latest
|
|
90
|
+
if: github.event_name == 'release' && github.event.action == 'published'
|
|
91
|
+
|
|
92
|
+
steps:
|
|
93
|
+
- uses: actions/download-artifact@v4
|
|
94
|
+
with:
|
|
95
|
+
name: artifact
|
|
96
|
+
path: dist
|
|
97
|
+
|
|
98
|
+
- uses: pypa/gh-action-pypi-publish@v1.10.3
|
|
99
|
+
if: startsWith(github.ref, 'refs/tags')
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
|
2
|
+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
|
3
|
+
|
|
4
|
+
name: CI
|
|
5
|
+
|
|
6
|
+
on:
|
|
7
|
+
workflow_dispatch:
|
|
8
|
+
pull_request:
|
|
9
|
+
push:
|
|
10
|
+
branches:
|
|
11
|
+
- main
|
|
12
|
+
- dev
|
|
13
|
+
|
|
14
|
+
concurrency:
|
|
15
|
+
group: ${{ github.workflow }}-${{ github.ref }}
|
|
16
|
+
cancel-in-progress: true
|
|
17
|
+
|
|
18
|
+
env:
|
|
19
|
+
FORCE_COLOR: 3
|
|
20
|
+
PROJECT_NAME: "caskade"
|
|
21
|
+
|
|
22
|
+
jobs:
|
|
23
|
+
build:
|
|
24
|
+
runs-on: ${{matrix.os}}
|
|
25
|
+
strategy:
|
|
26
|
+
fail-fast: false
|
|
27
|
+
matrix:
|
|
28
|
+
python-version: ["3.9", "3.10", "3.11"]
|
|
29
|
+
os: [ubuntu-latest, windows-latest, macOS-latest]
|
|
30
|
+
|
|
31
|
+
steps:
|
|
32
|
+
- name: Checkout caskade
|
|
33
|
+
uses: actions/checkout@v4
|
|
34
|
+
with:
|
|
35
|
+
fetch-depth: 0
|
|
36
|
+
|
|
37
|
+
- name: Set up Python ${{ matrix.python-version }}
|
|
38
|
+
uses: actions/setup-python@v5
|
|
39
|
+
with:
|
|
40
|
+
python-version: ${{ matrix.python-version }}
|
|
41
|
+
allow-prereleases: true
|
|
42
|
+
|
|
43
|
+
- name: Record State
|
|
44
|
+
run: |
|
|
45
|
+
pwd
|
|
46
|
+
echo github.ref is: ${{ github.ref }}
|
|
47
|
+
echo GITHUB_SHA is: $GITHUB_SHA
|
|
48
|
+
echo github.event_name is: ${{ github.event_name }}
|
|
49
|
+
echo github workspace: ${{ github.workspace }}
|
|
50
|
+
pip --version
|
|
51
|
+
|
|
52
|
+
- name: Install dependencies
|
|
53
|
+
run: |
|
|
54
|
+
python -m pip install --upgrade pip
|
|
55
|
+
pip install pytest pytest-cov torch wheel pydantic
|
|
56
|
+
|
|
57
|
+
# We only want to install this on one run, because otherwise we'll have
|
|
58
|
+
# duplicate annotations.
|
|
59
|
+
- name: Install error reporter
|
|
60
|
+
if: ${{ matrix.python-version == '3.10' }}
|
|
61
|
+
run: |
|
|
62
|
+
python -m pip install pytest-github-actions-annotate-failures
|
|
63
|
+
|
|
64
|
+
- name: Install caskade
|
|
65
|
+
run: |
|
|
66
|
+
pip install -e ".[dev]"
|
|
67
|
+
pip show ${{ env.PROJECT_NAME }}
|
|
68
|
+
|
|
69
|
+
- name: Test with pytest
|
|
70
|
+
run: |
|
|
71
|
+
pytest -vvv --cov=${{ env.PROJECT_NAME }} --cov-report=xml --cov-report=term tests/
|
|
72
|
+
|
|
73
|
+
- name: Upload coverage reports to Codecov with GitHub Action
|
|
74
|
+
if:
|
|
75
|
+
${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest'}}
|
|
76
|
+
uses: codecov/codecov-action@v4
|
|
77
|
+
env:
|
|
78
|
+
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
|
79
|
+
with:
|
|
80
|
+
files: ./coverage.xml
|
|
81
|
+
flags: unittests
|
|
82
|
+
name: codecov-umbrella
|
|
83
|
+
fail_ci_if_error: true
|
caskade-0.0.1/.gitignore
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
share/python-wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
MANIFEST
|
|
28
|
+
|
|
29
|
+
# PyInstaller
|
|
30
|
+
# Usually these files are written by a python script from a template
|
|
31
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
32
|
+
*.manifest
|
|
33
|
+
*.spec
|
|
34
|
+
|
|
35
|
+
# Installer logs
|
|
36
|
+
pip-log.txt
|
|
37
|
+
pip-delete-this-directory.txt
|
|
38
|
+
|
|
39
|
+
# Unit test / coverage reports
|
|
40
|
+
htmlcov/
|
|
41
|
+
.tox/
|
|
42
|
+
.nox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*.cover
|
|
49
|
+
*.py,cover
|
|
50
|
+
.hypothesis/
|
|
51
|
+
.pytest_cache/
|
|
52
|
+
cover/
|
|
53
|
+
|
|
54
|
+
# Translations
|
|
55
|
+
*.mo
|
|
56
|
+
*.pot
|
|
57
|
+
|
|
58
|
+
# Django stuff:
|
|
59
|
+
*.log
|
|
60
|
+
local_settings.py
|
|
61
|
+
db.sqlite3
|
|
62
|
+
db.sqlite3-journal
|
|
63
|
+
|
|
64
|
+
# Flask stuff:
|
|
65
|
+
instance/
|
|
66
|
+
.webassets-cache
|
|
67
|
+
|
|
68
|
+
# Scrapy stuff:
|
|
69
|
+
.scrapy
|
|
70
|
+
|
|
71
|
+
# Sphinx documentation
|
|
72
|
+
docs/_build/
|
|
73
|
+
|
|
74
|
+
# PyBuilder
|
|
75
|
+
.pybuilder/
|
|
76
|
+
target/
|
|
77
|
+
|
|
78
|
+
# Jupyter Notebook
|
|
79
|
+
.ipynb_checkpoints
|
|
80
|
+
|
|
81
|
+
# IPython
|
|
82
|
+
profile_default/
|
|
83
|
+
ipython_config.py
|
|
84
|
+
|
|
85
|
+
# pyenv
|
|
86
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
87
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
88
|
+
# .python-version
|
|
89
|
+
|
|
90
|
+
# pipenv
|
|
91
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
92
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
93
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
94
|
+
# install all needed dependencies.
|
|
95
|
+
#Pipfile.lock
|
|
96
|
+
|
|
97
|
+
# poetry
|
|
98
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
99
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
100
|
+
# commonly ignored for libraries.
|
|
101
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
102
|
+
#poetry.lock
|
|
103
|
+
|
|
104
|
+
# pdm
|
|
105
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
106
|
+
#pdm.lock
|
|
107
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
108
|
+
# in version control.
|
|
109
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
|
110
|
+
.pdm.toml
|
|
111
|
+
.pdm-python
|
|
112
|
+
.pdm-build/
|
|
113
|
+
|
|
114
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
115
|
+
__pypackages__/
|
|
116
|
+
|
|
117
|
+
# Celery stuff
|
|
118
|
+
celerybeat-schedule
|
|
119
|
+
celerybeat.pid
|
|
120
|
+
|
|
121
|
+
# SageMath parsed files
|
|
122
|
+
*.sage.py
|
|
123
|
+
|
|
124
|
+
# Environments
|
|
125
|
+
.env
|
|
126
|
+
.venv
|
|
127
|
+
env/
|
|
128
|
+
venv/
|
|
129
|
+
ENV/
|
|
130
|
+
env.bak/
|
|
131
|
+
venv.bak/
|
|
132
|
+
|
|
133
|
+
# Spyder project settings
|
|
134
|
+
.spyderproject
|
|
135
|
+
.spyproject
|
|
136
|
+
|
|
137
|
+
# Rope project settings
|
|
138
|
+
.ropeproject
|
|
139
|
+
|
|
140
|
+
# mkdocs documentation
|
|
141
|
+
/site
|
|
142
|
+
|
|
143
|
+
# mypy
|
|
144
|
+
.mypy_cache/
|
|
145
|
+
.dmypy.json
|
|
146
|
+
dmypy.json
|
|
147
|
+
|
|
148
|
+
# Pyre type checker
|
|
149
|
+
.pyre/
|
|
150
|
+
|
|
151
|
+
# pytype static type analyzer
|
|
152
|
+
.pytype/
|
|
153
|
+
|
|
154
|
+
# Cython debug symbols
|
|
155
|
+
cython_debug/
|
|
156
|
+
|
|
157
|
+
# PyCharm
|
|
158
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
159
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
160
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
161
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
162
|
+
#.idea/
|
caskade-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Connor Stone, PhD
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
caskade-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: caskade
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Package for building scientific simulators, with dynamic arguments arranged in a directed acyclic graph.
|
|
5
|
+
Project-URL: Homepage, https://github.com/ConnorStoneAstro/caskade
|
|
6
|
+
Project-URL: Documentation, https://github.com/ConnorStoneAstro/caskade
|
|
7
|
+
Project-URL: Repository, https://github.com/ConnorStoneAstro/caskade
|
|
8
|
+
Project-URL: Issues, https://github.com/ConnorStoneAstro/caskade/issues
|
|
9
|
+
Author-email: Connor Stone <connorstone628@gmail.com>, Alexandre Adam <alexandre.adam@mila.quebec>
|
|
10
|
+
License: MIT License
|
|
11
|
+
|
|
12
|
+
Copyright (c) 2024 Connor Stone, PhD
|
|
13
|
+
|
|
14
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
15
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
16
|
+
in the Software without restriction, including without limitation the rights
|
|
17
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
18
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
19
|
+
furnished to do so, subject to the following conditions:
|
|
20
|
+
|
|
21
|
+
The above copyright notice and this permission notice shall be included in all
|
|
22
|
+
copies or substantial portions of the Software.
|
|
23
|
+
|
|
24
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
25
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
26
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
27
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
28
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
29
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
30
|
+
SOFTWARE.
|
|
31
|
+
License-File: LICENSE
|
|
32
|
+
Keywords: DAG,caskade,differentiable programming,pytorch,scientific python
|
|
33
|
+
Classifier: Development Status :: 1 - Planning
|
|
34
|
+
Classifier: Intended Audience :: Science/Research
|
|
35
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
36
|
+
Classifier: Operating System :: OS Independent
|
|
37
|
+
Classifier: Programming Language :: Python :: 3
|
|
38
|
+
Requires-Python: >=3.9
|
|
39
|
+
Requires-Dist: torch
|
|
40
|
+
Provides-Extra: dev
|
|
41
|
+
Requires-Dist: pre-commit<4,>=3.6; extra == 'dev'
|
|
42
|
+
Requires-Dist: pytest-cov<5,>=4.1; extra == 'dev'
|
|
43
|
+
Requires-Dist: pytest-mock<4,>=3.12; extra == 'dev'
|
|
44
|
+
Requires-Dist: pytest<9,>=8.0; extra == 'dev'
|
|
45
|
+
Description-Content-Type: text/markdown
|
|
46
|
+
|
|
47
|
+
# caskade
|
|
48
|
+
|
|
49
|
+
Build scientific simulators, treating them as a directed acyclic graph. Handles
|
|
50
|
+
argument passing for complex nested simulators.
|
|
51
|
+
|
|
52
|
+
## Install
|
|
53
|
+
|
|
54
|
+
``` bash
|
|
55
|
+
pip install caskade
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## Usage
|
|
59
|
+
|
|
60
|
+
Make a `Module` object which may have some `Param`s. Define a `forward` method
|
|
61
|
+
using the decorator.
|
|
62
|
+
|
|
63
|
+
``` python
|
|
64
|
+
from caskade import Module, Param, forward
|
|
65
|
+
|
|
66
|
+
class MySim(Module):
|
|
67
|
+
def __init__(self, a, b=None):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.a = a
|
|
70
|
+
self.b = Param("b", b)
|
|
71
|
+
|
|
72
|
+
@forward
|
|
73
|
+
def myfun(self, x, b=None):
|
|
74
|
+
return x + self.a + b
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
We may now create instances of the simulator and pass the dynamic parameters.
|
|
78
|
+
|
|
79
|
+
``` python
|
|
80
|
+
import torch
|
|
81
|
+
|
|
82
|
+
sim = MySim(1.0)
|
|
83
|
+
|
|
84
|
+
params = [torch.tensor(2.0)]
|
|
85
|
+
|
|
86
|
+
print(sim.myfun(3.0, params=params))
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
Which will print `6` by automatically filling `b` with the value from `params`.
|
|
90
|
+
|
|
91
|
+
### Why do this?
|
|
92
|
+
|
|
93
|
+
The above example is not very impressive, the real power comes from the fact
|
|
94
|
+
that `Module` objects can be nested arbitrarily making a much more complicated
|
|
95
|
+
analysis graph. Further, the `Param` objects can be linked or have other complex
|
|
96
|
+
relationships. All of the complexity of the nested structure and argument
|
|
97
|
+
passing is abstracted away so that at the top one need only pass a list of
|
|
98
|
+
tensors for each parameter, a single large 1d tensor, or a dictionary with the
|
|
99
|
+
same structure as the graph.
|
caskade-0.0.1/README.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# caskade
|
|
2
|
+
|
|
3
|
+
Build scientific simulators, treating them as a directed acyclic graph. Handles
|
|
4
|
+
argument passing for complex nested simulators.
|
|
5
|
+
|
|
6
|
+
## Install
|
|
7
|
+
|
|
8
|
+
``` bash
|
|
9
|
+
pip install caskade
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
## Usage
|
|
13
|
+
|
|
14
|
+
Make a `Module` object which may have some `Param`s. Define a `forward` method
|
|
15
|
+
using the decorator.
|
|
16
|
+
|
|
17
|
+
``` python
|
|
18
|
+
from caskade import Module, Param, forward
|
|
19
|
+
|
|
20
|
+
class MySim(Module):
|
|
21
|
+
def __init__(self, a, b=None):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.a = a
|
|
24
|
+
self.b = Param("b", b)
|
|
25
|
+
|
|
26
|
+
@forward
|
|
27
|
+
def myfun(self, x, b=None):
|
|
28
|
+
return x + self.a + b
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
We may now create instances of the simulator and pass the dynamic parameters.
|
|
32
|
+
|
|
33
|
+
``` python
|
|
34
|
+
import torch
|
|
35
|
+
|
|
36
|
+
sim = MySim(1.0)
|
|
37
|
+
|
|
38
|
+
params = [torch.tensor(2.0)]
|
|
39
|
+
|
|
40
|
+
print(sim.myfun(3.0, params=params))
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Which will print `6` by automatically filling `b` with the value from `params`.
|
|
44
|
+
|
|
45
|
+
### Why do this?
|
|
46
|
+
|
|
47
|
+
The above example is not very impressive, the real power comes from the fact
|
|
48
|
+
that `Module` objects can be nested arbitrarily making a much more complicated
|
|
49
|
+
analysis graph. Further, the `Param` objects can be linked or have other complex
|
|
50
|
+
relationships. All of the complexity of the nested structure and argument
|
|
51
|
+
passing is abstracted away so that at the top one need only pass a list of
|
|
52
|
+
tensors for each parameter, a single large 1d tensor, or a dictionary with the
|
|
53
|
+
same structure as the graph.
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling", "hatch-requirements-txt", "hatch-vcs"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "caskade"
|
|
7
|
+
dynamic = [
|
|
8
|
+
"dependencies",
|
|
9
|
+
"version"
|
|
10
|
+
]
|
|
11
|
+
authors = [
|
|
12
|
+
{ name="Connor Stone", email="connorstone628@gmail.com" },
|
|
13
|
+
{ name="Alexandre Adam", email="alexandre.adam@mila.quebec" },
|
|
14
|
+
]
|
|
15
|
+
description = "Package for building scientific simulators, with dynamic arguments arranged in a directed acyclic graph."
|
|
16
|
+
readme = "README.md"
|
|
17
|
+
requires-python = ">=3.9"
|
|
18
|
+
license = {file = "LICENSE"}
|
|
19
|
+
keywords = [
|
|
20
|
+
"caskade",
|
|
21
|
+
"DAG",
|
|
22
|
+
"scientific python",
|
|
23
|
+
"differentiable programming",
|
|
24
|
+
"pytorch"
|
|
25
|
+
]
|
|
26
|
+
classifiers=[
|
|
27
|
+
"Development Status :: 1 - Planning",
|
|
28
|
+
"Intended Audience :: Science/Research",
|
|
29
|
+
"License :: OSI Approved :: MIT License",
|
|
30
|
+
"Operating System :: OS Independent",
|
|
31
|
+
"Programming Language :: Python :: 3"
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[project.urls]
|
|
35
|
+
Homepage = "https://github.com/ConnorStoneAstro/caskade"
|
|
36
|
+
Documentation = "https://github.com/ConnorStoneAstro/caskade"
|
|
37
|
+
Repository = "https://github.com/ConnorStoneAstro/caskade"
|
|
38
|
+
Issues = "https://github.com/ConnorStoneAstro/caskade/issues"
|
|
39
|
+
|
|
40
|
+
[project.optional-dependencies]
|
|
41
|
+
dev = [
|
|
42
|
+
"pytest>=8.0,<9",
|
|
43
|
+
"pytest-cov>=4.1,<5",
|
|
44
|
+
"pytest-mock>=3.12,<4",
|
|
45
|
+
"pre-commit>=3.6,<4",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
[tool.hatch.metadata.hooks.requirements_txt]
|
|
49
|
+
files = ["requirements.txt"]
|
|
50
|
+
|
|
51
|
+
[tool.hatch.version]
|
|
52
|
+
source = "vcs"
|
|
53
|
+
|
|
54
|
+
[tool.hatch.build.hooks.vcs]
|
|
55
|
+
version-file = "src/caskade/_version.py"
|
|
56
|
+
|
|
57
|
+
[tool.hatch.version.raw-options]
|
|
58
|
+
local_scheme = "no-local-version"
|
|
59
|
+
|
|
60
|
+
[tool.ruff]
|
|
61
|
+
line-length = 100
|
|
62
|
+
|
|
63
|
+
[tool.pytest.ini_options]
|
|
64
|
+
norecursedirs = "tests/utils"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torch
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from ._version import version as VERSION # noqa
|
|
2
|
+
|
|
3
|
+
from .base import Node
|
|
4
|
+
from .context import ActiveContext
|
|
5
|
+
from .decorators import forward
|
|
6
|
+
from .module import Module
|
|
7
|
+
from .param import Param, LiveParam
|
|
8
|
+
from .tests import test
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__version__ = VERSION
|
|
12
|
+
__author__ = "Connor and Alexandre"
|
|
13
|
+
|
|
14
|
+
__all__ = ("Node", "Module", "Param", "LiveParam", "ActiveContext", "forward")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# file generated by setuptools_scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
TYPE_CHECKING = False
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from typing import Tuple, Union
|
|
6
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
7
|
+
else:
|
|
8
|
+
VERSION_TUPLE = object
|
|
9
|
+
|
|
10
|
+
version: str
|
|
11
|
+
__version__: str
|
|
12
|
+
__version_tuple__: VERSION_TUPLE
|
|
13
|
+
version_tuple: VERSION_TUPLE
|
|
14
|
+
|
|
15
|
+
__version__ = version = '0.0.1'
|
|
16
|
+
__version_tuple__ = version_tuple = (0, 0, 1)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Node:
|
|
7
|
+
"""
|
|
8
|
+
Base graph node class for caskade objects.
|
|
9
|
+
|
|
10
|
+
The `Node` object is the base class for all caskade objects. It is used to
|
|
11
|
+
construct the directed acyclic graph (DAG). The primary function of the
|
|
12
|
+
`Node` object is to manage the parent-child relationships between nodes in
|
|
13
|
+
the graph. There is limited functionality for the `Node` object, though it
|
|
14
|
+
implements the base versions of the `active` state and `to` /
|
|
15
|
+
`update_dynamic_params` methods. The `active` state is used to communicate
|
|
16
|
+
through the graph that the simulator is currently running. The `to` method
|
|
17
|
+
is used to move and/or cast the values of the parameter. The
|
|
18
|
+
`update_dynamic_params` method is used by `Module` objects to keep track of
|
|
19
|
+
all dynamic `Param` objects below them in the graph.
|
|
20
|
+
|
|
21
|
+
Examples
|
|
22
|
+
--------
|
|
23
|
+
``` python
|
|
24
|
+
n1 = Node("node1")
|
|
25
|
+
n2 = Node("node2")
|
|
26
|
+
n1.link("subnode", n2) # link n2 as a child of n1, may use any str as the key
|
|
27
|
+
n1.unlink("subnode") # alternately n1.unlink(n2) to unlink by object
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, name):
|
|
31
|
+
assert isinstance(name, str), f"{self.__class__.__name__} name must be a string"
|
|
32
|
+
assert "|" not in name, f"{self.__class__.__name__} cannot contain '|'"
|
|
33
|
+
self._name = name
|
|
34
|
+
self._children = {}
|
|
35
|
+
self._parents = set()
|
|
36
|
+
self._active = False
|
|
37
|
+
self._type = "node"
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def name(self) -> str:
|
|
41
|
+
return self._name
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def children(self) -> dict:
|
|
45
|
+
return self._children
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def parents(self) -> set:
|
|
49
|
+
return self._parents
|
|
50
|
+
|
|
51
|
+
def link(self, key, child):
|
|
52
|
+
# Avoid double linking to the same object
|
|
53
|
+
if key in self.children:
|
|
54
|
+
raise ValueError(f"Child key {key} already linked to parent {self.name}")
|
|
55
|
+
for ownchild in self.children.values():
|
|
56
|
+
if ownchild == child:
|
|
57
|
+
raise ValueError(f"Child {child.name} already linked to parent {self.name}")
|
|
58
|
+
|
|
59
|
+
self._children[key] = child
|
|
60
|
+
child._parents.add(self)
|
|
61
|
+
self.update_dynamic_params()
|
|
62
|
+
|
|
63
|
+
def unlink(self, key):
|
|
64
|
+
if isinstance(key, Node):
|
|
65
|
+
for node in self.children:
|
|
66
|
+
if self.children[node] == key:
|
|
67
|
+
key = node
|
|
68
|
+
break
|
|
69
|
+
self._children[key]._parents.remove(self)
|
|
70
|
+
self._children[key].update_dynamic_params()
|
|
71
|
+
del self._children[key]
|
|
72
|
+
self.update_dynamic_params()
|
|
73
|
+
|
|
74
|
+
def topological_ordering(self, with_type=None) -> tuple:
|
|
75
|
+
ordering = [self]
|
|
76
|
+
for node in self.children.values():
|
|
77
|
+
for subnode in node.topological_ordering():
|
|
78
|
+
if subnode not in ordering:
|
|
79
|
+
ordering.append(subnode)
|
|
80
|
+
if with_type is None:
|
|
81
|
+
return tuple(ordering)
|
|
82
|
+
return tuple(filter(lambda n: n._type == with_type, ordering))
|
|
83
|
+
|
|
84
|
+
def update_dynamic_params(self):
|
|
85
|
+
for parent in self.parents:
|
|
86
|
+
parent.update_dynamic_params()
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def active(self) -> bool:
|
|
90
|
+
return self._active
|
|
91
|
+
|
|
92
|
+
@active.setter
|
|
93
|
+
def active(self, value):
|
|
94
|
+
# Avoid unnecessary updates
|
|
95
|
+
if self._active == value:
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
# Set self active level
|
|
99
|
+
self._active = value
|
|
100
|
+
|
|
101
|
+
# Propagate active level to children
|
|
102
|
+
for child in self._children.values():
|
|
103
|
+
child.active = value
|
|
104
|
+
|
|
105
|
+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
|
106
|
+
"""
|
|
107
|
+
Moves and/or casts the PyTorch values of the Node.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
device: (Optional[torch.device], optional)
|
|
112
|
+
The device to move the values to. Defaults to None.
|
|
113
|
+
dtype: (Optional[torch.dtype], optional)
|
|
114
|
+
The desired data type. Defaults to None.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
for child in self.children.values():
|
|
118
|
+
child.to(device=device, dtype=dtype)
|
|
119
|
+
|
|
120
|
+
def graph_dict(self) -> dict:
|
|
121
|
+
graph = {
|
|
122
|
+
f"{self.name}|{self._type}": {},
|
|
123
|
+
}
|
|
124
|
+
for node in self.children.values():
|
|
125
|
+
graph[f"{self.name}|{self._type}"].update(node.graph_dict())
|
|
126
|
+
return graph
|
|
127
|
+
|
|
128
|
+
def __str__(self) -> str:
|
|
129
|
+
return str(self.graph_dict())
|
|
130
|
+
|
|
131
|
+
def __repr__(self) -> str:
|
|
132
|
+
return f"{self.__class__.__name__}({self.name})"
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from typing import Union, Mapping, Sequence
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from .module import Module
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ActiveContext:
|
|
9
|
+
def __init__(
|
|
10
|
+
self, module: Module, params: Union[Sequence[Tensor], Mapping[str, Tensor], Tensor]
|
|
11
|
+
):
|
|
12
|
+
self.module = module
|
|
13
|
+
self.params = params
|
|
14
|
+
|
|
15
|
+
def __enter__(self):
|
|
16
|
+
self.module.active = True
|
|
17
|
+
self.module.fill_params(self.params)
|
|
18
|
+
|
|
19
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
20
|
+
self.module.clear_params()
|
|
21
|
+
self.module.active = False
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import functools
|
|
3
|
+
|
|
4
|
+
from .context import ActiveContext
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def forward(method):
|
|
8
|
+
"""
|
|
9
|
+
Decorator to define a forward method for a module.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
method: (Callable)
|
|
14
|
+
The forward method to be decorated.
|
|
15
|
+
|
|
16
|
+
Returns
|
|
17
|
+
-------
|
|
18
|
+
Callable
|
|
19
|
+
The decorated forward method.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# Get kwargs from function signature
|
|
23
|
+
method_kwargs = []
|
|
24
|
+
for arg in inspect.signature(method).parameters.values():
|
|
25
|
+
if arg.default is not arg.empty:
|
|
26
|
+
method_kwargs.append(arg.name)
|
|
27
|
+
|
|
28
|
+
@functools.wraps(method)
|
|
29
|
+
def wrapped(self, *args, **kwargs):
|
|
30
|
+
if self.active:
|
|
31
|
+
kwargs.update(self.fill_kwargs(method_kwargs))
|
|
32
|
+
return method(self, *args, **kwargs)
|
|
33
|
+
|
|
34
|
+
# Extract params from the arguments
|
|
35
|
+
if len(self.dynamic_params) == 0:
|
|
36
|
+
params = {}
|
|
37
|
+
elif "params" in kwargs:
|
|
38
|
+
params = kwargs.pop("params")
|
|
39
|
+
elif args:
|
|
40
|
+
params = args.pop(0)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Params must be provided for dynamic modules. Expected {len(self.dynamic_params)} params."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
with ActiveContext(self, params):
|
|
47
|
+
kwargs.update(self.fill_kwargs(method_kwargs))
|
|
48
|
+
return method(self, *args, **kwargs)
|
|
49
|
+
|
|
50
|
+
return wrapped
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Sequence, Mapping
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from .base import Node
|
|
6
|
+
from .param import Param, LiveParam
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Module(Node):
|
|
10
|
+
|
|
11
|
+
def __init__(self, name):
|
|
12
|
+
super().__init__(name=name)
|
|
13
|
+
self.dynamic_params = ()
|
|
14
|
+
self.live_params = ()
|
|
15
|
+
self._type = "module"
|
|
16
|
+
self._batch = False
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def batch(self) -> bool:
|
|
20
|
+
return self._batch
|
|
21
|
+
|
|
22
|
+
@batch.setter
|
|
23
|
+
def batch(self, value):
|
|
24
|
+
assert isinstance(value, bool)
|
|
25
|
+
self._batch = value
|
|
26
|
+
|
|
27
|
+
def update_dynamic_params(self):
|
|
28
|
+
super().update_dynamic_params()
|
|
29
|
+
self.dynamic_params = tuple(self.topological_ordering("dynamic"))
|
|
30
|
+
self.live_params = tuple(self.topological_ordering("live"))
|
|
31
|
+
|
|
32
|
+
def fill_params(self, params):
|
|
33
|
+
assert self.active, "Module must be active to fill params"
|
|
34
|
+
|
|
35
|
+
if isinstance(params, Tensor):
|
|
36
|
+
if self.batch:
|
|
37
|
+
B = params.shape[0]
|
|
38
|
+
pos = 0
|
|
39
|
+
for param in self.dynamic_params:
|
|
40
|
+
try:
|
|
41
|
+
size = param.shape.numel()
|
|
42
|
+
except AttributeError:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Param {param.name} has no shape. dynamic parameters must have a shape to use Tensor input."
|
|
45
|
+
)
|
|
46
|
+
if self.batch:
|
|
47
|
+
param.value = params[:, pos : pos + size].view((B,) + param.shape)
|
|
48
|
+
pos += size * B
|
|
49
|
+
else:
|
|
50
|
+
param.value = params[pos : pos + size].view(param.shape)
|
|
51
|
+
pos += size
|
|
52
|
+
elif isinstance(params, Sequence):
|
|
53
|
+
if len(params) == len(self.dynamic_params):
|
|
54
|
+
for param, value in zip(self.dynamic_params, params):
|
|
55
|
+
param.value = value
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Input params length ({len(params)}) does not match dynamic params length ({len(self.dynamic_params)})"
|
|
59
|
+
)
|
|
60
|
+
elif isinstance(params, Mapping):
|
|
61
|
+
for key in params:
|
|
62
|
+
if key in self.children:
|
|
63
|
+
if isinstance(self.children[key], Param):
|
|
64
|
+
self.children[key].value = params[key]
|
|
65
|
+
elif isinstance(self.children[key], Module):
|
|
66
|
+
self.children[key].fill_params(params[key])
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(f"Key {key} type {type(self.children[key])} not supported")
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(f"Key {key} not found in {self.name} children")
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Input params type {type(params)} not supported. Should be Tensor, Sequence or Mapping."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def clear_params(self):
|
|
77
|
+
assert self.active, "Module must be active to clear params"
|
|
78
|
+
|
|
79
|
+
for param in self.dynamic_params:
|
|
80
|
+
param.value = None
|
|
81
|
+
|
|
82
|
+
for param in self.live_params:
|
|
83
|
+
param.value = LiveParam
|
|
84
|
+
|
|
85
|
+
def fill_kwargs(self, keys) -> dict[str, Tensor]:
|
|
86
|
+
return {key: getattr(self, key).value for key in keys}
|
|
87
|
+
|
|
88
|
+
def __setattr__(self, key, value):
|
|
89
|
+
try:
|
|
90
|
+
if key in self.children and isinstance(self.children[key], Param):
|
|
91
|
+
self.children[key].value = value
|
|
92
|
+
return
|
|
93
|
+
if isinstance(value, Node):
|
|
94
|
+
self.link(key, value)
|
|
95
|
+
self.update_dynamic_params()
|
|
96
|
+
|
|
97
|
+
super().__setattr__(key, value)
|
|
98
|
+
except AttributeError:
|
|
99
|
+
super().__setattr__(key, value)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from typing import Optional, Union, Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from .base import Node
|
|
7
|
+
|
|
8
|
+
__all__ = ("Param", "LiveParam")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LiveParamBase:
|
|
12
|
+
"""Placeholder to identify a parameter as live updating. Like `None` there
|
|
13
|
+
exists only one instance of this class."""
|
|
14
|
+
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
LiveParam = LiveParamBase()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Param(Node):
|
|
22
|
+
"""
|
|
23
|
+
Node to represent a parameter in the graph.
|
|
24
|
+
|
|
25
|
+
The `Param` object is used to represent a parameter in the graph. During
|
|
26
|
+
runtime this will represent a tensor value which can be used in various
|
|
27
|
+
calculations. The `Param` object can be set to a constant value (`value`);
|
|
28
|
+
`None` meaning the value is to be provided at runtime (`dynamic`);
|
|
29
|
+
`LiveParam` meaning the value will be computed internally in the simulator
|
|
30
|
+
during runtime (`live`); another `Param` object meaning it will take on that
|
|
31
|
+
value at runtime (`pointer`); or a function of other `Param` objects to be
|
|
32
|
+
computed at runtime (`function`). These options allow users to flexibly set
|
|
33
|
+
the behavior of the simulator.
|
|
34
|
+
|
|
35
|
+
Examples
|
|
36
|
+
--------
|
|
37
|
+
``` python
|
|
38
|
+
p1 = Param("test", (1.0, 2.0)) # constant value, length 2 vector
|
|
39
|
+
p2 = Param("test", None, (2,2)) # dynamic 2x2 matrix value
|
|
40
|
+
p3 = Param("test", LiveParam) # live updating value
|
|
41
|
+
p4 = Param("test", p1) # pointer to another parameter
|
|
42
|
+
p5 = Param("test", lambda p: p.children["other"].value * 2) # function of another parameter
|
|
43
|
+
p5.link("other", p2) # link the other parameter needed for the function
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
name: (str)
|
|
49
|
+
The name of the parameter.
|
|
50
|
+
value: (Optional[Union[Tensor, float, int]], optional)
|
|
51
|
+
The value of the parameter. Defaults to None meaning dynamic.
|
|
52
|
+
shape: (Optional[tuple[int, ...]], optional)
|
|
53
|
+
The shape of the parameter. Defaults to () meaning scalar.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
name,
|
|
59
|
+
value: Optional[Union[Tensor, float, int]] = None,
|
|
60
|
+
shape: Optional[tuple[int, ...]] = (),
|
|
61
|
+
):
|
|
62
|
+
super().__init__(name=name)
|
|
63
|
+
if value is None:
|
|
64
|
+
if shape is None:
|
|
65
|
+
raise ValueError("Either value or shape must be provided")
|
|
66
|
+
if not isinstance(shape, tuple):
|
|
67
|
+
raise ValueError("Shape must be a tuple")
|
|
68
|
+
self.shape = shape
|
|
69
|
+
elif not isinstance(value, (Param, Callable, LiveParamBase)):
|
|
70
|
+
value = torch.as_tensor(value)
|
|
71
|
+
assert (
|
|
72
|
+
shape == () or shape == value.shape
|
|
73
|
+
), f"Shape {shape} does not match value shape {value.shape}"
|
|
74
|
+
self.value = value
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def dynamic(self) -> bool:
|
|
78
|
+
return self._type == "dynamic"
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def live(self) -> bool:
|
|
82
|
+
return self._type == "live"
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def shape(self) -> tuple:
|
|
86
|
+
return self._shape
|
|
87
|
+
|
|
88
|
+
@shape.setter
|
|
89
|
+
def shape(self, shape):
|
|
90
|
+
if self._type in ["pointer", "function"]:
|
|
91
|
+
raise RuntimeError("Cannot set shape of parameter with type 'pointer' or 'function'")
|
|
92
|
+
self._shape = shape
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def value(self) -> Union[Tensor, None]:
|
|
96
|
+
if self._type == "pointer":
|
|
97
|
+
return self._value.value
|
|
98
|
+
if self._type == "function":
|
|
99
|
+
return self._value(self)
|
|
100
|
+
return self._value
|
|
101
|
+
|
|
102
|
+
@value.setter
|
|
103
|
+
def value(self, value):
|
|
104
|
+
# While active, update silently
|
|
105
|
+
if self.active:
|
|
106
|
+
if self.dynamic or self.live:
|
|
107
|
+
self._value = value
|
|
108
|
+
return
|
|
109
|
+
raise RuntimeError(f"Cannot set value of non-live parameter {self.name} while active")
|
|
110
|
+
|
|
111
|
+
# unlink if pointer to avoid floating references
|
|
112
|
+
if self._type == "pointer":
|
|
113
|
+
self.unlink(self._value)
|
|
114
|
+
|
|
115
|
+
if value is None:
|
|
116
|
+
self._type = "dynamic"
|
|
117
|
+
elif isinstance(value, LiveParamBase):
|
|
118
|
+
self._type = "live"
|
|
119
|
+
elif isinstance(value, Param):
|
|
120
|
+
self._type = "pointer"
|
|
121
|
+
self.link(value.name, value)
|
|
122
|
+
self._shape = None
|
|
123
|
+
elif callable(value):
|
|
124
|
+
self._type = "function"
|
|
125
|
+
self._shape = None
|
|
126
|
+
else:
|
|
127
|
+
self._type = "value"
|
|
128
|
+
value = torch.as_tensor(value)
|
|
129
|
+
self.shape = value.shape
|
|
130
|
+
|
|
131
|
+
self._value = value
|
|
132
|
+
self.update_dynamic_params()
|
|
133
|
+
|
|
134
|
+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
|
135
|
+
"""
|
|
136
|
+
Moves and/or casts the values of the parameter.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
device: (Optional[torch.device], optional)
|
|
141
|
+
The device to move the values to. Defaults to None.
|
|
142
|
+
dtype: (Optional[torch.dtype], optional)
|
|
143
|
+
The desired data type. Defaults to None.
|
|
144
|
+
"""
|
|
145
|
+
super().to(device=device, dtype=dtype)
|
|
146
|
+
if self._type == "value":
|
|
147
|
+
self._value = self._value.to(device=device, dtype=dtype)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from caskade import Module, Param, forward, LiveParam
|
|
4
|
+
|
|
5
|
+
__all__ = ("test",)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _test_full_integration():
|
|
9
|
+
|
|
10
|
+
class TestSim(Module):
|
|
11
|
+
def __init__(self, a, b, c, c_shape, m1):
|
|
12
|
+
super().__init__("test_sim")
|
|
13
|
+
self.a = a
|
|
14
|
+
self.b = Param("b", b)
|
|
15
|
+
self.c = Param("c", c, c_shape)
|
|
16
|
+
self.m1 = m1
|
|
17
|
+
|
|
18
|
+
@forward
|
|
19
|
+
def testfun(self, x, b=None):
|
|
20
|
+
self.c.value = b + x
|
|
21
|
+
y = self.m1()
|
|
22
|
+
return x + self.a + b + y
|
|
23
|
+
|
|
24
|
+
class TestSubSim(Module):
|
|
25
|
+
def __init__(self, d, e, f):
|
|
26
|
+
super().__init__("test_sub_sim")
|
|
27
|
+
self.d = Param("d", d)
|
|
28
|
+
self.e = Param("e", e)
|
|
29
|
+
self.f = Param("f", f)
|
|
30
|
+
|
|
31
|
+
@forward
|
|
32
|
+
def __call__(self, d=None, e=None, f=None):
|
|
33
|
+
return d + e + f
|
|
34
|
+
|
|
35
|
+
sub1 = TestSubSim(d=1.0, e=lambda s: s.children["flink"].value, f=None)
|
|
36
|
+
sub1.e.link("flink", sub1.f)
|
|
37
|
+
main1 = TestSim(a=2.0, b=None, c=LiveParam, c_shape=(), m1=sub1)
|
|
38
|
+
sub1.f = main1.c
|
|
39
|
+
|
|
40
|
+
b_value = torch.tensor(3.0)
|
|
41
|
+
res = main1.testfun(1.0, params=[b_value])
|
|
42
|
+
assert res.item() == 15.0
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test():
|
|
46
|
+
_test_full_integration()
|
|
47
|
+
print("Success!")
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from caskade import Node
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_creation():
|
|
7
|
+
node = Node("test")
|
|
8
|
+
assert node._name == "test"
|
|
9
|
+
assert node._children == {}
|
|
10
|
+
assert node._parents == set()
|
|
11
|
+
assert node._active == False
|
|
12
|
+
assert node._type == "node"
|
|
13
|
+
|
|
14
|
+
with pytest.raises(AttributeError):
|
|
15
|
+
node.name = "newname"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_link():
|
|
19
|
+
node1 = Node("node1")
|
|
20
|
+
node2 = Node("node2")
|
|
21
|
+
node1.link("subnode", node2)
|
|
22
|
+
|
|
23
|
+
assert "subnode" in node1._children
|
|
24
|
+
assert node1._children["subnode"] == node2
|
|
25
|
+
assert node1._parents == set()
|
|
26
|
+
assert node2._parents == set([node1])
|
|
27
|
+
|
|
28
|
+
str(node1)
|
|
29
|
+
repr(node1)
|
|
30
|
+
|
|
31
|
+
node1.unlink(node2)
|
|
32
|
+
assert "subnode" not in node1._children
|
|
33
|
+
assert node2._parents == set()
|
|
34
|
+
assert node1._parents == set()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_topological_ordering():
|
|
38
|
+
node1 = Node("node1")
|
|
39
|
+
node2 = Node("node2")
|
|
40
|
+
node3 = Node("node3")
|
|
41
|
+
node4 = Node("node4")
|
|
42
|
+
node5 = Node("node5")
|
|
43
|
+
node6 = Node("node6")
|
|
44
|
+
|
|
45
|
+
node1.link("subnode1", node2)
|
|
46
|
+
node1.link("subnode2", node3)
|
|
47
|
+
node2.link("subnode3", node4)
|
|
48
|
+
node2.link("subnode4", node5)
|
|
49
|
+
node3.link("subnode5", node6)
|
|
50
|
+
|
|
51
|
+
ordering = node1.topological_ordering()
|
|
52
|
+
assert ordering == (node1, node2, node4, node5, node3, node6)
|
|
53
|
+
|
|
54
|
+
ordering = node1.topological_ordering(with_type="node")
|
|
55
|
+
assert ordering == (node1, node2, node4, node5, node3, node6)
|
|
56
|
+
|
|
57
|
+
ordering = node1.topological_ordering(with_type="dynamic")
|
|
58
|
+
assert ordering == ()
|
|
59
|
+
|
|
60
|
+
node1.unlink("subnode1")
|
|
61
|
+
ordering = node1.topological_ordering()
|
|
62
|
+
assert ordering == (node1, node3, node6)
|
|
63
|
+
|
|
64
|
+
node1.unlink("subnode2")
|
|
65
|
+
ordering = node1.topological_ordering()
|
|
66
|
+
assert ordering == (node1,)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_active():
|
|
70
|
+
node1 = Node("node1")
|
|
71
|
+
node2 = Node("node2")
|
|
72
|
+
node3 = Node("node3")
|
|
73
|
+
node4 = Node("node4")
|
|
74
|
+
node5 = Node("node5")
|
|
75
|
+
node6 = Node("node6")
|
|
76
|
+
|
|
77
|
+
node1.link("subnode1", node2)
|
|
78
|
+
node1.link("subnode2", node3)
|
|
79
|
+
node2.link("subnode3", node4)
|
|
80
|
+
node2.link("subnode4", node5)
|
|
81
|
+
node3.link("subnode5", node6)
|
|
82
|
+
|
|
83
|
+
node1.active = True
|
|
84
|
+
assert node1.active == True
|
|
85
|
+
assert node2.active == True
|
|
86
|
+
assert node3.active == True
|
|
87
|
+
assert node4.active == True
|
|
88
|
+
assert node5.active == True
|
|
89
|
+
assert node6.active == True
|
|
90
|
+
|
|
91
|
+
node2.active = False
|
|
92
|
+
assert node1.active == True
|
|
93
|
+
assert node2.active == False
|
|
94
|
+
assert node3.active == True
|
|
95
|
+
assert node4.active == False
|
|
96
|
+
assert node5.active == False
|
|
97
|
+
assert node6.active == True
|
|
98
|
+
|
|
99
|
+
node1.active = False
|
|
100
|
+
assert node1.active == False
|
|
101
|
+
assert node2.active == False
|
|
102
|
+
assert node3.active == False
|
|
103
|
+
assert node4.active == False
|
|
104
|
+
assert node5.active == False
|
|
105
|
+
assert node6.active == False
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from caskade import Module, Param, forward, LiveParam
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_full_integration():
|
|
7
|
+
|
|
8
|
+
class TestSim(Module):
|
|
9
|
+
def __init__(self, a, b, c, c_shape, m1):
|
|
10
|
+
super().__init__("test_sim")
|
|
11
|
+
self.a = a
|
|
12
|
+
self.b = Param("b", b)
|
|
13
|
+
self.c = Param("c", c, c_shape)
|
|
14
|
+
self.m1 = m1
|
|
15
|
+
|
|
16
|
+
@forward
|
|
17
|
+
def testfun(self, x, b=None):
|
|
18
|
+
self.c.value = b + x
|
|
19
|
+
y = self.m1()
|
|
20
|
+
return x + self.a + b + y
|
|
21
|
+
|
|
22
|
+
class TestSubSim(Module):
|
|
23
|
+
def __init__(self, d, e, f):
|
|
24
|
+
super().__init__("test_sub_sim")
|
|
25
|
+
self.d = Param("d", d)
|
|
26
|
+
self.e = Param("e", e)
|
|
27
|
+
self.f = Param("f", f)
|
|
28
|
+
|
|
29
|
+
@forward
|
|
30
|
+
def __call__(self, d=None, e=None, f=None):
|
|
31
|
+
return d + e + f
|
|
32
|
+
|
|
33
|
+
sub1 = TestSubSim(d=1.0, e=lambda s: s.children["flink"].value, f=None)
|
|
34
|
+
sub1.e.link("flink", sub1.f)
|
|
35
|
+
main1 = TestSim(a=2.0, b=None, c=LiveParam, c_shape=(), m1=sub1)
|
|
36
|
+
sub1.f = main1.c
|
|
37
|
+
|
|
38
|
+
b_value = torch.tensor(3.0)
|
|
39
|
+
res = main1.testfun(1.0, params=[b_value])
|
|
40
|
+
assert res.item() == 15.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from caskade import Module
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from caskade import Param, LiveParam
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_live_param():
|
|
8
|
+
lp1 = LiveParam
|
|
9
|
+
lp2 = LiveParam
|
|
10
|
+
|
|
11
|
+
assert lp1 is lp2, "LiveParam is not a singleton"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_param_creation():
|
|
15
|
+
|
|
16
|
+
p1 = Param("test")
|
|
17
|
+
assert p1.name == "test"
|
|
18
|
+
assert p1.dynamic
|
|
19
|
+
assert not p1.live
|
|
20
|
+
assert p1.value is None
|
|
21
|
+
|
|
22
|
+
p2 = Param("test", 1.0)
|
|
23
|
+
assert p2.name == "test"
|
|
24
|
+
assert p2.value.item() == 1.0
|
|
25
|
+
p3 = Param("test", torch.ones((1, 2, 3)))
|
|
26
|
+
with pytest.raises(RuntimeError):
|
|
27
|
+
p3.active = True
|
|
28
|
+
p3.value = 1.0
|
|
29
|
+
|
|
30
|
+
with pytest.raises(AssertionError):
|
|
31
|
+
p4 = Param("test", 1.0, shape=(1, 2, 3))
|
|
32
|
+
|
|
33
|
+
p5 = Param("test", p3)
|
|
34
|
+
with pytest.raises(RuntimeError):
|
|
35
|
+
p5.shape = (1, 2, 3)
|
|
36
|
+
|
|
37
|
+
p6 = Param("test", lambda p: p.children["other"].value * 2)
|
|
38
|
+
p6.link("other", p2)
|
|
39
|
+
with pytest.raises(RuntimeError):
|
|
40
|
+
p6.shape = (1, 2, 3)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_value_setter():
|
|
44
|
+
|
|
45
|
+
# dynamic
|
|
46
|
+
p = Param("test")
|
|
47
|
+
assert p._type == "dynamic"
|
|
48
|
+
|
|
49
|
+
# value
|
|
50
|
+
p.value = 1.0
|
|
51
|
+
assert p._type == "value"
|
|
52
|
+
assert p.value.item() == 1.0
|
|
53
|
+
|
|
54
|
+
p = Param("testshape", shape=(2,))
|
|
55
|
+
p.value = [1.0, 2.0]
|
|
56
|
+
|
|
57
|
+
# pointer
|
|
58
|
+
other = Param("testother", 2.0)
|
|
59
|
+
p.value = other
|
|
60
|
+
assert p._type == "pointer"
|
|
61
|
+
assert p.shape is None
|
|
62
|
+
|
|
63
|
+
# function
|
|
64
|
+
p.value = lambda p: p.children["other"].value * 2
|
|
65
|
+
p.link("other", other)
|
|
66
|
+
assert p._type == "function"
|
|
67
|
+
assert p.value.item() == 4.0
|
|
68
|
+
|
|
69
|
+
# live
|
|
70
|
+
p.value = LiveParam
|
|
71
|
+
assert p._type == "live"
|