drinx 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- drinx-1.0.0/.github/actions/setup-python-env/action.yml +30 -0
- drinx-1.0.0/.github/workflows/cicd.yml +72 -0
- drinx-1.0.0/.github/workflows/publish.yml +42 -0
- drinx-1.0.0/.gitignore +9 -0
- drinx-1.0.0/.pre-commit-config.yaml +34 -0
- drinx-1.0.0/.readthedocs.yml +14 -0
- drinx-1.0.0/CLAUDE.md +61 -0
- drinx-1.0.0/LICENSE +21 -0
- drinx-1.0.0/PKG-INFO +190 -0
- drinx-1.0.0/README.md +161 -0
- drinx-1.0.0/codecov.yml +15 -0
- drinx-1.0.0/docs/requirements.txt +20 -0
- drinx-1.0.0/docs/source/_static/android-chrome-192x192.png +0 -0
- drinx-1.0.0/docs/source/_static/drinx.png +0 -0
- drinx-1.0.0/docs/source/_static/drinx_text.png +0 -0
- drinx-1.0.0/docs/source/_static/drinx_text.svg +62 -0
- drinx-1.0.0/docs/source/_static/favicon.ico +0 -0
- drinx-1.0.0/docs/source/_templates/autosummary/class.rst +31 -0
- drinx-1.0.0/docs/source/api.rst +12 -0
- drinx-1.0.0/docs/source/conf.py +86 -0
- drinx-1.0.0/docs/source/examples/basic_usage.ipynb +509 -0
- drinx-1.0.0/docs/source/index.rst +161 -0
- drinx-1.0.0/pyproject.toml +48 -0
- drinx-1.0.0/src/drinx/__init__.py +13 -0
- drinx-1.0.0/src/drinx/attribute.py +184 -0
- drinx-1.0.0/src/drinx/base.py +404 -0
- drinx-1.0.0/src/drinx/transform.py +166 -0
- drinx-1.0.0/tests/test_at_set.py +124 -0
- drinx-1.0.0/tests/test_base.py +1805 -0
- drinx-1.0.0/tests/test_dataclass.py +633 -0
- drinx-1.0.0/tests/test_field.py +204 -0
- drinx-1.0.0/tests/test_jax_transforms.py +423 -0
- drinx-1.0.0/tests/test_readme_examples.py +189 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
name: "Setup Python Environment"
|
|
2
|
+
description: "Set up Python environment for the given Python version"
|
|
3
|
+
|
|
4
|
+
inputs:
|
|
5
|
+
python-version:
|
|
6
|
+
description: "Python version to use"
|
|
7
|
+
required: true
|
|
8
|
+
default: "3.12"
|
|
9
|
+
uv-version:
|
|
10
|
+
description: "uv version to use"
|
|
11
|
+
required: true
|
|
12
|
+
default: "0.5.8"
|
|
13
|
+
|
|
14
|
+
runs:
|
|
15
|
+
using: "composite"
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/setup-python@v5
|
|
18
|
+
with:
|
|
19
|
+
python-version: ${{ inputs.python-version }}
|
|
20
|
+
|
|
21
|
+
- name: Install uv
|
|
22
|
+
uses: astral-sh/setup-uv@v2
|
|
23
|
+
with:
|
|
24
|
+
version: ${{ inputs.uv-version }}
|
|
25
|
+
enable-cache: "true"
|
|
26
|
+
cache-suffix: ${{ matrix.python-version }}
|
|
27
|
+
|
|
28
|
+
- name: Install Python dependencies
|
|
29
|
+
run: uv sync --extra dev
|
|
30
|
+
shell: bash
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
name: CI-CD
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
pull_request:
|
|
5
|
+
types: [opened, synchronize, reopened]
|
|
6
|
+
push:
|
|
7
|
+
branches: [main]
|
|
8
|
+
workflow_dispatch:
|
|
9
|
+
|
|
10
|
+
permissions:
|
|
11
|
+
contents: read
|
|
12
|
+
|
|
13
|
+
jobs:
|
|
14
|
+
checks:
|
|
15
|
+
runs-on: ubuntu-latest
|
|
16
|
+
steps:
|
|
17
|
+
- name: Check out
|
|
18
|
+
uses: actions/checkout@v4
|
|
19
|
+
with:
|
|
20
|
+
persist-credentials: false
|
|
21
|
+
|
|
22
|
+
- uses: actions/cache@v4
|
|
23
|
+
with:
|
|
24
|
+
path: ~/.cache/pre-commit
|
|
25
|
+
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
|
26
|
+
|
|
27
|
+
- name: Set up the environment
|
|
28
|
+
uses: ./.github/actions/setup-python-env
|
|
29
|
+
|
|
30
|
+
- name: Run pre-commit
|
|
31
|
+
run: uv run pre-commit run -a --show-diff-on-failure
|
|
32
|
+
|
|
33
|
+
- name: build docs
|
|
34
|
+
run: uv run sphinx-build -W --keep-going docs/source/ docs/build/
|
|
35
|
+
|
|
36
|
+
tests:
|
|
37
|
+
runs-on: ubuntu-latest
|
|
38
|
+
permissions:
|
|
39
|
+
contents: read
|
|
40
|
+
id-token: write
|
|
41
|
+
strategy:
|
|
42
|
+
matrix:
|
|
43
|
+
python-version: ["3.11"]
|
|
44
|
+
fail-fast: false
|
|
45
|
+
defaults:
|
|
46
|
+
run:
|
|
47
|
+
shell: bash
|
|
48
|
+
steps:
|
|
49
|
+
- name: Check out
|
|
50
|
+
uses: actions/checkout@v4
|
|
51
|
+
with:
|
|
52
|
+
persist-credentials: false
|
|
53
|
+
|
|
54
|
+
- name: Set up the environment
|
|
55
|
+
uses: ./.github/actions/setup-python-env
|
|
56
|
+
with:
|
|
57
|
+
python-version: ${{ matrix.python-version }}
|
|
58
|
+
|
|
59
|
+
- name: Run tests
|
|
60
|
+
run: uv run python -m pytest tests --cov --cov-branch --cov-config=pyproject.toml --cov-report=xml
|
|
61
|
+
|
|
62
|
+
- name: Typecheck with ty
|
|
63
|
+
run: uvx ty check --error-on-warning
|
|
64
|
+
|
|
65
|
+
- name: Upload coverage reports to Codecov
|
|
66
|
+
uses: codecov/codecov-action@v4
|
|
67
|
+
if: ${{ matrix.python-version == '3.11' }}
|
|
68
|
+
with:
|
|
69
|
+
token: ${{ secrets.CODECOV_TOKEN }}
|
|
70
|
+
files: ./coverage.xml
|
|
71
|
+
fail_ci_if_error: true
|
|
72
|
+
slug: ymahlau/drinx
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
workflow_dispatch: # Allows manual triggering from Actions tab
|
|
7
|
+
|
|
8
|
+
permissions:
|
|
9
|
+
contents: read
|
|
10
|
+
|
|
11
|
+
jobs:
|
|
12
|
+
pypi-publish:
|
|
13
|
+
name: Build and publish Python 🐍 distributions to PyPI
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
# This permission is REQUIRED for Trusted Publishing
|
|
16
|
+
permissions:
|
|
17
|
+
id-token: write
|
|
18
|
+
contents: read
|
|
19
|
+
|
|
20
|
+
steps:
|
|
21
|
+
- name: Checkout code
|
|
22
|
+
uses: actions/checkout@v4
|
|
23
|
+
with:
|
|
24
|
+
persist-credentials: false
|
|
25
|
+
|
|
26
|
+
- name: Set up Python
|
|
27
|
+
uses: actions/setup-python@v5
|
|
28
|
+
with:
|
|
29
|
+
python-version: "3.12"
|
|
30
|
+
|
|
31
|
+
- name: Install build tools
|
|
32
|
+
run: pip install build
|
|
33
|
+
|
|
34
|
+
- name: Build package
|
|
35
|
+
run: python -m build
|
|
36
|
+
|
|
37
|
+
- name: Publish distribution 📦 to PyPI
|
|
38
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
39
|
+
with:
|
|
40
|
+
# We don't need a password/token input because we use
|
|
41
|
+
# Trusted Publishing (id-token permission above).
|
|
42
|
+
print-hash: true
|
drinx-1.0.0/.gitignore
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# See https://pre-commit.com for more information
|
|
2
|
+
# See https://pre-commit.com/hooks.html for more hooks
|
|
3
|
+
repos:
|
|
4
|
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
5
|
+
rev: v5.0.0
|
|
6
|
+
hooks:
|
|
7
|
+
- id: check-yaml
|
|
8
|
+
- id: check-added-large-files
|
|
9
|
+
- id: check-shebang-scripts-are-executable
|
|
10
|
+
- id: check-toml
|
|
11
|
+
- id: check-merge-conflict
|
|
12
|
+
args: [--assume-in-merge]
|
|
13
|
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
14
|
+
# Ruff version.
|
|
15
|
+
rev: v0.11.2
|
|
16
|
+
hooks:
|
|
17
|
+
# Run the linter.
|
|
18
|
+
- id: ruff
|
|
19
|
+
args:
|
|
20
|
+
- --fix
|
|
21
|
+
# Run the formatter.
|
|
22
|
+
- id: ruff-format
|
|
23
|
+
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
|
24
|
+
rev: v1.5.2
|
|
25
|
+
hooks:
|
|
26
|
+
- id: zizmor
|
|
27
|
+
- repo: local
|
|
28
|
+
hooks:
|
|
29
|
+
- id: ty
|
|
30
|
+
name: ty (via uvx)
|
|
31
|
+
entry: uvx ty check --error-on-warning
|
|
32
|
+
language: system
|
|
33
|
+
types: [python]
|
|
34
|
+
pass_filenames: false
|
drinx-1.0.0/CLAUDE.md
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# CLAUDE.md
|
|
2
|
+
|
|
3
|
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
4
|
+
|
|
5
|
+
## Project Overview
|
|
6
|
+
|
|
7
|
+
Drinx is a small Python library ("Dataclass Registry in JAX") that wraps Python's standard `dataclasses` module to make dataclasses compatible with JAX transformations (e.g., `jit`, `vmap`, `grad`). It does this by registering each decorated class as a JAX pytree node, with support for marking fields as "static" (excluded from JAX tracing) or "dynamic" (included as leaves).
|
|
8
|
+
|
|
9
|
+
## Commands
|
|
10
|
+
|
|
11
|
+
This project uses `uv` for dependency management and virtual environments.
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
# Install dev dependencies
|
|
15
|
+
uv sync --extra dev
|
|
16
|
+
|
|
17
|
+
# Run tests
|
|
18
|
+
uv run pytest
|
|
19
|
+
|
|
20
|
+
# Run a single test file
|
|
21
|
+
uv run pytest tests/test_foo.py
|
|
22
|
+
|
|
23
|
+
# Run a single test
|
|
24
|
+
uv run pytest tests/test_foo.py::test_name
|
|
25
|
+
|
|
26
|
+
# Lint
|
|
27
|
+
uv run ruff check src/
|
|
28
|
+
|
|
29
|
+
# Format
|
|
30
|
+
uv run ruff format src/
|
|
31
|
+
|
|
32
|
+
# Type check
|
|
33
|
+
uvx ty check --error-on-warning
|
|
34
|
+
|
|
35
|
+
# Build docs
|
|
36
|
+
uv run sphinx-build docs/source docs/build
|
|
37
|
+
|
|
38
|
+
# Live-reload docs
|
|
39
|
+
uv run sphinx-autobuild docs/source docs/build
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Pre-commit hooks run ruff (lint + format), zizmor (GitHub Actions security), and `ty` (type checker) automatically on commit.
|
|
43
|
+
|
|
44
|
+
## Architecture
|
|
45
|
+
|
|
46
|
+
The library lives in `src/drinx/` with four files:
|
|
47
|
+
|
|
48
|
+
- **`attribute.py`**: Defines `field`, `static_field`, `private_field`, `static_private_field`. All are thin wrappers around `dataclasses.field` that inject `jax_static=True/False` into the field's metadata dict. `private_*` variants set `init=False`. The unified `field()` function accepts a `static: bool` parameter directly.
|
|
49
|
+
|
|
50
|
+
- **`transform.py`**: Defines the `@dataclass` decorator and `_register_jax_tree`. The decorator wraps `dataclasses.dataclass` (always `frozen=True`) then registers the class as a JAX pytree. Flatten/unflatten split fields by `jax_static` metadata: static fields → `aux` (not traced), dynamic fields → `leaves` (traced). A `_jax_tree_registered` guard prevents double-registration.
|
|
51
|
+
|
|
52
|
+
- **`base.py`**: Defines `DataClass`, a base class alternative to the `@dataclass` decorator. Uses `@dataclass_transform` for type checker support and `__init_subclass__` to automatically apply the `dataclass` transform to any subclass.
|
|
53
|
+
|
|
54
|
+
- **`__init__.py`**: Re-exports `dataclass`, `field`, `static_field`, `private_field`, `static_private_field`, `DataClass`.
|
|
55
|
+
|
|
56
|
+
### Key design decisions
|
|
57
|
+
|
|
58
|
+
- All drinx dataclasses are **always frozen** (`frozen=True` is hardcoded). This is required for correctness as JAX pytree nodes must be immutable.
|
|
59
|
+
- The `jax_static` metadata key is the internal marker used to distinguish static vs. dynamic fields.
|
|
60
|
+
- Two usage patterns: decorator (`@drinx.dataclass`) or inheritance (`class Foo(DataClass)`). Both produce identically registered pytrees.
|
|
61
|
+
- The library has a single runtime dependency: `jax>=0.9.0`.
|
drinx-1.0.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Yannik Mahlau
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
drinx-1.0.0/PKG-INFO
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: drinx
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Drinx: Dataclass Registry in JAX
|
|
5
|
+
Author-email: Yannik Mahlau <mahlau@tnt.uni-hannover.de>
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Python: >=3.11
|
|
8
|
+
Requires-Dist: jax>=0.9.0
|
|
9
|
+
Provides-Extra: dev
|
|
10
|
+
Requires-Dist: black>=24.10.0; extra == 'dev'
|
|
11
|
+
Requires-Dist: ipykernel>=6.29.5; extra == 'dev'
|
|
12
|
+
Requires-Dist: ipywidgets; extra == 'dev'
|
|
13
|
+
Requires-Dist: jupyterlab; extra == 'dev'
|
|
14
|
+
Requires-Dist: myst-nb>=1.3.0; extra == 'dev'
|
|
15
|
+
Requires-Dist: nbsphinx>=0.9; extra == 'dev'
|
|
16
|
+
Requires-Dist: pre-commit>=4.0.1; extra == 'dev'
|
|
17
|
+
Requires-Dist: pydoclint>=0.6.6; extra == 'dev'
|
|
18
|
+
Requires-Dist: pymdown-extensions>=10.12; extra == 'dev'
|
|
19
|
+
Requires-Dist: pytest-cov>=6.0.0; extra == 'dev'
|
|
20
|
+
Requires-Dist: pytest>=8.3.4; extra == 'dev'
|
|
21
|
+
Requires-Dist: ruff>=0.8.2; extra == 'dev'
|
|
22
|
+
Requires-Dist: sphinx-autobuild>=2025.8.25; extra == 'dev'
|
|
23
|
+
Requires-Dist: sphinx-autodoc-typehints>=3.2.0; extra == 'dev'
|
|
24
|
+
Requires-Dist: sphinx-book-theme>=1.1.4; extra == 'dev'
|
|
25
|
+
Requires-Dist: sphinx-copybutton>=0.5.2; extra == 'dev'
|
|
26
|
+
Requires-Dist: sphinx<9.0; extra == 'dev'
|
|
27
|
+
Requires-Dist: tox-uv>=1.16.1; extra == 'dev'
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
|
|
30
|
+

