ezmsg-learn 1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. ezmsg_learn-1.0/.github/workflows/python-publish-ezmsg-learn.yml +26 -0
  2. ezmsg_learn-1.0/.github/workflows/python-tests.yml +41 -0
  3. ezmsg_learn-1.0/.gitignore +174 -0
  4. ezmsg_learn-1.0/.pre-commit-config.yaml +7 -0
  5. ezmsg_learn-1.0/PKG-INFO +34 -0
  6. ezmsg_learn-1.0/README.md +21 -0
  7. ezmsg_learn-1.0/pyproject.toml +49 -0
  8. ezmsg_learn-1.0/src/ezmsg/learn/__init__.py +2 -0
  9. ezmsg_learn-1.0/src/ezmsg/learn/__version__.py +34 -0
  10. ezmsg_learn-1.0/src/ezmsg/learn/dim_reduce/__init__.py +0 -0
  11. ezmsg_learn-1.0/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +284 -0
  12. ezmsg_learn-1.0/src/ezmsg/learn/dim_reduce/incremental_decomp.py +181 -0
  13. ezmsg_learn-1.0/src/ezmsg/learn/linear_model/__init__.py +1 -0
  14. ezmsg_learn-1.0/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
  15. ezmsg_learn-1.0/src/ezmsg/learn/linear_model/cca.py +1 -0
  16. ezmsg_learn-1.0/src/ezmsg/learn/linear_model/linear_regressor.py +5 -0
  17. ezmsg_learn-1.0/src/ezmsg/learn/linear_model/sgd.py +5 -0
  18. ezmsg_learn-1.0/src/ezmsg/learn/linear_model/slda.py +6 -0
  19. ezmsg_learn-1.0/src/ezmsg/learn/model/__init__.py +0 -0
  20. ezmsg_learn-1.0/src/ezmsg/learn/model/cca.py +122 -0
  21. ezmsg_learn-1.0/src/ezmsg/learn/model/mlp.py +133 -0
  22. ezmsg_learn-1.0/src/ezmsg/learn/model/mlp_old.py +49 -0
  23. ezmsg_learn-1.0/src/ezmsg/learn/model/refit_kalman.py +401 -0
  24. ezmsg_learn-1.0/src/ezmsg/learn/model/rnn.py +160 -0
  25. ezmsg_learn-1.0/src/ezmsg/learn/model/transformer.py +175 -0
  26. ezmsg_learn-1.0/src/ezmsg/learn/nlin_model/__init__.py +1 -0
  27. ezmsg_learn-1.0/src/ezmsg/learn/nlin_model/mlp.py +6 -0
  28. ezmsg_learn-1.0/src/ezmsg/learn/process/__init__.py +0 -0
  29. ezmsg_learn-1.0/src/ezmsg/learn/process/adaptive_linear_regressor.py +157 -0
  30. ezmsg_learn-1.0/src/ezmsg/learn/process/base.py +173 -0
  31. ezmsg_learn-1.0/src/ezmsg/learn/process/linear_regressor.py +99 -0
  32. ezmsg_learn-1.0/src/ezmsg/learn/process/mlp_old.py +200 -0
  33. ezmsg_learn-1.0/src/ezmsg/learn/process/refit_kalman.py +407 -0
  34. ezmsg_learn-1.0/src/ezmsg/learn/process/rnn.py +266 -0
  35. ezmsg_learn-1.0/src/ezmsg/learn/process/sgd.py +131 -0
  36. ezmsg_learn-1.0/src/ezmsg/learn/process/sklearn.py +274 -0
  37. ezmsg_learn-1.0/src/ezmsg/learn/process/slda.py +119 -0
  38. ezmsg_learn-1.0/src/ezmsg/learn/process/torch.py +378 -0
  39. ezmsg_learn-1.0/src/ezmsg/learn/process/transformer.py +222 -0
  40. ezmsg_learn-1.0/src/ezmsg/learn/util.py +66 -0
  41. ezmsg_learn-1.0/tests/dim_reduce/test_adaptive_decomp.py +254 -0
  42. ezmsg_learn-1.0/tests/dim_reduce/test_incremental_decomp.py +296 -0
  43. ezmsg_learn-1.0/tests/integration/test_mlp_system.py +68 -0
  44. ezmsg_learn-1.0/tests/integration/test_refit_kalman_system.py +126 -0
  45. ezmsg_learn-1.0/tests/integration/test_rnn_system.py +75 -0
  46. ezmsg_learn-1.0/tests/integration/test_sklearn_system.py +87 -0
  47. ezmsg_learn-1.0/tests/integration/test_torch_system.py +68 -0
  48. ezmsg_learn-1.0/tests/integration/test_transformer_system.py +77 -0
  49. ezmsg_learn-1.0/tests/unit/test_adaptive_linear_regressor.py +54 -0
  50. ezmsg_learn-1.0/tests/unit/test_linear_regressor.py +54 -0
  51. ezmsg_learn-1.0/tests/unit/test_mlp.py +199 -0
  52. ezmsg_learn-1.0/tests/unit/test_mlp_old.py +249 -0
  53. ezmsg_learn-1.0/tests/unit/test_refit_kalman.py +535 -0
  54. ezmsg_learn-1.0/tests/unit/test_rnn.py +370 -0
  55. ezmsg_learn-1.0/tests/unit/test_sgd.py +81 -0
  56. ezmsg_learn-1.0/tests/unit/test_sklearn.py +228 -0
  57. ezmsg_learn-1.0/tests/unit/test_slda.py +110 -0
  58. ezmsg_learn-1.0/tests/unit/test_torch.py +378 -0
  59. ezmsg_learn-1.0/tests/unit/test_transformer.py +326 -0
