asdex 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. asdex-0.1.0/.claude/skills/add-handler/SKILL.md +109 -0
  2. asdex-0.1.0/.github/workflows/benchmarks.yml +33 -0
  3. asdex-0.1.0/.github/workflows/ci.yml +49 -0
  4. asdex-0.1.0/.github/workflows/docs.yml +27 -0
  5. asdex-0.1.0/.github/workflows/release.yml +33 -0
  6. asdex-0.1.0/.gitignore +20 -0
  7. asdex-0.1.0/.pre-commit-config.yaml +19 -0
  8. asdex-0.1.0/CLAUDE.md +103 -0
  9. asdex-0.1.0/LICENSE +21 -0
  10. asdex-0.1.0/PKG-INFO +112 -0
  11. asdex-0.1.0/README.md +84 -0
  12. asdex-0.1.0/TODO.md +73 -0
  13. asdex-0.1.0/docs/CLAUDE.md +130 -0
  14. asdex-0.1.0/docs/explanation/asd.md +152 -0
  15. asdex-0.1.0/docs/explanation/coloring.md +91 -0
  16. asdex-0.1.0/docs/explanation/global-sparsity.md +88 -0
  17. asdex-0.1.0/docs/explanation/sparsity-detection.md +219 -0
  18. asdex-0.1.0/docs/how-to/brusselator.md +111 -0
  19. asdex-0.1.0/docs/how-to/hessians.md +258 -0
  20. asdex-0.1.0/docs/how-to/jacobians.md +240 -0
  21. asdex-0.1.0/docs/index.md +70 -0
  22. asdex-0.1.0/docs/javascripts/mathjax.js +16 -0
  23. asdex-0.1.0/docs/reference/coloring.md +7 -0
  24. asdex-0.1.0/docs/reference/data-structures.md +4 -0
  25. asdex-0.1.0/docs/reference/hessian.md +12 -0
  26. asdex-0.1.0/docs/reference/index.md +34 -0
  27. asdex-0.1.0/docs/reference/jacobian.md +13 -0
  28. asdex-0.1.0/docs/reference/sparsity.md +4 -0
  29. asdex-0.1.0/docs/stylesheets/extra.css +16 -0
  30. asdex-0.1.0/docs/tutorials/getting-started.md +161 -0
  31. asdex-0.1.0/mkdocs.yml +94 -0
  32. asdex-0.1.0/pyproject.toml +137 -0
  33. asdex-0.1.0/src/asdex/__init__.py +48 -0
  34. asdex-0.1.0/src/asdex/_display.py +242 -0
  35. asdex-0.1.0/src/asdex/_interpret/CLAUDE.md +133 -0
  36. asdex-0.1.0/src/asdex/_interpret/__init__.py +356 -0
  37. asdex-0.1.0/src/asdex/_interpret/_broadcast.py +80 -0
  38. asdex-0.1.0/src/asdex/_interpret/_commons.py +343 -0
  39. asdex-0.1.0/src/asdex/_interpret/_concatenate.py +51 -0
  40. asdex-0.1.0/src/asdex/_interpret/_cond.py +62 -0
  41. asdex-0.1.0/src/asdex/_interpret/_conv.py +160 -0
  42. asdex-0.1.0/src/asdex/_interpret/_dot_general.py +150 -0
  43. asdex-0.1.0/src/asdex/_interpret/_dynamic_slice.py +124 -0
  44. asdex-0.1.0/src/asdex/_interpret/_elementwise.py +218 -0
  45. asdex-0.1.0/src/asdex/_interpret/_equinox/__init__.py +1 -0
  46. asdex-0.1.0/src/asdex/_interpret/_equinox/_select_if_vmap.py +47 -0
  47. asdex-0.1.0/src/asdex/_interpret/_gather.py +319 -0
  48. asdex-0.1.0/src/asdex/_interpret/_mul.py +66 -0
  49. asdex-0.1.0/src/asdex/_interpret/_pad.py +94 -0
  50. asdex-0.1.0/src/asdex/_interpret/_platform_index.py +45 -0
  51. asdex-0.1.0/src/asdex/_interpret/_reduce.py +72 -0
  52. asdex-0.1.0/src/asdex/_interpret/_reshape.py +73 -0
  53. asdex-0.1.0/src/asdex/_interpret/_rev.py +35 -0
  54. asdex-0.1.0/src/asdex/_interpret/_scan.py +112 -0
  55. asdex-0.1.0/src/asdex/_interpret/_scatter.py +261 -0
  56. asdex-0.1.0/src/asdex/_interpret/_select.py +55 -0
  57. asdex-0.1.0/src/asdex/_interpret/_slice.py +51 -0
  58. asdex-0.1.0/src/asdex/_interpret/_sort.py +69 -0
  59. asdex-0.1.0/src/asdex/_interpret/_split.py +44 -0
  60. asdex-0.1.0/src/asdex/_interpret/_squeeze.py +28 -0
  61. asdex-0.1.0/src/asdex/_interpret/_tile.py +49 -0
  62. asdex-0.1.0/src/asdex/_interpret/_top_k.py +60 -0
  63. asdex-0.1.0/src/asdex/_interpret/_transpose.py +46 -0
  64. asdex-0.1.0/src/asdex/_interpret/_while.py +69 -0
  65. asdex-0.1.0/src/asdex/coloring.py +603 -0
  66. asdex-0.1.0/src/asdex/decompression.py +315 -0
  67. asdex-0.1.0/src/asdex/detection.py +109 -0
  68. asdex-0.1.0/src/asdex/modes.py +42 -0
  69. asdex-0.1.0/src/asdex/pattern.py +409 -0
  70. asdex-0.1.0/src/asdex/verify.py +353 -0
  71. asdex-0.1.0/tests/CLAUDE.md +66 -0
  72. asdex-0.1.0/tests/__init__.py +1 -0
  73. asdex-0.1.0/tests/_interpret/__init__.py +0 -0
  74. asdex-0.1.0/tests/_interpret/_equinox/__init__.py +0 -0
  75. asdex-0.1.0/tests/_interpret/_equinox/test_select_if_vmap.py +86 -0
  76. asdex-0.1.0/tests/_interpret/test_associative_scan.py +165 -0
  77. asdex-0.1.0/tests/_interpret/test_broadcast.py +93 -0
  78. asdex-0.1.0/tests/_interpret/test_concatenate.py +120 -0
  79. asdex-0.1.0/tests/_interpret/test_cond.py +150 -0
  80. asdex-0.1.0/tests/_interpret/test_conv.py +504 -0
  81. asdex-0.1.0/tests/_interpret/test_dot_general.py +645 -0
  82. asdex-0.1.0/tests/_interpret/test_dynamic_slice.py +174 -0
  83. asdex-0.1.0/tests/_interpret/test_elementwise.py +131 -0
  84. asdex-0.1.0/tests/_interpret/test_gather.py +808 -0
  85. asdex-0.1.0/tests/_interpret/test_internals.py +330 -0
  86. asdex-0.1.0/tests/_interpret/test_nested_jaxpr.py +83 -0
  87. asdex-0.1.0/tests/_interpret/test_pad.py +405 -0
  88. asdex-0.1.0/tests/_interpret/test_platform_index.py +88 -0
  89. asdex-0.1.0/tests/_interpret/test_reduce.py +264 -0
  90. asdex-0.1.0/tests/_interpret/test_reduce_and.py +123 -0
  91. asdex-0.1.0/tests/_interpret/test_reshape.py +412 -0
  92. asdex-0.1.0/tests/_interpret/test_rev.py +274 -0
  93. asdex-0.1.0/tests/_interpret/test_scan.py +639 -0
  94. asdex-0.1.0/tests/_interpret/test_scatter.py +705 -0
  95. asdex-0.1.0/tests/_interpret/test_select.py +143 -0
  96. asdex-0.1.0/tests/_interpret/test_slice.py +87 -0
  97. asdex-0.1.0/tests/_interpret/test_sort.py +220 -0
  98. asdex-0.1.0/tests/_interpret/test_split.py +187 -0
  99. asdex-0.1.0/tests/_interpret/test_squeeze.py +23 -0
  100. asdex-0.1.0/tests/_interpret/test_tile.py +173 -0
  101. asdex-0.1.0/tests/_interpret/test_top_k.py +163 -0
  102. asdex-0.1.0/tests/_interpret/test_transpose.py +374 -0
  103. asdex-0.1.0/tests/_interpret/test_while.py +182 -0
  104. asdex-0.1.0/tests/conftest.py +35 -0
  105. asdex-0.1.0/tests/test_benchmarks.py +184 -0
  106. asdex-0.1.0/tests/test_coloring.py +1208 -0
  107. asdex-0.1.0/tests/test_decompression.py +755 -0
  108. asdex-0.1.0/tests/test_detection.py +799 -0
  109. asdex-0.1.0/tests/test_diffrax.py +85 -0
  110. asdex-0.1.0/tests/test_flax.py +383 -0
  111. asdex-0.1.0/tests/test_modes.py +81 -0
  112. asdex-0.1.0/tests/test_multidim.py +282 -0
  113. asdex-0.1.0/tests/test_pattern.py +398 -0
  114. asdex-0.1.0/tests/test_scalar.py +297 -0
  115. asdex-0.1.0/tests/test_sympy.py +591 -0
  116. asdex-0.1.0/tests/test_verify.py +395 -0
  117. asdex-0.1.0/tests/test_vmap.py +69 -0
