ezmsg-learn 1.0__tar.gz → 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg_learn-1.1.0/.github/workflows/docs.yml +65 -0
- ezmsg_learn-1.0/.github/workflows/python-publish-ezmsg-learn.yml → ezmsg_learn-1.1.0/.github/workflows/python-publish.yml +1 -1
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/.github/workflows/python-tests.yml +8 -3
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/.gitignore +2 -0
- ezmsg_learn-1.1.0/LICENSE +21 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/PKG-INFO +5 -9
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/README.md +1 -7
- ezmsg_learn-1.1.0/docs/Makefile +20 -0
- ezmsg_learn-1.1.0/docs/make.bat +35 -0
- ezmsg_learn-1.1.0/docs/source/_templates/autosummary/module.rst +64 -0
- ezmsg_learn-1.1.0/docs/source/api/index.rst +11 -0
- ezmsg_learn-1.1.0/docs/source/conf.py +123 -0
- ezmsg_learn-1.1.0/docs/source/guides/classification.rst +267 -0
- ezmsg_learn-1.1.0/docs/source/index.rst +64 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/pyproject.toml +29 -1
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/__version__.py +2 -2
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +9 -19
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/dim_reduce/incremental_decomp.py +8 -16
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/linear_model/linear_regressor.py +4 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/linear_model/sgd.py +6 -2
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/linear_model/slda.py +7 -1
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/model/mlp.py +8 -14
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/model/refit_kalman.py +17 -49
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/nlin_model/mlp.py +5 -1
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/adaptive_linear_regressor.py +10 -25
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/base.py +12 -31
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/linear_regressor.py +8 -12
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/mlp_old.py +16 -28
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/refit_kalman.py +3 -7
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/rnn.py +10 -31
- ezmsg_learn-1.1.0/src/ezmsg/learn/process/sgd.py +117 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/sklearn.py +11 -44
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/slda.py +6 -15
- ezmsg_learn-1.1.0/src/ezmsg/learn/process/ssr.py +374 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/torch.py +9 -25
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/transformer.py +8 -15
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/util.py +5 -4
- ezmsg_learn-1.1.0/tests/benchmark/bench_lrr.py +317 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/dim_reduce/test_adaptive_decomp.py +10 -22
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/dim_reduce/test_incremental_decomp.py +10 -19
- ezmsg_learn-1.1.0/tests/integration/conftest.py +39 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/integration/test_mlp_system.py +3 -13
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/integration/test_refit_kalman_system.py +15 -16
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/integration/test_rnn_system.py +3 -13
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/integration/test_sklearn_system.py +2 -8
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/integration/test_torch_system.py +3 -13
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/integration/test_transformer_system.py +3 -13
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_adaptive_linear_regressor.py +3 -6
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_linear_regressor.py +3 -6
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_mlp_old.py +9 -21
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_refit_kalman.py +19 -52
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_sgd.py +15 -13
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_slda.py +6 -4
- ezmsg_learn-1.1.0/tests/unit/test_ssr.py +324 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_torch.py +3 -7
- ezmsg_learn-1.0/src/ezmsg/learn/process/sgd.py +0 -131
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/.pre-commit-config.yaml +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/__init__.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/dim_reduce/__init__.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/linear_model/__init__.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/linear_model/cca.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/model/__init__.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/model/cca.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/model/mlp_old.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/model/rnn.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/model/transformer.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/nlin_model/__init__.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/src/ezmsg/learn/process/__init__.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_mlp.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_rnn.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_sklearn.py +0 -0
- {ezmsg_learn-1.0 → ezmsg_learn-1.1.0}/tests/unit/test_transformer.py +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
name: Documentation
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
tags:
|
|
8
|
+
- 'v*'
|
|
9
|
+
pull_request:
|
|
10
|
+
branches:
|
|
11
|
+
- dev
|
|
12
|
+
workflow_dispatch:
|
|
13
|
+
|
|
14
|
+
permissions:
|
|
15
|
+
contents: read
|
|
16
|
+
pages: write
|
|
17
|
+
id-token: write
|
|
18
|
+
|
|
19
|
+
# Allow only one concurrent deployment
|
|
20
|
+
concurrency:
|
|
21
|
+
group: "pages"
|
|
22
|
+
cancel-in-progress: false
|
|
23
|
+
|
|
24
|
+
jobs:
|
|
25
|
+
build:
|
|
26
|
+
runs-on: ubuntu-latest
|
|
27
|
+
steps:
|
|
28
|
+
- uses: actions/checkout@v4
|
|
29
|
+
with:
|
|
30
|
+
fetch-depth: 0 # Needed for hatch-vcs to determine version
|
|
31
|
+
|
|
32
|
+
- name: Install uv
|
|
33
|
+
uses: astral-sh/setup-uv@v6
|
|
34
|
+
with:
|
|
35
|
+
enable-cache: true
|
|
36
|
+
python-version: "3.12"
|
|
37
|
+
|
|
38
|
+
- name: Install the project
|
|
39
|
+
run: uv sync --only-group docs
|
|
40
|
+
|
|
41
|
+
- name: Build documentation
|
|
42
|
+
run: |
|
|
43
|
+
cd docs
|
|
44
|
+
uv run make html
|
|
45
|
+
|
|
46
|
+
- name: Add .nojekyll file
|
|
47
|
+
run: touch docs/build/html/.nojekyll
|
|
48
|
+
|
|
49
|
+
- name: Upload artifact
|
|
50
|
+
uses: actions/upload-pages-artifact@v3
|
|
51
|
+
with:
|
|
52
|
+
path: 'docs/build/html'
|
|
53
|
+
|
|
54
|
+
deploy:
|
|
55
|
+
# Only deploy on push to main or release tags
|
|
56
|
+
if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v'))
|
|
57
|
+
environment:
|
|
58
|
+
name: github-pages
|
|
59
|
+
url: ${{ steps.deployment.outputs.page_url }}
|
|
60
|
+
runs-on: ubuntu-latest
|
|
61
|
+
needs: build
|
|
62
|
+
steps:
|
|
63
|
+
- name: Deploy to GitHub Pages
|
|
64
|
+
id: deployment
|
|
65
|
+
uses: actions/deploy-pages@v4
|
|
@@ -15,7 +15,7 @@ jobs:
|
|
|
15
15
|
build:
|
|
16
16
|
strategy:
|
|
17
17
|
matrix:
|
|
18
|
-
python-version: ["3.12"]
|
|
18
|
+
python-version: ["3.12", "3.13"]
|
|
19
19
|
os:
|
|
20
20
|
- "ubuntu-latest"
|
|
21
21
|
- "windows-latest"
|
|
@@ -37,5 +37,10 @@ jobs:
|
|
|
37
37
|
run:
|
|
38
38
|
uv tool run ruff check --output-format=github src
|
|
39
39
|
|
|
40
|
-
- name: Run tests
|
|
41
|
-
run: uv run pytest tests
|
|
40
|
+
- name: Run unit tests
|
|
41
|
+
run: uv run pytest tests/unit tests/dim_reduce -v
|
|
42
|
+
|
|
43
|
+
- name: Run integration tests
|
|
44
|
+
run: uv run pytest tests/integration -v --tb=long
|
|
45
|
+
env:
|
|
46
|
+
PYTHONFAULTHANDLER: 1
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 ezmsg-org
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-learn
|
|
3
|
-
Version: 1.0
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: ezmsg namespace package for machine learning
|
|
5
5
|
Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
7
|
+
License-File: LICENSE
|
|
7
8
|
Requires-Python: >=3.10.15
|
|
8
|
-
Requires-Dist: ezmsg-
|
|
9
|
+
Requires-Dist: ezmsg-baseproc>=1.0.2
|
|
10
|
+
Requires-Dist: ezmsg-sigproc>=2.14.0
|
|
9
11
|
Requires-Dist: river>=0.22.0
|
|
10
12
|
Requires-Dist: scikit-learn>=1.6.0
|
|
11
13
|
Requires-Dist: torch>=2.6.0
|
|
@@ -24,11 +26,5 @@ Processing units include dimensionality reduction, linear regression, and classi
|
|
|
24
26
|
This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
|
|
25
27
|
|
|
26
28
|
```bash
|
|
27
|
-
pip install git+
|
|
28
|
-
```
|
|
29
|
-
|
|
30
|
-
Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
|
|
31
|
-
|
|
32
|
-
```bash
|
|
33
|
-
pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
|
|
29
|
+
pip install git+https://github.com/ezmsg-org/ezmsg-learn
|
|
34
30
|
```
|
|
@@ -11,11 +11,5 @@ Processing units include dimensionality reduction, linear regression, and classi
|
|
|
11
11
|
This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
|
|
12
12
|
|
|
13
13
|
```bash
|
|
14
|
-
pip install git+
|
|
15
|
-
```
|
|
16
|
-
|
|
17
|
-
Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
|
|
18
|
-
|
|
19
|
-
```bash
|
|
20
|
-
pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
|
|
14
|
+
pip install git+https://github.com/ezmsg-org/ezmsg-learn
|
|
21
15
|
```
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Minimal makefile for Sphinx documentation
|
|
2
|
+
#
|
|
3
|
+
|
|
4
|
+
# You can set these variables from the command line, and also
|
|
5
|
+
# from the environment for the first two.
|
|
6
|
+
SPHINXOPTS ?=
|
|
7
|
+
SPHINXBUILD ?= sphinx-build
|
|
8
|
+
SOURCEDIR = source
|
|
9
|
+
BUILDDIR = build
|
|
10
|
+
|
|
11
|
+
# Put it first so that "make" without argument is like "make help".
|
|
12
|
+
help:
|
|
13
|
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
14
|
+
|
|
15
|
+
.PHONY: help Makefile
|
|
16
|
+
|
|
17
|
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
|
18
|
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
|
19
|
+
%: Makefile
|
|
20
|
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
@ECHO OFF
|
|
2
|
+
|
|
3
|
+
pushd %~dp0
|
|
4
|
+
|
|
5
|
+
REM Command file for Sphinx documentation
|
|
6
|
+
|
|
7
|
+
if "%SPHINXBUILD%" == "" (
|
|
8
|
+
set SPHINXBUILD=sphinx-build
|
|
9
|
+
)
|
|
10
|
+
set SOURCEDIR=source
|
|
11
|
+
set BUILDDIR=build
|
|
12
|
+
|
|
13
|
+
if "%1" == "" goto help
|
|
14
|
+
|
|
15
|
+
%SPHINXBUILD% >NUL 2>NUL
|
|
16
|
+
if errorlevel 9009 (
|
|
17
|
+
echo.
|
|
18
|
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
|
19
|
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
|
20
|
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
|
21
|
+
echo.may add the Sphinx directory to PATH.
|
|
22
|
+
echo.
|
|
23
|
+
echo.If you don't have Sphinx installed, grab it from
|
|
24
|
+
echo.http://sphinx-doc.org/
|
|
25
|
+
exit /b 1
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
|
29
|
+
goto end
|
|
30
|
+
|
|
31
|
+
:help
|
|
32
|
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
|
33
|
+
|
|
34
|
+
:end
|
|
35
|
+
popd
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
{{ fullname | escape | underline}}
|
|
2
|
+
|
|
3
|
+
.. automodule:: {{ fullname }}
|
|
4
|
+
|
|
5
|
+
{% block attributes %}
|
|
6
|
+
{% if attributes %}
|
|
7
|
+
.. rubric:: Module Attributes
|
|
8
|
+
|
|
9
|
+
.. autosummary::
|
|
10
|
+
:toctree:
|
|
11
|
+
{% for item in attributes %}
|
|
12
|
+
{{ item }}
|
|
13
|
+
{%- endfor %}
|
|
14
|
+
{% endif %}
|
|
15
|
+
{% endblock %}
|
|
16
|
+
|
|
17
|
+
{% block functions %}
|
|
18
|
+
{% if functions %}
|
|
19
|
+
.. rubric:: Functions
|
|
20
|
+
|
|
21
|
+
{% for item in functions %}
|
|
22
|
+
.. autofunction:: {{ item }}
|
|
23
|
+
{%- endfor %}
|
|
24
|
+
{% endif %}
|
|
25
|
+
{% endblock %}
|
|
26
|
+
|
|
27
|
+
{% block classes %}
|
|
28
|
+
{% if classes %}
|
|
29
|
+
.. rubric:: Classes
|
|
30
|
+
|
|
31
|
+
{% for item in classes %}
|
|
32
|
+
.. autoclass:: {{ item }}
|
|
33
|
+
:members:
|
|
34
|
+
:undoc-members:
|
|
35
|
+
:show-inheritance:
|
|
36
|
+
:special-members: __init__
|
|
37
|
+
{%- endfor %}
|
|
38
|
+
{% endif %}
|
|
39
|
+
{% endblock %}
|
|
40
|
+
|
|
41
|
+
{% block exceptions %}
|
|
42
|
+
{% if exceptions %}
|
|
43
|
+
.. rubric:: Exceptions
|
|
44
|
+
|
|
45
|
+
{% for item in exceptions %}
|
|
46
|
+
.. autoexception:: {{ item }}
|
|
47
|
+
:members:
|
|
48
|
+
:show-inheritance:
|
|
49
|
+
{%- endfor %}
|
|
50
|
+
{% endif %}
|
|
51
|
+
{% endblock %}
|
|
52
|
+
|
|
53
|
+
{% block modules %}
|
|
54
|
+
{% if modules %}
|
|
55
|
+
.. rubric:: Modules
|
|
56
|
+
|
|
57
|
+
.. autosummary::
|
|
58
|
+
:toctree:
|
|
59
|
+
:recursive:
|
|
60
|
+
{% for item in modules %}
|
|
61
|
+
{{ item }}
|
|
62
|
+
{%- endfor %}
|
|
63
|
+
{% endif %}
|
|
64
|
+
{% endblock %}
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# Configuration file for the Sphinx documentation builder.
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
# Add the source directory to the path
|
|
7
|
+
sys.path.insert(0, os.path.abspath("../../src"))
|
|
8
|
+
|
|
9
|
+
# -- Project information --------------------------
|
|
10
|
+
|
|
11
|
+
project = "ezmsg.learn"
|
|
12
|
+
copyright = "2024, ezmsg Contributors"
|
|
13
|
+
author = "ezmsg Contributors"
|
|
14
|
+
|
|
15
|
+
# The version is managed by hatch-vcs and stored in __version__.py
|
|
16
|
+
try:
|
|
17
|
+
from ezmsg.learn.__version__ import version as release
|
|
18
|
+
except ImportError:
|
|
19
|
+
release = "unknown"
|
|
20
|
+
|
|
21
|
+
# For display purposes, extract the base version without git commit info
|
|
22
|
+
version = release.split("+")[0] if release != "unknown" else release
|
|
23
|
+
|
|
24
|
+
# -- General configuration --------------------------
|
|
25
|
+
|
|
26
|
+
extensions = [
|
|
27
|
+
"sphinx.ext.autodoc",
|
|
28
|
+
"sphinx.ext.autosummary",
|
|
29
|
+
"sphinx.ext.napoleon",
|
|
30
|
+
"sphinx.ext.intersphinx",
|
|
31
|
+
"sphinx.ext.viewcode",
|
|
32
|
+
"sphinx.ext.duration",
|
|
33
|
+
"sphinx_autodoc_typehints",
|
|
34
|
+
"sphinx_copybutton",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
templates_path = ["_templates"]
|
|
38
|
+
source_suffix = [".rst"]
|
|
39
|
+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
|
40
|
+
|
|
41
|
+
# The toctree master document
|
|
42
|
+
master_doc = "index"
|
|
43
|
+
|
|
44
|
+
# -- Autodoc configuration ------------------------------
|
|
45
|
+
|
|
46
|
+
# Auto-generate API docs
|
|
47
|
+
autosummary_generate = True
|
|
48
|
+
autosummary_imported_members = False
|
|
49
|
+
autodoc_typehints = "description"
|
|
50
|
+
autodoc_member_order = "bysource"
|
|
51
|
+
autodoc_typehints_format = "short"
|
|
52
|
+
python_use_unqualified_type_names = True
|
|
53
|
+
autodoc_default_options = {
|
|
54
|
+
"members": True,
|
|
55
|
+
"member-order": "bysource",
|
|
56
|
+
"special-members": "__init__",
|
|
57
|
+
"undoc-members": True,
|
|
58
|
+
"show-inheritance": True,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
# Don't show the full module path in the docs
|
|
62
|
+
add_module_names = False
|
|
63
|
+
|
|
64
|
+
# -- Intersphinx configuration --------------------------
|
|
65
|
+
|
|
66
|
+
intersphinx_mapping = {
|
|
67
|
+
"python": ("https://docs.python.org/3/", None),
|
|
68
|
+
"numpy": ("https://numpy.org/doc/stable/", None),
|
|
69
|
+
"torch": ("https://pytorch.org/docs/stable/", None),
|
|
70
|
+
"sklearn": ("https://scikit-learn.org/stable/", None),
|
|
71
|
+
"ezmsg": ("https://www.ezmsg.org/ezmsg/", None),
|
|
72
|
+
"ezmsg.sigproc": ("https://www.ezmsg.org/ezmsg-sigproc/", None),
|
|
73
|
+
"ezmsg.event": ("https://www.ezmsg.org/ezmsg-event/", None),
|
|
74
|
+
"ezmsg.lsl": ("https://www.ezmsg.org/ezmsg-lsl/", None),
|
|
75
|
+
}
|
|
76
|
+
intersphinx_disabled_domains = ["std"]
|
|
77
|
+
|
|
78
|
+
# -- Options for HTML output -----------------------------
|
|
79
|
+
|
|
80
|
+
html_theme = "pydata_sphinx_theme"
|
|
81
|
+
html_static_path = ["_static"]
|
|
82
|
+
|
|
83
|
+
# Set the base URL for the documentation
|
|
84
|
+
html_baseurl = "https://www.ezmsg.org/ezmsg-learn/"
|
|
85
|
+
|
|
86
|
+
html_theme_options = {
|
|
87
|
+
"logo": {
|
|
88
|
+
"text": f"ezmsg.learn {version}",
|
|
89
|
+
"link": "https://ezmsg.org", # Link back to main site
|
|
90
|
+
},
|
|
91
|
+
"header_links_before_dropdown": 4,
|
|
92
|
+
"navbar_start": ["navbar-logo"],
|
|
93
|
+
"navbar_end": ["theme-switcher", "navbar-icon-links"],
|
|
94
|
+
"icon_links": [
|
|
95
|
+
{
|
|
96
|
+
"name": "GitHub",
|
|
97
|
+
"url": "https://github.com/ezmsg-org/ezmsg-learn",
|
|
98
|
+
"icon": "fa-brands fa-github",
|
|
99
|
+
},
|
|
100
|
+
{
|
|
101
|
+
"name": "ezmsg.org",
|
|
102
|
+
"url": "https://www.ezmsg.org",
|
|
103
|
+
"icon": "fa-solid fa-house",
|
|
104
|
+
},
|
|
105
|
+
],
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
# Timestamp is inserted at every page bottom in this strftime format.
|
|
109
|
+
html_last_updated_fmt = "%Y-%m-%d"
|
|
110
|
+
|
|
111
|
+
# -- Options for linkcode -----------------------------
|
|
112
|
+
|
|
113
|
+
branch = "main"
|
|
114
|
+
code_url = f"https://github.com/ezmsg-org/ezmsg-learn/blob/{branch}/"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def linkcode_resolve(domain, info):
|
|
118
|
+
if domain != "py":
|
|
119
|
+
return None
|
|
120
|
+
if not info["module"]:
|
|
121
|
+
return None
|
|
122
|
+
filename = info["module"].replace(".", "/")
|
|
123
|
+
return f"{code_url}src/{filename}.py"
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
Real-Time Classification
|
|
2
|
+
========================
|
|
3
|
+
|
|
4
|
+
This guide shows how to use ezmsg-learn for real-time classification in streaming pipelines.
|
|
5
|
+
|
|
6
|
+
.. contents:: On this page
|
|
7
|
+
:local:
|
|
8
|
+
:depth: 2
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
Overview
|
|
12
|
+
--------
|
|
13
|
+
|
|
14
|
+
ezmsg-learn provides machine learning components that integrate with ezmsg pipelines.
|
|
15
|
+
Key features include:
|
|
16
|
+
|
|
17
|
+
- **Pre-trained models**: Load and apply existing classifiers
|
|
18
|
+
- **Online learning**: Update models incrementally with streaming data
|
|
19
|
+
- **Flexible backends**: Support for scikit-learn, PyTorch, and River models
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
Available Classifiers
|
|
23
|
+
---------------------
|
|
24
|
+
|
|
25
|
+
ezmsg-learn includes several classifier types:
|
|
26
|
+
|
|
27
|
+
.. list-table::
|
|
28
|
+
:header-rows: 1
|
|
29
|
+
:widths: 25 40 35
|
|
30
|
+
|
|
31
|
+
* - Classifier
|
|
32
|
+
- Description
|
|
33
|
+
- Use Case
|
|
34
|
+
* - ``SLDA``
|
|
35
|
+
- Shrinkage Linear Discriminant Analysis
|
|
36
|
+
- BCI, small datasets
|
|
37
|
+
* - ``SklearnModelUnit``
|
|
38
|
+
- Wrapper for any scikit-learn model
|
|
39
|
+
- General ML tasks
|
|
40
|
+
* - ``SGDClassifier``
|
|
41
|
+
- Stochastic Gradient Descent
|
|
42
|
+
- Online learning
|
|
43
|
+
* - ``MLPUnit``
|
|
44
|
+
- Multi-layer Perceptron (PyTorch)
|
|
45
|
+
- Complex patterns
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
Using a Pre-Trained SLDA Classifier
|
|
49
|
+
-----------------------------------
|
|
50
|
+
|
|
51
|
+
The simplest approach is to use a pre-trained model:
|
|
52
|
+
|
|
53
|
+
.. code-block:: python
|
|
54
|
+
|
|
55
|
+
from ezmsg.learn.process.slda import SLDA, SLDASettings
|
|
56
|
+
|
|
57
|
+
classifier = SLDA(
|
|
58
|
+
SLDASettings(
|
|
59
|
+
settings_path="path/to/trained_model.pkl",
|
|
60
|
+
axis="time", # Axis containing samples
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
**Input format**: ``AxisArray[time, features]`` where features are flattened from your pipeline.
|
|
65
|
+
|
|
66
|
+
**Output format**: ``ClassifierMessage[time, classes]`` with class probabilities.
|
|
67
|
+
|
|
68
|
+
Training an SLDA model (offline):
|
|
69
|
+
|
|
70
|
+
.. code-block:: python
|
|
71
|
+
|
|
72
|
+
import pickle
|
|
73
|
+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
|
|
74
|
+
|
|
75
|
+
# Train offline with your data
|
|
76
|
+
X_train = ... # shape: (n_samples, n_features)
|
|
77
|
+
y_train = ... # shape: (n_samples,)
|
|
78
|
+
|
|
79
|
+
lda = LDA(solver="lsqr", shrinkage="auto")
|
|
80
|
+
lda.fit(X_train, y_train)
|
|
81
|
+
|
|
82
|
+
# Save for use in ezmsg
|
|
83
|
+
with open("trained_model.pkl", "wb") as f:
|
|
84
|
+
pickle.dump(lda, f)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
Using Scikit-Learn Models
|
|
88
|
+
-------------------------
|
|
89
|
+
|
|
90
|
+
``SklearnModelUnit`` wraps any scikit-learn compatible model:
|
|
91
|
+
|
|
92
|
+
.. code-block:: python
|
|
93
|
+
|
|
94
|
+
from ezmsg.learn.process.sklearn import SklearnModelUnit, SklearnModelSettings
|
|
95
|
+
import numpy as np
|
|
96
|
+
|
|
97
|
+
classifier = SklearnModelUnit(
|
|
98
|
+
SklearnModelSettings(
|
|
99
|
+
model_class="sklearn.linear_model.SGDClassifier",
|
|
100
|
+
model_kwargs={
|
|
101
|
+
"loss": "log_loss", # For probability outputs
|
|
102
|
+
"warm_start": True,
|
|
103
|
+
},
|
|
104
|
+
partial_fit_classes=np.array([0, 1]), # Required for online learning
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
Loading a pre-trained model:
|
|
109
|
+
|
|
110
|
+
.. code-block:: python
|
|
111
|
+
|
|
112
|
+
classifier = SklearnModelUnit(
|
|
113
|
+
SklearnModelSettings(
|
|
114
|
+
model_class="sklearn.linear_model.SGDClassifier",
|
|
115
|
+
checkpoint_path="path/to/saved_model.pkl",
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
Online Learning
|
|
121
|
+
---------------
|
|
122
|
+
|
|
123
|
+
For models that support ``partial_fit``, you can update them during streaming:
|
|
124
|
+
|
|
125
|
+
.. code-block:: python
|
|
126
|
+
|
|
127
|
+
from ezmsg.learn.process.sklearn import SklearnModelProcessor, SklearnModelSettings
|
|
128
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
129
|
+
|
|
130
|
+
# Create processor with online learning support
|
|
131
|
+
processor = SklearnModelProcessor(
|
|
132
|
+
settings=SklearnModelSettings(
|
|
133
|
+
model_class="sklearn.linear_model.SGDClassifier",
|
|
134
|
+
model_kwargs={"loss": "log_loss"},
|
|
135
|
+
partial_fit_classes=np.array([0, 1]),
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Training with labeled samples
|
|
140
|
+
sample_msg = SampleMessage(
|
|
141
|
+
sample=feature_array, # AxisArray with features
|
|
142
|
+
trigger=label_value, # The class label
|
|
143
|
+
)
|
|
144
|
+
processor.partial_fit(sample_msg)
|
|
145
|
+
|
|
146
|
+
# Prediction (after training)
|
|
147
|
+
prediction = processor(input_features)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
Complete Pipeline Example
|
|
151
|
+
-------------------------
|
|
152
|
+
|
|
153
|
+
Here's a complete BCI classification pipeline:
|
|
154
|
+
|
|
155
|
+
.. code-block:: python
|
|
156
|
+
|
|
157
|
+
import ezmsg.core as ez
|
|
158
|
+
from ezmsg.lsl.inlet import LSLInletUnit, LSLInletSettings, LSLInfo
|
|
159
|
+
from ezmsg.lsl.outlet import LSLOutletUnit, LSLOutletSettings
|
|
160
|
+
from ezmsg.sigproc.butterworthfilter import ButterworthFilter, ButterworthFilterSettings
|
|
161
|
+
from ezmsg.sigproc.window import Window, WindowSettings
|
|
162
|
+
from ezmsg.sigproc.spectrum import Spectrum, SpectrumSettings
|
|
163
|
+
from ezmsg.sigproc.aggregate import RangedAggregate, RangedAggregateSettings, AggregationFunction
|
|
164
|
+
from ezmsg.learn.process.slda import SLDA, SLDASettings
|
|
165
|
+
|
|
166
|
+
components = {
|
|
167
|
+
# Data acquisition
|
|
168
|
+
"LSL_IN": LSLInletUnit(
|
|
169
|
+
LSLInletSettings(info=LSLInfo(name="EEG", type="EEG"))
|
|
170
|
+
),
|
|
171
|
+
|
|
172
|
+
# Signal processing
|
|
173
|
+
"FILTER": ButterworthFilter(
|
|
174
|
+
ButterworthFilterSettings(order=4, cuton=8.0, cutoff=30.0)
|
|
175
|
+
),
|
|
176
|
+
"WINDOW": Window(
|
|
177
|
+
WindowSettings(window_dur=1.0, window_shift=0.5)
|
|
178
|
+
),
|
|
179
|
+
"SPECTRUM": Spectrum(SpectrumSettings(window="hann")),
|
|
180
|
+
"BANDPOWER": RangedAggregate(
|
|
181
|
+
RangedAggregateSettings(
|
|
182
|
+
axis="freq",
|
|
183
|
+
bands=[(8.0, 12.0), (18.0, 25.0)],
|
|
184
|
+
operation=AggregationFunction.MEAN,
|
|
185
|
+
)
|
|
186
|
+
),
|
|
187
|
+
|
|
188
|
+
# Classification
|
|
189
|
+
"CLASSIFIER": SLDA(
|
|
190
|
+
SLDASettings(settings_path="model.pkl", axis="time")
|
|
191
|
+
),
|
|
192
|
+
|
|
193
|
+
# Output
|
|
194
|
+
"LSL_OUT": LSLOutletUnit(
|
|
195
|
+
LSLOutletSettings(stream_name="Predictions", stream_type="Markers")
|
|
196
|
+
),
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
connections = (
|
|
200
|
+
(components["LSL_IN"].OUTPUT_SIGNAL, components["FILTER"].INPUT_SIGNAL),
|
|
201
|
+
(components["FILTER"].OUTPUT_SIGNAL, components["WINDOW"].INPUT_SIGNAL),
|
|
202
|
+
(components["WINDOW"].OUTPUT_SIGNAL, components["SPECTRUM"].INPUT_SIGNAL),
|
|
203
|
+
(components["SPECTRUM"].OUTPUT_SIGNAL, components["BANDPOWER"].INPUT_SIGNAL),
|
|
204
|
+
(components["BANDPOWER"].OUTPUT_SIGNAL, components["CLASSIFIER"].INPUT_SIGNAL),
|
|
205
|
+
(components["CLASSIFIER"].OUTPUT_SIGNAL, components["LSL_OUT"].INPUT_SIGNAL),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if __name__ == "__main__":
|
|
209
|
+
ez.run(components=components, connections=connections)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
Feature Preparation
|
|
213
|
+
-------------------
|
|
214
|
+
|
|
215
|
+
Classifiers expect flattened 2D input ``[samples, features]``. Multi-dimensional arrays
|
|
216
|
+
are automatically flattened along the channel dimension.
|
|
217
|
+
|
|
218
|
+
For example, if your bandpower output is ``[time=1, band=2, ch=8]``:
|
|
219
|
+
|
|
220
|
+
- The classifier receives shape ``[1, 16]`` (2 bands × 8 channels)
|
|
221
|
+
- Features are flattened in C-order (row-major)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
Output Format
|
|
225
|
+
-------------
|
|
226
|
+
|
|
227
|
+
Classification outputs use ``ClassifierMessage``, which extends ``AxisArray`` with:
|
|
228
|
+
|
|
229
|
+
- **dims**: ``["time", "classes"]``
|
|
230
|
+
- **data**: Probability scores for each class
|
|
231
|
+
- **labels**: List of class names/identifiers
|
|
232
|
+
|
|
233
|
+
Example output shape: ``[time=1, classes=2]`` with probabilities for each class.
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
Tips for Better Performance
|
|
237
|
+
---------------------------
|
|
238
|
+
|
|
239
|
+
1. **Normalize features**: Use ``Scaler`` from ezmsg-sigproc before classification
|
|
240
|
+
|
|
241
|
+
.. code-block:: python
|
|
242
|
+
|
|
243
|
+
from ezmsg.sigproc.scaler import Scaler, ScalerSettings
|
|
244
|
+
scaler = Scaler(ScalerSettings(mode="zscore"))
|
|
245
|
+
|
|
246
|
+
2. **Match training conditions**: Ensure online features match offline training preprocessing
|
|
247
|
+
|
|
248
|
+
3. **Window size**: Larger windows give more stable features but higher latency
|
|
249
|
+
|
|
250
|
+
4. **Feature selection**: Start with relevant frequency bands for your application
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
Troubleshooting
|
|
254
|
+
---------------
|
|
255
|
+
|
|
256
|
+
**"Model has not been fit yet"**:
|
|
257
|
+
The model needs training data before prediction. Either:
|
|
258
|
+
- Provide a ``checkpoint_path`` with a pre-trained model
|
|
259
|
+
- Call ``fit()`` or ``partial_fit()`` before processing
|
|
260
|
+
|
|
261
|
+
**Shape mismatch errors**:
|
|
262
|
+
- Verify input feature dimensions match trained model
|
|
263
|
+
- Check ``n_features_in_`` attribute of loaded models
|
|
264
|
+
|
|
265
|
+
**NaN in predictions**:
|
|
266
|
+
- Ensure input features don't contain NaN values
|
|
267
|
+
- Check for numerical stability in preprocessing
|