|
|
31
|
+
|
|
32
|
+
[](https://drinx.readthedocs.io/en/latest/)
|
|
33
|
+
[](https://pypi.org/project/drinx/)
|
|
34
|
+
[](https://codecov.io/gh/ymahlau/drinx)
|
|
35
|
+
[](https://github.com/ymahlau/drinx/actions/workflows/cicd.yml/badge.svg?branch=main)
|
|
36
|
+
|
|
37
|
+
# Drinx: Dataclass Registry in JAX 🥂
|
|
38
|
+
|
|
39
|
+
Often it is useful to have structures in a program containing a mixture of JAX arrays and non-JAX types (e.g. strings, ...).
|
|
40
|
+
But, this makes it difficult to pass these objects through JAX transformations.
|
|
41
|
+
Drinx solves this by allowing dataclass fields to be declared as static.
|
|
42
|
+
Moreover, drinx introduces numerous quality-of-life features when working with dataclasses in JAX.
|
|
43
|
+
|
|
44
|
+
## Installation
|
|
45
|
+
|
|
46
|
+
You can install drinx simply via
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
pip install drinx
|
|
50
|
+
```
|
|
51
|
+
If you want to use the GPU-acceleration from JAX, you can install afterwards:
|
|
52
|
+
```bash
|
|
53
|
+
pip install jax[cuda]
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## Quickstart
|
|
57
|
+
|
|
58
|
+
Below you can find some examples to get you quickly started with drinx.
|
|
59
|
+
But, beware, there are so much more features available, which are documented in detail in our [Documentation](https://drinx.readthedocs.io/en/latest/)
|
|
60
|
+
|
|
61
|
+
### Decorator style
|
|
62
|
+
|
|
63
|
+
Use `@drinx.dataclass` as a drop-in replacement for `@dataclasses.dataclass`.
|
|
64
|
+
The class is automatically frozen and registered as a JAX pytree:
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
import jax
|
|
68
|
+
import jax.numpy as jnp
|
|
69
|
+
import drinx
|
|
70
|
+
|
|
71
|
+
@drinx.dataclass
|
|
72
|
+
class Params:
|
|
73
|
+
weights: jax.Array
|
|
74
|
+
bias: jax.Array
|
|
75
|
+
|
|
76
|
+
params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))
|
|
77
|
+
|
|
78
|
+
# Works transparently with JAX transforms
|
|
79
|
+
doubled = jax.tree_util.tree_map(lambda x: x * 2, params)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
### Static fields
|
|
83
|
+
|
|
84
|
+
Fields that should not be traced by JAX (e.g. shapes, dtypes, hyperparameters)
|
|
85
|
+
are marked with `static_field` or `field(static=True)`. Changing a static
|
|
86
|
+
field triggers recompilation under `jit`:
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
@drinx.dataclass
|
|
90
|
+
class Model:
|
|
91
|
+
weights: jax.Array
|
|
92
|
+
hidden_size: int = drinx.static_field(default=128)
|
|
93
|
+
|
|
94
|
+
@jax.jit
|
|
95
|
+
def forward(model, x):
|
|
96
|
+
# hidden_size is a compile-time constant; weights are traced
|
|
97
|
+
return model.weights[:model.hidden_size] @ x
|
|
98
|
+
|
|
99
|
+
model = Model(weights=jnp.ones((128, 32)))
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
### Inheritance style
|
|
103
|
+
|
|
104
|
+
Subclass `DataClass` instead of using the decorator. The transform is applied
|
|
105
|
+
automatically — no `@dataclass` needed:
|
|
106
|
+
|
|
107
|
+
```python
|
|
108
|
+
class Model(drinx.DataClass):
|
|
109
|
+
weights: jax.Array
|
|
110
|
+
learning_rate: float = drinx.static_field(default=1e-3)
|
|
111
|
+
|
|
112
|
+
model = Model(weights=jnp.ones((10,)))
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
Dataclass options are forwarded via the class definition, or alternatively by using a combination of inheritance and decorator.
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
class Config(drinx.DataClass, kw_only=True, order=True):
|
|
119
|
+
hidden_size: int = drinx.static_field(default=128)
|
|
120
|
+
num_layers: int = drinx.static_field(default=4)
|
|
121
|
+
|
|
122
|
+
# This is the recommended way: Typechecker will recognize the kw_only argument correctly
|
|
123
|
+
@drinx.dataclass(kw_only=True, order=True)
|
|
124
|
+
class Config(drinx.DataClass):
|
|
125
|
+
hidden_size: int = drinx.static_field(default=128)
|
|
126
|
+
num_layers: int = drinx.static_field(default=4)
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
### Functional updates with `aset`
|
|
130
|
+
|
|
131
|
+
Because drinx dataclasses are frozen, fields cannot be mutated in place.
|
|
132
|
+
`aset` performs a functional update and returns a new instance. It supports
|
|
133
|
+
nested paths using `->` as a separator, integer indices `[n]`, and string
|
|
134
|
+
dictionary keys `['k']`.
|
|
135
|
+
Note that this function is only available when inheriting the `drinx.Dataclass`, but not from the decorator.
|
|
136
|
+
|
|
137
|
+
```python
|
|
138
|
+
class Inner(drinx.DataClass):
|
|
139
|
+
w: jax.Array
|
|
140
|
+
|
|
141
|
+
class Outer(drinx.DataClass):
|
|
142
|
+
inner: Inner
|
|
143
|
+
bias: jax.Array
|
|
144
|
+
|
|
145
|
+
outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((1,)))
|
|
146
|
+
|
|
147
|
+
# Update a top-level field
|
|
148
|
+
outer2 = outer.aset("bias", jnp.ones((1,)))
|
|
149
|
+
|
|
150
|
+
# Update a nested field
|
|
151
|
+
outer3 = outer.aset("inner->w", jnp.zeros((3,)))
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
### JAX transforms
|
|
155
|
+
|
|
156
|
+
Drinx dataclasses work with all JAX transforms out of the box:
|
|
157
|
+
|
|
158
|
+
```python
|
|
159
|
+
class State(drinx.DataClass):
|
|
160
|
+
x: jax.Array
|
|
161
|
+
step_size: float = drinx.static_field(default=0.1)
|
|
162
|
+
|
|
163
|
+
# jit
|
|
164
|
+
@jax.jit
|
|
165
|
+
def update(state):
|
|
166
|
+
# updated_copy is convenience wrapper for altering top-level attributes
|
|
167
|
+
return state.updated_copy(x=state.x - state.step_size)
|
|
168
|
+
|
|
169
|
+
def loss(state):
|
|
170
|
+
return jnp.sum(state.x ** 2)
|
|
171
|
+
|
|
172
|
+
grads = jax.grad(loss)(State(x=jnp.array([1.0, 2.0, 3.0])))
|
|
173
|
+
|
|
174
|
+
@jax.vmap
|
|
175
|
+
def scale(state):
|
|
176
|
+
return state.x * 2
|
|
177
|
+
|
|
178
|
+
batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
|
|
179
|
+
result = scale(batched) # shape (2, 2)
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
## Documentation
|
|
183
|
+
|
|
184
|
+
For more examples and a detailed documentation, check out the API [here](https://drinx.readthedocs.io/en/latest/).
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
## Citation
|
|
188
|
+
|
|
189
|
+
TODO: add citation once published
|
|
190
|
+
|
drinx-1.0.0/README.md
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+

|
|
2
|
+
|
|
3
|
+
[](https://drinx.readthedocs.io/en/latest/)
|
|
4
|
+
[](https://pypi.org/project/drinx/)
|
|
5
|
+
[](https://codecov.io/gh/ymahlau/drinx)
|
|
6
|
+
[](https://github.com/ymahlau/drinx/actions/workflows/cicd.yml/badge.svg?branch=main)
|
|
7
|
+
|
|
8
|
+
# Drinx: Dataclass Registry in JAX 🥂
|
|
9
|
+
|
|
10
|
+
Often it is useful to have structures in a program containing a mixture of JAX arrays and non-JAX types (e.g. strings, ...).
|
|
11
|
+
But, this makes it difficult to pass these objects through JAX transformations.
|
|
12
|
+
Drinx solves this by allowing dataclass fields to be declared as static.
|
|
13
|
+
Moreover, drinx introduces numerous quality-of-life features when working with dataclasses in JAX.
|
|
14
|
+
|
|
15
|
+
## Installation
|
|
16
|
+
|
|
17
|
+
You can install drinx simply via
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install drinx
|
|
21
|
+
```
|
|
22
|
+
If you want to use the GPU-acceleration from JAX, you can install afterwards:
|
|
23
|
+
```bash
|
|
24
|
+
pip install jax[cuda]
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Quickstart
|
|
28
|
+
|
|
29
|
+
Below you can find some examples to get you quickly started with drinx.
|
|
30
|
+
But, beware, there are so much more features available, which are documented in detail in our [Documentation](https://drinx.readthedocs.io/en/latest/)
|
|
31
|
+
|
|
32
|
+
### Decorator style
|
|
33
|
+
|
|
34
|
+
Use `@drinx.dataclass` as a drop-in replacement for `@dataclasses.dataclass`.
|
|
35
|
+
The class is automatically frozen and registered as a JAX pytree:
|
|
36
|
+
|
|
37
|
+
```python
|
|
38
|
+
import jax
|
|
39
|
+
import jax.numpy as jnp
|
|
40
|
+
import drinx
|
|
41
|
+
|
|
42
|
+
@drinx.dataclass
|
|
43
|
+
class Params:
|
|
44
|
+
weights: jax.Array
|
|
45
|
+
bias: jax.Array
|
|
46
|
+
|
|
47
|
+
params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))
|
|
48
|
+
|
|
49
|
+
# Works transparently with JAX transforms
|
|
50
|
+
doubled = jax.tree_util.tree_map(lambda x: x * 2, params)
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
### Static fields
|
|
54
|
+
|
|
55
|
+
Fields that should not be traced by JAX (e.g. shapes, dtypes, hyperparameters)
|
|
56
|
+
are marked with `static_field` or `field(static=True)`. Changing a static
|
|
57
|
+
field triggers recompilation under `jit`:
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
@drinx.dataclass
|
|
61
|
+
class Model:
|
|
62
|
+
weights: jax.Array
|
|
63
|
+
hidden_size: int = drinx.static_field(default=128)
|
|
64
|
+
|
|
65
|
+
@jax.jit
|
|
66
|
+
def forward(model, x):
|
|
67
|
+
# hidden_size is a compile-time constant; weights are traced
|
|
68
|
+
return model.weights[:model.hidden_size] @ x
|
|
69
|
+
|
|
70
|
+
model = Model(weights=jnp.ones((128, 32)))
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
### Inheritance style
|
|
74
|
+
|
|
75
|
+
Subclass `DataClass` instead of using the decorator. The transform is applied
|
|
76
|
+
automatically — no `@dataclass` needed:
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
class Model(drinx.DataClass):
|
|
80
|
+
weights: jax.Array
|
|
81
|
+
learning_rate: float = drinx.static_field(default=1e-3)
|
|
82
|
+
|
|
83
|
+
model = Model(weights=jnp.ones((10,)))
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Dataclass options are forwarded via the class definition, or alternatively by using a combination of inheritance and decorator.
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
class Config(drinx.DataClass, kw_only=True, order=True):
|
|
90
|
+
hidden_size: int = drinx.static_field(default=128)
|
|
91
|
+
num_layers: int = drinx.static_field(default=4)
|
|
92
|
+
|
|
93
|
+
# This is the recommended way: Typechecker will recognize the kw_only argument correctly
|
|
94
|
+
@drinx.dataclass(kw_only=True, order=True)
|
|
95
|
+
class Config(drinx.DataClass):
|
|
96
|
+
hidden_size: int = drinx.static_field(default=128)
|
|
97
|
+
num_layers: int = drinx.static_field(default=4)
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
### Functional updates with `aset`
|
|
101
|
+
|
|
102
|
+
Because drinx dataclasses are frozen, fields cannot be mutated in place.
|
|
103
|
+
`aset` performs a functional update and returns a new instance. It supports
|
|
104
|
+
nested paths using `->` as a separator, integer indices `[n]`, and string
|
|
105
|
+
dictionary keys `['k']`.
|
|
106
|
+
Note that this function is only available when inheriting the `drinx.Dataclass`, but not from the decorator.
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
class Inner(drinx.DataClass):
|
|
110
|
+
w: jax.Array
|
|
111
|
+
|
|
112
|
+
class Outer(drinx.DataClass):
|
|
113
|
+
inner: Inner
|
|
114
|
+
bias: jax.Array
|
|
115
|
+
|
|
116
|
+
outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((1,)))
|
|
117
|
+
|
|
118
|
+
# Update a top-level field
|
|
119
|
+
outer2 = outer.aset("bias", jnp.ones((1,)))
|
|
120
|
+
|
|
121
|
+
# Update a nested field
|
|
122
|
+
outer3 = outer.aset("inner->w", jnp.zeros((3,)))
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
### JAX transforms
|
|
126
|
+
|
|
127
|
+
Drinx dataclasses work with all JAX transforms out of the box:
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
class State(drinx.DataClass):
|
|
131
|
+
x: jax.Array
|
|
132
|
+
step_size: float = drinx.static_field(default=0.1)
|
|
133
|
+
|
|
134
|
+
# jit
|
|
135
|
+
@jax.jit
|
|
136
|
+
def update(state):
|
|
137
|
+
# updated_copy is convenience wrapper for altering top-level attributes
|
|
138
|
+
return state.updated_copy(x=state.x - state.step_size)
|
|
139
|
+
|
|
140
|
+
def loss(state):
|
|
141
|
+
return jnp.sum(state.x ** 2)
|
|
142
|
+
|
|
143
|
+
grads = jax.grad(loss)(State(x=jnp.array([1.0, 2.0, 3.0])))
|
|
144
|
+
|
|
145
|
+
@jax.vmap
|
|
146
|
+
def scale(state):
|
|
147
|
+
return state.x * 2
|
|
148
|
+
|
|
149
|
+
batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
|
|
150
|
+
result = scale(batched) # shape (2, 2)
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
## Documentation
|
|
154
|
+
|
|
155
|
+
For more examples and a detailed documentation, check out the API [here](https://drinx.readthedocs.io/en/latest/).
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
## Citation
|
|
159
|
+
|
|
160
|
+
TODO: add citation once published
|
|
161
|
+
|
drinx-1.0.0/codecov.yml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
coverage:
|
|
2
|
+
ignore:
|
|
3
|
+
- "tests/"
|
|
4
|
+
- "**/test_*.py"
|
|
5
|
+
status:
|
|
6
|
+
patch:
|
|
7
|
+
default:
|
|
8
|
+
informational: true
|
|
9
|
+
target: 90%
|
|
10
|
+
threshold: 0%
|
|
11
|
+
project:
|
|
12
|
+
default:
|
|
13
|
+
informational: true
|
|
14
|
+
target: 45% # Fail if overall coverage drops below
|
|
15
|
+
threshold: 2% # Allow decrease from previous coverage
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
black>=24.10.0
|
|
2
|
+
ipykernel>=6.29.5
|
|
3
|
+
pre-commit>=4.0.1
|
|
4
|
+
pymdown-extensions>=10.12
|
|
5
|
+
pytest>=8.3.4
|
|
6
|
+
pytest-cov>=6.0.0
|
|
7
|
+
ruff>=0.8.2
|
|
8
|
+
tox-uv>=1.16.1
|
|
9
|
+
pydoclint>=0.6.6
|
|
10
|
+
sphinx>=9.0.0
|
|
11
|
+
sphinx-autobuild>=2025.8.25
|
|
12
|
+
sphinx-book-theme>=1.1.4
|
|
13
|
+
nbsphinx>=0.9
|
|
14
|
+
sphinx-autodoc-typehints>=3.2.0
|
|
15
|
+
myst_nb>=1.3.0
|
|
16
|
+
sphinx-copybutton>=0.5.2
|
|
17
|
+
|
|
18
|
+
jax>=0.9.0
|
|
19
|
+
jupyterlab
|
|
20
|
+
ipywidgets
|
|
Binary file
|
|
Binary file
|
|
Binary file
|