jax-shapeguard 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_shapeguard-0.3.0/.github/workflows/lint.yml +37 -0
- jax_shapeguard-0.3.0/.github/workflows/publish.yml +29 -0
- jax_shapeguard-0.3.0/.github/workflows/test.yml +39 -0
- jax_shapeguard-0.3.0/.gitignore +44 -0
- jax_shapeguard-0.3.0/LICENSE +21 -0
- jax_shapeguard-0.3.0/MILESTONES.md +241 -0
- jax_shapeguard-0.3.0/PKG-INFO +69 -0
- jax_shapeguard-0.3.0/README.md +42 -0
- jax_shapeguard-0.3.0/pyproject.toml +94 -0
- jax_shapeguard-0.3.0/shapeguard/__init__.py +61 -0
- jax_shapeguard-0.3.0/shapeguard/_compat.py +96 -0
- jax_shapeguard-0.3.0/shapeguard/broadcast.py +211 -0
- jax_shapeguard-0.3.0/shapeguard/config.py +51 -0
- jax_shapeguard-0.3.0/shapeguard/context.py +104 -0
- jax_shapeguard-0.3.0/shapeguard/core.py +159 -0
- jax_shapeguard-0.3.0/shapeguard/decorator.py +472 -0
- jax_shapeguard-0.3.0/shapeguard/errors.py +208 -0
- jax_shapeguard-0.3.0/shapeguard/spec.py +223 -0
- jax_shapeguard-0.3.0/tests/__init__.py +1 -0
- jax_shapeguard-0.3.0/tests/conftest.py +49 -0
- jax_shapeguard-0.3.0/tests/test_batch.py +88 -0
- jax_shapeguard-0.3.0/tests/test_broadcast.py +221 -0
- jax_shapeguard-0.3.0/tests/test_config.py +61 -0
- jax_shapeguard-0.3.0/tests/test_context.py +140 -0
- jax_shapeguard-0.3.0/tests/test_contract.py +163 -0
- jax_shapeguard-0.3.0/tests/test_core.py +113 -0
- jax_shapeguard-0.3.0/tests/test_decorator.py +211 -0
- jax_shapeguard-0.3.0/tests/test_ellipsis.py +165 -0
- jax_shapeguard-0.3.0/tests/test_ensures.py +272 -0
- jax_shapeguard-0.3.0/tests/test_errors.py +171 -0
- jax_shapeguard-0.3.0/tests/test_jit_modes.py +236 -0
- jax_shapeguard-0.3.0/tests/test_pytree.py +161 -0
- jax_shapeguard-0.3.0/tests/test_spec.py +154 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
name: Lint
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
lint:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
|
|
13
|
+
steps:
|
|
14
|
+
- uses: actions/checkout@v4
|
|
15
|
+
|
|
16
|
+
- name: Install uv
|
|
17
|
+
uses: astral-sh/setup-uv@v4
|
|
18
|
+
with:
|
|
19
|
+
version: "latest"
|
|
20
|
+
|
|
21
|
+
- name: Set up Python
|
|
22
|
+
run: uv python install 3.11
|
|
23
|
+
|
|
24
|
+
- name: Install dependencies
|
|
25
|
+
run: uv sync --dev
|
|
26
|
+
|
|
27
|
+
- name: Install linting tools
|
|
28
|
+
run: uv pip install ruff mypy
|
|
29
|
+
|
|
30
|
+
- name: Run ruff check
|
|
31
|
+
run: uv run ruff check shapeguard tests
|
|
32
|
+
|
|
33
|
+
- name: Run ruff format check
|
|
34
|
+
run: uv run ruff format --check shapeguard tests
|
|
35
|
+
|
|
36
|
+
- name: Run mypy
|
|
37
|
+
run: uv run mypy shapeguard --ignore-missing-imports
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
|
|
7
|
+
jobs:
|
|
8
|
+
publish:
|
|
9
|
+
runs-on: ubuntu-latest
|
|
10
|
+
environment: pypi
|
|
11
|
+
permissions:
|
|
12
|
+
id-token: write # Required for trusted publishing
|
|
13
|
+
|
|
14
|
+
steps:
|
|
15
|
+
- uses: actions/checkout@v4
|
|
16
|
+
|
|
17
|
+
- name: Install uv
|
|
18
|
+
uses: astral-sh/setup-uv@v4
|
|
19
|
+
with:
|
|
20
|
+
version: "latest"
|
|
21
|
+
|
|
22
|
+
- name: Set up Python
|
|
23
|
+
run: uv python install 3.11
|
|
24
|
+
|
|
25
|
+
- name: Build package
|
|
26
|
+
run: uv build
|
|
27
|
+
|
|
28
|
+
- name: Publish to PyPI
|
|
29
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
name: Tests
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
test:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
strategy:
|
|
13
|
+
fail-fast: false
|
|
14
|
+
matrix:
|
|
15
|
+
python-version: ["3.10", "3.11", "3.12"]
|
|
16
|
+
|
|
17
|
+
steps:
|
|
18
|
+
- uses: actions/checkout@v4
|
|
19
|
+
|
|
20
|
+
- name: Install uv
|
|
21
|
+
uses: astral-sh/setup-uv@v4
|
|
22
|
+
with:
|
|
23
|
+
version: "latest"
|
|
24
|
+
|
|
25
|
+
- name: Set up Python ${{ matrix.python-version }}
|
|
26
|
+
run: uv python install ${{ matrix.python-version }}
|
|
27
|
+
|
|
28
|
+
- name: Install dependencies
|
|
29
|
+
run: uv sync --dev
|
|
30
|
+
|
|
31
|
+
- name: Run tests
|
|
32
|
+
run: uv run pytest -v --cov=shapeguard --cov-report=xml
|
|
33
|
+
|
|
34
|
+
- name: Upload coverage
|
|
35
|
+
uses: codecov/codecov-action@v4
|
|
36
|
+
if: matrix.python-version == '3.11'
|
|
37
|
+
with:
|
|
38
|
+
files: coverage.xml
|
|
39
|
+
fail_ci_if_error: false
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Python
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
*.so
|
|
6
|
+
.Python
|
|
7
|
+
build/
|
|
8
|
+
develop-eggs/
|
|
9
|
+
dist/
|
|
10
|
+
downloads/
|
|
11
|
+
eggs/
|
|
12
|
+
.eggs/
|
|
13
|
+
lib/
|
|
14
|
+
lib64/
|
|
15
|
+
parts/
|
|
16
|
+
sdist/
|
|
17
|
+
var/
|
|
18
|
+
wheels/
|
|
19
|
+
*.egg-info/
|
|
20
|
+
.installed.cfg
|
|
21
|
+
*.egg
|
|
22
|
+
|
|
23
|
+
# Virtual environments
|
|
24
|
+
.venv/
|
|
25
|
+
venv/
|
|
26
|
+
ENV/
|
|
27
|
+
|
|
28
|
+
# Testing
|
|
29
|
+
.pytest_cache/
|
|
30
|
+
.coverage
|
|
31
|
+
htmlcov/
|
|
32
|
+
.tox/
|
|
33
|
+
.nox/
|
|
34
|
+
|
|
35
|
+
# IDE
|
|
36
|
+
.idea/
|
|
37
|
+
.vscode/
|
|
38
|
+
*.swp
|
|
39
|
+
*.swo
|
|
40
|
+
*~
|
|
41
|
+
|
|
42
|
+
# OS
|
|
43
|
+
.DS_Store
|
|
44
|
+
Thumbs.db
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Jayendra Parmar
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
# ShapeGuard Milestones
|
|
2
|
+
|
|
3
|
+
## Design Decisions
|
|
4
|
+
|
|
5
|
+
- **Dim identity**: Same object required (`n = Dim("n")` must be reused)
|
|
6
|
+
- **Performance**: Dev-only tool, prioritize error quality over speed
|
|
7
|
+
- **Backend priority**: NumPy-first, JAX support in v0.2
|
|
8
|
+
|
|
9
|
+
---
|
|
10
|
+
|
|
11
|
+
## Milestone 1: Core Foundation (v0.1-alpha)
|
|
12
|
+
|
|
13
|
+
### Goal
|
|
14
|
+
Minimal working library with symbolic dimensions, shape checking, and decorator.
|
|
15
|
+
|
|
16
|
+
### Files
|
|
17
|
+
```
|
|
18
|
+
shapeguard/
|
|
19
|
+
__init__.py # Public API exports
|
|
20
|
+
core.py # Dim class, UnificationContext
|
|
21
|
+
spec.py # Shape specification matching
|
|
22
|
+
decorator.py # @expects decorator
|
|
23
|
+
errors.py # ShapeGuardError
|
|
24
|
+
_compat.py # Backend detection
|
|
25
|
+
tests/
|
|
26
|
+
test_core.py
|
|
27
|
+
test_spec.py
|
|
28
|
+
test_decorator.py
|
|
29
|
+
test_errors.py
|
|
30
|
+
conftest.py
|
|
31
|
+
pyproject.toml
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
### Deliverables
|
|
35
|
+
- [x] `Dim` class with identity-based unification
|
|
36
|
+
- [x] `UnificationContext` for tracking bindings across arguments
|
|
37
|
+
- [x] `ShapeSpec` matching: concrete `(3, 4)`, symbolic `(n, m)`, wildcard `(None, 4)`
|
|
38
|
+
- [x] `@expects` decorator for input validation
|
|
39
|
+
- [x] `ShapeGuardError` with function, argument, expected, actual, reason
|
|
40
|
+
- [x] `check_shape(x, spec, name)` standalone function
|
|
41
|
+
- [x] Backend-agnostic shape extraction (works with any `.shape` attribute)
|
|
42
|
+
- [x] Unit tests with 90%+ coverage (91% achieved)
|
|
43
|
+
|
|
44
|
+
### API Surface
|
|
45
|
+
```python
|
|
46
|
+
from shapeguard import Dim, expects, check_shape, ShapeGuardError
|
|
47
|
+
|
|
48
|
+
n, m = Dim("n"), Dim("m")
|
|
49
|
+
|
|
50
|
+
@expects(x=(n, m), y=(m,))
|
|
51
|
+
def forward(x, y):
|
|
52
|
+
return x @ y
|
|
53
|
+
|
|
54
|
+
check_shape(arr, (n, 128), name="input")
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
---
|
|
58
|
+
|
|
59
|
+
## Milestone 2: ML-Practical Features (v0.1-beta)
|
|
60
|
+
|
|
61
|
+
### Goal
|
|
62
|
+
Ergonomic features for real ML workflows.
|
|
63
|
+
|
|
64
|
+
### Deliverables
|
|
65
|
+
- [x] `Batch` dimension (always first, flexible size per call)
|
|
66
|
+
- [x] Ellipsis support `(..., n, m)` for variable leading dims
|
|
67
|
+
- [x] `ShapeContext` manager for grouped checks with shared bindings
|
|
68
|
+
- [x] Improved error messages with binding trace (92% coverage)
|
|
69
|
+
|
|
70
|
+
### API Additions
|
|
71
|
+
```python
|
|
72
|
+
from shapeguard import Batch, ShapeContext
|
|
73
|
+
|
|
74
|
+
B = Batch()
|
|
75
|
+
|
|
76
|
+
@expects(x=(B, n, m))
|
|
77
|
+
def layer(x): ...
|
|
78
|
+
|
|
79
|
+
@expects(x=(..., n, m))
|
|
80
|
+
def normalize(x): ...
|
|
81
|
+
|
|
82
|
+
with ShapeContext() as ctx:
|
|
83
|
+
ctx.check(x, (n, m), "x")
|
|
84
|
+
ctx.check(y, (m, k), "y")
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
---
|
|
88
|
+
|
|
89
|
+
## Milestone 3: JAX Integration (v0.2)
|
|
90
|
+
|
|
91
|
+
### Goal
|
|
92
|
+
Seamless JAX compatibility including JIT behavior.
|
|
93
|
+
|
|
94
|
+
### Deliverables
|
|
95
|
+
- [x] JIT/tracing detection
|
|
96
|
+
- [x] Configurable JIT modes: `skip`, `warn`, `check`
|
|
97
|
+
- [x] PyTree shape specs for nested params
|
|
98
|
+
- [ ] Performance benchmarks (deferred)
|
|
99
|
+
|
|
100
|
+
### API Additions
|
|
101
|
+
```python
|
|
102
|
+
from shapeguard import expects, config
|
|
103
|
+
|
|
104
|
+
config.jit_mode = "skip" # Global setting
|
|
105
|
+
|
|
106
|
+
@expects(x=(B, n, m), jit_mode="static") # Per-function
|
|
107
|
+
@jax.jit
|
|
108
|
+
def forward(x): ...
|
|
109
|
+
|
|
110
|
+
@expects(
|
|
111
|
+
params={"weights": (n, m), "bias": (m,)},
|
|
112
|
+
x=(B, n)
|
|
113
|
+
)
|
|
114
|
+
def apply(params, x): ...
|
|
115
|
+
```
|
|
116
|
+
|
|
117
|
+
---
|
|
118
|
+
|
|
119
|
+
## Milestone 4: Broadcasting Support (v0.2)
|
|
120
|
+
|
|
121
|
+
### Goal
|
|
122
|
+
Explicit broadcasting inspection and validation.
|
|
123
|
+
|
|
124
|
+
### Deliverables
|
|
125
|
+
- [x] `broadcast_shape()` for concrete shapes
|
|
126
|
+
- [x] `explain_broadcast()` step-by-step explainer
|
|
127
|
+
- [ ] `_broadcast=True` option in `@expects` (deferred)
|
|
128
|
+
|
|
129
|
+
### API Additions
|
|
130
|
+
```python
|
|
131
|
+
from shapeguard import broadcast_shape, explain_broadcast
|
|
132
|
+
|
|
133
|
+
broadcast_shape((3, 1), (1, 4)) # → (3, 4)
|
|
134
|
+
broadcast_shape(a, b) # From arrays
|
|
135
|
+
|
|
136
|
+
explain_broadcast((3, 1, 4), (5, 4))
|
|
137
|
+
# Broadcasting (3, 1, 4) with (5, 4):
|
|
138
|
+
# Dim 0: 3 (from left only)
|
|
139
|
+
# Dim 1: 1 → 5 (broadcast)
|
|
140
|
+
# Dim 2: 4 = 4 (match)
|
|
141
|
+
# Result: (3, 5, 4)
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
---
|
|
145
|
+
|
|
146
|
+
## Milestone 5: Output Contracts (v0.3)
|
|
147
|
+
|
|
148
|
+
### Goal
|
|
149
|
+
Validate function outputs, not just inputs.
|
|
150
|
+
|
|
151
|
+
### Deliverables
|
|
152
|
+
- [x] `@ensures` decorator for output validation
|
|
153
|
+
- [x] `@contract` combined decorator
|
|
154
|
+
- [x] Tuple/dict output support
|
|
155
|
+
|
|
156
|
+
### API Additions
|
|
157
|
+
```python
|
|
158
|
+
from shapeguard import expects, ensures, contract
|
|
159
|
+
|
|
160
|
+
@expects(a=(n, m), b=(m, k))
|
|
161
|
+
@ensures(result=(n, k))
|
|
162
|
+
def matmul(a, b):
|
|
163
|
+
return a @ b
|
|
164
|
+
|
|
165
|
+
@contract(
|
|
166
|
+
inputs={"a": (n, m), "b": (m, k)},
|
|
167
|
+
output=(n, k)
|
|
168
|
+
)
|
|
169
|
+
def matmul(a, b):
|
|
170
|
+
return a @ b
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
---
|
|
174
|
+
|
|
175
|
+
## Milestone 6: ML Helpers (v0.3)
|
|
176
|
+
|
|
177
|
+
### Goal
|
|
178
|
+
Domain-specific helpers for common ML patterns.
|
|
179
|
+
|
|
180
|
+
### Deliverables
|
|
181
|
+
- [ ] Pre-defined dims: `B`, `T`, `C`, `H`, `W`, `D`
|
|
182
|
+
- [ ] `attention_shapes()` helper
|
|
183
|
+
- [ ] `conv_output_shape()` calculator
|
|
184
|
+
|
|
185
|
+
### API Additions
|
|
186
|
+
```python
|
|
187
|
+
from shapeguard.ml import B, T, C, H, W, D
|
|
188
|
+
from shapeguard.ml import attention_shapes, conv_output_shape
|
|
189
|
+
|
|
190
|
+
@expects(x=(B, T, D))
|
|
191
|
+
def transformer_layer(x): ...
|
|
192
|
+
|
|
193
|
+
@expects(**attention_shapes(B, heads, seq_q, seq_k, d_k))
|
|
194
|
+
def attention(q, k, v): ...
|
|
195
|
+
|
|
196
|
+
out_shape = conv_output_shape(
|
|
197
|
+
input=(B, C, 224, 224),
|
|
198
|
+
kernel=(3, 3),
|
|
199
|
+
stride=2,
|
|
200
|
+
padding=1
|
|
201
|
+
)
|
|
202
|
+
```
|
|
203
|
+
|
|
204
|
+
---
|
|
205
|
+
|
|
206
|
+
## Milestone 7: Testing Utilities (v0.4)
|
|
207
|
+
|
|
208
|
+
### Goal
|
|
209
|
+
Property-based testing support.
|
|
210
|
+
|
|
211
|
+
### Deliverables
|
|
212
|
+
- [ ] Hypothesis strategies for shaped arrays
|
|
213
|
+
- [ ] `verify_contract()` auto-test generator
|
|
214
|
+
- [ ] pytest plugin
|
|
215
|
+
|
|
216
|
+
### API Additions
|
|
217
|
+
```python
|
|
218
|
+
from shapeguard.testing import arrays, verify_contract
|
|
219
|
+
import hypothesis
|
|
220
|
+
|
|
221
|
+
@hypothesis.given(x=arrays(shape=(n, m), n=(1, 100), m=(1, 100)))
|
|
222
|
+
def test_normalize(x):
|
|
223
|
+
result = normalize(x)
|
|
224
|
+
assert result.shape == x.shape
|
|
225
|
+
|
|
226
|
+
verify_contract(matmul, samples=100)
|
|
227
|
+
```
|
|
228
|
+
|
|
229
|
+
---
|
|
230
|
+
|
|
231
|
+
## Summary Timeline
|
|
232
|
+
|
|
233
|
+
| Milestone | Version | Status |
|
|
234
|
+
|-----------|---------|--------|
|
|
235
|
+
| 1. Core Foundation | v0.1-alpha | ✅ Complete (91% coverage) |
|
|
236
|
+
| 2. ML Features | v0.1-beta | ✅ Complete (92% coverage) |
|
|
237
|
+
| 3. JAX Integration | v0.2 | ✅ Complete (92% coverage) |
|
|
238
|
+
| 4. Broadcasting | v0.2 | ✅ Complete |
|
|
239
|
+
| 5. Output Contracts | v0.3 | ✅ Complete (91% coverage) |
|
|
240
|
+
| 6. ML Helpers | v0.3 | 🔲 Not started |
|
|
241
|
+
| 7. Testing Utils | v0.4 | 🔲 Not started |
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: jax-shapeguard
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Runtime shape contracts and diagnostics for NumPy and JAX
|
|
5
|
+
Project-URL: Homepage, https://github.com/jayendra13/jax-shape-guard
|
|
6
|
+
Project-URL: Repository, https://github.com/jayendra13/jax-shape-guard
|
|
7
|
+
Author-email: Jayendra Parmar <jayendra0parmar@gmail.com>
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Keywords: debugging,jax,ml,numpy,shapes,validation
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering
|
|
20
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
21
|
+
Classifier: Typing :: Typed
|
|
22
|
+
Requires-Python: >=3.10
|
|
23
|
+
Provides-Extra: jax
|
|
24
|
+
Requires-Dist: jax>=0.4; extra == 'jax'
|
|
25
|
+
Requires-Dist: jaxlib>=0.4; extra == 'jax'
|
|
26
|
+
Description-Content-Type: text/markdown
|
|
27
|
+
|
|
28
|
+
# ShapeGuard
|
|
29
|
+
|
|
30
|
+
[](https://github.com/jayendra13/jax-shape-guard/actions/workflows/test.yml)
|
|
31
|
+
[](https://github.com/jayendra13/jax-shape-guard/actions/workflows/lint.yml)
|
|
32
|
+
[](https://pypi.org/project/shapeguard/)
|
|
33
|
+
[](https://pypi.org/project/shapeguard/)
|
|
34
|
+
[](https://github.com/jayendra13/jax-shape-guard/blob/main/LICENSE)
|
|
35
|
+
|
|
36
|
+
Runtime shape contracts and diagnostics for NumPy and JAX.
|
|
37
|
+
|
|
38
|
+
## Installation
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
pip install shapeguard
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Quick Start
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from shapeguard import Dim, expects
|
|
48
|
+
|
|
49
|
+
n, m, k = Dim("n"), Dim("m"), Dim("k")
|
|
50
|
+
|
|
51
|
+
@expects(a=(n, m), b=(m, k))
|
|
52
|
+
def matmul(a, b):
|
|
53
|
+
return a @ b
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
When shapes don't match, you get clear errors:
|
|
57
|
+
|
|
58
|
+
```
|
|
59
|
+
ShapeGuardError:
|
|
60
|
+
function: matmul
|
|
61
|
+
argument: b
|
|
62
|
+
expected: (m, k)
|
|
63
|
+
actual: (5, 7)
|
|
64
|
+
reason: dimension 'm' bound to 4 from a.shape[1], but got 5 from b.shape[0]
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## License
|
|
68
|
+
|
|
69
|
+
MIT
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# ShapeGuard
|
|
2
|
+
|
|
3
|
+
[](https://github.com/jayendra13/jax-shape-guard/actions/workflows/test.yml)
|
|
4
|
+
[](https://github.com/jayendra13/jax-shape-guard/actions/workflows/lint.yml)
|
|
5
|
+
[](https://pypi.org/project/shapeguard/)
|
|
6
|
+
[](https://pypi.org/project/shapeguard/)
|
|
7
|
+
[](https://github.com/jayendra13/jax-shape-guard/blob/main/LICENSE)
|
|
8
|
+
|
|
9
|
+
Runtime shape contracts and diagnostics for NumPy and JAX.
|
|
10
|
+
|
|
11
|
+
## Installation
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pip install shapeguard
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Quick Start
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
from shapeguard import Dim, expects
|
|
21
|
+
|
|
22
|
+
n, m, k = Dim("n"), Dim("m"), Dim("k")
|
|
23
|
+
|
|
24
|
+
@expects(a=(n, m), b=(m, k))
|
|
25
|
+
def matmul(a, b):
|
|
26
|
+
return a @ b
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
When shapes don't match, you get clear errors:
|
|
30
|
+
|
|
31
|
+
```
|
|
32
|
+
ShapeGuardError:
|
|
33
|
+
function: matmul
|
|
34
|
+
argument: b
|
|
35
|
+
expected: (m, k)
|
|
36
|
+
actual: (5, 7)
|
|
37
|
+
reason: dimension 'm' bound to 4 from a.shape[1], but got 5 from b.shape[0]
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## License
|
|
41
|
+
|
|
42
|
+
MIT
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[tool.hatch.build.targets.wheel]
|
|
6
|
+
packages = ["shapeguard"]
|
|
7
|
+
|
|
8
|
+
[project]
|
|
9
|
+
name = "jax-shapeguard"
|
|
10
|
+
version = "0.3.0"
|
|
11
|
+
description = "Runtime shape contracts and diagnostics for NumPy and JAX"
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
license = "MIT"
|
|
14
|
+
requires-python = ">=3.10"
|
|
15
|
+
authors = [
|
|
16
|
+
{ name = "Jayendra Parmar", email = "jayendra0parmar@gmail.com" }
|
|
17
|
+
]
|
|
18
|
+
keywords = ["numpy", "jax", "shapes", "validation", "debugging", "ml"]
|
|
19
|
+
classifiers = [
|
|
20
|
+
"Development Status :: 3 - Alpha",
|
|
21
|
+
"Intended Audience :: Developers",
|
|
22
|
+
"Intended Audience :: Science/Research",
|
|
23
|
+
"License :: OSI Approved :: MIT License",
|
|
24
|
+
"Programming Language :: Python :: 3",
|
|
25
|
+
"Programming Language :: Python :: 3.10",
|
|
26
|
+
"Programming Language :: Python :: 3.11",
|
|
27
|
+
"Programming Language :: Python :: 3.12",
|
|
28
|
+
"Topic :: Scientific/Engineering",
|
|
29
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
30
|
+
"Typing :: Typed",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.optional-dependencies]
|
|
34
|
+
jax = [
|
|
35
|
+
"jax>=0.4",
|
|
36
|
+
"jaxlib>=0.4",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
[dependency-groups]
|
|
40
|
+
dev = [
|
|
41
|
+
"pytest>=7.0",
|
|
42
|
+
"pytest-cov>=4.0",
|
|
43
|
+
"numpy>=1.20",
|
|
44
|
+
"jax>=0.4",
|
|
45
|
+
"jaxlib>=0.4",
|
|
46
|
+
"ruff>=0.4",
|
|
47
|
+
"mypy>=1.0",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
[project.urls]
|
|
51
|
+
Homepage = "https://github.com/jayendra13/jax-shape-guard"
|
|
52
|
+
Repository = "https://github.com/jayendra13/jax-shape-guard"
|
|
53
|
+
|
|
54
|
+
[tool.pytest.ini_options]
|
|
55
|
+
testpaths = ["tests"]
|
|
56
|
+
addopts = "-v --tb=short"
|
|
57
|
+
|
|
58
|
+
[tool.coverage.run]
|
|
59
|
+
source = ["shapeguard"]
|
|
60
|
+
branch = true
|
|
61
|
+
|
|
62
|
+
[tool.coverage.report]
|
|
63
|
+
exclude_lines = [
|
|
64
|
+
"pragma: no cover",
|
|
65
|
+
"if TYPE_CHECKING:",
|
|
66
|
+
"raise NotImplementedError",
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
[tool.ruff]
|
|
70
|
+
target-version = "py310"
|
|
71
|
+
line-length = 100
|
|
72
|
+
|
|
73
|
+
[tool.ruff.lint]
|
|
74
|
+
select = [
|
|
75
|
+
"E", # pycodestyle errors
|
|
76
|
+
"W", # pycodestyle warnings
|
|
77
|
+
"F", # pyflakes
|
|
78
|
+
"I", # isort
|
|
79
|
+
"B", # flake8-bugbear
|
|
80
|
+
"UP", # pyupgrade
|
|
81
|
+
]
|
|
82
|
+
ignore = [
|
|
83
|
+
"E501", # line too long (handled by formatter)
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
[tool.ruff.lint.isort]
|
|
87
|
+
known-first-party = ["shapeguard"]
|
|
88
|
+
|
|
89
|
+
[tool.mypy]
|
|
90
|
+
python_version = "3.10"
|
|
91
|
+
warn_return_any = true
|
|
92
|
+
warn_unused_ignores = true
|
|
93
|
+
disallow_untyped_defs = false
|
|
94
|
+
ignore_missing_imports = true
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ShapeGuard: Runtime shape contracts and diagnostics for NumPy and JAX.
|
|
3
|
+
|
|
4
|
+
Basic usage:
|
|
5
|
+
from shapeguard import Dim, expects, check_shape
|
|
6
|
+
|
|
7
|
+
n, m = Dim("n"), Dim("m")
|
|
8
|
+
|
|
9
|
+
@expects(x=(n, m), y=(m,))
|
|
10
|
+
def forward(x, y):
|
|
11
|
+
return x @ y
|
|
12
|
+
|
|
13
|
+
ML workflows:
|
|
14
|
+
from shapeguard import Batch, ShapeContext
|
|
15
|
+
|
|
16
|
+
B = Batch()
|
|
17
|
+
|
|
18
|
+
@expects(x=(B, n, m))
|
|
19
|
+
def layer(x): ...
|
|
20
|
+
|
|
21
|
+
# Ellipsis for variable leading dims
|
|
22
|
+
@expects(x=(..., n, m))
|
|
23
|
+
def normalize(x): ...
|
|
24
|
+
|
|
25
|
+
# Grouped checks
|
|
26
|
+
with ShapeContext() as ctx:
|
|
27
|
+
ctx.check(x, (n, m), "x")
|
|
28
|
+
ctx.check(y, (m, k), "y")
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from shapeguard.broadcast import broadcast_shape, explain_broadcast
|
|
32
|
+
from shapeguard.config import config
|
|
33
|
+
from shapeguard.context import ShapeContext
|
|
34
|
+
from shapeguard.core import Batch, Dim, UnificationContext
|
|
35
|
+
from shapeguard.decorator import contract, ensures, expects
|
|
36
|
+
from shapeguard.errors import BroadcastError, OutputShapeError, ShapeGuardError
|
|
37
|
+
from shapeguard.spec import check_shape
|
|
38
|
+
|
|
39
|
+
__version__ = "0.3.0"
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
# Core
|
|
43
|
+
"Dim",
|
|
44
|
+
"Batch",
|
|
45
|
+
"UnificationContext",
|
|
46
|
+
# Validation
|
|
47
|
+
"expects",
|
|
48
|
+
"ensures",
|
|
49
|
+
"contract",
|
|
50
|
+
"check_shape",
|
|
51
|
+
"ShapeContext",
|
|
52
|
+
# Broadcasting
|
|
53
|
+
"broadcast_shape",
|
|
54
|
+
"explain_broadcast",
|
|
55
|
+
# Configuration
|
|
56
|
+
"config",
|
|
57
|
+
# Errors
|
|
58
|
+
"ShapeGuardError",
|
|
59
|
+
"OutputShapeError",
|
|
60
|
+
"BroadcastError",
|
|
61
|
+
]
|