ezmsg-learn 1.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. ezmsg_learn-1.2.0/.github/workflows/docs.yml +65 -0
  2. ezmsg_learn-1.0/.github/workflows/python-publish-ezmsg-learn.yml → ezmsg_learn-1.2.0/.github/workflows/python-publish.yml +1 -1
  3. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/.github/workflows/python-tests.yml +8 -3
  4. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/.gitignore +2 -0
  5. ezmsg_learn-1.2.0/LICENSE +21 -0
  6. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/PKG-INFO +5 -9
  7. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/README.md +1 -7
  8. ezmsg_learn-1.2.0/docs/Makefile +20 -0
  9. ezmsg_learn-1.2.0/docs/make.bat +35 -0
  10. ezmsg_learn-1.2.0/docs/source/_templates/autosummary/module.rst +64 -0
  11. ezmsg_learn-1.2.0/docs/source/api/index.rst +11 -0
  12. ezmsg_learn-1.2.0/docs/source/conf.py +123 -0
  13. ezmsg_learn-1.2.0/docs/source/guides/classification.rst +267 -0
  14. ezmsg_learn-1.2.0/docs/source/index.rst +64 -0
  15. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/pyproject.toml +29 -1
  16. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/__version__.py +2 -2
  17. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +9 -19
  18. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/incremental_decomp.py +8 -16
  19. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
  20. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/linear_regressor.py +4 -0
  21. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/sgd.py +6 -2
  22. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/slda.py +7 -1
  23. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/mlp.py +8 -14
  24. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/refit_kalman.py +17 -49
  25. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/nlin_model/mlp.py +5 -1
  26. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/adaptive_linear_regressor.py +20 -36
  27. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/base.py +12 -31
  28. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/linear_regressor.py +13 -18
  29. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/mlp_old.py +18 -31
  30. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/refit_kalman.py +8 -13
  31. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/rnn.py +14 -36
  32. ezmsg_learn-1.2.0/src/ezmsg/learn/process/sgd.py +116 -0
  33. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/sklearn.py +17 -51
  34. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/slda.py +6 -15
  35. ezmsg_learn-1.2.0/src/ezmsg/learn/process/ssr.py +374 -0
  36. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/torch.py +12 -29
  37. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/transformer.py +11 -19
  38. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/util.py +5 -4
  39. ezmsg_learn-1.2.0/tests/benchmark/bench_lrr.py +317 -0
  40. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/dim_reduce/test_adaptive_decomp.py +10 -22
  41. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/dim_reduce/test_incremental_decomp.py +10 -19
  42. ezmsg_learn-1.2.0/tests/integration/conftest.py +39 -0
  43. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/integration/test_mlp_system.py +3 -13
  44. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/integration/test_refit_kalman_system.py +15 -16
  45. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/integration/test_rnn_system.py +3 -13
  46. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/integration/test_sklearn_system.py +2 -8
  47. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/integration/test_torch_system.py +3 -13
  48. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/integration/test_transformer_system.py +3 -13
  49. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_adaptive_linear_regressor.py +4 -7
  50. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_linear_regressor.py +4 -7
  51. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_mlp.py +9 -9
  52. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_mlp_old.py +15 -22
  53. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_refit_kalman.py +26 -60
  54. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_rnn.py +19 -25
  55. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_sgd.py +18 -15
  56. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_sklearn.py +12 -13
  57. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_slda.py +6 -4
  58. ezmsg_learn-1.2.0/tests/unit/test_ssr.py +324 -0
  59. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_torch.py +12 -21
  60. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/tests/unit/test_transformer.py +17 -19
  61. ezmsg_learn-1.0/src/ezmsg/learn/process/sgd.py +0 -131
  62. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/.pre-commit-config.yaml +0 -0
  63. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/__init__.py +0 -0
  64. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/__init__.py +0 -0
  65. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/__init__.py +0 -0
  66. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/cca.py +0 -0
  67. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/__init__.py +0 -0
  68. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/cca.py +0 -0
  69. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/mlp_old.py +0 -0
  70. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/rnn.py +0 -0
  71. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/transformer.py +0 -0
  72. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/nlin_model/__init__.py +0 -0
  73. {ezmsg_learn-1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/__init__.py +0 -0
@@ -0,0 +1,65 @@
1
+ name: Documentation
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ tags:
8
+ - 'v*'
9
+ pull_request:
10
+ branches:
11
+ - dev
12
+ workflow_dispatch:
13
+
14
+ permissions:
15
+ contents: read
16
+ pages: write
17
+ id-token: write
18
+
19
+ # Allow only one concurrent deployment
20
+ concurrency:
21
+ group: "pages"
22
+ cancel-in-progress: false
23
+
24
+ jobs:
25
+ build:
26
+ runs-on: ubuntu-latest
27
+ steps:
28
+ - uses: actions/checkout@v4
29
+ with:
30
+ fetch-depth: 0 # Needed for hatch-vcs to determine version
31
+
32
+ - name: Install uv
33
+ uses: astral-sh/setup-uv@v6
34
+ with:
35
+ enable-cache: true
36
+ python-version: "3.12"
37
+
38
+ - name: Install the project
39
+ run: uv sync --only-group docs
40
+
41
+ - name: Build documentation
42
+ run: |
43
+ cd docs
44
+ uv run make html
45
+
46
+ - name: Add .nojekyll file
47
+ run: touch docs/build/html/.nojekyll
48
+
49
+ - name: Upload artifact
50
+ uses: actions/upload-pages-artifact@v3
51
+ with:
52
+ path: 'docs/build/html'
53
+
54
+ deploy:
55
+ # Only deploy on push to main or release tags
56
+ if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v'))
57
+ environment:
58
+ name: github-pages
59
+ url: ${{ steps.deployment.outputs.page_url }}
60
+ runs-on: ubuntu-latest
61
+ needs: build
62
+ steps:
63
+ - name: Deploy to GitHub Pages
64
+ id: deployment
65
+ uses: actions/deploy-pages@v4
@@ -1,4 +1,4 @@
1
- name: Upload Python Package - ezmsg-learn
1
+ name: Upload Python Package
2
2
 
3
3
  on:
4
4
  release:
@@ -15,7 +15,7 @@ jobs:
15
15
  build:
16
16
  strategy:
17
17
  matrix:
18
- python-version: ["3.12"]
18
+ python-version: ["3.12", "3.13"]
19
19
  os:
20
20
  - "ubuntu-latest"
21
21
  - "windows-latest"
@@ -37,5 +37,10 @@ jobs:
37
37
  run:
38
38
  uv tool run ruff check --output-format=github src
39
39
 
40
- - name: Run tests
41
- run: uv run pytest tests
40
+ - name: Run unit tests
41
+ run: uv run pytest tests/unit tests/dim_reduce -v
42
+
43
+ - name: Run integration tests
44
+ run: uv run pytest tests/integration -v --tb=long
45
+ env:
46
+ PYTHONFAULTHANDLER: 1
@@ -70,6 +70,8 @@ instance/
70
70
 
71
71
  # Sphinx documentation
72
72
  docs/_build/
73
+ docs/build/
74
+ docs/source/api/generated
73
75
 
74
76
  # PyBuilder
75
77
  .pybuilder/
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 ezmsg-org
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -1,11 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-learn
3
- Version: 1.0
3
+ Version: 1.2.0
4
4
  Summary: ezmsg namespace package for machine learning
5
5
  Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
6
6
  License-Expression: MIT
7
+ License-File: LICENSE
7
8
  Requires-Python: >=3.10.15
8
- Requires-Dist: ezmsg-sigproc
9
+ Requires-Dist: ezmsg-baseproc>=1.3.0
10
+ Requires-Dist: ezmsg-sigproc>=2.15.0
9
11
  Requires-Dist: river>=0.22.0
10
12
  Requires-Dist: scikit-learn>=1.6.0
11
13
  Requires-Dist: torch>=2.6.0
@@ -24,11 +26,5 @@ Processing units include dimensionality reduction, linear regression, and classi
24
26
  This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
25
27
 
26
28
  ```bash
27
- pip install git+ssh://git@github.com/ezmsg-org/ezmsg-learn
28
- ```
29
-
30
- Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
31
-
32
- ```bash
33
- pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
29
+ pip install git+https://github.com/ezmsg-org/ezmsg-learn
34
30
  ```
@@ -11,11 +11,5 @@ Processing units include dimensionality reduction, linear regression, and classi
11
11
  This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
12
12
 
13
13
  ```bash
14
- pip install git+ssh://git@github.com/ezmsg-org/ezmsg-learn
15
- ```
16
-
17
- Note that this package depends on a specific version of `ezmsg-sigproc` (specifically, [this branch]("70-use-protocols-for-axisarray-transformers")) that has yet to be merged and released. This may conflict with your project's separate dependency on ezmsg-sigproc. However, this specific version of ezmsg-sigproc should be backwards compatible with its main branch, so in your project you can modify the dependency on ezmsg-sigproc to point to the new branch. e.g.,
18
-
19
- ```bash
20
- pip install git+ssh://git@github.com/ezmsg-org/ezmsg-sigproc@70-use-protocols-for-axisarray-transformers
14
+ pip install git+https://github.com/ezmsg-org/ezmsg-learn
21
15
  ```
@@ -0,0 +1,20 @@
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@@ -0,0 +1,35 @@
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ if "%1" == "" goto help
14
+
15
+ %SPHINXBUILD% >NUL 2>NUL
16
+ if errorlevel 9009 (
17
+ echo.
18
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19
+ echo.installed, then set the SPHINXBUILD environment variable to point
20
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
21
+ echo.may add the Sphinx directory to PATH.
22
+ echo.
23
+ echo.If you don't have Sphinx installed, grab it from
24
+ echo.http://sphinx-doc.org/
25
+ exit /b 1
26
+ )
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
@@ -0,0 +1,64 @@
1
+ {{ fullname | escape | underline}}
2
+
3
+ .. automodule:: {{ fullname }}
4
+
5
+ {% block attributes %}
6
+ {% if attributes %}
7
+ .. rubric:: Module Attributes
8
+
9
+ .. autosummary::
10
+ :toctree:
11
+ {% for item in attributes %}
12
+ {{ item }}
13
+ {%- endfor %}
14
+ {% endif %}
15
+ {% endblock %}
16
+
17
+ {% block functions %}
18
+ {% if functions %}
19
+ .. rubric:: Functions
20
+
21
+ {% for item in functions %}
22
+ .. autofunction:: {{ item }}
23
+ {%- endfor %}
24
+ {% endif %}
25
+ {% endblock %}
26
+
27
+ {% block classes %}
28
+ {% if classes %}
29
+ .. rubric:: Classes
30
+
31
+ {% for item in classes %}
32
+ .. autoclass:: {{ item }}
33
+ :members:
34
+ :undoc-members:
35
+ :show-inheritance:
36
+ :special-members: __init__
37
+ {%- endfor %}
38
+ {% endif %}
39
+ {% endblock %}
40
+
41
+ {% block exceptions %}
42
+ {% if exceptions %}
43
+ .. rubric:: Exceptions
44
+
45
+ {% for item in exceptions %}
46
+ .. autoexception:: {{ item }}
47
+ :members:
48
+ :show-inheritance:
49
+ {%- endfor %}
50
+ {% endif %}
51
+ {% endblock %}
52
+
53
+ {% block modules %}
54
+ {% if modules %}
55
+ .. rubric:: Modules
56
+
57
+ .. autosummary::
58
+ :toctree:
59
+ :recursive:
60
+ {% for item in modules %}
61
+ {{ item }}
62
+ {%- endfor %}
63
+ {% endif %}
64
+ {% endblock %}
@@ -0,0 +1,11 @@
1
+ API Reference
2
+ =============
3
+
4
+ This page contains auto-generated API reference documentation.
5
+
6
+ .. autosummary::
7
+ :toctree: generated
8
+ :recursive:
9
+ :template: autosummary/module.rst
10
+
11
+ ezmsg.learn
@@ -0,0 +1,123 @@
1
+ # Configuration file for the Sphinx documentation builder.
2
+
3
+ import os
4
+ import sys
5
+
6
+ # Add the source directory to the path
7
+ sys.path.insert(0, os.path.abspath("../../src"))
8
+
9
+ # -- Project information --------------------------
10
+
11
+ project = "ezmsg.learn"
12
+ copyright = "2024, ezmsg Contributors"
13
+ author = "ezmsg Contributors"
14
+
15
+ # The version is managed by hatch-vcs and stored in __version__.py
16
+ try:
17
+ from ezmsg.learn.__version__ import version as release
18
+ except ImportError:
19
+ release = "unknown"
20
+
21
+ # For display purposes, extract the base version without git commit info
22
+ version = release.split("+")[0] if release != "unknown" else release
23
+
24
+ # -- General configuration --------------------------
25
+
26
+ extensions = [
27
+ "sphinx.ext.autodoc",
28
+ "sphinx.ext.autosummary",
29
+ "sphinx.ext.napoleon",
30
+ "sphinx.ext.intersphinx",
31
+ "sphinx.ext.viewcode",
32
+ "sphinx.ext.duration",
33
+ "sphinx_autodoc_typehints",
34
+ "sphinx_copybutton",
35
+ ]
36
+
37
+ templates_path = ["_templates"]
38
+ source_suffix = [".rst"]
39
+ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
40
+
41
+ # The toctree master document
42
+ master_doc = "index"
43
+
44
+ # -- Autodoc configuration ------------------------------
45
+
46
+ # Auto-generate API docs
47
+ autosummary_generate = True
48
+ autosummary_imported_members = False
49
+ autodoc_typehints = "description"
50
+ autodoc_member_order = "bysource"
51
+ autodoc_typehints_format = "short"
52
+ python_use_unqualified_type_names = True
53
+ autodoc_default_options = {
54
+ "members": True,
55
+ "member-order": "bysource",
56
+ "special-members": "__init__",
57
+ "undoc-members": True,
58
+ "show-inheritance": True,
59
+ }
60
+
61
+ # Don't show the full module path in the docs
62
+ add_module_names = False
63
+
64
+ # -- Intersphinx configuration --------------------------
65
+
66
+ intersphinx_mapping = {
67
+ "python": ("https://docs.python.org/3/", None),
68
+ "numpy": ("https://numpy.org/doc/stable/", None),
69
+ "torch": ("https://pytorch.org/docs/stable/", None),
70
+ "sklearn": ("https://scikit-learn.org/stable/", None),
71
+ "ezmsg": ("https://www.ezmsg.org/ezmsg/", None),
72
+ "ezmsg.sigproc": ("https://www.ezmsg.org/ezmsg-sigproc/", None),
73
+ "ezmsg.event": ("https://www.ezmsg.org/ezmsg-event/", None),
74
+ "ezmsg.lsl": ("https://www.ezmsg.org/ezmsg-lsl/", None),
75
+ }
76
+ intersphinx_disabled_domains = ["std"]
77
+
78
+ # -- Options for HTML output -----------------------------
79
+
80
+ html_theme = "pydata_sphinx_theme"
81
+ html_static_path = ["_static"]
82
+
83
+ # Set the base URL for the documentation
84
+ html_baseurl = "https://www.ezmsg.org/ezmsg-learn/"
85
+
86
+ html_theme_options = {
87
+ "logo": {
88
+ "text": f"ezmsg.learn {version}",
89
+ "link": "https://ezmsg.org", # Link back to main site
90
+ },
91
+ "header_links_before_dropdown": 4,
92
+ "navbar_start": ["navbar-logo"],
93
+ "navbar_end": ["theme-switcher", "navbar-icon-links"],
94
+ "icon_links": [
95
+ {
96
+ "name": "GitHub",
97
+ "url": "https://github.com/ezmsg-org/ezmsg-learn",
98
+ "icon": "fa-brands fa-github",
99
+ },
100
+ {
101
+ "name": "ezmsg.org",
102
+ "url": "https://www.ezmsg.org",
103
+ "icon": "fa-solid fa-house",
104
+ },
105
+ ],
106
+ }
107
+
108
+ # Timestamp is inserted at every page bottom in this strftime format.
109
+ html_last_updated_fmt = "%Y-%m-%d"
110
+
111
+ # -- Options for linkcode -----------------------------
112
+
113
+ branch = "main"
114
+ code_url = f"https://github.com/ezmsg-org/ezmsg-learn/blob/{branch}/"
115
+
116
+
117
+ def linkcode_resolve(domain, info):
118
+ if domain != "py":
119
+ return None
120
+ if not info["module"]:
121
+ return None
122
+ filename = info["module"].replace(".", "/")
123
+ return f"{code_url}src/{filename}.py"
@@ -0,0 +1,267 @@
1
+ Real-Time Classification
2
+ ========================
3
+
4
+ This guide shows how to use ezmsg-learn for real-time classification in streaming pipelines.
5
+
6
+ .. contents:: On this page
7
+ :local:
8
+ :depth: 2
9
+
10
+
11
+ Overview
12
+ --------
13
+
14
+ ezmsg-learn provides machine learning components that integrate with ezmsg pipelines.
15
+ Key features include:
16
+
17
+ - **Pre-trained models**: Load and apply existing classifiers
18
+ - **Online learning**: Update models incrementally with streaming data
19
+ - **Flexible backends**: Support for scikit-learn, PyTorch, and River models
20
+
21
+
22
+ Available Classifiers
23
+ ---------------------
24
+
25
+ ezmsg-learn includes several classifier types:
26
+
27
+ .. list-table::
28
+ :header-rows: 1
29
+ :widths: 25 40 35
30
+
31
+ * - Classifier
32
+ - Description
33
+ - Use Case
34
+ * - ``SLDA``
35
+ - Shrinkage Linear Discriminant Analysis
36
+ - BCI, small datasets
37
+ * - ``SklearnModelUnit``
38
+ - Wrapper for any scikit-learn model
39
+ - General ML tasks
40
+ * - ``SGDClassifier``
41
+ - Stochastic Gradient Descent
42
+ - Online learning
43
+ * - ``MLPUnit``
44
+ - Multi-layer Perceptron (PyTorch)
45
+ - Complex patterns
46
+
47
+
48
+ Using a Pre-Trained SLDA Classifier
49
+ -----------------------------------
50
+
51
+ The simplest approach is to use a pre-trained model:
52
+
53
+ .. code-block:: python
54
+
55
+ from ezmsg.learn.process.slda import SLDA, SLDASettings
56
+
57
+ classifier = SLDA(
58
+ SLDASettings(
59
+ settings_path="path/to/trained_model.pkl",
60
+ axis="time", # Axis containing samples
61
+ )
62
+ )
63
+
64
+ **Input format**: ``AxisArray[time, features]`` where features are flattened from your pipeline.
65
+
66
+ **Output format**: ``ClassifierMessage[time, classes]`` with class probabilities.
67
+
68
+ Training an SLDA model (offline):
69
+
70
+ .. code-block:: python
71
+
72
+ import pickle
73
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
74
+
75
+ # Train offline with your data
76
+ X_train = ... # shape: (n_samples, n_features)
77
+ y_train = ... # shape: (n_samples,)
78
+
79
+ lda = LDA(solver="lsqr", shrinkage="auto")
80
+ lda.fit(X_train, y_train)
81
+
82
+ # Save for use in ezmsg
83
+ with open("trained_model.pkl", "wb") as f:
84
+ pickle.dump(lda, f)
85
+
86
+
87
+ Using Scikit-Learn Models
88
+ -------------------------
89
+
90
+ ``SklearnModelUnit`` wraps any scikit-learn compatible model:
91
+
92
+ .. code-block:: python
93
+
94
+ from ezmsg.learn.process.sklearn import SklearnModelUnit, SklearnModelSettings
95
+ import numpy as np
96
+
97
+ classifier = SklearnModelUnit(
98
+ SklearnModelSettings(
99
+ model_class="sklearn.linear_model.SGDClassifier",
100
+ model_kwargs={
101
+ "loss": "log_loss", # For probability outputs
102
+ "warm_start": True,
103
+ },
104
+ partial_fit_classes=np.array([0, 1]), # Required for online learning
105
+ )
106
+ )
107
+
108
+ Loading a pre-trained model:
109
+
110
+ .. code-block:: python
111
+
112
+ classifier = SklearnModelUnit(
113
+ SklearnModelSettings(
114
+ model_class="sklearn.linear_model.SGDClassifier",
115
+ checkpoint_path="path/to/saved_model.pkl",
116
+ )
117
+ )
118
+
119
+
120
+ Online Learning
121
+ ---------------
122
+
123
+ For models that support ``partial_fit``, you can update them during streaming:
124
+
125
+ .. code-block:: python
126
+
127
+ from ezmsg.learn.process.sklearn import SklearnModelProcessor, SklearnModelSettings
128
+ from ezmsg.sigproc.sampler import SampleMessage
129
+
130
+ # Create processor with online learning support
131
+ processor = SklearnModelProcessor(
132
+ settings=SklearnModelSettings(
133
+ model_class="sklearn.linear_model.SGDClassifier",
134
+ model_kwargs={"loss": "log_loss"},
135
+ partial_fit_classes=np.array([0, 1]),
136
+ )
137
+ )
138
+
139
+ # Training with labeled samples
140
+ sample_msg = SampleMessage(
141
+ sample=feature_array, # AxisArray with features
142
+ trigger=label_value, # The class label
143
+ )
144
+ processor.partial_fit(sample_msg)
145
+
146
+ # Prediction (after training)
147
+ prediction = processor(input_features)
148
+
149
+
150
+ Complete Pipeline Example
151
+ -------------------------
152
+
153
+ Here's a complete BCI classification pipeline:
154
+
155
+ .. code-block:: python
156
+
157
+ import ezmsg.core as ez
158
+ from ezmsg.lsl.inlet import LSLInletUnit, LSLInletSettings, LSLInfo
159
+ from ezmsg.lsl.outlet import LSLOutletUnit, LSLOutletSettings
160
+ from ezmsg.sigproc.butterworthfilter import ButterworthFilter, ButterworthFilterSettings
161
+ from ezmsg.sigproc.window import Window, WindowSettings
162
+ from ezmsg.sigproc.spectrum import Spectrum, SpectrumSettings
163
+ from ezmsg.sigproc.aggregate import RangedAggregate, RangedAggregateSettings, AggregationFunction
164
+ from ezmsg.learn.process.slda import SLDA, SLDASettings
165
+
166
+ components = {
167
+ # Data acquisition
168
+ "LSL_IN": LSLInletUnit(
169
+ LSLInletSettings(info=LSLInfo(name="EEG", type="EEG"))
170
+ ),
171
+
172
+ # Signal processing
173
+ "FILTER": ButterworthFilter(
174
+ ButterworthFilterSettings(order=4, cuton=8.0, cutoff=30.0)
175
+ ),
176
+ "WINDOW": Window(
177
+ WindowSettings(window_dur=1.0, window_shift=0.5)
178
+ ),
179
+ "SPECTRUM": Spectrum(SpectrumSettings(window="hann")),
180
+ "BANDPOWER": RangedAggregate(
181
+ RangedAggregateSettings(
182
+ axis="freq",
183
+ bands=[(8.0, 12.0), (18.0, 25.0)],
184
+ operation=AggregationFunction.MEAN,
185
+ )
186
+ ),
187
+
188
+ # Classification
189
+ "CLASSIFIER": SLDA(
190
+ SLDASettings(settings_path="model.pkl", axis="time")
191
+ ),
192
+
193
+ # Output
194
+ "LSL_OUT": LSLOutletUnit(
195
+ LSLOutletSettings(stream_name="Predictions", stream_type="Markers")
196
+ ),
197
+ }
198
+
199
+ connections = (
200
+ (components["LSL_IN"].OUTPUT_SIGNAL, components["FILTER"].INPUT_SIGNAL),
201
+ (components["FILTER"].OUTPUT_SIGNAL, components["WINDOW"].INPUT_SIGNAL),
202
+ (components["WINDOW"].OUTPUT_SIGNAL, components["SPECTRUM"].INPUT_SIGNAL),
203
+ (components["SPECTRUM"].OUTPUT_SIGNAL, components["BANDPOWER"].INPUT_SIGNAL),
204
+ (components["BANDPOWER"].OUTPUT_SIGNAL, components["CLASSIFIER"].INPUT_SIGNAL),
205
+ (components["CLASSIFIER"].OUTPUT_SIGNAL, components["LSL_OUT"].INPUT_SIGNAL),
206
+ )
207
+
208
+ if __name__ == "__main__":
209
+ ez.run(components=components, connections=connections)
210
+
211
+
212
+ Feature Preparation
213
+ -------------------
214
+
215
+ Classifiers expect flattened 2D input ``[samples, features]``. Multi-dimensional arrays
216
+ are automatically flattened along the channel dimension.
217
+
218
+ For example, if your bandpower output is ``[time=1, band=2, ch=8]``:
219
+
220
+ - The classifier receives shape ``[1, 16]`` (2 bands × 8 channels)
221
+ - Features are flattened in C-order (row-major)
222
+
223
+
224
+ Output Format
225
+ -------------
226
+
227
+ Classification outputs use ``ClassifierMessage``, which extends ``AxisArray`` with:
228
+
229
+ - **dims**: ``["time", "classes"]``
230
+ - **data**: Probability scores for each class
231
+ - **labels**: List of class names/identifiers
232
+
233
+ Example output shape: ``[time=1, classes=2]`` with probabilities for each class.
234
+
235
+
236
+ Tips for Better Performance
237
+ ---------------------------
238
+
239
+ 1. **Normalize features**: Use ``Scaler`` from ezmsg-sigproc before classification
240
+
241
+ .. code-block:: python
242
+
243
+ from ezmsg.sigproc.scaler import Scaler, ScalerSettings
244
+ scaler = Scaler(ScalerSettings(mode="zscore"))
245
+
246
+ 2. **Match training conditions**: Ensure online features match offline training preprocessing
247
+
248
+ 3. **Window size**: Larger windows give more stable features but higher latency
249
+
250
+ 4. **Feature selection**: Start with relevant frequency bands for your application
251
+
252
+
253
+ Troubleshooting
254
+ ---------------
255
+
256
+ **"Model has not been fit yet"**:
257
+ The model needs training data before prediction. Either:
258
+ - Provide a ``checkpoint_path`` with a pre-trained model
259
+ - Call ``fit()`` or ``partial_fit()`` before processing
260
+
261
+ **Shape mismatch errors**:
262
+ - Verify input feature dimensions match trained model
263
+ - Check ``n_features_in_`` attribute of loaded models
264
+
265
+ **NaN in predictions**:
266
+ - Ensure input features don't contain NaN values
267
+ - Check for numerical stability in preprocessing