@@ -0,0 +1,109 @@
1
+ ---
2
+ name: add-handler
3
+ description: Add a precise primitive handler to the jaxpr interpreter, replacing a conservative fallback.
4
+ argument-hint: "[primitive-name]"
5
+ disable-model-invocation: true
6
+ ---
7
+
8
+ # Add precise handler for `$ARGUMENTS`
9
+
10
+ Add a precise propagation handler for the `$ARGUMENTS` primitive,
11
+ replacing the conservative fallback.
12
+
13
+ ## Workflow
14
+
15
+ ### 1. Research
16
+
17
+ Do all research in this step using extended **planning mode**.
18
+
19
+ Before writing code:
20
+ - Read the JAX docs for the primitive: fetch `https://docs.jax.dev/en/latest/_autosummary/jax.lax.$ARGUMENTS.html`
21
+ - Read `src/asdex/_interpret/CLAUDE.md` for conventions (docstring style, semantic line breaks, handler structure)
22
+ - Read `src/asdex/_interpret/_commons.py` to understand available utilities
23
+ - Read an existing handler with similar structure (e.g. `_pad.py`, `_transpose.py`, `_reduction.py`) as a reference
24
+ - Read the existing test in `tests/_interpret/test_internals.py` for the primitive (search for `$ARGUMENTS`)
25
+ - Read `src/asdex/_interpret/__init__.py` to see the current dispatch and fallback setup
26
+
27
+ Understand the primitive's semantics:
28
+ how do input and output element indices map to each other?
29
+ What is the Jacobian structure (permutation, selection, block-diagonal, etc.)?
30
+
31
+ ### 2. Implement handler
32
+
33
+ - Create `src/asdex/_interpret/_$ARGUMENTS.py` with `prop_$ARGUMENTS(eqn, deps)`.
34
+ - Follow the handler docstring style from `_interpret/CLAUDE.md`.
35
+ - If the primitive preserves or predictably transforms input values,
36
+ propagate `const_vals` so downstream handlers can stay precise.
37
+
38
+ ### 3. Wire up dispatch
39
+
40
+ In `src/asdex/_interpret/__init__.py`:
41
+ - Import the new handler
42
+ - Add a `case "$ARGUMENTS":` branch in `prop_dispatch` calling the handler
43
+ - Remove `"$ARGUMENTS"` from the conservative fallback `case` group
44
+
45
+ ### 4. Update tests
46
+
47
+ In `tests/_interpret/test_internals.py`:
48
+ - Update the existing test: change expected values from dense (`np.ones`) to the precise pattern
49
+ - Remove the `@pytest.mark.fallback` marker and `TODO` comments
50
+
51
+ Create `tests/_interpret/test_$ARGUMENTS.py` with thorough tests:
52
+ - Multiple dimensionalities (1D, 2D, 3D, 4D where applicable)
53
+ - Broadcasting shapes: size-1 dimensions that broadcast (e.g. `(3,4)` op `(3,1)`)
54
+ - Non-square shapes (e.g. `(3,4)` not `(4,4)`) so dimension mix-ups are caught
55
+ - Edge cases (size-0 dimensions, identity/trivial parameters)
56
+ - Real-world usage patterns (e.g. `jnp` functions that lower to this primitive)
57
+ - For at least one test per dimensionality, verify precision by comparing the detected
58
+ pattern against `(np.abs(jax.jacobian(f)(x)) > 1e-10)` using `assert_array_equal`.
59
+ Choose test functions that avoid local sparsity (e.g. multiply by zero) so the
60
+ numerical Jacobian matches the structural pattern.
61
+
62
+ ### 5. Verify
63
+
64
+ Run in order:
65
+ ```bash
66
+ uv run ruff check src/asdex/_interpret/_$ARGUMENTS.py
67
+ uv run pytest tests/_interpret/test_$ARGUMENTS.py -v
68
+ uv run pytest tests/_interpret/test_internals.py -v
69
+ uv run pytest tests/ -x
70
+ ```
71
+
72
+ ### 6. Adversarial tests
73
+
74
+ Reread the JAX docs for the primitive: fetch `https://docs.jax.dev/en/latest/_autosummary/jax.lax.$ARGUMENTS.html`
75
+
76
+ Try to break the implementation by testing inputs the handler might not expect:
77
+
78
+ - **Dimensionality**: 1D, 2D, 3D, and higher — if any are missing, add them.
79
+ - **Asymmetric shapes**: always use non-square shapes (e.g. `(3,4)` not `(4,4)`) so that dimension transposition bugs are caught.
80
+ - **Degenerate shapes**: size-0 dimensions, size-1 dimensions, scalar inputs (where the primitive supports them).
81
+ - **Boundary parameters**: empty parameter lists, all-dimensions, single-dimension, negative indices (if applicable).
82
+ - **Compositions**: the primitive chained with itself (e.g. double-reverse, transpose-of-transpose) or with related ops.
83
+ - **Non-contiguous patterns**: inputs where dependencies are not simply `{i}` per element (e.g. from a prior broadcast or reduction) to verify `.copy()` and set merging behave correctly.
84
+ - **Conservative audit**: for each test case, verify the result is strictly sparser than what `conservative_indices()` would produce. If the handler silently falls back on any shape the primitive supports, investigate.
85
+ If the fallback cannot be fixed immediately, you **must** add a `@pytest.mark.fallback` test with a `TODO(primitive)` comment showing the precise expected pattern.
86
+ Catching conservative patterns is extremely valuable for future development.
87
+ - **Const chain**: if the primitive can appear between a literal and a downstream consumer (e.g. type conversions, reshapes, broadcasts on index arrays), write a test composing it with a downstream gather and verify the gather resolves precisely.
88
+
89
+ For each new test, verify the expected output by hand or against `jax.jacobian`.
90
+ Update and re-verify the handler if any test reveals a bug.
91
+
92
+ ### 7. Simplify
93
+
94
+ Review the implementation with fresh eyes and look for opportunities to reduce complexity:
95
+
96
+ - **Vectorize loops**: can per-element Python loops be replaced with numpy operations?
97
+ Pattern: build a flat permutation or index array with `np.arange`, `np.flip`, `np.transpose`, `np.indices`, or `np.ravel_multi_index`,
98
+ then index into `in_indices` in a single list comprehension.
99
+ See `_rev.py`, `_reshape.py`, `_concatenate.py`, and `_broadcast.py` for examples.
100
+ - **Remove unused imports**: after vectorizing, utilities like `flat_to_coords`, `row_strides`, and `numel` may no longer be needed.
101
+ - **Eliminate intermediate variables**: if a value is computed and used only once, inline it.
102
+ - **Simplify special cases**: can a special-case branch be absorbed into the general case?
103
+
104
+ After any change, re-run verification (step 6).
105
+
106
+ ### 8. Update docs
107
+
108
+ - `TODO.md`: check off the primitive and its test items
109
+ - `src/asdex/_interpret/CLAUDE.md`: add the new module to the file listing
@@ -0,0 +1,33 @@
1
+ name: Benchmarks
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ Benchmarks:
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ contents: write
13
+ deployments: write
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ - name: Install uv
17
+ uses: astral-sh/setup-uv@v5
18
+ - name: Set up Python
19
+ run: uv python install 3.12
20
+ - name: Install dependencies
21
+ run: uv sync --group dev
22
+ - name: Run benchmarks
23
+ run: uv run pytest tests/test_benchmarks.py -m dashboard --benchmark-json=benchmark.json
24
+ - name: Store benchmark result
25
+ uses: benchmark-action/github-action-benchmark@v1
26
+ with:
27
+ tool: 'pytest'
28
+ output-file-path: benchmark.json
29
+ github-token: ${{ secrets.GITHUB_TOKEN }}
30
+ auto-push: true
31
+ alert-threshold: '150%'
32
+ comment-on-alert: true
33
+ fail-on-alert: false
@@ -0,0 +1,49 @@
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ Tests:
11
+ runs-on: ${{ matrix.os }}
12
+ strategy:
13
+ fail-fast: false
14
+ matrix:
15
+ os: [ubuntu-latest]
16
+ python-version: ["3.11", "3.14"]
17
+ steps:
18
+ - uses: actions/checkout@v4
19
+ - name: Install uv
20
+ uses: astral-sh/setup-uv@v5
21
+ - name: Set up Python
22
+ run: uv python install ${{ matrix.python-version }}
23
+ - name: Install dependencies
24
+ run: uv sync --group dev
25
+ - name: Test with coverage
26
+ run: uv run pytest -m "" --cov=src/asdex --cov-report=xml # -m "" overrides default marker filter to include slow and benchmark tests
27
+ - name: Upload coverage to Codecov
28
+ if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.14'
29
+ uses: codecov/codecov-action@v5
30
+ with:
31
+ files: coverage.xml
32
+ token: ${{ secrets.CODECOV_TOKEN }}
33
+
34
+ Linting:
35
+ runs-on: ubuntu-latest
36
+ steps:
37
+ - uses: actions/checkout@v4
38
+ - name: Install uv
39
+ uses: astral-sh/setup-uv@v5
40
+ - name: Set up Python
41
+ run: uv python install 3.14
42
+ - name: Install dependencies
43
+ run: uv sync --group dev
44
+ - name: Lint
45
+ run: uv run ruff check .
46
+ - name: Format
47
+ run: uv run ruff format --check .
48
+ - name: Type check
49
+ run: uv run ty check
@@ -0,0 +1,27 @@
1
+ name: Documentation
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ docs:
11
+ runs-on: ubuntu-latest
12
+ permissions:
13
+ contents: write
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ - uses: astral-sh/setup-uv@v5
17
+ - run: uv python install 3.12
18
+ - run: uv sync --group docs
19
+ - run: uv run mkdocs build --strict
20
+
21
+ - name: Deploy to GitHub Pages
22
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main'
23
+ uses: peaceiris/actions-gh-pages@v4
24
+ with:
25
+ github_token: ${{ secrets.GITHUB_TOKEN }}
26
+ publish_dir: ./site
27
+ keep_files: true
@@ -0,0 +1,33 @@
1
+ name: Release
2
+ # Based on https://github.com/astral-sh/trusted-publishing-examples/blob/main/.github/workflows/release.yml
3
+
4
+ on:
5
+ push:
6
+ tags:
7
+ # Publish on any tag starting with a `v`, e.g., v0.1.0
8
+ - v*
9
+
10
+ jobs:
11
+ run:
12
+ name: Publish to PyPI
13
+ runs-on: ubuntu-latest
14
+ environment:
15
+ name: pypi
16
+ permissions:
17
+ id-token: write
18
+ contents: read
19
+ steps:
20
+ - name: Checkout
21
+ uses: actions/checkout@v6
22
+ - name: Install uv
23
+ uses: astral-sh/setup-uv@v7
24
+ - name: Install Python 3.14
25
+ run: uv python install 3.14
26
+ - name: Build
27
+ run: uv build
28
+ - name: Install dependencies
29
+ run: uv sync --group dev
30
+ - name: Test
31
+ run: uv run pytest -m ""
32
+ - name: Publish
33
+ run: uv publish
asdex-0.1.0/.gitignore ADDED
@@ -0,0 +1,20 @@
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+ uv.lock
9
+ .python-version
10
+ .benchmarks
11
+ .pytest_cache
12
+ .ruff_cache
13
+
14
+ # MkDocs build output
15
+ site/
16
+
17
+ # Virtual environments
18
+ .venv
19
+ .claude/settings.local.json
20
+ _context
@@ -0,0 +1,19 @@
1
+ repos:
2
+ - repo: local
3
+ hooks:
4
+ - id: ruff
5
+ name: ruff
6
+ entry: uv run ruff check --fix
7
+ language: system
8
+ types: [python]
9
+ - id: ruff-format
10
+ name: ruff-format
11
+ entry: uv run ruff format
12
+ language: system
13
+ types: [python]
14
+ - id: ty
15
+ name: ty
16
+ entry: uv run ty check
17
+ language: system
18
+ pass_filenames: false
19
+ types: [python]
asdex-0.1.0/CLAUDE.md ADDED
@@ -0,0 +1,103 @@
1
+ # asdex - Automatic Sparse Differentiation in JAX
2
+
3
+ This package implements [Automatic Sparse Differentiation](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/) (ASD) in JAX.
4
+
5
+ ## Overview
6
+
7
+ ASD exploits sparsity to reduce the cost of computing sparse Jacobians and Hessians:
8
+
9
+ 1. **Detection**: Analyze the jaxpr computation graph to detect the global sparsity pattern
10
+ 2. **Coloring**: Assign colors to rows so that rows sharing non-zero columns get different colors
11
+ 3. **Decompression**: Compute one VJP/HVP per color instead of one per row, then extract the sparse matrix
12
+
13
+ ## Structure
14
+
15
+ ```
16
+ src/asdex/
17
+ ├── __init__.py # Public API
18
+ ├── pattern.py # SparsityPattern and ColoredPattern data structures
19
+ ├── detection.py # Jacobian and Hessian sparsity detection via jaxpr analysis
20
+ ├── coloring.py # Graph coloring (row, column, symmetric) and convenience functions
21
+ ├── decompression.py # Sparse Jacobian (VJP/JVP) and Hessian (HVP) computation
22
+ ├── verify.py # Correctness checks (check_jacobian_correctness, check_hessian_correctness)
23
+ ├── _display.py # Display/formatting utilities
24
+ └── _interpret/ # Custom jaxpr interpreter for index set propagation
25
+ ```
26
+
27
+ The interpreter internals are described in `src/asdex/_interpret/CLAUDE.md`.
28
+ The structure of the test folder is described in `tests/CLAUDE.md`.
29
+
30
+ ## Development
31
+
32
+ ```bash
33
+ uv run ruff check --fix . # lint + auto-fix
34
+ uv run ruff format . # format
35
+ uv run ty check # type check
36
+ uv run pytest # run tests
37
+ ```
38
+
39
+ ## Code style
40
+
41
+ - Favor `match` statements over long if-else chains.
42
+ Use explicit cases and default to `case _ as unreachable: assert_never(unreachable)`.
43
+ - Use plain `# Section name` comments for section separators,
44
+ not banner-style `# -- Section name ---`.
45
+
46
+ ## Architecture
47
+
48
+ ### Jacobians
49
+
50
+ ```
51
+ jacobian(f, input_shape)(x) # one-call API
52
+ jacobian_from_coloring(f, coloring)(x) # from pre-computed coloring
53
+
54
+ ├─ 1. DETECTION
55
+ │ jacobian_sparsity(f, input_shape)
56
+ │ ├─ make_jaxpr(f) → jaxpr
57
+ │ ├─ prop_jaxpr() → index sets
58
+ │ └─ SparsityPattern
59
+
60
+ ├─ 2. COLORING
61
+ │ jacobian_coloring_from_sparsity(sparsity)
62
+
63
+ └─ 3. DECOMPRESSION
64
+ One VJP or JVP per color
65
+
66
+ Precompute: jacobian_coloring(f, shape) = detect + color
67
+ ```
68
+
69
+ ### Hessians
70
+
71
+ ```
72
+ hessian(f, input_shape)(x) # one-call API
73
+ hessian_from_coloring(f, coloring)(x) # from pre-computed coloring
74
+
75
+ ├─ 1. DETECTION
76
+ │ hessian_sparsity(f, input_shape)
77
+ │ └─ jacobian_sparsity(grad(f), input_shape)
78
+
79
+ ├─ 2. COLORING
80
+ │ hessian_coloring_from_sparsity(sparsity)
81
+
82
+ └─ 3. DECOMPRESSION
83
+ One HVP per color (fwd-over-rev)
84
+
85
+ Precompute: hessian_coloring(f, shape) = detect + color_symmetric
86
+ ```
87
+
88
+ ## Commits
89
+
90
+ Use [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) for all commit messages (e.g. `feat:`, `fix:`, `docs:`, `refactor:`, `test:`).
91
+ For breaking changes, add `!` after the type (e.g. `feat!:`).
92
+
93
+ ## Design philosophy
94
+
95
+ When writing new code, adhere to these design principles:
96
+
97
+ - **Minimize complexity**: The primary goal of software design is to minimize complexity—anything that makes a system hard to understand and modify.
98
+
99
+ - **Information hiding**: Each module should encapsulate design decisions that other modules don't need to know about, preventing information leakage across boundaries.
100
+
101
+ - **Pull complexity downward**: It's better for a module to be internally complex if it keeps the interface simple for others. Don't expose complexity to callers.
102
+
103
+ - **Favor exceptions over wrong results**: Raise errors for unknown edge cases rather than guessing.
asdex-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Adrian Hill <gh@adrianhill.de>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
asdex-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,112 @@
1
+ Metadata-Version: 2.4
2
+ Name: asdex
3
+ Version: 0.1.0
4
+ Summary: Automatic Sparse Differentiation in JAX
5
+ Project-URL: Homepage, https://github.com/adrhill/asdex
6
+ Project-URL: Documentation, https://adrianhill.de/asdex
7
+ Project-URL: Benchmarks, https://adrianhill.de/asdex/benchmarks
8
+ Project-URL: Repository, https://github.com/adrhill/asdex
9
+ Project-URL: Issues, https://github.com/adrhill/asdex/issues
10
+ Author-email: Adrian Hill <gh@adrianhill.de>
11
+ License-Expression: MIT
12
+ License-File: LICENSE
13
+ Keywords: automatic-differentiation,automatic-sparse-differentiation,graph-coloring,hessian,jacobian,jax,jaxpr,scientific-computing,sparse-automatic-differentiation,sparsity,sparsity-detection
14
+ Classifier: Development Status :: 3 - Alpha
15
+ Classifier: Intended Audience :: Science/Research
16
+ Classifier: License :: OSI Approved :: MIT License
17
+ Classifier: Programming Language :: Python :: 3
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Classifier: Programming Language :: Python :: 3.13
21
+ Classifier: Programming Language :: Python :: 3.14
22
+ Classifier: Topic :: Scientific/Engineering
23
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
24
+ Requires-Python: >=3.11
25
+ Requires-Dist: jax>=0.9.0
26
+ Requires-Dist: numpy>=1.26.0
27
+ Description-Content-Type: text/markdown
28
+
29
+ # asdex
30
+
31
+ [![CI](https://github.com/adrhill/asdex/actions/workflows/ci.yml/badge.svg)](https://github.com/adrhill/asdex/actions/workflows/ci.yml)
32
+ [![codecov](https://codecov.io/gh/adrhill/asdex/graph/badge.svg)](https://codecov.io/gh/adrhill/asdex)
33
+ [![Docs](https://img.shields.io/badge/docs-online-blue)](https://adrianhill.de/asdex/)
34
+ [![Benchmarks](https://img.shields.io/badge/benchmarks-view-blue)](https://adrianhill.de/asdex/dev/bench/)
35
+
36
+ [Automatic Sparse Differentiation](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/) in JAX.
37
+
38
+ `asdex` (pronounced _Aztecs_) exploits sparsity structure to efficiently compute sparse Jacobians and Hessians.
39
+ It implements a custom [Jaxpr](https://docs.jax.dev/en/latest/jaxpr.html) interpreter
40
+ that uses [abstract interpretation](https://en.wikipedia.org/wiki/Abstract_interpretation)
41
+ to detect sparsity patterns from the computation graph,
42
+ then uses graph coloring to minimize the number of AD passes needed.
43
+
44
+ > [!WARNING]
45
+ > `asdex` is in early development.
46
+ > The API may change without notice.
47
+ > Use at your own risk.
48
+
49
+ ## Installation
50
+
51
+ ```bash
52
+ pip install git+https://github.com/adrhill/asdex.git
53
+ ```
54
+
55
+ Or with [uv](https://docs.astral.sh/uv/):
56
+
57
+ ```bash
58
+ uv add git+https://github.com/adrhill/asdex.git
59
+ ```
60
+
61
+ ## Example
62
+
63
+ ```python
64
+ import numpy as np
65
+ from asdex import jacobian
66
+
67
+ def f(x):
68
+ return (x[1:] - x[:-1]) ** 2
69
+
70
+ jac_fn = jacobian(f, input_shape=50)
71
+ # ColoredPattern(49×50, nnz=98, sparsity=96.0%, JVP, 2 colors)
72
+ # 2 JVPs (instead of 49 VJPs or 50 JVPs)
73
+ # ⎡⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎡⣿⎤
74
+ # ⎢⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
75
+ # ⎢⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
76
+ # ⎢⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
77
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
78
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
79
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⎥
80
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
81
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
82
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⎥ ⎢⣿⎥
83
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⎥ ⎢⣿⎥
84
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⎥ ⎢⣿⎥
85
+ # ⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⎦ ⎣⠉⎦
86
+
87
+ for x in inputs:
88
+ J = jac_fn(x)
89
+ ```
90
+
91
+ Instead of 49 VJPs or 50 JVPs,
92
+ `asdex` computes the full sparse Jacobian with just 2 JVPs.
93
+
94
+ ## Documentation
95
+
96
+ - [Getting Started](https://adrianhill.de/asdex/tutorials/getting-started/) — step-by-step tutorial
97
+ - [How-To Guides](https://adrianhill.de/asdex/how-to/jacobians/) — task-oriented recipes
98
+ - [Explanation](https://adrianhill.de/asdex/explanation/sparsity-detection/) — how and why it works
99
+ - [API Reference](https://adrianhill.de/asdex/reference/) — full API documentation
100
+
101
+ ## Acknowledgements
102
+
103
+ This package is built with Claude Code based on previous work by [Adrian Hill](https://github.com/adrhill), [Guillaume Dalle](https://github.com/gdalle), and [Alexis Montoison](https://github.com/amontoison) in the [Julia programming language](https://julialang.org):
104
+
105
+ - [_An Illustrated Guide to Automatic Sparse Differentiation_](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/), A. Hill, G. Dalle, A. Montoison (2025)
106
+ - [_Sparser, Better, Faster, Stronger: Efficient Automatic Differentiation for Sparse Jacobians and Hessians_](https://openreview.net/forum?id=GtXSN52nIW), A. Hill & G. Dalle (2025)
107
+ - [_Revisiting Sparse Matrix Coloring and Bicoloring_](https://arxiv.org/abs/2505.07308), A. Montoison, G. Dalle, A. Gebremedhin (2025)
108
+ - [_SparseConnectivityTracer.jl_](https://github.com/adrhill/SparseConnectivityTracer.jl), A. Hill, G. Dalle
109
+ - [_SparseMatrixColorings.jl_](https://github.com/gdalle/SparseMatrixColorings.jl), G. Dalle, A. Montoison
110
+ - [_sparsediffax_](https://github.com/gdalle/sparsediffax), G. Dalle
111
+
112
+ which in turn stands on the shoulders of giants — notably Andreas Griewank, Andrea Walther, and Assefaw Gebremedhin.
asdex-0.1.0/README.md ADDED
@@ -0,0 +1,84 @@
1
+ # asdex
2
+
3
+ [![CI](https://github.com/adrhill/asdex/actions/workflows/ci.yml/badge.svg)](https://github.com/adrhill/asdex/actions/workflows/ci.yml)
4
+ [![codecov](https://codecov.io/gh/adrhill/asdex/graph/badge.svg)](https://codecov.io/gh/adrhill/asdex)
5
+ [![Docs](https://img.shields.io/badge/docs-online-blue)](https://adrianhill.de/asdex/)
6
+ [![Benchmarks](https://img.shields.io/badge/benchmarks-view-blue)](https://adrianhill.de/asdex/dev/bench/)
7
+
8
+ [Automatic Sparse Differentiation](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/) in JAX.
9
+
10
+ `asdex` (pronounced _Aztecs_) exploits sparsity structure to efficiently compute sparse Jacobians and Hessians.
11
+ It implements a custom [Jaxpr](https://docs.jax.dev/en/latest/jaxpr.html) interpreter
12
+ that uses [abstract interpretation](https://en.wikipedia.org/wiki/Abstract_interpretation)
13
+ to detect sparsity patterns from the computation graph,
14
+ then uses graph coloring to minimize the number of AD passes needed.
15
+
16
+ > [!WARNING]
17
+ > `asdex` is in early development.
18
+ > The API may change without notice.
19
+ > Use at your own risk.
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ pip install git+https://github.com/adrhill/asdex.git
25
+ ```
26
+
27
+ Or with [uv](https://docs.astral.sh/uv/):
28
+
29
+ ```bash
30
+ uv add git+https://github.com/adrhill/asdex.git
31
+ ```
32
+
33
+ ## Example
34
+
35
+ ```python
36
+ import numpy as np
37
+ from asdex import jacobian
38
+
39
+ def f(x):
40
+ return (x[1:] - x[:-1]) ** 2
41
+
42
+ jac_fn = jacobian(f, input_shape=50)
43
+ # ColoredPattern(49×50, nnz=98, sparsity=96.0%, JVP, 2 colors)
44
+ # 2 JVPs (instead of 49 VJPs or 50 JVPs)
45
+ # ⎡⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎡⣿⎤
46
+ # ⎢⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
47
+ # ⎢⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
48
+ # ⎢⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
49
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
50
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
51
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⎥
52
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
53
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
54
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⎥ ⎢⣿⎥
55
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⎥ ⎢⣿⎥
56
+ # ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⎥ ⎢⣿⎥
57
+ # ⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⎦ ⎣⠉⎦
58
+
59
+ for x in inputs:
60
+ J = jac_fn(x)
61
+ ```
62
+
63
+ Instead of 49 VJPs or 50 JVPs,
64
+ `asdex` computes the full sparse Jacobian with just 2 JVPs.
65
+
66
+ ## Documentation
67
+
68
+ - [Getting Started](https://adrianhill.de/asdex/tutorials/getting-started/) — step-by-step tutorial
69
+ - [How-To Guides](https://adrianhill.de/asdex/how-to/jacobians/) — task-oriented recipes
70
+ - [Explanation](https://adrianhill.de/asdex/explanation/sparsity-detection/) — how and why it works
71
+ - [API Reference](https://adrianhill.de/asdex/reference/) — full API documentation
72
+
73
+ ## Acknowledgements
74
+
75
+ This package is built with Claude Code based on previous work by [Adrian Hill](https://github.com/adrhill), [Guillaume Dalle](https://github.com/gdalle), and [Alexis Montoison](https://github.com/amontoison) in the [Julia programming language](https://julialang.org):
76
+
77
+ - [_An Illustrated Guide to Automatic Sparse Differentiation_](https://iclr-blogposts.github.io/2025/blog/sparse-autodiff/), A. Hill, G. Dalle, A. Montoison (2025)
78
+ - [_Sparser, Better, Faster, Stronger: Efficient Automatic Differentiation for Sparse Jacobians and Hessians_](https://openreview.net/forum?id=GtXSN52nIW), A. Hill & G. Dalle (2025)
79
+ - [_Revisiting Sparse Matrix Coloring and Bicoloring_](https://arxiv.org/abs/2505.07308), A. Montoison, G. Dalle, A. Gebremedhin (2025)
80
+ - [_SparseConnectivityTracer.jl_](https://github.com/adrhill/SparseConnectivityTracer.jl), A. Hill, G. Dalle
81
+ - [_SparseMatrixColorings.jl_](https://github.com/gdalle/SparseMatrixColorings.jl), G. Dalle, A. Montoison
82
+ - [_sparsediffax_](https://github.com/gdalle/sparsediffax), G. Dalle
83
+
84
+ which in turn stands on the shoulders of giants — notably Andreas Griewank, Andrea Walther, and Assefaw Gebremedhin.