asdex 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asdex-0.1.0/.claude/skills/add-handler/SKILL.md +109 -0
- asdex-0.1.0/.github/workflows/benchmarks.yml +33 -0
- asdex-0.1.0/.github/workflows/ci.yml +49 -0
- asdex-0.1.0/.github/workflows/docs.yml +27 -0
- asdex-0.1.0/.github/workflows/release.yml +33 -0
- asdex-0.1.0/.gitignore +20 -0
- asdex-0.1.0/.pre-commit-config.yaml +19 -0
- asdex-0.1.0/CLAUDE.md +103 -0
- asdex-0.1.0/LICENSE +21 -0
- asdex-0.1.0/PKG-INFO +112 -0
- asdex-0.1.0/README.md +84 -0
- asdex-0.1.0/TODO.md +73 -0
- asdex-0.1.0/docs/CLAUDE.md +130 -0
- asdex-0.1.0/docs/explanation/asd.md +152 -0
- asdex-0.1.0/docs/explanation/coloring.md +91 -0
- asdex-0.1.0/docs/explanation/global-sparsity.md +88 -0
- asdex-0.1.0/docs/explanation/sparsity-detection.md +219 -0
- asdex-0.1.0/docs/how-to/brusselator.md +111 -0
- asdex-0.1.0/docs/how-to/hessians.md +258 -0
- asdex-0.1.0/docs/how-to/jacobians.md +240 -0
- asdex-0.1.0/docs/index.md +70 -0
- asdex-0.1.0/docs/javascripts/mathjax.js +16 -0
- asdex-0.1.0/docs/reference/coloring.md +7 -0
- asdex-0.1.0/docs/reference/data-structures.md +4 -0
- asdex-0.1.0/docs/reference/hessian.md +12 -0
- asdex-0.1.0/docs/reference/index.md +34 -0
- asdex-0.1.0/docs/reference/jacobian.md +13 -0
- asdex-0.1.0/docs/reference/sparsity.md +4 -0
- asdex-0.1.0/docs/stylesheets/extra.css +16 -0
- asdex-0.1.0/docs/tutorials/getting-started.md +161 -0
- asdex-0.1.0/mkdocs.yml +94 -0
- asdex-0.1.0/pyproject.toml +137 -0
- asdex-0.1.0/src/asdex/__init__.py +48 -0
- asdex-0.1.0/src/asdex/_display.py +242 -0
- asdex-0.1.0/src/asdex/_interpret/CLAUDE.md +133 -0
- asdex-0.1.0/src/asdex/_interpret/__init__.py +356 -0
- asdex-0.1.0/src/asdex/_interpret/_broadcast.py +80 -0
- asdex-0.1.0/src/asdex/_interpret/_commons.py +343 -0
- asdex-0.1.0/src/asdex/_interpret/_concatenate.py +51 -0
- asdex-0.1.0/src/asdex/_interpret/_cond.py +62 -0
- asdex-0.1.0/src/asdex/_interpret/_conv.py +160 -0
- asdex-0.1.0/src/asdex/_interpret/_dot_general.py +150 -0
- asdex-0.1.0/src/asdex/_interpret/_dynamic_slice.py +124 -0
- asdex-0.1.0/src/asdex/_interpret/_elementwise.py +218 -0
- asdex-0.1.0/src/asdex/_interpret/_equinox/__init__.py +1 -0
- asdex-0.1.0/src/asdex/_interpret/_equinox/_select_if_vmap.py +47 -0
- asdex-0.1.0/src/asdex/_interpret/_gather.py +319 -0
- asdex-0.1.0/src/asdex/_interpret/_mul.py +66 -0
- asdex-0.1.0/src/asdex/_interpret/_pad.py +94 -0
- asdex-0.1.0/src/asdex/_interpret/_platform_index.py +45 -0
- asdex-0.1.0/src/asdex/_interpret/_reduce.py +72 -0
- asdex-0.1.0/src/asdex/_interpret/_reshape.py +73 -0
- asdex-0.1.0/src/asdex/_interpret/_rev.py +35 -0
- asdex-0.1.0/src/asdex/_interpret/_scan.py +112 -0
- asdex-0.1.0/src/asdex/_interpret/_scatter.py +261 -0
- asdex-0.1.0/src/asdex/_interpret/_select.py +55 -0
- asdex-0.1.0/src/asdex/_interpret/_slice.py +51 -0
- asdex-0.1.0/src/asdex/_interpret/_sort.py +69 -0
- asdex-0.1.0/src/asdex/_interpret/_split.py +44 -0
- asdex-0.1.0/src/asdex/_interpret/_squeeze.py +28 -0
- asdex-0.1.0/src/asdex/_interpret/_tile.py +49 -0
- asdex-0.1.0/src/asdex/_interpret/_top_k.py +60 -0
- asdex-0.1.0/src/asdex/_interpret/_transpose.py +46 -0
- asdex-0.1.0/src/asdex/_interpret/_while.py +69 -0
- asdex-0.1.0/src/asdex/coloring.py +603 -0
- asdex-0.1.0/src/asdex/decompression.py +315 -0
- asdex-0.1.0/src/asdex/detection.py +109 -0
- asdex-0.1.0/src/asdex/modes.py +42 -0
- asdex-0.1.0/src/asdex/pattern.py +409 -0
- asdex-0.1.0/src/asdex/verify.py +353 -0
- asdex-0.1.0/tests/CLAUDE.md +66 -0
- asdex-0.1.0/tests/__init__.py +1 -0
- asdex-0.1.0/tests/_interpret/__init__.py +0 -0
- asdex-0.1.0/tests/_interpret/_equinox/__init__.py +0 -0
- asdex-0.1.0/tests/_interpret/_equinox/test_select_if_vmap.py +86 -0
- asdex-0.1.0/tests/_interpret/test_associative_scan.py +165 -0
- asdex-0.1.0/tests/_interpret/test_broadcast.py +93 -0
- asdex-0.1.0/tests/_interpret/test_concatenate.py +120 -0
- asdex-0.1.0/tests/_interpret/test_cond.py +150 -0
- asdex-0.1.0/tests/_interpret/test_conv.py +504 -0
- asdex-0.1.0/tests/_interpret/test_dot_general.py +645 -0
- asdex-0.1.0/tests/_interpret/test_dynamic_slice.py +174 -0
- asdex-0.1.0/tests/_interpret/test_elementwise.py +131 -0
- asdex-0.1.0/tests/_interpret/test_gather.py +808 -0
- asdex-0.1.0/tests/_interpret/test_internals.py +330 -0
- asdex-0.1.0/tests/_interpret/test_nested_jaxpr.py +83 -0
- asdex-0.1.0/tests/_interpret/test_pad.py +405 -0
- asdex-0.1.0/tests/_interpret/test_platform_index.py +88 -0
- asdex-0.1.0/tests/_interpret/test_reduce.py +264 -0
- asdex-0.1.0/tests/_interpret/test_reduce_and.py +123 -0
- asdex-0.1.0/tests/_interpret/test_reshape.py +412 -0
- asdex-0.1.0/tests/_interpret/test_rev.py +274 -0
- asdex-0.1.0/tests/_interpret/test_scan.py +639 -0
- asdex-0.1.0/tests/_interpret/test_scatter.py +705 -0
- asdex-0.1.0/tests/_interpret/test_select.py +143 -0
- asdex-0.1.0/tests/_interpret/test_slice.py +87 -0
- asdex-0.1.0/tests/_interpret/test_sort.py +220 -0
- asdex-0.1.0/tests/_interpret/test_split.py +187 -0
- asdex-0.1.0/tests/_interpret/test_squeeze.py +23 -0
- asdex-0.1.0/tests/_interpret/test_tile.py +173 -0
- asdex-0.1.0/tests/_interpret/test_top_k.py +163 -0
- asdex-0.1.0/tests/_interpret/test_transpose.py +374 -0
- asdex-0.1.0/tests/_interpret/test_while.py +182 -0
- asdex-0.1.0/tests/conftest.py +35 -0
- asdex-0.1.0/tests/test_benchmarks.py +184 -0
- asdex-0.1.0/tests/test_coloring.py +1208 -0
- asdex-0.1.0/tests/test_decompression.py +755 -0
- asdex-0.1.0/tests/test_detection.py +799 -0
- asdex-0.1.0/tests/test_diffrax.py +85 -0
- asdex-0.1.0/tests/test_flax.py +383 -0
- asdex-0.1.0/tests/test_modes.py +81 -0
- asdex-0.1.0/tests/test_multidim.py +282 -0
- asdex-0.1.0/tests/test_pattern.py +398 -0
- asdex-0.1.0/tests/test_scalar.py +297 -0
- asdex-0.1.0/tests/test_sympy.py +591 -0
- asdex-0.1.0/tests/test_verify.py +395 -0
- asdex-0.1.0/tests/test_vmap.py +69 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: add-handler
|
|
3
|
+
description: Add a precise primitive handler to the jaxpr interpreter, replacing a conservative fallback.
|
|
4
|
+
argument-hint: "[primitive-name]"
|
|
5
|
+
disable-model-invocation: true
|
|
6
|
+
---
|
|
7
|
+
|
|
8
|
+
# Add precise handler for `$ARGUMENTS`
|
|
9
|
+
|
|
10
|
+
Add a precise propagation handler for the `$ARGUMENTS` primitive,
|
|
11
|
+
replacing the conservative fallback.
|
|
12
|
+
|
|
13
|
+
## Workflow
|
|
14
|
+
|
|
15
|
+
### 1. Research
|
|
16
|
+
|
|
17
|
+
Do all research in this step using extended **planning mode**.
|
|
18
|
+
|
|
19
|
+
Before writing code:
|
|
20
|
+
- Read the JAX docs for the primitive: fetch `https://docs.jax.dev/en/latest/_autosummary/jax.lax.$ARGUMENTS.html`
|
|
21
|
+
- Read `src/asdex/_interpret/CLAUDE.md` for conventions (docstring style, semantic line breaks, handler structure)
|
|
22
|
+
- Read `src/asdex/_interpret/_commons.py` to understand available utilities
|
|
23
|
+
- Read an existing handler with similar structure (e.g. `_pad.py`, `_transpose.py`, `_reduction.py`) as a reference
|
|
24
|
+
- Read the existing test in `tests/_interpret/test_internals.py` for the primitive (search for `$ARGUMENTS`)
|
|
25
|
+
- Read `src/asdex/_interpret/__init__.py` to see the current dispatch and fallback setup
|
|
26
|
+
|
|
27
|
+
Understand the primitive's semantics:
|
|
28
|
+
how do input and output element indices map to each other?
|
|
29
|
+
What is the Jacobian structure (permutation, selection, block-diagonal, etc.)?
|
|
30
|
+
|
|
31
|
+
### 2. Implement handler
|
|
32
|
+
|
|
33
|
+
- Create `src/asdex/_interpret/_$ARGUMENTS.py` with `prop_$ARGUMENTS(eqn, deps)`.
|
|
34
|
+
- Follow the handler docstring style from `_interpret/CLAUDE.md`.
|
|
35
|
+
- If the primitive preserves or predictably transforms input values,
|
|
36
|
+
propagate `const_vals` so downstream handlers can stay precise.
|
|
37
|
+
|
|
38
|
+
### 3. Wire up dispatch
|
|
39
|
+
|
|
40
|
+
In `src/asdex/_interpret/__init__.py`:
|
|
41
|
+
- Import the new handler
|
|
42
|
+
- Add a `case "$ARGUMENTS":` branch in `prop_dispatch` calling the handler
|
|
43
|
+
- Remove `"$ARGUMENTS"` from the conservative fallback `case` group
|
|
44
|
+
|
|
45
|
+
### 4. Update tests
|
|
46
|
+
|
|
47
|
+
In `tests/_interpret/test_internals.py`:
|
|
48
|
+
- Update the existing test: change expected values from dense (`np.ones`) to the precise pattern
|
|
49
|
+
- Remove the `@pytest.mark.fallback` marker and `TODO` comments
|
|
50
|
+
|
|
51
|
+
Create `tests/_interpret/test_$ARGUMENTS.py` with thorough tests:
|
|
52
|
+
- Multiple dimensionalities (1D, 2D, 3D, 4D where applicable)
|
|
53
|
+
- Broadcasting shapes: size-1 dimensions that broadcast (e.g. `(3,4)` op `(3,1)`)
|
|
54
|
+
- Non-square shapes (e.g. `(3,4)` not `(4,4)`) so dimension mix-ups are caught
|
|
55
|
+
- Edge cases (size-0 dimensions, identity/trivial parameters)
|
|
56
|
+
- Real-world usage patterns (e.g. `jnp` functions that lower to this primitive)
|
|
57
|
+
- For at least one test per dimensionality, verify precision by comparing the detected
|
|
58
|
+
pattern against `(np.abs(jax.jacobian(f)(x)) > 1e-10)` using `assert_array_equal`.
|
|
59
|
+
Choose test functions that avoid local sparsity (e.g. multiply by zero) so the
|
|
60
|
+
numerical Jacobian matches the structural pattern.
|
|
61
|
+
|
|
62
|
+
### 5. Verify
|
|
63
|
+
|
|
64
|
+
Run in order:
|
|
65
|
+
```bash
|
|
66
|
+
uv run ruff check src/asdex/_interpret/_$ARGUMENTS.py
|
|
67
|
+
uv run pytest tests/_interpret/test_$ARGUMENTS.py -v
|
|
68
|
+
uv run pytest tests/_interpret/test_internals.py -v
|
|
69
|
+
uv run pytest tests/ -x
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
### 6. Adversarial tests
|
|
73
|
+
|
|
74
|
+
Reread the JAX docs for the primitive: fetch `https://docs.jax.dev/en/latest/_autosummary/jax.lax.$ARGUMENTS.html`
|
|
75
|
+
|
|
76
|
+
Try to break the implementation by testing inputs the handler might not expect:
|
|
77
|
+
|
|
78
|
+
- **Dimensionality**: 1D, 2D, 3D, and higher — if any are missing, add them.
|
|
79
|
+
- **Asymmetric shapes**: always use non-square shapes (e.g. `(3,4)` not `(4,4)`) so that dimension transposition bugs are caught.
|
|
80
|
+
- **Degenerate shapes**: size-0 dimensions, size-1 dimensions, scalar inputs (where the primitive supports them).
|
|
81
|
+
- **Boundary parameters**: empty parameter lists, all-dimensions, single-dimension, negative indices (if applicable).
|
|
82
|
+
- **Compositions**: the primitive chained with itself (e.g. double-reverse, transpose-of-transpose) or with related ops.
|
|
83
|
+
- **Non-contiguous patterns**: inputs where dependencies are not simply `{i}` per element (e.g. from a prior broadcast or reduction) to verify `.copy()` and set merging behave correctly.
|
|
84
|
+
- **Conservative audit**: for each test case, verify the result is strictly sparser than what `conservative_indices()` would produce. If the handler silently falls back on any shape the primitive supports, investigate.
|
|
85
|
+
If the fallback cannot be fixed immediately, you **must** add a `@pytest.mark.fallback` test with a `TODO(primitive)` comment showing the precise expected pattern.
|
|
86
|
+
Catching conservative patterns is extremely valuable for future development.
|
|
87
|
+
- **Const chain**: if the primitive can appear between a literal and a downstream consumer (e.g. type conversions, reshapes, broadcasts on index arrays), write a test composing it with a downstream gather and verify the gather resolves precisely.
|
|
88
|
+
|
|
89
|
+
For each new test, verify the expected output by hand or against `jax.jacobian`.
|
|
90
|
+
Update and re-verify the handler if any test reveals a bug.
|
|
91
|
+
|
|
92
|
+
### 7. Simplify
|
|
93
|
+
|
|
94
|
+
Review the implementation with fresh eyes and look for opportunities to reduce complexity:
|
|
95
|
+
|
|
96
|
+
- **Vectorize loops**: can per-element Python loops be replaced with numpy operations?
|
|
97
|
+
Pattern: build a flat permutation or index array with `np.arange`, `np.flip`, `np.transpose`, `np.indices`, or `np.ravel_multi_index`,
|
|
98
|
+
then index into `in_indices` in a single list comprehension.
|
|
99
|
+
See `_rev.py`, `_reshape.py`, `_concatenate.py`, and `_broadcast.py` for examples.
|
|
100
|
+
- **Remove unused imports**: after vectorizing, utilities like `flat_to_coords`, `row_strides`, and `numel` may no longer be needed.
|
|
101
|
+
- **Eliminate intermediate variables**: if a value is computed and used only once, inline it.
|
|
102
|
+
- **Simplify special cases**: can a special-case branch be absorbed into the general case?
|
|
103
|
+
|
|
104
|
+
After any change, re-run verification (step 6).
|
|
105
|
+
|
|
106
|
+
### 8. Update docs
|
|
107
|
+
|
|
108
|
+
- `TODO.md`: check off the primitive and its test items
|
|
109
|
+
- `src/asdex/_interpret/CLAUDE.md`: add the new module to the file listing
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
name: Benchmarks
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
Benchmarks:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
permissions:
|
|
12
|
+
contents: write
|
|
13
|
+
deployments: write
|
|
14
|
+
steps:
|
|
15
|
+
- uses: actions/checkout@v4
|
|
16
|
+
- name: Install uv
|
|
17
|
+
uses: astral-sh/setup-uv@v5
|
|
18
|
+
- name: Set up Python
|
|
19
|
+
run: uv python install 3.12
|
|
20
|
+
- name: Install dependencies
|
|
21
|
+
run: uv sync --group dev
|
|
22
|
+
- name: Run benchmarks
|
|
23
|
+
run: uv run pytest tests/test_benchmarks.py -m dashboard --benchmark-json=benchmark.json
|
|
24
|
+
- name: Store benchmark result
|
|
25
|
+
uses: benchmark-action/github-action-benchmark@v1
|
|
26
|
+
with:
|
|
27
|
+
tool: 'pytest'
|
|
28
|
+
output-file-path: benchmark.json
|
|
29
|
+
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
30
|
+
auto-push: true
|
|
31
|
+
alert-threshold: '150%'
|
|
32
|
+
comment-on-alert: true
|
|
33
|
+
fail-on-alert: false
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
Tests:
|
|
11
|
+
runs-on: ${{ matrix.os }}
|
|
12
|
+
strategy:
|
|
13
|
+
fail-fast: false
|
|
14
|
+
matrix:
|
|
15
|
+
os: [ubuntu-latest]
|
|
16
|
+
python-version: ["3.11", "3.14"]
|
|
17
|
+
steps:
|
|
18
|
+
- uses: actions/checkout@v4
|
|
19
|
+
- name: Install uv
|
|
20
|
+
uses: astral-sh/setup-uv@v5
|
|
21
|
+
- name: Set up Python
|
|
22
|
+
run: uv python install ${{ matrix.python-version }}
|
|
23
|
+
- name: Install dependencies
|
|
24
|
+
run: uv sync --group dev
|
|
25
|
+
- name: Test with coverage
|
|
26
|
+
run: uv run pytest -m "" --cov=src/asdex --cov-report=xml # -m "" overrides default marker filter to include slow and benchmark tests
|
|
27
|
+
- name: Upload coverage to Codecov
|
|
28
|
+
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.14'
|
|
29
|
+
uses: codecov/codecov-action@v5
|
|
30
|
+
with:
|
|
31
|
+
files: coverage.xml
|
|
32
|
+
token: ${{ secrets.CODECOV_TOKEN }}
|
|
33
|
+
|
|
34
|
+
Linting:
|
|
35
|
+
runs-on: ubuntu-latest
|
|
36
|
+
steps:
|
|
37
|
+
- uses: actions/checkout@v4
|
|
38
|
+
- name: Install uv
|
|
39
|
+
uses: astral-sh/setup-uv@v5
|
|
40
|
+
- name: Set up Python
|
|
41
|
+
run: uv python install 3.14
|
|
42
|
+
- name: Install dependencies
|
|
43
|
+
run: uv sync --group dev
|
|
44
|
+
- name: Lint
|
|
45
|
+
run: uv run ruff check .
|
|
46
|
+
- name: Format
|
|
47
|
+
run: uv run ruff format --check .
|
|
48
|
+
- name: Type check
|
|
49
|
+
run: uv run ty check
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
name: Documentation
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
docs:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
permissions:
|
|
13
|
+
contents: write
|
|
14
|
+
steps:
|
|
15
|
+
- uses: actions/checkout@v4
|
|
16
|
+
- uses: astral-sh/setup-uv@v5
|
|
17
|
+
- run: uv python install 3.12
|
|
18
|
+
- run: uv sync --group docs
|
|
19
|
+
- run: uv run mkdocs build --strict
|
|
20
|
+
|
|
21
|
+
- name: Deploy to GitHub Pages
|
|
22
|
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
|
23
|
+
uses: peaceiris/actions-gh-pages@v4
|
|
24
|
+
with:
|
|
25
|
+
github_token: ${{ secrets.GITHUB_TOKEN }}
|
|
26
|
+
publish_dir: ./site
|
|
27
|
+
keep_files: true
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
name: Release
|
|
2
|
+
# Based on https://github.com/astral-sh/trusted-publishing-examples/blob/main/.github/workflows/release.yml
|
|
3
|
+
|
|
4
|
+
on:
|
|
5
|
+
push:
|
|
6
|
+
tags:
|
|
7
|
+
# Publish on any tag starting with a `v`, e.g., v0.1.0
|
|
8
|
+
- v*
|
|
9
|
+
|
|
10
|
+
jobs:
|
|
11
|
+
run:
|
|
12
|
+
name: Publish to PyPI
|
|
13
|
+
runs-on: ubuntu-latest
|
|
14
|
+
environment:
|
|
15
|
+
name: pypi
|
|
16
|
+
permissions:
|
|
17
|
+
id-token: write
|
|
18
|
+
contents: read
|
|
19
|
+
steps:
|
|
20
|
+
- name: Checkout
|
|
21
|
+
uses: actions/checkout@v6
|
|
22
|
+
- name: Install uv
|
|
23
|
+
uses: astral-sh/setup-uv@v7
|
|
24
|
+
- name: Install Python 3.14
|
|
25
|
+
run: uv python install 3.14
|
|
26
|
+
- name: Build
|
|
27
|
+
run: uv build
|
|
28
|
+
- name: Install dependencies
|
|
29
|
+
run: uv sync --group dev
|
|
30
|
+
- name: Test
|
|
31
|
+
run: uv run pytest -m ""
|
|
32
|
+
- name: Publish
|
|
33
|
+
run: uv publish
|
asdex-0.1.0/.gitignore
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Python-generated files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[oc]
|
|
4
|
+
build/
|
|
5
|
+
dist/
|
|
6
|
+
wheels/
|
|
7
|
+
*.egg-info
|
|
8
|
+
uv.lock
|
|
9
|
+
.python-version
|
|
10
|
+
.benchmarks
|
|
11
|
+
.pytest_cache
|
|
12
|
+
.ruff_cache
|
|
13
|
+
|
|
14
|
+
# MkDocs build output
|
|
15
|
+
site/
|
|
16
|
+
|
|
17
|
+
# Virtual environments
|
|
18
|
+
.venv
|
|
19
|
+
.claude/settings.local.json
|
|
20
|
+
_context
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
repos:
|
|
2
|
+
- repo: local
|
|
3
|
+
hooks:
|
|
4
|
+
- id: ruff
|
|
5
|
+
name: ruff
|
|
6
|
+
entry: uv run ruff check --fix
|
|
7
|
+
language: system
|
|
8
|
+
types: [python]
|
|
9
|
+
- id: ruff-format
|
|
10
|
+
name: ruff-format
|
|
11
|
+
entry: uv run ruff format
|
|
12
|
+
language: system
|
|
13
|
+
types: [python]
|
|
14
|
+
- id: ty
|
|
15
|
+
name: ty
|
|
16
|
+
entry: uv run ty check
|
|
17
|
+
language: system
|
|
18
|
+
pass_filenames: false
|
|
19
|
+
types: [python]
|
asdex-0.1.0/CLAUDE.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# asdex - Automatic Sparse Differentiation in JAX
|
|
2
|
+
|
|
3
|
+
This package implements [Automatic Sparse Differentiation](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/) (ASD) in JAX.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
ASD exploits sparsity to reduce the cost of computing sparse Jacobians and Hessians:
|
|
8
|
+
|
|
9
|
+
1. **Detection**: Analyze the jaxpr computation graph to detect the global sparsity pattern
|
|
10
|
+
2. **Coloring**: Assign colors to rows so that rows sharing non-zero columns get different colors
|
|
11
|
+
3. **Decompression**: Compute one VJP/HVP per color instead of one per row, then extract the sparse matrix
|
|
12
|
+
|
|
13
|
+
## Structure
|
|
14
|
+
|
|
15
|
+
```
|
|
16
|
+
src/asdex/
|
|
17
|
+
├── __init__.py # Public API
|
|
18
|
+
├── pattern.py # SparsityPattern and ColoredPattern data structures
|
|
19
|
+
├── detection.py # Jacobian and Hessian sparsity detection via jaxpr analysis
|
|
20
|
+
├── coloring.py # Graph coloring (row, column, symmetric) and convenience functions
|
|
21
|
+
├── decompression.py # Sparse Jacobian (VJP/JVP) and Hessian (HVP) computation
|
|
22
|
+
├── verify.py # Correctness checks (check_jacobian_correctness, check_hessian_correctness)
|
|
23
|
+
├── _display.py # Display/formatting utilities
|
|
24
|
+
└── _interpret/ # Custom jaxpr interpreter for index set propagation
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
The interpreter internals are described in `src/asdex/_interpret/CLAUDE.md`.
|
|
28
|
+
The structure of the test folder is described in `tests/CLAUDE.md`.
|
|
29
|
+
|
|
30
|
+
## Development
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
uv run ruff check --fix . # lint + auto-fix
|
|
34
|
+
uv run ruff format . # format
|
|
35
|
+
uv run ty check # type check
|
|
36
|
+
uv run pytest # run tests
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Code style
|
|
40
|
+
|
|
41
|
+
- Favor `match` statements over long if-else chains.
|
|
42
|
+
Use explicit cases and default to `case _ as unreachable: assert_never(unreachable)`.
|
|
43
|
+
- Use plain `# Section name` comments for section separators,
|
|
44
|
+
not banner-style `# -- Section name ---`.
|
|
45
|
+
|
|
46
|
+
## Architecture
|
|
47
|
+
|
|
48
|
+
### Jacobians
|
|
49
|
+
|
|
50
|
+
```
|
|
51
|
+
jacobian(f, input_shape)(x) # one-call API
|
|
52
|
+
jacobian_from_coloring(f, coloring)(x) # from pre-computed coloring
|
|
53
|
+
│
|
|
54
|
+
├─ 1. DETECTION
|
|
55
|
+
│ jacobian_sparsity(f, input_shape)
|
|
56
|
+
│ ├─ make_jaxpr(f) → jaxpr
|
|
57
|
+
│ ├─ prop_jaxpr() → index sets
|
|
58
|
+
│ └─ SparsityPattern
|
|
59
|
+
│
|
|
60
|
+
├─ 2. COLORING
|
|
61
|
+
│ jacobian_coloring_from_sparsity(sparsity)
|
|
62
|
+
│
|
|
63
|
+
└─ 3. DECOMPRESSION
|
|
64
|
+
One VJP or JVP per color
|
|
65
|
+
|
|
66
|
+
Precompute: jacobian_coloring(f, shape) = detect + color
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
### Hessians
|
|
70
|
+
|
|
71
|
+
```
|
|
72
|
+
hessian(f, input_shape)(x) # one-call API
|
|
73
|
+
hessian_from_coloring(f, coloring)(x) # from pre-computed coloring
|
|
74
|
+
│
|
|
75
|
+
├─ 1. DETECTION
|
|
76
|
+
│ hessian_sparsity(f, input_shape)
|
|
77
|
+
│ └─ jacobian_sparsity(grad(f), input_shape)
|
|
78
|
+
│
|
|
79
|
+
├─ 2. COLORING
|
|
80
|
+
│ hessian_coloring_from_sparsity(sparsity)
|
|
81
|
+
│
|
|
82
|
+
└─ 3. DECOMPRESSION
|
|
83
|
+
One HVP per color (fwd-over-rev)
|
|
84
|
+
|
|
85
|
+
Precompute: hessian_coloring(f, shape) = detect + color_symmetric
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
## Commits
|
|
89
|
+
|
|
90
|
+
Use [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) for all commit messages (e.g. `feat:`, `fix:`, `docs:`, `refactor:`, `test:`).
|
|
91
|
+
For breaking changes, add `!` after the type (e.g. `feat!:`).
|
|
92
|
+
|
|
93
|
+
## Design philosophy
|
|
94
|
+
|
|
95
|
+
When writing new code, adhere to these design principles:
|
|
96
|
+
|
|
97
|
+
- **Minimize complexity**: The primary goal of software design is to minimize complexity—anything that makes a system hard to understand and modify.
|
|
98
|
+
|
|
99
|
+
- **Information hiding**: Each module should encapsulate design decisions that other modules don't need to know about, preventing information leakage across boundaries.
|
|
100
|
+
|
|
101
|
+
- **Pull complexity downward**: It's better for a module to be internally complex if it keeps the interface simple for others. Don't expose complexity to callers.
|
|
102
|
+
|
|
103
|
+
- **Favor exceptions over wrong results**: Raise errors for unknown edge cases rather than guessing.
|
asdex-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Adrian Hill <gh@adrianhill.de>
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
asdex-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: asdex
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Automatic Sparse Differentiation in JAX
|
|
5
|
+
Project-URL: Homepage, https://github.com/adrhill/asdex
|
|
6
|
+
Project-URL: Documentation, https://adrianhill.de/asdex
|
|
7
|
+
Project-URL: Benchmarks, https://adrianhill.de/asdex/benchmarks
|
|
8
|
+
Project-URL: Repository, https://github.com/adrhill/asdex
|
|
9
|
+
Project-URL: Issues, https://github.com/adrhill/asdex/issues
|
|
10
|
+
Author-email: Adrian Hill <gh@adrianhill.de>
|
|
11
|
+
License-Expression: MIT
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Keywords: automatic-differentiation,automatic-sparse-differentiation,graph-coloring,hessian,jacobian,jax,jaxpr,scientific-computing,sparse-automatic-differentiation,sparsity,sparsity-detection
|
|
14
|
+
Classifier: Development Status :: 3 - Alpha
|
|
15
|
+
Classifier: Intended Audience :: Science/Research
|
|
16
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
17
|
+
Classifier: Programming Language :: Python :: 3
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
22
|
+
Classifier: Topic :: Scientific/Engineering
|
|
23
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
24
|
+
Requires-Python: >=3.11
|
|
25
|
+
Requires-Dist: jax>=0.9.0
|
|
26
|
+
Requires-Dist: numpy>=1.26.0
|
|
27
|
+
Description-Content-Type: text/markdown
|
|
28
|
+
|
|
29
|
+
# asdex
|
|
30
|
+
|
|
31
|
+
[](https://github.com/adrhill/asdex/actions/workflows/ci.yml)
|
|
32
|
+
[](https://codecov.io/gh/adrhill/asdex)
|
|
33
|
+
[](https://adrianhill.de/asdex/)
|
|
34
|
+
[](https://adrianhill.de/asdex/dev/bench/)
|
|
35
|
+
|
|
36
|
+
[Automatic Sparse Differentiation](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/) in JAX.
|
|
37
|
+
|
|
38
|
+
`asdex` (pronounced _Aztecs_) exploits sparsity structure to efficiently compute sparse Jacobians and Hessians.
|
|
39
|
+
It implements a custom [Jaxpr](https://docs.jax.dev/en/latest/jaxpr.html) interpreter
|
|
40
|
+
that uses [abstract interpretation](https://en.wikipedia.org/wiki/Abstract_interpretation)
|
|
41
|
+
to detect sparsity patterns from the computation graph,
|
|
42
|
+
then uses graph coloring to minimize the number of AD passes needed.
|
|
43
|
+
|
|
44
|
+
> [!WARNING]
|
|
45
|
+
> `asdex` is in early development.
|
|
46
|
+
> The API may change without notice.
|
|
47
|
+
> Use at your own risk.
|
|
48
|
+
|
|
49
|
+
## Installation
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
pip install git+https://github.com/adrhill/asdex.git
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
Or with [uv](https://docs.astral.sh/uv/):
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
uv add git+https://github.com/adrhill/asdex.git
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Example
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
import numpy as np
|
|
65
|
+
from asdex import jacobian
|
|
66
|
+
|
|
67
|
+
def f(x):
|
|
68
|
+
return (x[1:] - x[:-1]) ** 2
|
|
69
|
+
|
|
70
|
+
jac_fn = jacobian(f, input_shape=50)
|
|
71
|
+
# ColoredPattern(49×50, nnz=98, sparsity=96.0%, JVP, 2 colors)
|
|
72
|
+
# 2 JVPs (instead of 49 VJPs or 50 JVPs)
|
|
73
|
+
# ⎡⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎡⣿⎤
|
|
74
|
+
# ⎢⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
75
|
+
# ⎢⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
76
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
77
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
78
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
79
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⎥
|
|
80
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
81
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
82
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
83
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⎥ ⎢⣿⎥
|
|
84
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⎥ ⎢⣿⎥
|
|
85
|
+
# ⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⎦ ⎣⠉⎦
|
|
86
|
+
|
|
87
|
+
for x in inputs:
|
|
88
|
+
J = jac_fn(x)
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
Instead of 49 VJPs or 50 JVPs,
|
|
92
|
+
`asdex` computes the full sparse Jacobian with just 2 JVPs.
|
|
93
|
+
|
|
94
|
+
## Documentation
|
|
95
|
+
|
|
96
|
+
- [Getting Started](https://adrianhill.de/asdex/tutorials/getting-started/) — step-by-step tutorial
|
|
97
|
+
- [How-To Guides](https://adrianhill.de/asdex/how-to/jacobians/) — task-oriented recipes
|
|
98
|
+
- [Explanation](https://adrianhill.de/asdex/explanation/sparsity-detection/) — how and why it works
|
|
99
|
+
- [API Reference](https://adrianhill.de/asdex/reference/) — full API documentation
|
|
100
|
+
|
|
101
|
+
## Acknowledgements
|
|
102
|
+
|
|
103
|
+
This package is built with Claude Code based on previous work by [Adrian Hill](https://github.com/adrhill), [Guillaume Dalle](https://github.com/gdalle), and [Alexis Montoison](https://github.com/amontoison) in the [Julia programming language](https://julialang.org):
|
|
104
|
+
|
|
105
|
+
- [_An Illustrated Guide to Automatic Sparse Differentiation_](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/), A. Hill, G. Dalle, A. Montoison (2025)
|
|
106
|
+
- [_Sparser, Better, Faster, Stronger: Efficient Automatic Differentiation for Sparse Jacobians and Hessians_](https://openreview.net/forum?id=GtXSN52nIW), A. Hill & G. Dalle (2025)
|
|
107
|
+
- [_Revisiting Sparse Matrix Coloring and Bicoloring_](https://arxiv.org/abs/2505.07308), A. Montoison, G. Dalle, A. Gebremedhin (2025)
|
|
108
|
+
- [_SparseConnectivityTracer.jl_](https://github.com/adrhill/SparseConnectivityTracer.jl), A. Hill, G. Dalle
|
|
109
|
+
- [_SparseMatrixColorings.jl_](https://github.com/gdalle/SparseMatrixColorings.jl), G. Dalle, A. Montoison
|
|
110
|
+
- [_sparsediffax_](https://github.com/gdalle/sparsediffax), G. Dalle
|
|
111
|
+
|
|
112
|
+
which in turn stands on the shoulders of giants — notably Andreas Griewank, Andrea Walther, and Assefaw Gebremedhin.
|
asdex-0.1.0/README.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# asdex
|
|
2
|
+
|
|
3
|
+
[](https://github.com/adrhill/asdex/actions/workflows/ci.yml)
|
|
4
|
+
[](https://codecov.io/gh/adrhill/asdex)
|
|
5
|
+
[](https://adrianhill.de/asdex/)
|
|
6
|
+
[](https://adrianhill.de/asdex/dev/bench/)
|
|
7
|
+
|
|
8
|
+
[Automatic Sparse Differentiation](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/) in JAX.
|
|
9
|
+
|
|
10
|
+
`asdex` (pronounced _Aztecs_) exploits sparsity structure to efficiently compute sparse Jacobians and Hessians.
|
|
11
|
+
It implements a custom [Jaxpr](https://docs.jax.dev/en/latest/jaxpr.html) interpreter
|
|
12
|
+
that uses [abstract interpretation](https://en.wikipedia.org/wiki/Abstract_interpretation)
|
|
13
|
+
to detect sparsity patterns from the computation graph,
|
|
14
|
+
then uses graph coloring to minimize the number of AD passes needed.
|
|
15
|
+
|
|
16
|
+
> [!WARNING]
|
|
17
|
+
> `asdex` is in early development.
|
|
18
|
+
> The API may change without notice.
|
|
19
|
+
> Use at your own risk.
|
|
20
|
+
|
|
21
|
+
## Installation
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
pip install git+https://github.com/adrhill/asdex.git
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Or with [uv](https://docs.astral.sh/uv/):
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
uv add git+https://github.com/adrhill/asdex.git
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Example
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
import numpy as np
|
|
37
|
+
from asdex import jacobian
|
|
38
|
+
|
|
39
|
+
def f(x):
|
|
40
|
+
return (x[1:] - x[:-1]) ** 2
|
|
41
|
+
|
|
42
|
+
jac_fn = jacobian(f, input_shape=50)
|
|
43
|
+
# ColoredPattern(49×50, nnz=98, sparsity=96.0%, JVP, 2 colors)
|
|
44
|
+
# 2 JVPs (instead of 49 VJPs or 50 JVPs)
|
|
45
|
+
# ⎡⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎡⣿⎤
|
|
46
|
+
# ⎢⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
47
|
+
# ⎢⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
48
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
49
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
50
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
51
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⎥
|
|
52
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
53
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
54
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⎥ ⎢⣿⎥
|
|
55
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⎥ ⎢⣿⎥
|
|
56
|
+
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⎥ ⎢⣿⎥
|
|
57
|
+
# ⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⎦ ⎣⠉⎦
|
|
58
|
+
|
|
59
|
+
for x in inputs:
|
|
60
|
+
J = jac_fn(x)
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
Instead of 49 VJPs or 50 JVPs,
|
|
64
|
+
`asdex` computes the full sparse Jacobian with just 2 JVPs.
|
|
65
|
+
|
|
66
|
+
## Documentation
|
|
67
|
+
|
|
68
|
+
- [Getting Started](https://adrianhill.de/asdex/tutorials/getting-started/) — step-by-step tutorial
|
|
69
|
+
- [How-To Guides](https://adrianhill.de/asdex/how-to/jacobians/) — task-oriented recipes
|
|
70
|
+
- [Explanation](https://adrianhill.de/asdex/explanation/sparsity-detection/) — how and why it works
|
|
71
|
+
- [API Reference](https://adrianhill.de/asdex/reference/) — full API documentation
|
|
72
|
+
|
|
73
|
+
## Acknowledgements
|
|
74
|
+
|
|
75
|
+
This package is built with Claude Code based on previous work by [Adrian Hill](https://github.com/adrhill), [Guillaume Dalle](https://github.com/gdalle), and [Alexis Montoison](https://github.com/amontoison) in the [Julia programming language](https://julialang.org):
|
|
76
|
+
|
|
77
|
+
- [_An Illustrated Guide to Automatic Sparse Differentiation_](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/), A. Hill, G. Dalle, A. Montoison (2025)
|
|
78
|
+
- [_Sparser, Better, Faster, Stronger: Efficient Automatic Differentiation for Sparse Jacobians and Hessians_](https://openreview.net/forum?id=GtXSN52nIW), A. Hill & G. Dalle (2025)
|
|
79
|
+
- [_Revisiting Sparse Matrix Coloring and Bicoloring_](https://arxiv.org/abs/2505.07308), A. Montoison, G. Dalle, A. Gebremedhin (2025)
|
|
80
|
+
- [_SparseConnectivityTracer.jl_](https://github.com/adrhill/SparseConnectivityTracer.jl), A. Hill, G. Dalle
|
|
81
|
+
- [_SparseMatrixColorings.jl_](https://github.com/gdalle/SparseMatrixColorings.jl), G. Dalle, A. Montoison
|
|
82
|
+
- [_sparsediffax_](https://github.com/gdalle/sparsediffax), G. Dalle
|
|
83
|
+
|
|
84
|
+
which in turn stands on the shoulders of giants — notably Andreas Griewank, Andrea Walther, and Assefaw Gebremedhin.
|