jax-shapeguard 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. jax_shapeguard-0.3.0/.github/workflows/lint.yml +37 -0
  2. jax_shapeguard-0.3.0/.github/workflows/publish.yml +29 -0
  3. jax_shapeguard-0.3.0/.github/workflows/test.yml +39 -0
  4. jax_shapeguard-0.3.0/.gitignore +44 -0
  5. jax_shapeguard-0.3.0/LICENSE +21 -0
  6. jax_shapeguard-0.3.0/MILESTONES.md +241 -0
  7. jax_shapeguard-0.3.0/PKG-INFO +69 -0
  8. jax_shapeguard-0.3.0/README.md +42 -0
  9. jax_shapeguard-0.3.0/pyproject.toml +94 -0
  10. jax_shapeguard-0.3.0/shapeguard/__init__.py +61 -0
  11. jax_shapeguard-0.3.0/shapeguard/_compat.py +96 -0
  12. jax_shapeguard-0.3.0/shapeguard/broadcast.py +211 -0
  13. jax_shapeguard-0.3.0/shapeguard/config.py +51 -0
  14. jax_shapeguard-0.3.0/shapeguard/context.py +104 -0
  15. jax_shapeguard-0.3.0/shapeguard/core.py +159 -0
  16. jax_shapeguard-0.3.0/shapeguard/decorator.py +472 -0
  17. jax_shapeguard-0.3.0/shapeguard/errors.py +208 -0
  18. jax_shapeguard-0.3.0/shapeguard/spec.py +223 -0
  19. jax_shapeguard-0.3.0/tests/__init__.py +1 -0
  20. jax_shapeguard-0.3.0/tests/conftest.py +49 -0
  21. jax_shapeguard-0.3.0/tests/test_batch.py +88 -0
  22. jax_shapeguard-0.3.0/tests/test_broadcast.py +221 -0
  23. jax_shapeguard-0.3.0/tests/test_config.py +61 -0
  24. jax_shapeguard-0.3.0/tests/test_context.py +140 -0
  25. jax_shapeguard-0.3.0/tests/test_contract.py +163 -0
  26. jax_shapeguard-0.3.0/tests/test_core.py +113 -0
  27. jax_shapeguard-0.3.0/tests/test_decorator.py +211 -0
  28. jax_shapeguard-0.3.0/tests/test_ellipsis.py +165 -0
  29. jax_shapeguard-0.3.0/tests/test_ensures.py +272 -0
  30. jax_shapeguard-0.3.0/tests/test_errors.py +171 -0
  31. jax_shapeguard-0.3.0/tests/test_jit_modes.py +236 -0
  32. jax_shapeguard-0.3.0/tests/test_pytree.py +161 -0
  33. jax_shapeguard-0.3.0/tests/test_spec.py +154 -0
