drinx 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,30 @@
1
+ name: "Setup Python Environment"
2
+ description: "Set up Python environment for the given Python version"
3
+
4
+ inputs:
5
+ python-version:
6
+ description: "Python version to use"
7
+ required: true
8
+ default: "3.12"
9
+ uv-version:
10
+ description: "uv version to use"
11
+ required: true
12
+ default: "0.5.8"
13
+
14
+ runs:
15
+ using: "composite"
16
+ steps:
17
+ - uses: actions/setup-python@v5
18
+ with:
19
+ python-version: ${{ inputs.python-version }}
20
+
21
+ - name: Install uv
22
+ uses: astral-sh/setup-uv@v2
23
+ with:
24
+ version: ${{ inputs.uv-version }}
25
+ enable-cache: "true"
26
+ cache-suffix: ${{ matrix.python-version }}
27
+
28
+ - name: Install Python dependencies
29
+ run: uv sync --extra dev
30
+ shell: bash
@@ -0,0 +1,72 @@
1
+ name: CI-CD
2
+
3
+ on:
4
+ pull_request:
5
+ types: [opened, synchronize, reopened]
6
+ push:
7
+ branches: [main]
8
+ workflow_dispatch:
9
+
10
+ permissions:
11
+ contents: read
12
+
13
+ jobs:
14
+ checks:
15
+ runs-on: ubuntu-latest
16
+ steps:
17
+ - name: Check out
18
+ uses: actions/checkout@v4
19
+ with:
20
+ persist-credentials: false
21
+
22
+ - uses: actions/cache@v4
23
+ with:
24
+ path: ~/.cache/pre-commit
25
+ key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
26
+
27
+ - name: Set up the environment
28
+ uses: ./.github/actions/setup-python-env
29
+
30
+ - name: Run pre-commit
31
+ run: uv run pre-commit run -a --show-diff-on-failure
32
+
33
+ - name: build docs
34
+ run: uv run sphinx-build -W --keep-going docs/source/ docs/build/
35
+
36
+ tests:
37
+ runs-on: ubuntu-latest
38
+ permissions:
39
+ contents: read
40
+ id-token: write
41
+ strategy:
42
+ matrix:
43
+ python-version: ["3.11"]
44
+ fail-fast: false
45
+ defaults:
46
+ run:
47
+ shell: bash
48
+ steps:
49
+ - name: Check out
50
+ uses: actions/checkout@v4
51
+ with:
52
+ persist-credentials: false
53
+
54
+ - name: Set up the environment
55
+ uses: ./.github/actions/setup-python-env
56
+ with:
57
+ python-version: ${{ matrix.python-version }}
58
+
59
+ - name: Run tests
60
+ run: uv run python -m pytest tests --cov --cov-branch --cov-config=pyproject.toml --cov-report=xml
61
+
62
+ - name: Typecheck with ty
63
+ run: uvx ty check --error-on-warning
64
+
65
+ - name: Upload coverage reports to Codecov
66
+ uses: codecov/codecov-action@v4
67
+ if: ${{ matrix.python-version == '3.11' }}
68
+ with:
69
+ token: ${{ secrets.CODECOV_TOKEN }}
70
+ files: ./coverage.xml
71
+ fail_ci_if_error: true
72
+ slug: ymahlau/drinx
@@ -0,0 +1,42 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ release:
5
+ types: [published]
6
+ workflow_dispatch: # Allows manual triggering from Actions tab
7
+
8
+ permissions:
9
+ contents: read
10
+
11
+ jobs:
12
+ pypi-publish:
13
+ name: Build and publish Python 🐍 distributions to PyPI
14
+ runs-on: ubuntu-latest
15
+ # This permission is REQUIRED for Trusted Publishing
16
+ permissions:
17
+ id-token: write
18
+ contents: read
19
+
20
+ steps:
21
+ - name: Checkout code
22
+ uses: actions/checkout@v4
23
+ with:
24
+ persist-credentials: false
25
+
26
+ - name: Set up Python
27
+ uses: actions/setup-python@v5
28
+ with:
29
+ python-version: "3.12"
30
+
31
+ - name: Install build tools
32
+ run: pip install build
33
+
34
+ - name: Build package
35
+ run: python -m build
36
+
37
+ - name: Publish distribution 📦 to PyPI
38
+ uses: pypa/gh-action-pypi-publish@release/v1
39
+ with:
40
+ # We don't need a password/token input because we use
41
+ # Trusted Publishing (id-token permission above).
42
+ print-hash: true
drinx-1.0.0/.gitignore ADDED
@@ -0,0 +1,9 @@
1
+ uv.lock
2
+ .venv
3
+ outputs/
4
+ *.pyc
5
+ *.npz
6
+ docs/build
7
+ docs/jupyter_execute
8
+ docs/source/api
9
+ scripts_dev/
@@ -0,0 +1,34 @@
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v5.0.0
6
+ hooks:
7
+ - id: check-yaml
8
+ - id: check-added-large-files
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-merge-conflict
12
+ args: [--assume-in-merge]
13
+ - repo: https://github.com/astral-sh/ruff-pre-commit
14
+ # Ruff version.
15
+ rev: v0.11.2
16
+ hooks:
17
+ # Run the linter.
18
+ - id: ruff
19
+ args:
20
+ - --fix
21
+ # Run the formatter.
22
+ - id: ruff-format
23
+ - repo: https://github.com/woodruffw/zizmor-pre-commit
24
+ rev: v1.5.2
25
+ hooks:
26
+ - id: zizmor
27
+ - repo: local
28
+ hooks:
29
+ - id: ty
30
+ name: ty (via uvx)
31
+ entry: uvx ty check --error-on-warning
32
+ language: system
33
+ types: [python]
34
+ pass_filenames: false
@@ -0,0 +1,14 @@
1
+ version: 2
2
+ formats: [htmlzip]
3
+
4
+ sphinx:
5
+ configuration: docs/source/conf.py
6
+
7
+ build:
8
+ os: "ubuntu-22.04"
9
+ tools:
10
+ python: "3.12"
11
+ jobs:
12
+ install:
13
+ - pip install -r docs/requirements.txt
14
+ - pip install --no-deps .
drinx-1.0.0/CLAUDE.md ADDED
@@ -0,0 +1,61 @@
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ Drinx is a small Python library ("Dataclass Registry in JAX") that wraps Python's standard `dataclasses` module to make dataclasses compatible with JAX transformations (e.g., `jit`, `vmap`, `grad`). It does this by registering each decorated class as a JAX pytree node, with support for marking fields as "static" (excluded from JAX tracing) or "dynamic" (included as leaves).
8
+
9
+ ## Commands
10
+
11
+ This project uses `uv` for dependency management and virtual environments.
12
+
13
+ ```bash
14
+ # Install dev dependencies
15
+ uv sync --extra dev
16
+
17
+ # Run tests
18
+ uv run pytest
19
+
20
+ # Run a single test file
21
+ uv run pytest tests/test_foo.py
22
+
23
+ # Run a single test
24
+ uv run pytest tests/test_foo.py::test_name
25
+
26
+ # Lint
27
+ uv run ruff check src/
28
+
29
+ # Format
30
+ uv run ruff format src/
31
+
32
+ # Type check
33
+ uvx ty check --error-on-warning
34
+
35
+ # Build docs
36
+ uv run sphinx-build docs/source docs/build
37
+
38
+ # Live-reload docs
39
+ uv run sphinx-autobuild docs/source docs/build
40
+ ```
41
+
42
+ Pre-commit hooks run ruff (lint + format), zizmor (GitHub Actions security), and `ty` (type checker) automatically on commit.
43
+
44
+ ## Architecture
45
+
46
+ The library lives in `src/drinx/` with four files:
47
+
48
+ - **`attribute.py`**: Defines `field`, `static_field`, `private_field`, `static_private_field`. All are thin wrappers around `dataclasses.field` that inject `jax_static=True/False` into the field's metadata dict. `private_*` variants set `init=False`. The unified `field()` function accepts a `static: bool` parameter directly.
49
+
50
+ - **`transform.py`**: Defines the `@dataclass` decorator and `_register_jax_tree`. The decorator wraps `dataclasses.dataclass` (always `frozen=True`) then registers the class as a JAX pytree. Flatten/unflatten split fields by `jax_static` metadata: static fields → `aux` (not traced), dynamic fields → `leaves` (traced). A `_jax_tree_registered` guard prevents double-registration.
51
+
52
+ - **`base.py`**: Defines `DataClass`, a base class alternative to the `@dataclass` decorator. Uses `@dataclass_transform` for type checker support and `__init_subclass__` to automatically apply the `dataclass` transform to any subclass.
53
+
54
+ - **`__init__.py`**: Re-exports `dataclass`, `field`, `static_field`, `private_field`, `static_private_field`, `DataClass`.
55
+
56
+ ### Key design decisions
57
+
58
+ - All drinx dataclasses are **always frozen** (`frozen=True` is hardcoded). This is required for correctness as JAX pytree nodes must be immutable.
59
+ - The `jax_static` metadata key is the internal marker used to distinguish static vs. dynamic fields.
60
+ - Two usage patterns: decorator (`@drinx.dataclass`) or inheritance (`class Foo(DataClass)`). Both produce identically registered pytrees.
61
+ - The library has a single runtime dependency: `jax>=0.9.0`.
drinx-1.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Yannik Mahlau
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
drinx-1.0.0/PKG-INFO ADDED
@@ -0,0 +1,190 @@
1
+ Metadata-Version: 2.4
2
+ Name: drinx
3
+ Version: 1.0.0
4
+ Summary: Drinx: Dataclass Registry in JAX
5
+ Author-email: Yannik Mahlau <mahlau@tnt.uni-hannover.de>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.11
8
+ Requires-Dist: jax>=0.9.0
9
+ Provides-Extra: dev
10
+ Requires-Dist: black>=24.10.0; extra == 'dev'
11
+ Requires-Dist: ipykernel>=6.29.5; extra == 'dev'
12
+ Requires-Dist: ipywidgets; extra == 'dev'
13
+ Requires-Dist: jupyterlab; extra == 'dev'
14
+ Requires-Dist: myst-nb>=1.3.0; extra == 'dev'
15
+ Requires-Dist: nbsphinx>=0.9; extra == 'dev'
16
+ Requires-Dist: pre-commit>=4.0.1; extra == 'dev'
17
+ Requires-Dist: pydoclint>=0.6.6; extra == 'dev'
18
+ Requires-Dist: pymdown-extensions>=10.12; extra == 'dev'
19
+ Requires-Dist: pytest-cov>=6.0.0; extra == 'dev'
20
+ Requires-Dist: pytest>=8.3.4; extra == 'dev'
21
+ Requires-Dist: ruff>=0.8.2; extra == 'dev'
22
+ Requires-Dist: sphinx-autobuild>=2025.8.25; extra == 'dev'
23
+ Requires-Dist: sphinx-autodoc-typehints>=3.2.0; extra == 'dev'
24
+ Requires-Dist: sphinx-book-theme>=1.1.4; extra == 'dev'
25
+ Requires-Dist: sphinx-copybutton>=0.5.2; extra == 'dev'
26
+ Requires-Dist: sphinx<9.0; extra == 'dev'
27
+ Requires-Dist: tox-uv>=1.16.1; extra == 'dev'
28
+ Description-Content-Type: text/markdown
29
+
30
+ ![title image](https://github.com/ymahlau/drinx/blob/main/docs/source/_static/drinx.png?raw=true)
31
+
32
+ [![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://drinx.readthedocs.io/en/latest/)
33
+ [![PyPI version](https://img.shields.io/pypi/v/drinx)](https://pypi.org/project/drinx/)
34
+ [![codecov](https://codecov.io/gh/ymahlau/drinx/branch/main/graph/badge.svg)](https://codecov.io/gh/ymahlau/drinx)
35
+ [![Continuous integration](https://github.com/ymahlau/drinx/actions/workflows/cicd.yml/badge.svg?branch=main)](https://github.com/ymahlau/drinx/actions/workflows/cicd.yml/badge.svg?branch=main)
36
+
37
+ # Drinx: Dataclass Registry in JAX 🥂
38
+
39
+ Often it is useful to have structures in a program containing a mixture of JAX arrays and non-JAX types (e.g. strings, ...).
40
+ But, this makes it difficult to pass these objects through JAX transformations.
41
+ Drinx solves this by allowing dataclass fields to be declared as static.
42
+ Moreover, drinx introduces numerous quality-of-life features when working with dataclasses in JAX.
43
+
44
+ ## Installation
45
+
46
+ You can install drinx simply via
47
+
48
+ ```bash
49
+ pip install drinx
50
+ ```
51
+ If you want to use the GPU-acceleration from JAX, you can install afterwards:
52
+ ```bash
53
+ pip install jax[cuda]
54
+ ```
55
+
56
+ ## Quickstart
57
+
58
+ Below you can find some examples to get you quickly started with drinx.
59
+ But, beware, there are so much more features available, which are documented in detail in our [Documentation](https://drinx.readthedocs.io/en/latest/)
60
+
61
+ ### Decorator style
62
+
63
+ Use `@drinx.dataclass` as a drop-in replacement for `@dataclasses.dataclass`.
64
+ The class is automatically frozen and registered as a JAX pytree:
65
+
66
+ ```python
67
+ import jax
68
+ import jax.numpy as jnp
69
+ import drinx
70
+
71
+ @drinx.dataclass
72
+ class Params:
73
+ weights: jax.Array
74
+ bias: jax.Array
75
+
76
+ params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))
77
+
78
+ # Works transparently with JAX transforms
79
+ doubled = jax.tree_util.tree_map(lambda x: x * 2, params)
80
+ ```
81
+
82
+ ### Static fields
83
+
84
+ Fields that should not be traced by JAX (e.g. shapes, dtypes, hyperparameters)
85
+ are marked with `static_field` or `field(static=True)`. Changing a static
86
+ field triggers recompilation under `jit`:
87
+
88
+ ```python
89
+ @drinx.dataclass
90
+ class Model:
91
+ weights: jax.Array
92
+ hidden_size: int = drinx.static_field(default=128)
93
+
94
+ @jax.jit
95
+ def forward(model, x):
96
+ # hidden_size is a compile-time constant; weights are traced
97
+ return model.weights[:model.hidden_size] @ x
98
+
99
+ model = Model(weights=jnp.ones((128, 32)))
100
+ ```
101
+
102
+ ### Inheritance style
103
+
104
+ Subclass `DataClass` instead of using the decorator. The transform is applied
105
+ automatically — no `@dataclass` needed:
106
+
107
+ ```python
108
+ class Model(drinx.DataClass):
109
+ weights: jax.Array
110
+ learning_rate: float = drinx.static_field(default=1e-3)
111
+
112
+ model = Model(weights=jnp.ones((10,)))
113
+ ```
114
+
115
+ Dataclass options are forwarded via the class definition, or alternatively by using a combination of inheritance and decorator.
116
+
117
+ ```python
118
+ class Config(drinx.DataClass, kw_only=True, order=True):
119
+ hidden_size: int = drinx.static_field(default=128)
120
+ num_layers: int = drinx.static_field(default=4)
121
+
122
+ # This is the recommended way: Typechecker will recognize the kw_only argument correctly
123
+ @drinx.dataclass(kw_only=True, order=True)
124
+ class Config(drinx.DataClass):
125
+ hidden_size: int = drinx.static_field(default=128)
126
+ num_layers: int = drinx.static_field(default=4)
127
+ ```
128
+
129
+ ### Functional updates with `aset`
130
+
131
+ Because drinx dataclasses are frozen, fields cannot be mutated in place.
132
+ `aset` performs a functional update and returns a new instance. It supports
133
+ nested paths using `->` as a separator, integer indices `[n]`, and string
134
+ dictionary keys `['k']`.
135
+ Note that this function is only available when inheriting the `drinx.Dataclass`, but not from the decorator.
136
+
137
+ ```python
138
+ class Inner(drinx.DataClass):
139
+ w: jax.Array
140
+
141
+ class Outer(drinx.DataClass):
142
+ inner: Inner
143
+ bias: jax.Array
144
+
145
+ outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((1,)))
146
+
147
+ # Update a top-level field
148
+ outer2 = outer.aset("bias", jnp.ones((1,)))
149
+
150
+ # Update a nested field
151
+ outer3 = outer.aset("inner->w", jnp.zeros((3,)))
152
+ ```
153
+
154
+ ### JAX transforms
155
+
156
+ Drinx dataclasses work with all JAX transforms out of the box:
157
+
158
+ ```python
159
+ class State(drinx.DataClass):
160
+ x: jax.Array
161
+ step_size: float = drinx.static_field(default=0.1)
162
+
163
+ # jit
164
+ @jax.jit
165
+ def update(state):
166
+ # updated_copy is convenience wrapper for altering top-level attributes
167
+ return state.updated_copy(x=state.x - state.step_size)
168
+
169
+ def loss(state):
170
+ return jnp.sum(state.x ** 2)
171
+
172
+ grads = jax.grad(loss)(State(x=jnp.array([1.0, 2.0, 3.0])))
173
+
174
+ @jax.vmap
175
+ def scale(state):
176
+ return state.x * 2
177
+
178
+ batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
179
+ result = scale(batched) # shape (2, 2)
180
+ ```
181
+
182
+ ## Documentation
183
+
184
+ For more examples and a detailed documentation, check out the API [here](https://drinx.readthedocs.io/en/latest/).
185
+
186
+
187
+ ## Citation
188
+
189
+ TODO: add citation once published
190
+
drinx-1.0.0/README.md ADDED
@@ -0,0 +1,161 @@
1
+ ![title image](https://github.com/ymahlau/drinx/blob/main/docs/source/_static/drinx.png?raw=true)
2
+
3
+ [![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://drinx.readthedocs.io/en/latest/)
4
+ [![PyPI version](https://img.shields.io/pypi/v/drinx)](https://pypi.org/project/drinx/)
5
+ [![codecov](https://codecov.io/gh/ymahlau/drinx/branch/main/graph/badge.svg)](https://codecov.io/gh/ymahlau/drinx)
6
+ [![Continuous integration](https://github.com/ymahlau/drinx/actions/workflows/cicd.yml/badge.svg?branch=main)](https://github.com/ymahlau/drinx/actions/workflows/cicd.yml/badge.svg?branch=main)
7
+
8
+ # Drinx: Dataclass Registry in JAX 🥂
9
+
10
+ Often it is useful to have structures in a program containing a mixture of JAX arrays and non-JAX types (e.g. strings, ...).
11
+ But, this makes it difficult to pass these objects through JAX transformations.
12
+ Drinx solves this by allowing dataclass fields to be declared as static.
13
+ Moreover, drinx introduces numerous quality-of-life features when working with dataclasses in JAX.
14
+
15
+ ## Installation
16
+
17
+ You can install drinx simply via
18
+
19
+ ```bash
20
+ pip install drinx
21
+ ```
22
+ If you want to use the GPU-acceleration from JAX, you can install afterwards:
23
+ ```bash
24
+ pip install jax[cuda]
25
+ ```
26
+
27
+ ## Quickstart
28
+
29
+ Below you can find some examples to get you quickly started with drinx.
30
+ But, beware, there are so much more features available, which are documented in detail in our [Documentation](https://drinx.readthedocs.io/en/latest/)
31
+
32
+ ### Decorator style
33
+
34
+ Use `@drinx.dataclass` as a drop-in replacement for `@dataclasses.dataclass`.
35
+ The class is automatically frozen and registered as a JAX pytree:
36
+
37
+ ```python
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import drinx
41
+
42
+ @drinx.dataclass
43
+ class Params:
44
+ weights: jax.Array
45
+ bias: jax.Array
46
+
47
+ params = Params(weights=jnp.ones((3,)), bias=jnp.zeros((3,)))
48
+
49
+ # Works transparently with JAX transforms
50
+ doubled = jax.tree_util.tree_map(lambda x: x * 2, params)
51
+ ```
52
+
53
+ ### Static fields
54
+
55
+ Fields that should not be traced by JAX (e.g. shapes, dtypes, hyperparameters)
56
+ are marked with `static_field` or `field(static=True)`. Changing a static
57
+ field triggers recompilation under `jit`:
58
+
59
+ ```python
60
+ @drinx.dataclass
61
+ class Model:
62
+ weights: jax.Array
63
+ hidden_size: int = drinx.static_field(default=128)
64
+
65
+ @jax.jit
66
+ def forward(model, x):
67
+ # hidden_size is a compile-time constant; weights are traced
68
+ return model.weights[:model.hidden_size] @ x
69
+
70
+ model = Model(weights=jnp.ones((128, 32)))
71
+ ```
72
+
73
+ ### Inheritance style
74
+
75
+ Subclass `DataClass` instead of using the decorator. The transform is applied
76
+ automatically — no `@dataclass` needed:
77
+
78
+ ```python
79
+ class Model(drinx.DataClass):
80
+ weights: jax.Array
81
+ learning_rate: float = drinx.static_field(default=1e-3)
82
+
83
+ model = Model(weights=jnp.ones((10,)))
84
+ ```
85
+
86
+ Dataclass options are forwarded via the class definition, or alternatively by using a combination of inheritance and decorator.
87
+
88
+ ```python
89
+ class Config(drinx.DataClass, kw_only=True, order=True):
90
+ hidden_size: int = drinx.static_field(default=128)
91
+ num_layers: int = drinx.static_field(default=4)
92
+
93
+ # This is the recommended way: Typechecker will recognize the kw_only argument correctly
94
+ @drinx.dataclass(kw_only=True, order=True)
95
+ class Config(drinx.DataClass):
96
+ hidden_size: int = drinx.static_field(default=128)
97
+ num_layers: int = drinx.static_field(default=4)
98
+ ```
99
+
100
+ ### Functional updates with `aset`
101
+
102
+ Because drinx dataclasses are frozen, fields cannot be mutated in place.
103
+ `aset` performs a functional update and returns a new instance. It supports
104
+ nested paths using `->` as a separator, integer indices `[n]`, and string
105
+ dictionary keys `['k']`.
106
+ Note that this function is only available when inheriting the `drinx.Dataclass`, but not from the decorator.
107
+
108
+ ```python
109
+ class Inner(drinx.DataClass):
110
+ w: jax.Array
111
+
112
+ class Outer(drinx.DataClass):
113
+ inner: Inner
114
+ bias: jax.Array
115
+
116
+ outer = Outer(inner=Inner(w=jnp.ones((3,))), bias=jnp.zeros((1,)))
117
+
118
+ # Update a top-level field
119
+ outer2 = outer.aset("bias", jnp.ones((1,)))
120
+
121
+ # Update a nested field
122
+ outer3 = outer.aset("inner->w", jnp.zeros((3,)))
123
+ ```
124
+
125
+ ### JAX transforms
126
+
127
+ Drinx dataclasses work with all JAX transforms out of the box:
128
+
129
+ ```python
130
+ class State(drinx.DataClass):
131
+ x: jax.Array
132
+ step_size: float = drinx.static_field(default=0.1)
133
+
134
+ # jit
135
+ @jax.jit
136
+ def update(state):
137
+ # updated_copy is convenience wrapper for altering top-level attributes
138
+ return state.updated_copy(x=state.x - state.step_size)
139
+
140
+ def loss(state):
141
+ return jnp.sum(state.x ** 2)
142
+
143
+ grads = jax.grad(loss)(State(x=jnp.array([1.0, 2.0, 3.0])))
144
+
145
+ @jax.vmap
146
+ def scale(state):
147
+ return state.x * 2
148
+
149
+ batched = State(x=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
150
+ result = scale(batched) # shape (2, 2)
151
+ ```
152
+
153
+ ## Documentation
154
+
155
+ For more examples and a detailed documentation, check out the API [here](https://drinx.readthedocs.io/en/latest/).
156
+
157
+
158
+ ## Citation
159
+
160
+ TODO: add citation once published
161
+
@@ -0,0 +1,15 @@
1
+ coverage:
2
+ ignore:
3
+ - "tests/"
4
+ - "**/test_*.py"
5
+ status:
6
+ patch:
7
+ default:
8
+ informational: true
9
+ target: 90%
10
+ threshold: 0%
11
+ project:
12
+ default:
13
+ informational: true
14
+ target: 45% # Fail if overall coverage drops below
15
+ threshold: 2% # Allow decrease from previous coverage
@@ -0,0 +1,20 @@
1
+ black>=24.10.0
2
+ ipykernel>=6.29.5
3
+ pre-commit>=4.0.1
4
+ pymdown-extensions>=10.12
5
+ pytest>=8.3.4
6
+ pytest-cov>=6.0.0
7
+ ruff>=0.8.2
8
+ tox-uv>=1.16.1
9
+ pydoclint>=0.6.6
10
+ sphinx>=9.0.0
11
+ sphinx-autobuild>=2025.8.25
12
+ sphinx-book-theme>=1.1.4
13
+ nbsphinx>=0.9
14
+ sphinx-autodoc-typehints>=3.2.0
15
+ myst_nb>=1.3.0
16
+ sphinx-copybutton>=0.5.2
17
+
18
+ jax>=0.9.0
19
+ jupyterlab
20
+ ipywidgets