ezmsg-learn 1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg_learn-1.0/.github/workflows/python-publish-ezmsg-learn.yml +26 -0
- ezmsg_learn-1.0/.github/workflows/python-tests.yml +41 -0
- ezmsg_learn-1.0/.gitignore +174 -0
- ezmsg_learn-1.0/.pre-commit-config.yaml +7 -0
- ezmsg_learn-1.0/PKG-INFO +34 -0
- ezmsg_learn-1.0/README.md +21 -0
- ezmsg_learn-1.0/pyproject.toml +49 -0
- ezmsg_learn-1.0/src/ezmsg/learn/__init__.py +2 -0
- ezmsg_learn-1.0/src/ezmsg/learn/__version__.py +34 -0
- ezmsg_learn-1.0/src/ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg_learn-1.0/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +284 -0
- ezmsg_learn-1.0/src/ezmsg/learn/dim_reduce/incremental_decomp.py +181 -0
- ezmsg_learn-1.0/src/ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg_learn-1.0/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- ezmsg_learn-1.0/src/ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg_learn-1.0/src/ezmsg/learn/linear_model/linear_regressor.py +5 -0
- ezmsg_learn-1.0/src/ezmsg/learn/linear_model/sgd.py +5 -0
- ezmsg_learn-1.0/src/ezmsg/learn/linear_model/slda.py +6 -0
- ezmsg_learn-1.0/src/ezmsg/learn/model/__init__.py +0 -0
- ezmsg_learn-1.0/src/ezmsg/learn/model/cca.py +122 -0
- ezmsg_learn-1.0/src/ezmsg/learn/model/mlp.py +133 -0
- ezmsg_learn-1.0/src/ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg_learn-1.0/src/ezmsg/learn/model/refit_kalman.py +401 -0
- ezmsg_learn-1.0/src/ezmsg/learn/model/rnn.py +160 -0
- ezmsg_learn-1.0/src/ezmsg/learn/model/transformer.py +175 -0
- ezmsg_learn-1.0/src/ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg_learn-1.0/src/ezmsg/learn/nlin_model/mlp.py +6 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/__init__.py +0 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/adaptive_linear_regressor.py +157 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/base.py +173 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/linear_regressor.py +99 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/mlp_old.py +200 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/refit_kalman.py +407 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/rnn.py +266 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/sgd.py +131 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/sklearn.py +274 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/slda.py +119 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/torch.py +378 -0
- ezmsg_learn-1.0/src/ezmsg/learn/process/transformer.py +222 -0
- ezmsg_learn-1.0/src/ezmsg/learn/util.py +66 -0
- ezmsg_learn-1.0/tests/dim_reduce/test_adaptive_decomp.py +254 -0
- ezmsg_learn-1.0/tests/dim_reduce/test_incremental_decomp.py +296 -0
- ezmsg_learn-1.0/tests/integration/test_mlp_system.py +68 -0
- ezmsg_learn-1.0/tests/integration/test_refit_kalman_system.py +126 -0
- ezmsg_learn-1.0/tests/integration/test_rnn_system.py +75 -0
- ezmsg_learn-1.0/tests/integration/test_sklearn_system.py +87 -0
- ezmsg_learn-1.0/tests/integration/test_torch_system.py +68 -0
- ezmsg_learn-1.0/tests/integration/test_transformer_system.py +77 -0
- ezmsg_learn-1.0/tests/unit/test_adaptive_linear_regressor.py +54 -0
- ezmsg_learn-1.0/tests/unit/test_linear_regressor.py +54 -0
- ezmsg_learn-1.0/tests/unit/test_mlp.py +199 -0
- ezmsg_learn-1.0/tests/unit/test_mlp_old.py +249 -0
- ezmsg_learn-1.0/tests/unit/test_refit_kalman.py +535 -0
- ezmsg_learn-1.0/tests/unit/test_rnn.py +370 -0
- ezmsg_learn-1.0/tests/unit/test_sgd.py +81 -0
- ezmsg_learn-1.0/tests/unit/test_sklearn.py +228 -0
- ezmsg_learn-1.0/tests/unit/test_slda.py +110 -0
- ezmsg_learn-1.0/tests/unit/test_torch.py +378 -0
- ezmsg_learn-1.0/tests/unit/test_transformer.py +326 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
name: Upload Python Package - ezmsg-learn
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
workflow_dispatch:
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
build:
|
|
10
|
+
name: build and upload release to PyPI
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
environment: "release"
|
|
13
|
+
permissions:
|
|
14
|
+
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
|
15
|
+
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- name: Install uv
|
|
20
|
+
uses: astral-sh/setup-uv@v6
|
|
21
|
+
|
|
22
|
+
- name: Build Package
|
|
23
|
+
run: uv build
|
|
24
|
+
|
|
25
|
+
- name: Publish package distributions to PyPI
|
|
26
|
+
run: uv publish
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
name: Test package
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
- dev
|
|
8
|
+
pull_request:
|
|
9
|
+
branches:
|
|
10
|
+
- main
|
|
11
|
+
- dev
|
|
12
|
+
workflow_dispatch:
|
|
13
|
+
|
|
14
|
+
jobs:
|
|
15
|
+
build:
|
|
16
|
+
strategy:
|
|
17
|
+
matrix:
|
|
18
|
+
python-version: ["3.12"]
|
|
19
|
+
os:
|
|
20
|
+
- "ubuntu-latest"
|
|
21
|
+
- "windows-latest"
|
|
22
|
+
- "macos-latest"
|
|
23
|
+
runs-on: ${{matrix.os}}
|
|
24
|
+
|
|
25
|
+
steps:
|
|
26
|
+
- uses: actions/checkout@v4
|
|
27
|
+
|
|
28
|
+
- name: Install uv
|
|
29
|
+
uses: astral-sh/setup-uv@v6
|
|
30
|
+
with:
|
|
31
|
+
python-version: ${{ matrix.python-version }}
|
|
32
|
+
|
|
33
|
+
- name: Install the project
|
|
34
|
+
run: uv sync
|
|
35
|
+
|
|
36
|
+
- name: Lint
|
|
37
|
+
run:
|
|
38
|
+
uv tool run ruff check --output-format=github src
|
|
39
|
+
|
|
40
|
+
- name: Run tests
|
|
41
|
+
run: uv run pytest tests
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
share/python-wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
MANIFEST
|
|
28
|
+
|
|
29
|
+
# PyInstaller
|
|
30
|
+
# Usually these files are written by a python script from a template
|
|
31
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
32
|
+
*.manifest
|
|
33
|
+
*.spec
|
|
34
|
+
|
|
35
|
+
# Installer logs
|
|
36
|
+
pip-log.txt
|
|
37
|
+
pip-delete-this-directory.txt
|
|
38
|
+
|
|
39
|
+
# Unit test / coverage reports
|
|
40
|
+
htmlcov/
|
|
41
|
+
.tox/
|
|
42
|
+
.nox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*.cover
|
|
49
|
+
*.py,cover
|
|
50
|
+
.hypothesis/
|
|
51
|
+
.pytest_cache/
|
|
52
|
+
cover/
|
|
53
|
+
|
|
54
|
+
# Translations
|
|
55
|
+
*.mo
|
|
56
|
+
*.pot
|
|
57
|
+
|
|
58
|
+
# Django stuff:
|
|
59
|
+
*.log
|
|
60
|
+
local_settings.py
|
|
61
|
+
db.sqlite3
|
|
62
|
+
db.sqlite3-journal
|
|
63
|
+
|
|
64
|
+
# Flask stuff:
|
|
65
|
+
instance/
|
|
66
|
+
.webassets-cache
|
|
67
|
+
|
|
68
|
+
# Scrapy stuff:
|
|
69
|
+
.scrapy
|
|
70
|
+
|
|
71
|
+
# Sphinx documentation
|
|
72
|
+
docs/_build/
|
|
73
|
+
|
|
74
|
+
# PyBuilder
|
|
75
|
+
.pybuilder/
|
|
76
|
+
target/
|
|
77
|
+
|
|
78
|
+
# Jupyter Notebook
|
|
79
|
+
.ipynb_checkpoints
|
|
80
|
+
|
|
81
|
+
# IPython
|
|
82
|
+
profile_default/
|
|
83
|
+
ipython_config.py
|
|
84
|
+
|
|
85
|
+
# pyenv
|
|
86
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
87
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
88
|
+
# .python-version
|
|
89
|
+
|
|
90
|
+
# pipenv
|
|
91
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
92
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
93
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
94
|
+
# install all needed dependencies.
|
|
95
|
+
#Pipfile.lock
|
|
96
|
+
|
|
97
|
+
# UV
|
|
98
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
|
99
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
100
|
+
# commonly ignored for libraries.
|
|
101
|
+
#uv.lock
|
|
102
|
+
|
|
103
|
+
# poetry
|
|
104
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
105
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
106
|
+
# commonly ignored for libraries.
|
|
107
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
108
|
+
#poetry.lock
|
|
109
|
+
|
|
110
|
+
# pdm
|
|
111
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
112
|
+
#pdm.lock
|
|
113
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
114
|
+
# in version control.
|
|
115
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
|
116
|
+
.pdm.toml
|
|
117
|
+
.pdm-python
|
|
118
|
+
.pdm-build/
|
|
119
|
+
|
|
120
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
121
|
+
__pypackages__/
|
|
122
|
+
|
|
123
|
+
# Celery stuff
|
|
124
|
+
celerybeat-schedule
|
|
125
|
+
celerybeat.pid
|
|
126
|
+
|
|
127
|
+
# SageMath parsed files
|
|
128
|
+
*.sage.py
|
|
129
|
+
|
|
130
|
+
# Environments
|
|
131
|
+
.env
|
|
132
|
+
.venv
|
|
133
|
+
env/
|
|
134
|
+
venv/
|
|
135
|
+
ENV/
|
|
136
|
+
env.bak/
|
|
137
|
+
venv.bak/
|
|
138
|
+
|
|
139
|
+
# Spyder project settings
|
|
140
|
+
.spyderproject
|
|
141
|
+
.spyproject
|
|
142
|
+
|
|
143
|
+
# Rope project settings
|
|
144
|
+
.ropeproject
|
|
145
|
+
|
|
146
|
+
# mkdocs documentation
|
|
147
|
+
/site
|
|
148
|
+
|
|
149
|
+
# mypy
|
|
150
|
+
.mypy_cache/
|
|
151
|
+
.dmypy.json
|
|
152
|
+
dmypy.json
|
|
153
|
+
|
|
154
|
+
# Pyre type checker
|
|
155
|
+
.pyre/
|
|
156
|
+
|
|
157
|
+
# pytype static type analyzer
|
|
158
|
+
.pytype/
|
|
159
|
+
|
|
160
|
+
# Cython debug symbols
|
|
161
|
+
cython_debug/
|
|
162
|
+
|
|
163
|
+
# PyCharm
|
|
164
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
165
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
166
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
167
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
168
|
+
.idea/
|
|
169
|
+
|
|
170
|
+
# PyPI configuration file
|
|
171
|
+
.pypirc
|
|
172
|
+
|
|
173
|
+
src/ezmsg/learn/__version__.py
|
|
174
|
+
uv.lock
|
ezmsg_learn-1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: ezmsg-learn
|
|
3
|
+
Version: 1.0
|
|
4
|
+
Summary: ezmsg namespace package for machine learning
|
|
5
|
+
Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Requires-Python: >=3.10.15
|
|
8
|
+
Requires-Dist: ezmsg-sigproc
|
|
9
|
+
Requires-Dist: river>=0.22.0
|
|
10
|
+
Requires-Dist: scikit-learn>=1.6.0
|
|
11
|
+
Requires-Dist: torch>=2.6.0
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
|
|
14
|
+
# ezmsg-learn
|
|
15
|
+
|
|
16
|
+
This repository contains a Python package with modules for machine learning (ML)-related processing in the [`ezmsg`](https://www.ezmsg.org) framework. As ezmsg is intended primarily for processing unbounded streaming signals, so are the modules in this repo.
|
|
17
|
+
|
|
18
|
+
> If you are only interested in offline analysis without concern for reproducibility in online applications, then you should probably look elsewhere.
|
|
19
|
+
|
|
20
|
+
Processing units include dimensionality reduction, linear regression, and classification that can be initialized with known weights, or adapted on-the-fly with incoming (labeled) data. Machine-learning code depends on `river`, `scikit-learn`, `numpy`, and `torch`.
|
|
21
|
+
|
|
22
|
+
## Getting Started
|
|
23
|
+
|
|
24
|
+
This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
pip install git+ssh://git@github.com/ezmsg-org/ezmsg-learn
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
|
|
34
|
+
```
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# ezmsg-learn
|
|
2
|
+
|
|
3
|
+
This repository contains a Python package with modules for machine learning (ML)-related processing in the [`ezmsg`](https://www.ezmsg.org) framework. As ezmsg is intended primarily for processing unbounded streaming signals, so are the modules in this repo.
|
|
4
|
+
|
|
5
|
+
> If you are only interested in offline analysis without concern for reproducibility in online applications, then you should probably look elsewhere.
|
|
6
|
+
|
|
7
|
+
Processing units include dimensionality reduction, linear regression, and classification that can be initialized with known weights, or adapted on-the-fly with incoming (labeled) data. Machine-learning code depends on `river`, `scikit-learn`, `numpy`, and `torch`.
|
|
8
|
+
|
|
9
|
+
## Getting Started
|
|
10
|
+
|
|
11
|
+
This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pip install git+ssh://git@github.com/ezmsg-org/ezmsg-learn
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
|
|
21
|
+
```
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "ezmsg-learn"
|
|
3
|
+
description = "ezmsg namespace package for machine learning"
|
|
4
|
+
readme = "README.md"
|
|
5
|
+
authors = [
|
|
6
|
+
{ name = "Chadwick Boulay", email = "chadwick.boulay@gmail.com" }
|
|
7
|
+
]
|
|
8
|
+
license = "MIT"
|
|
9
|
+
requires-python = ">=3.10.15"
|
|
10
|
+
dynamic = ["version"]
|
|
11
|
+
dependencies = [
|
|
12
|
+
"ezmsg-sigproc",
|
|
13
|
+
"river>=0.22.0",
|
|
14
|
+
"scikit-learn>=1.6.0",
|
|
15
|
+
"torch>=2.6.0",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[dependency-groups]
|
|
19
|
+
dev = [
|
|
20
|
+
{include-group = "lint"},
|
|
21
|
+
{include-group = "test"},
|
|
22
|
+
"pre-commit>=4.3.0",
|
|
23
|
+
]
|
|
24
|
+
lint = [
|
|
25
|
+
"ruff>=0.12.9",
|
|
26
|
+
]
|
|
27
|
+
test = [
|
|
28
|
+
"hmmlearn>=0.3.3",
|
|
29
|
+
"pytest>=8.4.1",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
[build-system]
|
|
33
|
+
requires = ["hatchling", "hatch-vcs"]
|
|
34
|
+
build-backend = "hatchling.build"
|
|
35
|
+
|
|
36
|
+
[tool.hatch.version]
|
|
37
|
+
source = "vcs"
|
|
38
|
+
|
|
39
|
+
[tool.hatch.build.hooks.vcs]
|
|
40
|
+
version-file = "src/ezmsg/learn/__version__.py"
|
|
41
|
+
|
|
42
|
+
[tool.hatch.build.targets.wheel]
|
|
43
|
+
packages = ["src/ezmsg"]
|
|
44
|
+
|
|
45
|
+
[tool.pytest.ini_options]
|
|
46
|
+
pythonpath = [
|
|
47
|
+
"src",
|
|
48
|
+
".",
|
|
49
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '1.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
|
File without changes
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
|
|
4
|
+
import numpy as np
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
from ezmsg.sigproc.base import (
|
|
7
|
+
processor_state,
|
|
8
|
+
BaseAdaptiveTransformer,
|
|
9
|
+
BaseAdaptiveTransformerUnit,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AdaptiveDecompSettings(ez.Settings):
|
|
15
|
+
axis: str = "!time"
|
|
16
|
+
n_components: int = 2
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@processor_state
|
|
20
|
+
class AdaptiveDecompState:
|
|
21
|
+
template: AxisArray | None = None
|
|
22
|
+
axis_groups: tuple[str, list[str], list[str]] | None = None
|
|
23
|
+
estimator: typing.Any = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
EstimatorType = typing.TypeVar(
|
|
27
|
+
"EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF]
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AdaptiveDecompTransformer(
|
|
32
|
+
BaseAdaptiveTransformer[
|
|
33
|
+
AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState
|
|
34
|
+
],
|
|
35
|
+
typing.Generic[EstimatorType],
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Base class for adaptive decomposition transformers. See IncrementalPCATransformer and MiniBatchNMFTransformer
|
|
39
|
+
for concrete implementations.
|
|
40
|
+
|
|
41
|
+
Note that for these classes, adaptation is not automatic. The user must call partial_fit on the transformer.
|
|
42
|
+
For automated adaptation, see IncrementalDecompTransformer.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, *args, **kwargs):
|
|
46
|
+
super().__init__(*args, **kwargs)
|
|
47
|
+
self._state.estimator = self._create_estimator()
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def get_message_type(cls, dir: str) -> typing.Type[AxisArray]:
|
|
51
|
+
# Override because we don't reuse the generic types.
|
|
52
|
+
return AxisArray
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def get_estimator_type(cls) -> typing.Type[EstimatorType]:
|
|
56
|
+
return typing.get_args(cls.__orig_bases__[0])[0]
|
|
57
|
+
|
|
58
|
+
def _create_estimator(self) -> EstimatorType:
|
|
59
|
+
estimator_klass = self.get_estimator_type()
|
|
60
|
+
estimator_settings = self.settings.__dict__.copy()
|
|
61
|
+
estimator_settings.pop("axis")
|
|
62
|
+
return estimator_klass(**estimator_settings)
|
|
63
|
+
|
|
64
|
+
def _calculate_axis_groups(self, message: AxisArray):
|
|
65
|
+
if self.settings.axis.startswith("!"):
|
|
66
|
+
# Iterate over the !axis and collapse all other axes
|
|
67
|
+
iter_axis = self.settings.axis[1:]
|
|
68
|
+
it_ax_ix = message.get_axis_idx(iter_axis)
|
|
69
|
+
targ_axes = message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :]
|
|
70
|
+
off_targ_axes = []
|
|
71
|
+
else:
|
|
72
|
+
# Do PCA on the parameterized axis
|
|
73
|
+
targ_axes = [self.settings.axis]
|
|
74
|
+
# Iterate over streaming axis
|
|
75
|
+
iter_axis = "win" if "win" in message.dims else "time"
|
|
76
|
+
if iter_axis == self.settings.axis:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Iterating axis ({iter_axis}) cannot be the same as the target axis ({self.settings.axis})"
|
|
79
|
+
)
|
|
80
|
+
it_ax_ix = message.get_axis_idx(iter_axis)
|
|
81
|
+
# Remaining axes are to be treated independently
|
|
82
|
+
off_targ_axes = [
|
|
83
|
+
_
|
|
84
|
+
for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :])
|
|
85
|
+
if _ != self.settings.axis
|
|
86
|
+
]
|
|
87
|
+
self._state.axis_groups = iter_axis, targ_axes, off_targ_axes
|
|
88
|
+
|
|
89
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
90
|
+
iter_axis = (
|
|
91
|
+
self.settings.axis[1:]
|
|
92
|
+
if self.settings.axis.startswith("!")
|
|
93
|
+
else ("win" if "win" in message.dims else "time")
|
|
94
|
+
)
|
|
95
|
+
ax_idx = message.get_axis_idx(iter_axis)
|
|
96
|
+
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
97
|
+
return hash((sample_shape, message.key))
|
|
98
|
+
|
|
99
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
100
|
+
"""Reset state"""
|
|
101
|
+
self._calculate_axis_groups(message)
|
|
102
|
+
iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
|
|
103
|
+
|
|
104
|
+
# Template
|
|
105
|
+
out_dims = [iter_axis] + off_targ_axes
|
|
106
|
+
out_axes = {
|
|
107
|
+
iter_axis: message.axes[iter_axis],
|
|
108
|
+
**{k: message.axes[k] for k in off_targ_axes},
|
|
109
|
+
}
|
|
110
|
+
if len(targ_axes) == 1:
|
|
111
|
+
targ_ax_name = targ_axes[0]
|
|
112
|
+
else:
|
|
113
|
+
targ_ax_name = "components"
|
|
114
|
+
out_dims += [targ_ax_name]
|
|
115
|
+
out_axes[targ_ax_name] = AxisArray.CoordinateAxis(
|
|
116
|
+
data=np.arange(self.settings.n_components).astype(str),
|
|
117
|
+
dims=[targ_ax_name],
|
|
118
|
+
unit="component",
|
|
119
|
+
)
|
|
120
|
+
out_shape = [message.data.shape[message.get_axis_idx(_)] for _ in off_targ_axes]
|
|
121
|
+
out_shape = (0,) + tuple(out_shape) + (self.settings.n_components,)
|
|
122
|
+
self._state.template = replace(
|
|
123
|
+
message,
|
|
124
|
+
data=np.zeros(out_shape, dtype=float),
|
|
125
|
+
dims=out_dims,
|
|
126
|
+
axes=out_axes,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
130
|
+
iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
|
|
131
|
+
ax_idx = message.get_axis_idx(iter_axis)
|
|
132
|
+
in_dat = message.data
|
|
133
|
+
|
|
134
|
+
if in_dat.shape[ax_idx] == 0:
|
|
135
|
+
return self._state.template
|
|
136
|
+
|
|
137
|
+
# Re-order axes
|
|
138
|
+
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
|
|
139
|
+
if message.dims != sorted_dims_exp:
|
|
140
|
+
# TODO: Implement axes transposition if needed
|
|
141
|
+
# re_order = [ax_idx] + off_targ_inds + targ_inds
|
|
142
|
+
# np.transpose(in_dat, re_order)
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
|
|
146
|
+
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
|
|
147
|
+
in_dat = in_dat.reshape((-1, d2))
|
|
148
|
+
|
|
149
|
+
replace_kwargs = {
|
|
150
|
+
"axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
# Transform data
|
|
154
|
+
if hasattr(self._state.estimator, "components_"):
|
|
155
|
+
decomp_dat = self._state.estimator.transform(in_dat).reshape(
|
|
156
|
+
(-1,) + self._state.template.data.shape[1:]
|
|
157
|
+
)
|
|
158
|
+
replace_kwargs["data"] = decomp_dat
|
|
159
|
+
|
|
160
|
+
return replace(self._state.template, **replace_kwargs)
|
|
161
|
+
|
|
162
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
163
|
+
# Check if we need to reset state
|
|
164
|
+
msg_hash = self._hash_message(message)
|
|
165
|
+
if self._hash != msg_hash:
|
|
166
|
+
self._reset_state(message)
|
|
167
|
+
self._hash = msg_hash
|
|
168
|
+
|
|
169
|
+
iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
|
|
170
|
+
ax_idx = message.get_axis_idx(iter_axis)
|
|
171
|
+
in_dat = message.data
|
|
172
|
+
|
|
173
|
+
if in_dat.shape[ax_idx] == 0:
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
# Re-order axes if needed
|
|
177
|
+
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
|
|
178
|
+
if message.dims != sorted_dims_exp:
|
|
179
|
+
# TODO: Implement axes transposition if needed
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
|
|
183
|
+
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
|
|
184
|
+
in_dat = in_dat.reshape((-1, d2))
|
|
185
|
+
|
|
186
|
+
# Fit the estimator
|
|
187
|
+
self._state.estimator.partial_fit(in_dat)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class IncrementalPCASettings(AdaptiveDecompSettings):
|
|
191
|
+
# Additional settings specific to PCA
|
|
192
|
+
whiten: bool = False
|
|
193
|
+
batch_size: typing.Optional[int] = None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class IncrementalPCATransformer(AdaptiveDecompTransformer[IncrementalPCA]):
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class MiniBatchNMFSettings(AdaptiveDecompSettings):
|
|
201
|
+
# Additional settings specific to NMF
|
|
202
|
+
init: typing.Optional[str] = "random"
|
|
203
|
+
"""
|
|
204
|
+
'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom', or None
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
batch_size: int = 1024
|
|
208
|
+
"""
|
|
209
|
+
batch_size is used only when doing a full fit (i.e., a reset),
|
|
210
|
+
or as the exponent to forget_factor, where a very small batch_size
|
|
211
|
+
will cause the model to update more slowly.
|
|
212
|
+
It is better to set batch_size to a larger number than the expected
|
|
213
|
+
chunk size and instead use forget_factor to control the learning rate.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
beta_loss: typing.Union[str, float] = "frobenius"
|
|
217
|
+
"""
|
|
218
|
+
'frobenius', 'kullback-leibler', 'itakura-saito'
|
|
219
|
+
Note that values different from 'frobenius'
|
|
220
|
+
(or 2) and 'kullback-leibler' (or 1) lead to significantly slower
|
|
221
|
+
fits. Note that for `beta_loss <= 0` (or 'itakura-saito'), the input
|
|
222
|
+
matrix `X` cannot contain zeros.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
tol: float = 1e-4
|
|
226
|
+
|
|
227
|
+
max_no_improvement: typing.Optional[int] = None
|
|
228
|
+
|
|
229
|
+
max_iter: int = 200
|
|
230
|
+
|
|
231
|
+
alpha_W: float = 0.0
|
|
232
|
+
|
|
233
|
+
alpha_H: typing.Union[float, str] = "same"
|
|
234
|
+
|
|
235
|
+
l1_ratio: float = 0.0
|
|
236
|
+
|
|
237
|
+
forget_factor: float = 0.7
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class MiniBatchNMFTransformer(AdaptiveDecompTransformer[MiniBatchNMF]):
|
|
241
|
+
pass
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
SettingsType = typing.TypeVar(
|
|
245
|
+
"SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings]
|
|
246
|
+
)
|
|
247
|
+
TransformerType = typing.TypeVar(
|
|
248
|
+
"TransformerType",
|
|
249
|
+
bound=typing.Union[IncrementalPCATransformer, MiniBatchNMFTransformer],
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class BaseAdaptiveDecompUnit(
|
|
254
|
+
BaseAdaptiveTransformerUnit[
|
|
255
|
+
SettingsType,
|
|
256
|
+
AxisArray,
|
|
257
|
+
AxisArray,
|
|
258
|
+
TransformerType,
|
|
259
|
+
],
|
|
260
|
+
typing.Generic[SettingsType, TransformerType],
|
|
261
|
+
):
|
|
262
|
+
INPUT_SAMPLE = ez.InputStream(AxisArray)
|
|
263
|
+
|
|
264
|
+
@ez.subscriber(INPUT_SAMPLE)
|
|
265
|
+
async def on_sample(self, msg: AxisArray) -> None:
|
|
266
|
+
await self.processor.apartial_fit(msg)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class IncrementalPCAUnit(
|
|
270
|
+
BaseAdaptiveDecompUnit[
|
|
271
|
+
IncrementalPCASettings,
|
|
272
|
+
IncrementalPCATransformer,
|
|
273
|
+
]
|
|
274
|
+
):
|
|
275
|
+
SETTINGS = IncrementalPCASettings
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class MiniBatchNMFUnit(
|
|
279
|
+
BaseAdaptiveDecompUnit[
|
|
280
|
+
MiniBatchNMFSettings,
|
|
281
|
+
MiniBatchNMFTransformer,
|
|
282
|
+
]
|
|
283
|
+
):
|
|
284
|
+
SETTINGS = MiniBatchNMFSettings
|