@@ -0,0 +1,37 @@
1
+ name: Lint
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ lint:
11
+ runs-on: ubuntu-latest
12
+
13
+ steps:
14
+ - uses: actions/checkout@v4
15
+
16
+ - name: Install uv
17
+ uses: astral-sh/setup-uv@v4
18
+ with:
19
+ version: "latest"
20
+
21
+ - name: Set up Python
22
+ run: uv python install 3.11
23
+
24
+ - name: Install dependencies
25
+ run: uv sync --dev
26
+
27
+ - name: Install linting tools
28
+ run: uv pip install ruff mypy
29
+
30
+ - name: Run ruff check
31
+ run: uv run ruff check shapeguard tests
32
+
33
+ - name: Run ruff format check
34
+ run: uv run ruff format --check shapeguard tests
35
+
36
+ - name: Run mypy
37
+ run: uv run mypy shapeguard --ignore-missing-imports
@@ -0,0 +1,29 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ release:
5
+ types: [published]
6
+
7
+ jobs:
8
+ publish:
9
+ runs-on: ubuntu-latest
10
+ environment: pypi
11
+ permissions:
12
+ id-token: write # Required for trusted publishing
13
+
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+
17
+ - name: Install uv
18
+ uses: astral-sh/setup-uv@v4
19
+ with:
20
+ version: "latest"
21
+
22
+ - name: Set up Python
23
+ run: uv python install 3.11
24
+
25
+ - name: Build package
26
+ run: uv build
27
+
28
+ - name: Publish to PyPI
29
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,39 @@
1
+ name: Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ fail-fast: false
14
+ matrix:
15
+ python-version: ["3.10", "3.11", "3.12"]
16
+
17
+ steps:
18
+ - uses: actions/checkout@v4
19
+
20
+ - name: Install uv
21
+ uses: astral-sh/setup-uv@v4
22
+ with:
23
+ version: "latest"
24
+
25
+ - name: Set up Python ${{ matrix.python-version }}
26
+ run: uv python install ${{ matrix.python-version }}
27
+
28
+ - name: Install dependencies
29
+ run: uv sync --dev
30
+
31
+ - name: Run tests
32
+ run: uv run pytest -v --cov=shapeguard --cov-report=xml
33
+
34
+ - name: Upload coverage
35
+ uses: codecov/codecov-action@v4
36
+ if: matrix.python-version == '3.11'
37
+ with:
38
+ files: coverage.xml
39
+ fail_ci_if_error: false
@@ -0,0 +1,44 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ .venv/
25
+ venv/
26
+ ENV/
27
+
28
+ # Testing
29
+ .pytest_cache/
30
+ .coverage
31
+ htmlcov/
32
+ .tox/
33
+ .nox/
34
+
35
+ # IDE
36
+ .idea/
37
+ .vscode/
38
+ *.swp
39
+ *.swo
40
+ *~
41
+
42
+ # OS
43
+ .DS_Store
44
+ Thumbs.db
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Jayendra Parmar
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,241 @@
1
+ # ShapeGuard Milestones
2
+
3
+ ## Design Decisions
4
+
5
+ - **Dim identity**: Same object required (`n = Dim("n")` must be reused)
6
+ - **Performance**: Dev-only tool, prioritize error quality over speed
7
+ - **Backend priority**: NumPy-first, JAX support in v0.2
8
+
9
+ ---
10
+
11
+ ## Milestone 1: Core Foundation (v0.1-alpha)
12
+
13
+ ### Goal
14
+ Minimal working library with symbolic dimensions, shape checking, and decorator.
15
+
16
+ ### Files
17
+ ```
18
+ shapeguard/
19
+ __init__.py # Public API exports
20
+ core.py # Dim class, UnificationContext
21
+ spec.py # Shape specification matching
22
+ decorator.py # @expects decorator
23
+ errors.py # ShapeGuardError
24
+ _compat.py # Backend detection
25
+ tests/
26
+ test_core.py
27
+ test_spec.py
28
+ test_decorator.py
29
+ test_errors.py
30
+ conftest.py
31
+ pyproject.toml
32
+ ```
33
+
34
+ ### Deliverables
35
+ - [x] `Dim` class with identity-based unification
36
+ - [x] `UnificationContext` for tracking bindings across arguments
37
+ - [x] `ShapeSpec` matching: concrete `(3, 4)`, symbolic `(n, m)`, wildcard `(None, 4)`
38
+ - [x] `@expects` decorator for input validation
39
+ - [x] `ShapeGuardError` with function, argument, expected, actual, reason
40
+ - [x] `check_shape(x, spec, name)` standalone function
41
+ - [x] Backend-agnostic shape extraction (works with any `.shape` attribute)
42
+ - [x] Unit tests with 90%+ coverage (91% achieved)
43
+
44
+ ### API Surface
45
+ ```python
46
+ from shapeguard import Dim, expects, check_shape, ShapeGuardError
47
+
48
+ n, m = Dim("n"), Dim("m")
49
+
50
+ @expects(x=(n, m), y=(m,))
51
+ def forward(x, y):
52
+ return x @ y
53
+
54
+ check_shape(arr, (n, 128), name="input")
55
+ ```
56
+
57
+ ---
58
+
59
+ ## Milestone 2: ML-Practical Features (v0.1-beta)
60
+
61
+ ### Goal
62
+ Ergonomic features for real ML workflows.
63
+
64
+ ### Deliverables
65
+ - [x] `Batch` dimension (always first, flexible size per call)
66
+ - [x] Ellipsis support `(..., n, m)` for variable leading dims
67
+ - [x] `ShapeContext` manager for grouped checks with shared bindings
68
+ - [x] Improved error messages with binding trace (92% coverage)
69
+
70
+ ### API Additions
71
+ ```python
72
+ from shapeguard import Batch, ShapeContext
73
+
74
+ B = Batch()
75
+
76
+ @expects(x=(B, n, m))
77
+ def layer(x): ...
78
+
79
+ @expects(x=(..., n, m))
80
+ def normalize(x): ...
81
+
82
+ with ShapeContext() as ctx:
83
+ ctx.check(x, (n, m), "x")
84
+ ctx.check(y, (m, k), "y")
85
+ ```
86
+
87
+ ---
88
+
89
+ ## Milestone 3: JAX Integration (v0.2)
90
+
91
+ ### Goal
92
+ Seamless JAX compatibility including JIT behavior.
93
+
94
+ ### Deliverables
95
+ - [x] JIT/tracing detection
96
+ - [x] Configurable JIT modes: `skip`, `warn`, `check`
97
+ - [x] PyTree shape specs for nested params
98
+ - [ ] Performance benchmarks (deferred)
99
+
100
+ ### API Additions
101
+ ```python
102
+ from shapeguard import expects, config
103
+
104
+ config.jit_mode = "skip" # Global setting
105
+
106
+ @expects(x=(B, n, m), jit_mode="static") # Per-function
107
+ @jax.jit
108
+ def forward(x): ...
109
+
110
+ @expects(
111
+ params={"weights": (n, m), "bias": (m,)},
112
+ x=(B, n)
113
+ )
114
+ def apply(params, x): ...
115
+ ```
116
+
117
+ ---
118
+
119
+ ## Milestone 4: Broadcasting Support (v0.2)
120
+
121
+ ### Goal
122
+ Explicit broadcasting inspection and validation.
123
+
124
+ ### Deliverables
125
+ - [x] `broadcast_shape()` for concrete shapes
126
+ - [x] `explain_broadcast()` step-by-step explainer
127
+ - [ ] `_broadcast=True` option in `@expects` (deferred)
128
+
129
+ ### API Additions
130
+ ```python
131
+ from shapeguard import broadcast_shape, explain_broadcast
132
+
133
+ broadcast_shape((3, 1), (1, 4)) # → (3, 4)
134
+ broadcast_shape(a, b) # From arrays
135
+
136
+ explain_broadcast((3, 1, 4), (5, 4))
137
+ # Broadcasting (3, 1, 4) with (5, 4):
138
+ # Dim 0: 3 (from left only)
139
+ # Dim 1: 1 → 5 (broadcast)
140
+ # Dim 2: 4 = 4 (match)
141
+ # Result: (3, 5, 4)
142
+ ```
143
+
144
+ ---
145
+
146
+ ## Milestone 5: Output Contracts (v0.3)
147
+
148
+ ### Goal
149
+ Validate function outputs, not just inputs.
150
+
151
+ ### Deliverables
152
+ - [x] `@ensures` decorator for output validation
153
+ - [x] `@contract` combined decorator
154
+ - [x] Tuple/dict output support
155
+
156
+ ### API Additions
157
+ ```python
158
+ from shapeguard import expects, ensures, contract
159
+
160
+ @expects(a=(n, m), b=(m, k))
161
+ @ensures(result=(n, k))
162
+ def matmul(a, b):
163
+ return a @ b
164
+
165
+ @contract(
166
+ inputs={"a": (n, m), "b": (m, k)},
167
+ output=(n, k)
168
+ )
169
+ def matmul(a, b):
170
+ return a @ b
171
+ ```
172
+
173
+ ---
174
+
175
+ ## Milestone 6: ML Helpers (v0.3)
176
+
177
+ ### Goal
178
+ Domain-specific helpers for common ML patterns.
179
+
180
+ ### Deliverables
181
+ - [ ] Pre-defined dims: `B`, `T`, `C`, `H`, `W`, `D`
182
+ - [ ] `attention_shapes()` helper
183
+ - [ ] `conv_output_shape()` calculator
184
+
185
+ ### API Additions
186
+ ```python
187
+ from shapeguard.ml import B, T, C, H, W, D
188
+ from shapeguard.ml import attention_shapes, conv_output_shape
189
+
190
+ @expects(x=(B, T, D))
191
+ def transformer_layer(x): ...
192
+
193
+ @expects(**attention_shapes(B, heads, seq_q, seq_k, d_k))
194
+ def attention(q, k, v): ...
195
+
196
+ out_shape = conv_output_shape(
197
+ input=(B, C, 224, 224),
198
+ kernel=(3, 3),
199
+ stride=2,
200
+ padding=1
201
+ )
202
+ ```
203
+
204
+ ---
205
+
206
+ ## Milestone 7: Testing Utilities (v0.4)
207
+
208
+ ### Goal
209
+ Property-based testing support.
210
+
211
+ ### Deliverables
212
+ - [ ] Hypothesis strategies for shaped arrays
213
+ - [ ] `verify_contract()` auto-test generator
214
+ - [ ] pytest plugin
215
+
216
+ ### API Additions
217
+ ```python
218
+ from shapeguard.testing import arrays, verify_contract
219
+ import hypothesis
220
+
221
+ @hypothesis.given(x=arrays(shape=(n, m), n=(1, 100), m=(1, 100)))
222
+ def test_normalize(x):
223
+ result = normalize(x)
224
+ assert result.shape == x.shape
225
+
226
+ verify_contract(matmul, samples=100)
227
+ ```
228
+
229
+ ---
230
+
231
+ ## Summary Timeline
232
+
233
+ | Milestone | Version | Status |
234
+ |-----------|---------|--------|
235
+ | 1. Core Foundation | v0.1-alpha | ✅ Complete (91% coverage) |
236
+ | 2. ML Features | v0.1-beta | ✅ Complete (92% coverage) |
237
+ | 3. JAX Integration | v0.2 | ✅ Complete (92% coverage) |
238
+ | 4. Broadcasting | v0.2 | ✅ Complete |
239
+ | 5. Output Contracts | v0.3 | ✅ Complete (91% coverage) |
240
+ | 6. ML Helpers | v0.3 | 🔲 Not started |
241
+ | 7. Testing Utils | v0.4 | 🔲 Not started |
@@ -0,0 +1,69 @@
1
+ Metadata-Version: 2.4
2
+ Name: jax-shapeguard
3
+ Version: 0.3.0
4
+ Summary: Runtime shape contracts and diagnostics for NumPy and JAX
5
+ Project-URL: Homepage, https://github.com/jayendra13/jax-shape-guard
6
+ Project-URL: Repository, https://github.com/jayendra13/jax-shape-guard
7
+ Author-email: Jayendra Parmar <jayendra0parmar@gmail.com>
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Keywords: debugging,jax,ml,numpy,shapes,validation
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering
20
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
21
+ Classifier: Typing :: Typed
22
+ Requires-Python: >=3.10
23
+ Provides-Extra: jax
24
+ Requires-Dist: jax>=0.4; extra == 'jax'
25
+ Requires-Dist: jaxlib>=0.4; extra == 'jax'
26
+ Description-Content-Type: text/markdown
27
+
28
+ # ShapeGuard
29
+
30
+ [![Tests](https://github.com/jayendra13/jax-shape-guard/actions/workflows/test.yml/badge.svg)](https://github.com/jayendra13/jax-shape-guard/actions/workflows/test.yml)
31
+ [![Lint](https://github.com/jayendra13/jax-shape-guard/actions/workflows/lint.yml/badge.svg)](https://github.com/jayendra13/jax-shape-guard/actions/workflows/lint.yml)
32
+ [![PyPI version](https://img.shields.io/pypi/v/shapeguard.svg)](https://pypi.org/project/shapeguard/)
33
+ [![Python versions](https://img.shields.io/pypi/pyversions/shapeguard.svg)](https://pypi.org/project/shapeguard/)
34
+ [![License](https://img.shields.io/github/license/jayendra13/jax-shape-guard.svg)](https://github.com/jayendra13/jax-shape-guard/blob/main/LICENSE)
35
+
36
+ Runtime shape contracts and diagnostics for NumPy and JAX.
37
+
38
+ ## Installation
39
+
40
+ ```bash
41
+ pip install shapeguard
42
+ ```
43
+
44
+ ## Quick Start
45
+
46
+ ```python
47
+ from shapeguard import Dim, expects
48
+
49
+ n, m, k = Dim("n"), Dim("m"), Dim("k")
50
+
51
+ @expects(a=(n, m), b=(m, k))
52
+ def matmul(a, b):
53
+ return a @ b
54
+ ```
55
+
56
+ When shapes don't match, you get clear errors:
57
+
58
+ ```
59
+ ShapeGuardError:
60
+ function: matmul
61
+ argument: b
62
+ expected: (m, k)
63
+ actual: (5, 7)
64
+ reason: dimension 'm' bound to 4 from a.shape[1], but got 5 from b.shape[0]
65
+ ```
66
+
67
+ ## License
68
+
69
+ MIT
@@ -0,0 +1,42 @@
1
+ # ShapeGuard
2
+
3
+ [![Tests](https://github.com/jayendra13/jax-shape-guard/actions/workflows/test.yml/badge.svg)](https://github.com/jayendra13/jax-shape-guard/actions/workflows/test.yml)
4
+ [![Lint](https://github.com/jayendra13/jax-shape-guard/actions/workflows/lint.yml/badge.svg)](https://github.com/jayendra13/jax-shape-guard/actions/workflows/lint.yml)
5
+ [![PyPI version](https://img.shields.io/pypi/v/shapeguard.svg)](https://pypi.org/project/shapeguard/)
6
+ [![Python versions](https://img.shields.io/pypi/pyversions/shapeguard.svg)](https://pypi.org/project/shapeguard/)
7
+ [![License](https://img.shields.io/github/license/jayendra13/jax-shape-guard.svg)](https://github.com/jayendra13/jax-shape-guard/blob/main/LICENSE)
8
+
9
+ Runtime shape contracts and diagnostics for NumPy and JAX.
10
+
11
+ ## Installation
12
+
13
+ ```bash
14
+ pip install shapeguard
15
+ ```
16
+
17
+ ## Quick Start
18
+
19
+ ```python
20
+ from shapeguard import Dim, expects
21
+
22
+ n, m, k = Dim("n"), Dim("m"), Dim("k")
23
+
24
+ @expects(a=(n, m), b=(m, k))
25
+ def matmul(a, b):
26
+ return a @ b
27
+ ```
28
+
29
+ When shapes don't match, you get clear errors:
30
+
31
+ ```
32
+ ShapeGuardError:
33
+ function: matmul
34
+ argument: b
35
+ expected: (m, k)
36
+ actual: (5, 7)
37
+ reason: dimension 'm' bound to 4 from a.shape[1], but got 5 from b.shape[0]
38
+ ```
39
+
40
+ ## License
41
+
42
+ MIT
@@ -0,0 +1,94 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [tool.hatch.build.targets.wheel]
6
+ packages = ["shapeguard"]
7
+
8
+ [project]
9
+ name = "jax-shapeguard"
10
+ version = "0.3.0"
11
+ description = "Runtime shape contracts and diagnostics for NumPy and JAX"
12
+ readme = "README.md"
13
+ license = "MIT"
14
+ requires-python = ">=3.10"
15
+ authors = [
16
+ { name = "Jayendra Parmar", email = "jayendra0parmar@gmail.com" }
17
+ ]
18
+ keywords = ["numpy", "jax", "shapes", "validation", "debugging", "ml"]
19
+ classifiers = [
20
+ "Development Status :: 3 - Alpha",
21
+ "Intended Audience :: Developers",
22
+ "Intended Audience :: Science/Research",
23
+ "License :: OSI Approved :: MIT License",
24
+ "Programming Language :: Python :: 3",
25
+ "Programming Language :: Python :: 3.10",
26
+ "Programming Language :: Python :: 3.11",
27
+ "Programming Language :: Python :: 3.12",
28
+ "Topic :: Scientific/Engineering",
29
+ "Topic :: Software Development :: Libraries :: Python Modules",
30
+ "Typing :: Typed",
31
+ ]
32
+
33
+ [project.optional-dependencies]
34
+ jax = [
35
+ "jax>=0.4",
36
+ "jaxlib>=0.4",
37
+ ]
38
+
39
+ [dependency-groups]
40
+ dev = [
41
+ "pytest>=7.0",
42
+ "pytest-cov>=4.0",
43
+ "numpy>=1.20",
44
+ "jax>=0.4",
45
+ "jaxlib>=0.4",
46
+ "ruff>=0.4",
47
+ "mypy>=1.0",
48
+ ]
49
+
50
+ [project.urls]
51
+ Homepage = "https://github.com/jayendra13/jax-shape-guard"
52
+ Repository = "https://github.com/jayendra13/jax-shape-guard"
53
+
54
+ [tool.pytest.ini_options]
55
+ testpaths = ["tests"]
56
+ addopts = "-v --tb=short"
57
+
58
+ [tool.coverage.run]
59
+ source = ["shapeguard"]
60
+ branch = true
61
+
62
+ [tool.coverage.report]
63
+ exclude_lines = [
64
+ "pragma: no cover",
65
+ "if TYPE_CHECKING:",
66
+ "raise NotImplementedError",
67
+ ]
68
+
69
+ [tool.ruff]
70
+ target-version = "py310"
71
+ line-length = 100
72
+
73
+ [tool.ruff.lint]
74
+ select = [
75
+ "E", # pycodestyle errors
76
+ "W", # pycodestyle warnings
77
+ "F", # pyflakes
78
+ "I", # isort
79
+ "B", # flake8-bugbear
80
+ "UP", # pyupgrade
81
+ ]
82
+ ignore = [
83
+ "E501", # line too long (handled by formatter)
84
+ ]
85
+
86
+ [tool.ruff.lint.isort]
87
+ known-first-party = ["shapeguard"]
88
+
89
+ [tool.mypy]
90
+ python_version = "3.10"
91
+ warn_return_any = true
92
+ warn_unused_ignores = true
93
+ disallow_untyped_defs = false
94
+ ignore_missing_imports = true
@@ -0,0 +1,61 @@
1
+ """
2
+ ShapeGuard: Runtime shape contracts and diagnostics for NumPy and JAX.
3
+
4
+ Basic usage:
5
+ from shapeguard import Dim, expects, check_shape
6
+
7
+ n, m = Dim("n"), Dim("m")
8
+
9
+ @expects(x=(n, m), y=(m,))
10
+ def forward(x, y):
11
+ return x @ y
12
+
13
+ ML workflows:
14
+ from shapeguard import Batch, ShapeContext
15
+
16
+ B = Batch()
17
+
18
+ @expects(x=(B, n, m))
19
+ def layer(x): ...
20
+
21
+ # Ellipsis for variable leading dims
22
+ @expects(x=(..., n, m))
23
+ def normalize(x): ...
24
+
25
+ # Grouped checks
26
+ with ShapeContext() as ctx:
27
+ ctx.check(x, (n, m), "x")
28
+ ctx.check(y, (m, k), "y")
29
+ """
30
+
31
+ from shapeguard.broadcast import broadcast_shape, explain_broadcast
32
+ from shapeguard.config import config
33
+ from shapeguard.context import ShapeContext
34
+ from shapeguard.core import Batch, Dim, UnificationContext
35
+ from shapeguard.decorator import contract, ensures, expects
36
+ from shapeguard.errors import BroadcastError, OutputShapeError, ShapeGuardError
37
+ from shapeguard.spec import check_shape
38
+
39
+ __version__ = "0.3.0"
40
+
41
+ __all__ = [
42
+ # Core
43
+ "Dim",
44
+ "Batch",
45
+ "UnificationContext",
46
+ # Validation
47
+ "expects",
48
+ "ensures",
49
+ "contract",
50
+ "check_shape",
51
+ "ShapeContext",
52
+ # Broadcasting
53
+ "broadcast_shape",
54
+ "explain_broadcast",
55
+ # Configuration
56
+ "config",
57
+ # Errors
58
+ "ShapeGuardError",
59
+ "OutputShapeError",
60
+ "BroadcastError",
61
+ ]