@@ -0,0 +1,26 @@
1
+ name: Upload Python Package - ezmsg-learn
2
+
3
+ on:
4
+ release:
5
+ types: [published]
6
+ workflow_dispatch:
7
+
8
+ jobs:
9
+ build:
10
+ name: build and upload release to PyPI
11
+ runs-on: ubuntu-latest
12
+ environment: "release"
13
+ permissions:
14
+ id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
15
+
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+
19
+ - name: Install uv
20
+ uses: astral-sh/setup-uv@v6
21
+
22
+ - name: Build Package
23
+ run: uv build
24
+
25
+ - name: Publish package distributions to PyPI
26
+ run: uv publish
@@ -0,0 +1,41 @@
1
+ name: Test package
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - dev
8
+ pull_request:
9
+ branches:
10
+ - main
11
+ - dev
12
+ workflow_dispatch:
13
+
14
+ jobs:
15
+ build:
16
+ strategy:
17
+ matrix:
18
+ python-version: ["3.12"]
19
+ os:
20
+ - "ubuntu-latest"
21
+ - "windows-latest"
22
+ - "macos-latest"
23
+ runs-on: ${{matrix.os}}
24
+
25
+ steps:
26
+ - uses: actions/checkout@v4
27
+
28
+ - name: Install uv
29
+ uses: astral-sh/setup-uv@v6
30
+ with:
31
+ python-version: ${{ matrix.python-version }}
32
+
33
+ - name: Install the project
34
+ run: uv sync
35
+
36
+ - name: Lint
37
+ run:
38
+ uv tool run ruff check --output-format=github src
39
+
40
+ - name: Run tests
41
+ run: uv run pytest tests
@@ -0,0 +1,174 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ .idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ src/ezmsg/learn/__version__.py
174
+ uv.lock
@@ -0,0 +1,7 @@
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ rev: v0.11.12
4
+ hooks:
5
+ - id: ruff
6
+ args: [ --fix ]
7
+ - id: ruff-format
@@ -0,0 +1,34 @@
1
+ Metadata-Version: 2.4
2
+ Name: ezmsg-learn
3
+ Version: 1.0
4
+ Summary: ezmsg namespace package for machine learning
5
+ Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
6
+ License-Expression: MIT
7
+ Requires-Python: >=3.10.15
8
+ Requires-Dist: ezmsg-sigproc
9
+ Requires-Dist: river>=0.22.0
10
+ Requires-Dist: scikit-learn>=1.6.0
11
+ Requires-Dist: torch>=2.6.0
12
+ Description-Content-Type: text/markdown
13
+
14
+ # ezmsg-learn
15
+
16
+ This repository contains a Python package with modules for machine learning (ML)-related processing in the [`ezmsg`](https://www.ezmsg.org) framework. As ezmsg is intended primarily for processing unbounded streaming signals, so are the modules in this repo.
17
+
18
+ > If you are only interested in offline analysis without concern for reproducibility in online applications, then you should probably look elsewhere.
19
+
20
+ Processing units include dimensionality reduction, linear regression, and classification that can be initialized with known weights, or adapted on-the-fly with incoming (labeled) data. Machine-learning code depends on `river`, `scikit-learn`, `numpy`, and `torch`.
21
+
22
+ ## Getting Started
23
+
24
+ This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
25
+
26
+ ```bash
27
+ pip install git+ssh://git@github.com/ezmsg-org/ezmsg-learn
28
+ ```
29
+
30
+ Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
31
+
32
+ ```bash
33
+ pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
34
+ ```
@@ -0,0 +1,21 @@
1
+ # ezmsg-learn
2
+
3
+ This repository contains a Python package with modules for machine learning (ML)-related processing in the [`ezmsg`](https://www.ezmsg.org) framework. As ezmsg is intended primarily for processing unbounded streaming signals, so are the modules in this repo.
4
+
5
+ > If you are only interested in offline analysis without concern for reproducibility in online applications, then you should probably look elsewhere.
6
+
7
+ Processing units include dimensionality reduction, linear regression, and classification that can be initialized with known weights, or adapted on-the-fly with incoming (labeled) data. Machine-learning code depends on `river`, `scikit-learn`, `numpy`, and `torch`.
8
+
9
+ ## Getting Started
10
+
11
+ This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
12
+
13
+ ```bash
14
+ pip install git+ssh://git@github.com/ezmsg-org/ezmsg-learn
15
+ ```
16
+
17
+ Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
18
+
19
+ ```bash
20
+ pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
21
+ ```
@@ -0,0 +1,49 @@
1
+ [project]
2
+ name = "ezmsg-learn"
3
+ description = "ezmsg namespace package for machine learning"
4
+ readme = "README.md"
5
+ authors = [
6
+ { name = "Chadwick Boulay", email = "chadwick.boulay@gmail.com" }
7
+ ]
8
+ license = "MIT"
9
+ requires-python = ">=3.10.15"
10
+ dynamic = ["version"]
11
+ dependencies = [
12
+ "ezmsg-sigproc",
13
+ "river>=0.22.0",
14
+ "scikit-learn>=1.6.0",
15
+ "torch>=2.6.0",
16
+ ]
17
+
18
+ [dependency-groups]
19
+ dev = [
20
+ {include-group = "lint"},
21
+ {include-group = "test"},
22
+ "pre-commit>=4.3.0",
23
+ ]
24
+ lint = [
25
+ "ruff>=0.12.9",
26
+ ]
27
+ test = [
28
+ "hmmlearn>=0.3.3",
29
+ "pytest>=8.4.1",
30
+ ]
31
+
32
+ [build-system]
33
+ requires = ["hatchling", "hatch-vcs"]
34
+ build-backend = "hatchling.build"
35
+
36
+ [tool.hatch.version]
37
+ source = "vcs"
38
+
39
+ [tool.hatch.build.hooks.vcs]
40
+ version-file = "src/ezmsg/learn/__version__.py"
41
+
42
+ [tool.hatch.build.targets.wheel]
43
+ packages = ["src/ezmsg"]
44
+
45
+ [tool.pytest.ini_options]
46
+ pythonpath = [
47
+ "src",
48
+ ".",
49
+ ]
@@ -0,0 +1,2 @@
1
+ def hello() -> str:
2
+ return "Hello from ezmsg-learn!"
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '1.0'
32
+ __version_tuple__ = version_tuple = (1, 0)
33
+
34
+ __commit_id__ = commit_id = None
File without changes
@@ -0,0 +1,284 @@
1
+ import typing
2
+
3
+ from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.sigproc.base import (
7
+ processor_state,
8
+ BaseAdaptiveTransformer,
9
+ BaseAdaptiveTransformerUnit,
10
+ )
11
+ from ezmsg.util.messages.axisarray import AxisArray, replace
12
+
13
+
14
+ class AdaptiveDecompSettings(ez.Settings):
15
+ axis: str = "!time"
16
+ n_components: int = 2
17
+
18
+
19
+ @processor_state
20
+ class AdaptiveDecompState:
21
+ template: AxisArray | None = None
22
+ axis_groups: tuple[str, list[str], list[str]] | None = None
23
+ estimator: typing.Any = None
24
+
25
+
26
+ EstimatorType = typing.TypeVar(
27
+ "EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF]
28
+ )
29
+
30
+
31
+ class AdaptiveDecompTransformer(
32
+ BaseAdaptiveTransformer[
33
+ AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState
34
+ ],
35
+ typing.Generic[EstimatorType],
36
+ ):
37
+ """
38
+ Base class for adaptive decomposition transformers. See IncrementalPCATransformer and MiniBatchNMFTransformer
39
+ for concrete implementations.
40
+
41
+ Note that for these classes, adaptation is not automatic. The user must call partial_fit on the transformer.
42
+ For automated adaptation, see IncrementalDecompTransformer.
43
+ """
44
+
45
+ def __init__(self, *args, **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+ self._state.estimator = self._create_estimator()
48
+
49
+ @classmethod
50
+ def get_message_type(cls, dir: str) -> typing.Type[AxisArray]:
51
+ # Override because we don't reuse the generic types.
52
+ return AxisArray
53
+
54
+ @classmethod
55
+ def get_estimator_type(cls) -> typing.Type[EstimatorType]:
56
+ return typing.get_args(cls.__orig_bases__[0])[0]
57
+
58
+ def _create_estimator(self) -> EstimatorType:
59
+ estimator_klass = self.get_estimator_type()
60
+ estimator_settings = self.settings.__dict__.copy()
61
+ estimator_settings.pop("axis")
62
+ return estimator_klass(**estimator_settings)
63
+
64
+ def _calculate_axis_groups(self, message: AxisArray):
65
+ if self.settings.axis.startswith("!"):
66
+ # Iterate over the !axis and collapse all other axes
67
+ iter_axis = self.settings.axis[1:]
68
+ it_ax_ix = message.get_axis_idx(iter_axis)
69
+ targ_axes = message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :]
70
+ off_targ_axes = []
71
+ else:
72
+ # Do PCA on the parameterized axis
73
+ targ_axes = [self.settings.axis]
74
+ # Iterate over streaming axis
75
+ iter_axis = "win" if "win" in message.dims else "time"
76
+ if iter_axis == self.settings.axis:
77
+ raise ValueError(
78
+ f"Iterating axis ({iter_axis}) cannot be the same as the target axis ({self.settings.axis})"
79
+ )
80
+ it_ax_ix = message.get_axis_idx(iter_axis)
81
+ # Remaining axes are to be treated independently
82
+ off_targ_axes = [
83
+ _
84
+ for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :])
85
+ if _ != self.settings.axis
86
+ ]
87
+ self._state.axis_groups = iter_axis, targ_axes, off_targ_axes
88
+
89
+ def _hash_message(self, message: AxisArray) -> int:
90
+ iter_axis = (
91
+ self.settings.axis[1:]
92
+ if self.settings.axis.startswith("!")
93
+ else ("win" if "win" in message.dims else "time")
94
+ )
95
+ ax_idx = message.get_axis_idx(iter_axis)
96
+ sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
97
+ return hash((sample_shape, message.key))
98
+
99
+ def _reset_state(self, message: AxisArray) -> None:
100
+ """Reset state"""
101
+ self._calculate_axis_groups(message)
102
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
103
+
104
+ # Template
105
+ out_dims = [iter_axis] + off_targ_axes
106
+ out_axes = {
107
+ iter_axis: message.axes[iter_axis],
108
+ **{k: message.axes[k] for k in off_targ_axes},
109
+ }
110
+ if len(targ_axes) == 1:
111
+ targ_ax_name = targ_axes[0]
112
+ else:
113
+ targ_ax_name = "components"
114
+ out_dims += [targ_ax_name]
115
+ out_axes[targ_ax_name] = AxisArray.CoordinateAxis(
116
+ data=np.arange(self.settings.n_components).astype(str),
117
+ dims=[targ_ax_name],
118
+ unit="component",
119
+ )
120
+ out_shape = [message.data.shape[message.get_axis_idx(_)] for _ in off_targ_axes]
121
+ out_shape = (0,) + tuple(out_shape) + (self.settings.n_components,)
122
+ self._state.template = replace(
123
+ message,
124
+ data=np.zeros(out_shape, dtype=float),
125
+ dims=out_dims,
126
+ axes=out_axes,
127
+ )
128
+
129
+ def _process(self, message: AxisArray) -> AxisArray:
130
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
131
+ ax_idx = message.get_axis_idx(iter_axis)
132
+ in_dat = message.data
133
+
134
+ if in_dat.shape[ax_idx] == 0:
135
+ return self._state.template
136
+
137
+ # Re-order axes
138
+ sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
139
+ if message.dims != sorted_dims_exp:
140
+ # TODO: Implement axes transposition if needed
141
+ # re_order = [ax_idx] + off_targ_inds + targ_inds
142
+ # np.transpose(in_dat, re_order)
143
+ pass
144
+
145
+ # fold [iter_axis] + off_targ_axes together and fold targ_axes together
146
+ d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
147
+ in_dat = in_dat.reshape((-1, d2))
148
+
149
+ replace_kwargs = {
150
+ "axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
151
+ }
152
+
153
+ # Transform data
154
+ if hasattr(self._state.estimator, "components_"):
155
+ decomp_dat = self._state.estimator.transform(in_dat).reshape(
156
+ (-1,) + self._state.template.data.shape[1:]
157
+ )
158
+ replace_kwargs["data"] = decomp_dat
159
+
160
+ return replace(self._state.template, **replace_kwargs)
161
+
162
+ def partial_fit(self, message: AxisArray) -> None:
163
+ # Check if we need to reset state
164
+ msg_hash = self._hash_message(message)
165
+ if self._hash != msg_hash:
166
+ self._reset_state(message)
167
+ self._hash = msg_hash
168
+
169
+ iter_axis, targ_axes, off_targ_axes = self._state.axis_groups
170
+ ax_idx = message.get_axis_idx(iter_axis)
171
+ in_dat = message.data
172
+
173
+ if in_dat.shape[ax_idx] == 0:
174
+ return
175
+
176
+ # Re-order axes if needed
177
+ sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
178
+ if message.dims != sorted_dims_exp:
179
+ # TODO: Implement axes transposition if needed
180
+ pass
181
+
182
+ # fold [iter_axis] + off_targ_axes together and fold targ_axes together
183
+ d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
184
+ in_dat = in_dat.reshape((-1, d2))
185
+
186
+ # Fit the estimator
187
+ self._state.estimator.partial_fit(in_dat)
188
+
189
+
190
+ class IncrementalPCASettings(AdaptiveDecompSettings):
191
+ # Additional settings specific to PCA
192
+ whiten: bool = False
193
+ batch_size: typing.Optional[int] = None
194
+
195
+
196
+ class IncrementalPCATransformer(AdaptiveDecompTransformer[IncrementalPCA]):
197
+ pass
198
+
199
+
200
+ class MiniBatchNMFSettings(AdaptiveDecompSettings):
201
+ # Additional settings specific to NMF
202
+ init: typing.Optional[str] = "random"
203
+ """
204
+ 'random', 'nndsvd', 'nndsvda', 'nndsvdar', 'custom', or None
205
+ """
206
+
207
+ batch_size: int = 1024
208
+ """
209
+ batch_size is used only when doing a full fit (i.e., a reset),
210
+ or as the exponent to forget_factor, where a very small batch_size
211
+ will cause the model to update more slowly.
212
+ It is better to set batch_size to a larger number than the expected
213
+ chunk size and instead use forget_factor to control the learning rate.
214
+ """
215
+
216
+ beta_loss: typing.Union[str, float] = "frobenius"
217
+ """
218
+ 'frobenius', 'kullback-leibler', 'itakura-saito'
219
+ Note that values different from 'frobenius'
220
+ (or 2) and 'kullback-leibler' (or 1) lead to significantly slower
221
+ fits. Note that for `beta_loss <= 0` (or 'itakura-saito'), the input
222
+ matrix `X` cannot contain zeros.
223
+ """
224
+
225
+ tol: float = 1e-4
226
+
227
+ max_no_improvement: typing.Optional[int] = None
228
+
229
+ max_iter: int = 200
230
+
231
+ alpha_W: float = 0.0
232
+
233
+ alpha_H: typing.Union[float, str] = "same"
234
+
235
+ l1_ratio: float = 0.0
236
+
237
+ forget_factor: float = 0.7
238
+
239
+
240
+ class MiniBatchNMFTransformer(AdaptiveDecompTransformer[MiniBatchNMF]):
241
+ pass
242
+
243
+
244
+ SettingsType = typing.TypeVar(
245
+ "SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings]
246
+ )
247
+ TransformerType = typing.TypeVar(
248
+ "TransformerType",
249
+ bound=typing.Union[IncrementalPCATransformer, MiniBatchNMFTransformer],
250
+ )
251
+
252
+
253
+ class BaseAdaptiveDecompUnit(
254
+ BaseAdaptiveTransformerUnit[
255
+ SettingsType,
256
+ AxisArray,
257
+ AxisArray,
258
+ TransformerType,
259
+ ],
260
+ typing.Generic[SettingsType, TransformerType],
261
+ ):
262
+ INPUT_SAMPLE = ez.InputStream(AxisArray)
263
+
264
+ @ez.subscriber(INPUT_SAMPLE)
265
+ async def on_sample(self, msg: AxisArray) -> None:
266
+ await self.processor.apartial_fit(msg)
267
+
268
+
269
+ class IncrementalPCAUnit(
270
+ BaseAdaptiveDecompUnit[
271
+ IncrementalPCASettings,
272
+ IncrementalPCATransformer,
273
+ ]
274
+ ):
275
+ SETTINGS = IncrementalPCASettings
276
+
277
+
278
+ class MiniBatchNMFUnit(
279
+ BaseAdaptiveDecompUnit[
280
+ MiniBatchNMFSettings,
281
+ MiniBatchNMFTransformer,
282
+ ]
283
+ ):
284
+ SETTINGS = MiniBatchNMFSettings