torchshapeflow 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchshapeflow might be problematic. Click here for more details.

Files changed (56) hide show
  1. torchshapeflow-0.2.0/.github/workflows/build-artifacts.yml +41 -0
  2. torchshapeflow-0.2.0/.github/workflows/ci.yml +43 -0
  3. torchshapeflow-0.2.0/.github/workflows/docs.yml +42 -0
  4. torchshapeflow-0.2.0/.github/workflows/release.yml +144 -0
  5. torchshapeflow-0.2.0/.gitignore +13 -0
  6. torchshapeflow-0.2.0/LICENSE +21 -0
  7. torchshapeflow-0.2.0/Makefile +52 -0
  8. torchshapeflow-0.2.0/PKG-INFO +113 -0
  9. torchshapeflow-0.2.0/README.md +71 -0
  10. torchshapeflow-0.2.0/docs/architecture.md +85 -0
  11. torchshapeflow-0.2.0/docs/development.md +120 -0
  12. torchshapeflow-0.2.0/docs/extension.md +58 -0
  13. torchshapeflow-0.2.0/docs/index.md +45 -0
  14. torchshapeflow-0.2.0/docs/limitations.md +44 -0
  15. torchshapeflow-0.2.0/docs/operators.md +632 -0
  16. torchshapeflow-0.2.0/docs/quickstart.md +81 -0
  17. torchshapeflow-0.2.0/docs/syntax.md +124 -0
  18. torchshapeflow-0.2.0/examples/attention_scores.py +13 -0
  19. torchshapeflow-0.2.0/examples/error_cases.py +17 -0
  20. torchshapeflow-0.2.0/examples/simple_cnn.py +17 -0
  21. torchshapeflow-0.2.0/examples/transformer_block.py +11 -0
  22. torchshapeflow-0.2.0/examples/vit_patch_embed.py +10 -0
  23. torchshapeflow-0.2.0/mkdocs.yml +19 -0
  24. torchshapeflow-0.2.0/pyproject.toml +77 -0
  25. torchshapeflow-0.2.0/scripts/bump_version.py +80 -0
  26. torchshapeflow-0.2.0/src/torchshapeflow/__init__.py +4 -0
  27. torchshapeflow-0.2.0/src/torchshapeflow/_version.py +1 -0
  28. torchshapeflow-0.2.0/src/torchshapeflow/analyzer.py +1489 -0
  29. torchshapeflow-0.2.0/src/torchshapeflow/annotations.py +34 -0
  30. torchshapeflow-0.2.0/src/torchshapeflow/cli.py +52 -0
  31. torchshapeflow-0.2.0/src/torchshapeflow/diagnostics.py +37 -0
  32. torchshapeflow-0.2.0/src/torchshapeflow/index.py +246 -0
  33. torchshapeflow-0.2.0/src/torchshapeflow/model.py +379 -0
  34. torchshapeflow-0.2.0/src/torchshapeflow/parser.py +112 -0
  35. torchshapeflow-0.2.0/src/torchshapeflow/py.typed +1 -0
  36. torchshapeflow-0.2.0/src/torchshapeflow/report.py +39 -0
  37. torchshapeflow-0.2.0/src/torchshapeflow/rules/__init__.py +59 -0
  38. torchshapeflow-0.2.0/src/torchshapeflow/rules/broadcasting.py +11 -0
  39. torchshapeflow-0.2.0/src/torchshapeflow/rules/common.py +98 -0
  40. torchshapeflow-0.2.0/src/torchshapeflow/rules/conv2d.py +41 -0
  41. torchshapeflow-0.2.0/src/torchshapeflow/rules/embedding.py +17 -0
  42. torchshapeflow-0.2.0/src/torchshapeflow/rules/indexing.py +94 -0
  43. torchshapeflow-0.2.0/src/torchshapeflow/rules/linear.py +21 -0
  44. torchshapeflow-0.2.0/src/torchshapeflow/rules/pool2d.py +40 -0
  45. torchshapeflow-0.2.0/src/torchshapeflow/rules/shape_ops.py +665 -0
  46. torchshapeflow-0.2.0/src/torchshapeflow/utils/__init__.py +1 -0
  47. torchshapeflow-0.2.0/src/torchshapeflow/utils/paths.py +9 -0
  48. torchshapeflow-0.2.0/tests/fixtures/attention_scores.py +13 -0
  49. torchshapeflow-0.2.0/tests/test_analyzer.py +873 -0
  50. torchshapeflow-0.2.0/tests/test_annotations.py +22 -0
  51. torchshapeflow-0.2.0/tests/test_cli.py +17 -0
  52. torchshapeflow-0.2.0/tests/test_index.py +427 -0
  53. torchshapeflow-0.2.0/tests/test_model.py +152 -0
  54. torchshapeflow-0.2.0/tests/test_parser.py +139 -0
  55. torchshapeflow-0.2.0/tests/test_rules.py +365 -0
  56. torchshapeflow-0.2.0/tests/test_shape_ops.py +631 -0
@@ -0,0 +1,41 @@
1
+ name: Build Artifacts
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ python-package:
12
+ runs-on: ubuntu-latest
13
+ steps:
14
+ - uses: actions/checkout@v6
15
+ - uses: astral-sh/setup-uv@v7
16
+ - uses: actions/setup-python@v6
17
+ with:
18
+ python-version: "3.12"
19
+ - run: uv build
20
+ - uses: actions/upload-artifact@v7
21
+ with:
22
+ name: python-dist
23
+ path: dist/*
24
+
25
+ vscode-extension:
26
+ runs-on: ubuntu-latest
27
+ steps:
28
+ - uses: actions/checkout@v6
29
+ - uses: actions/setup-node@v6
30
+ with:
31
+ node-version: "24"
32
+ cache: "npm"
33
+ cache-dependency-path: extensions/vscode/package-lock.json
34
+ - working-directory: extensions/vscode
35
+ run: npm ci
36
+ - working-directory: extensions/vscode
37
+ run: npm run package
38
+ - uses: actions/upload-artifact@v7
39
+ with:
40
+ name: vscode-extension
41
+ path: extensions/vscode/dist/*.vsix
@@ -0,0 +1,43 @@
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ pull_request:
6
+
7
+ jobs:
8
+ check:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v6
12
+ - uses: astral-sh/setup-uv@v7
13
+ - uses: actions/setup-python@v6
14
+ with:
15
+ python-version: "3.12"
16
+ - uses: actions/setup-node@v6
17
+ with:
18
+ node-version: "24"
19
+ cache: "npm"
20
+ cache-dependency-path: extensions/vscode/package-lock.json
21
+ - run: uv sync --extra dev
22
+ - run: uv run ruff format . --check
23
+ - run: uv run ruff check .
24
+ - run: uv run mypy .
25
+ - run: uv run mkdocs build
26
+ - working-directory: extensions/vscode
27
+ run: npm ci
28
+ - working-directory: extensions/vscode
29
+ run: npm run build
30
+
31
+ test:
32
+ runs-on: ubuntu-latest
33
+ strategy:
34
+ matrix:
35
+ python-version: ["3.10", "3.11", "3.12", "3.13"]
36
+ steps:
37
+ - uses: actions/checkout@v6
38
+ - uses: astral-sh/setup-uv@v7
39
+ - uses: actions/setup-python@v6
40
+ with:
41
+ python-version: ${{ matrix.python-version }}
42
+ - run: uv sync --extra dev
43
+ - run: uv run pytest -q
@@ -0,0 +1,42 @@
1
+ name: Docs
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ push:
6
+ branches:
7
+ - main
8
+
9
+ permissions:
10
+ contents: read
11
+ pages: write
12
+ id-token: write
13
+
14
+ concurrency:
15
+ group: pages
16
+ cancel-in-progress: true
17
+
18
+ jobs:
19
+ build:
20
+ runs-on: ubuntu-latest
21
+ steps:
22
+ - uses: actions/checkout@v6
23
+ - uses: astral-sh/setup-uv@v7
24
+ - uses: actions/setup-python@v6
25
+ with:
26
+ python-version: "3.12"
27
+ - run: uv sync --extra dev
28
+ - run: uv run mkdocs build
29
+ - uses: actions/upload-pages-artifact@v4
30
+ with:
31
+ path: site
32
+
33
+ deploy:
34
+ if: github.ref == 'refs/heads/main'
35
+ needs: build
36
+ runs-on: ubuntu-latest
37
+ environment:
38
+ name: github-pages
39
+ url: ${{ steps.deployment.outputs.page_url }}
40
+ steps:
41
+ - id: deployment
42
+ uses: actions/deploy-pages@v4
@@ -0,0 +1,144 @@
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ check-secrets:
11
+ runs-on: ubuntu-latest
12
+ outputs:
13
+ has-vsce-pat: ${{ steps.check.outputs.has-vsce-pat }}
14
+ has-ovsx-pat: ${{ steps.check.outputs.has-ovsx-pat }}
15
+ steps:
16
+ - id: check
17
+ env:
18
+ VSCE_PAT: ${{ secrets.VSCE_PAT }}
19
+ OVSX_PAT: ${{ secrets.OVSX_PAT }}
20
+ run: |
21
+ echo "has-vsce-pat=${{ secrets.VSCE_PAT != '' }}" >> "$GITHUB_OUTPUT"
22
+ echo "has-ovsx-pat=${{ secrets.OVSX_PAT != '' }}" >> "$GITHUB_OUTPUT"
23
+
24
+ build-python:
25
+ runs-on: ubuntu-latest
26
+ steps:
27
+ - uses: actions/checkout@v6
28
+ - uses: astral-sh/setup-uv@v7
29
+ - uses: actions/setup-python@v6
30
+ with:
31
+ python-version: "3.12"
32
+ - run: uv build
33
+ - uses: actions/upload-artifact@v7
34
+ with:
35
+ name: python-dist
36
+ path: dist/*
37
+
38
+ build-extension:
39
+ runs-on: ubuntu-latest
40
+ steps:
41
+ - uses: actions/checkout@v6
42
+ - uses: actions/setup-node@v6
43
+ with:
44
+ node-version: "24"
45
+ cache: "npm"
46
+ cache-dependency-path: extensions/vscode/package-lock.json
47
+ - working-directory: extensions/vscode
48
+ run: npm ci
49
+ - working-directory: extensions/vscode
50
+ run: npm run package
51
+ - uses: actions/upload-artifact@v7
52
+ with:
53
+ name: vscode-extension
54
+ path: extensions/vscode/dist/*.vsix
55
+
56
+ publish-vscode-marketplace:
57
+ if: needs.check-secrets.outputs.has-vsce-pat == 'true' && !contains(github.ref, '-rc') && !contains(github.ref, '-test')
58
+ needs: [check-secrets, build-extension]
59
+ runs-on: ubuntu-latest
60
+ steps:
61
+ - uses: actions/checkout@v6
62
+ - uses: actions/setup-node@v6
63
+ with:
64
+ node-version: "24"
65
+ cache: "npm"
66
+ cache-dependency-path: extensions/vscode/package-lock.json
67
+ - working-directory: extensions/vscode
68
+ run: npm ci
69
+ - working-directory: extensions/vscode
70
+ env:
71
+ VSCE_PAT: ${{ secrets.VSCE_PAT }}
72
+ run: npx @vscode/vsce publish -p "$VSCE_PAT"
73
+
74
+ publish-open-vsx:
75
+ if: needs.check-secrets.outputs.has-ovsx-pat == 'true' && !contains(github.ref, '-rc') && !contains(github.ref, '-test')
76
+ needs: [check-secrets, build-extension]
77
+ runs-on: ubuntu-latest
78
+ steps:
79
+ - uses: actions/checkout@v6
80
+ - uses: actions/setup-node@v6
81
+ with:
82
+ node-version: "24"
83
+ cache: "npm"
84
+ cache-dependency-path: extensions/vscode/package-lock.json
85
+ - working-directory: extensions/vscode
86
+ run: npm ci
87
+ - working-directory: extensions/vscode
88
+ run: npm run package
89
+ - working-directory: extensions/vscode
90
+ env:
91
+ OVSX_PAT: ${{ secrets.OVSX_PAT }}
92
+ run: npx ovsx publish dist/torchshapeflow.vsix -p "$OVSX_PAT"
93
+
94
+ publish-testpypi:
95
+ if: contains(github.ref, '-rc') || contains(github.ref, '-test')
96
+ needs: build-python
97
+ runs-on: ubuntu-latest
98
+ permissions:
99
+ id-token: write
100
+ environment:
101
+ name: testpypi
102
+ steps:
103
+ - uses: actions/download-artifact@v8
104
+ with:
105
+ name: python-dist
106
+ path: dist
107
+ - uses: pypa/gh-action-pypi-publish@release/v1
108
+ with:
109
+ repository-url: https://test.pypi.org/legacy/
110
+
111
+ publish-pypi:
112
+ if: "!contains(github.ref, '-rc') && !contains(github.ref, '-test')"
113
+ needs: build-python
114
+ runs-on: ubuntu-latest
115
+ permissions:
116
+ id-token: write
117
+ environment:
118
+ name: pypi
119
+ steps:
120
+ - uses: actions/download-artifact@v8
121
+ with:
122
+ name: python-dist
123
+ path: dist
124
+ - uses: pypa/gh-action-pypi-publish@release/v1
125
+
126
+ github-release:
127
+ needs: [build-python, build-extension]
128
+ runs-on: ubuntu-latest
129
+ permissions:
130
+ contents: write
131
+ steps:
132
+ - uses: actions/download-artifact@v8
133
+ with:
134
+ name: python-dist
135
+ path: dist
136
+ - uses: actions/download-artifact@v8
137
+ with:
138
+ name: vscode-extension
139
+ path: extensions/vscode/dist
140
+ - uses: softprops/action-gh-release@v2
141
+ with:
142
+ files: |
143
+ dist/*
144
+ extensions/vscode/dist/*.vsix
@@ -0,0 +1,13 @@
1
+ .pytest_cache/
2
+ .mypy_cache/
3
+ .ruff_cache/
4
+ .venv/
5
+ build/
6
+ dist/
7
+ htmlcov/
8
+ site/
9
+ *.pyc
10
+ __pycache__/
11
+ node_modules/
12
+ extensions/vscode/out/
13
+ *.vsix
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Xuesong Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,52 @@
1
+ PYTHON ?= python
2
+
3
+ .PHONY: install format lint typecheck test check docs docs-serve build python-dist extension-build extension-package bump-patch bump-minor bump-major clean
4
+
5
+ install:
6
+ uv sync --extra dev
7
+
8
+ format:
9
+ uv run ruff format .
10
+
11
+ lint:
12
+ uv run ruff check . --fix
13
+
14
+ typecheck:
15
+ uv run mypy .
16
+
17
+ test:
18
+ uv run pytest -q
19
+
20
+ check: format lint typecheck test
21
+
22
+ docs:
23
+ uv run mkdocs build
24
+
25
+ docs-serve:
26
+ uv run mkdocs serve
27
+
28
+ build: python-dist extension-package
29
+
30
+ python-dist:
31
+ uv build
32
+
33
+ extension-build:
34
+ cd extensions/vscode && npm ci && npm run build
35
+
36
+ extension-package:
37
+ cd extensions/vscode && npm ci && npm run package
38
+
39
+ bump-patch:
40
+ uv run python scripts/bump_version.py patch
41
+ uv lock
42
+
43
+ bump-minor:
44
+ uv run python scripts/bump_version.py minor
45
+ uv lock
46
+
47
+ bump-major:
48
+ uv run python scripts/bump_version.py major
49
+ uv lock
50
+
51
+ clean:
52
+ rm -rf .pytest_cache .mypy_cache .ruff_cache build dist site extensions/vscode/dist extensions/vscode/out
@@ -0,0 +1,113 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchshapeflow
3
+ Version: 0.2.0
4
+ Summary: Static AST-based PyTorch tensor shape analysis.
5
+ Project-URL: Homepage, https://github.com/Davidxswang/torchshapeflow
6
+ Project-URL: Repository, https://github.com/Davidxswang/torchshapeflow
7
+ Project-URL: Issues, https://github.com/Davidxswang/torchshapeflow/issues
8
+ Project-URL: Documentation, https://davidxswang.github.io/torchshapeflow/
9
+ Author: TorchShapeFlow Contributors
10
+ License: MIT License
11
+
12
+ Copyright (c) 2026 Xuesong Wang
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ The above copyright notice and this permission notice shall be included in all
22
+ copies or substantial portions of the Software.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ SOFTWARE.
31
+ License-File: LICENSE
32
+ Requires-Python: >=3.10
33
+ Requires-Dist: typer>=0.16.0
34
+ Provides-Extra: dev
35
+ Requires-Dist: mkdocs-material>=9.5.0; extra == 'dev'
36
+ Requires-Dist: mkdocs>=1.6.0; extra == 'dev'
37
+ Requires-Dist: mypy>=1.18.0; extra == 'dev'
38
+ Requires-Dist: pytest>=8.3.0; extra == 'dev'
39
+ Requires-Dist: ruff>=0.11.0; extra == 'dev'
40
+ Requires-Dist: torch>=2.10.0; extra == 'dev'
41
+ Description-Content-Type: text/markdown
42
+
43
+ # TorchShapeFlow
44
+
45
+ [![CI](https://github.com/Davidxswang/torchshapeflow/actions/workflows/ci.yml/badge.svg)](https://github.com/Davidxswang/torchshapeflow/actions/workflows/ci.yml)
46
+ [![PyPI](https://img.shields.io/pypi/v/torchshapeflow)](https://pypi.org/project/torchshapeflow/)
47
+ [![Python](https://img.shields.io/pypi/pyversions/torchshapeflow)](https://pypi.org/project/torchshapeflow/)
48
+ [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
49
+
50
+ TorchShapeFlow is a static, AST-based shape analyzer for PyTorch. It reads your Python source — no execution required — infers tensor shapes through your code, and reports mismatches as structured diagnostics.
51
+
52
+ ```python
53
+ from typing import Annotated
54
+ import torch
55
+ import torch.nn as nn
56
+ from torchshapeflow import Shape
57
+
58
+ class Net(nn.Module):
59
+ def __init__(self):
60
+ self.conv = nn.Conv2d(3, 8, 3, padding=1)
61
+ self.linear = nn.Linear(8 * 32 * 32, 10)
62
+
63
+ def forward(self, x: Annotated[torch.Tensor, Shape("B", 3, 32, 32)]):
64
+ y = self.conv(x) # inferred: [B, 8, 32, 32]
65
+ z = y.flatten(1) # inferred: [B, 8192]
66
+ return self.linear(z) # inferred: [B, 10]
67
+ ```
68
+
69
+ ```bash
70
+ $ tsf check mymodel.py
71
+ mymodel.py: ok
72
+ ```
73
+
74
+ ## Install
75
+
76
+ ```bash
77
+ pip install torchshapeflow
78
+ ```
79
+
80
+ ## Documentation
81
+
82
+ Full docs at **[davidxswang.github.io/torchshapeflow](https://davidxswang.github.io/torchshapeflow)**
83
+
84
+ - [Quickstart](https://davidxswang.github.io/torchshapeflow/quickstart/) — install and run your first check
85
+ - [Annotation syntax](https://davidxswang.github.io/torchshapeflow/syntax/) — how to annotate your tensors
86
+ - [Supported operators](https://davidxswang.github.io/torchshapeflow/operators/) — what is analyzed and what shapes are inferred
87
+ - [Limitations](https://davidxswang.github.io/torchshapeflow/limitations/) — what the analyzer does not handle
88
+
89
+ ## Contributing
90
+
91
+ ```bash
92
+ git clone https://github.com/Davidxswang/torchshapeflow
93
+ cd torchshapeflow
94
+ make install # uv sync --extra dev
95
+ make check # format + lint + typecheck + tests
96
+ ```
97
+
98
+ See [docs/development.md](docs/development.md) for the full development guide: all make targets, CI workflow descriptions, and how to add new operators.
99
+
100
+ ## Release
101
+
102
+ See [RELEASING.md](RELEASING.md) for the full release procedure.
103
+
104
+ Build commands:
105
+
106
+ - `make python-dist` — wheel and sdist into `dist/`
107
+ - `make extension-package` — VS Code extension `.vsix`
108
+ - `make build` — both
109
+
110
+ Marketplace publishing in the release workflow is gated on GitHub Actions secrets:
111
+
112
+ - `VSCE_PAT` for the VS Code Marketplace
113
+ - `OVSX_PAT` for Open VSX
@@ -0,0 +1,71 @@
1
+ # TorchShapeFlow
2
+
3
+ [![CI](https://github.com/Davidxswang/torchshapeflow/actions/workflows/ci.yml/badge.svg)](https://github.com/Davidxswang/torchshapeflow/actions/workflows/ci.yml)
4
+ [![PyPI](https://img.shields.io/pypi/v/torchshapeflow)](https://pypi.org/project/torchshapeflow/)
5
+ [![Python](https://img.shields.io/pypi/pyversions/torchshapeflow)](https://pypi.org/project/torchshapeflow/)
6
+ [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
7
+
8
+ TorchShapeFlow is a static, AST-based shape analyzer for PyTorch. It reads your Python source — no execution required — infers tensor shapes through your code, and reports mismatches as structured diagnostics.
9
+
10
+ ```python
11
+ from typing import Annotated
12
+ import torch
13
+ import torch.nn as nn
14
+ from torchshapeflow import Shape
15
+
16
+ class Net(nn.Module):
17
+ def __init__(self):
18
+ self.conv = nn.Conv2d(3, 8, 3, padding=1)
19
+ self.linear = nn.Linear(8 * 32 * 32, 10)
20
+
21
+ def forward(self, x: Annotated[torch.Tensor, Shape("B", 3, 32, 32)]):
22
+ y = self.conv(x) # inferred: [B, 8, 32, 32]
23
+ z = y.flatten(1) # inferred: [B, 8192]
24
+ return self.linear(z) # inferred: [B, 10]
25
+ ```
26
+
27
+ ```bash
28
+ $ tsf check mymodel.py
29
+ mymodel.py: ok
30
+ ```
31
+
32
+ ## Install
33
+
34
+ ```bash
35
+ pip install torchshapeflow
36
+ ```
37
+
38
+ ## Documentation
39
+
40
+ Full docs at **[davidxswang.github.io/torchshapeflow](https://davidxswang.github.io/torchshapeflow)**
41
+
42
+ - [Quickstart](https://davidxswang.github.io/torchshapeflow/quickstart/) — install and run your first check
43
+ - [Annotation syntax](https://davidxswang.github.io/torchshapeflow/syntax/) — how to annotate your tensors
44
+ - [Supported operators](https://davidxswang.github.io/torchshapeflow/operators/) — what is analyzed and what shapes are inferred
45
+ - [Limitations](https://davidxswang.github.io/torchshapeflow/limitations/) — what the analyzer does not handle
46
+
47
+ ## Contributing
48
+
49
+ ```bash
50
+ git clone https://github.com/Davidxswang/torchshapeflow
51
+ cd torchshapeflow
52
+ make install # uv sync --extra dev
53
+ make check # format + lint + typecheck + tests
54
+ ```
55
+
56
+ See [docs/development.md](docs/development.md) for the full development guide: all make targets, CI workflow descriptions, and how to add new operators.
57
+
58
+ ## Release
59
+
60
+ See [RELEASING.md](RELEASING.md) for the full release procedure.
61
+
62
+ Build commands:
63
+
64
+ - `make python-dist` — wheel and sdist into `dist/`
65
+ - `make extension-package` — VS Code extension `.vsix`
66
+ - `make build` — both
67
+
68
+ Marketplace publishing in the release workflow is gated on GitHub Actions secrets:
69
+
70
+ - `VSCE_PAT` for the VS Code Marketplace
71
+ - `OVSX_PAT` for Open VSX
@@ -0,0 +1,85 @@
1
+ # Architecture
2
+
3
+ ## Analysis pipeline
4
+
5
+ TorchShapeFlow analyzes Python source in a single pass per file:
6
+
7
+ 1. **Parse** — `ast.parse` converts source text into an AST module (`parser.parse_source`).
8
+ 2. **Collect module specs** — `_collect_class_specs` walks class `__init__` bodies to find `nn.Linear`, `nn.Conv2d`, `nn.Embedding`, `nn.MaxPool2d`, `nn.AvgPool2d`, `nn.Sequential`, `nn.MultiheadAttention`, and passthrough module assignments, recording their constructor arguments as spec values.
9
+ 3. **Seed shape environment** — for each function (or `forward` method), annotated parameters are parsed via `parser.parse_tensor_annotation` and added to the environment `env: dict[str, Value]`.
10
+ 4. **Propagate shapes** — `_analyze_statement` walks the function body statement by statement. For each assignment, `_eval_expr` evaluates the right-hand side, dispatching to the appropriate rule function. Results are stored back into `env`.
11
+ 5. **Emit results** — diagnostics and hover facts accumulate in a `ModuleContext` and are returned as a `FileReport`.
12
+
13
+ ## Module map
14
+
15
+ | Module | Responsibility |
16
+ |---|---|
17
+ | `model.py` | All core data types (`Dim` variants, `TensorShape`, `TensorValue`, `TensorTupleValue`, `LinearSpec`, `Conv2dSpec`, `PassthroughSpec`, `EmbeddingSpec`, `Pool2dSpec`, `SequentialSpec`, `MultiheadAttentionSpec`, `ModuleSpec`, `Value`). Shape arithmetic: `product_dim`, `quotient_dim`, `sum_dim`, `broadcast_shapes`, `batch_matmul_shape`, `normalize_index`. |
18
+ | `annotations.py` | Public `Shape` class used in `Annotated[Tensor, Shape(...)]`. |
19
+ | `parser.py` | Parses `Annotated[Tensor, Shape(...)]` annotation AST nodes into `TensorValue`. Raises `AnnotationParseError` on malformed annotations. |
20
+ | `analyzer.py` | Main AST walker. Manages the shape environment, dispatches to rule functions, emits diagnostics via `ModuleContext`. |
21
+ | `diagnostics.py` | `Diagnostic` dataclass and `Severity` type alias (`"error" \| "warning"`). |
22
+ | `report.py` | `FileReport` (list of diagnostics + hover facts per file) and `HoverFact` (inferred shape at a source location). |
23
+ | `cli.py` | Typer CLI. `tsf check` runs the analyzer and formats output. `tsf version` prints the package version. |
24
+ | `rules/__init__.py` | Re-exports all public inference functions. |
25
+ | `rules/shape_ops.py` | `infer_permute`, `infer_transpose`, `infer_reshape`, `infer_flatten`, `infer_squeeze`, `infer_unsqueeze`, `infer_size`, `infer_cat`, `infer_stack`, `infer_matmul`, `infer_mm`, `infer_movedim`, `infer_reduction`, `infer_chunk`, `infer_split`, `infer_einsum`. |
26
+ | `rules/broadcasting.py` | `infer_binary_broadcast` — wraps `broadcast_shapes` for element-wise ops. |
27
+ | `rules/linear.py` | `infer_linear` for `nn.Linear`. |
28
+ | `rules/conv2d.py` | `infer_conv2d` for `nn.Conv2d`. |
29
+ | `rules/embedding.py` | `infer_embedding` for `nn.Embedding`. |
30
+ | `rules/pool2d.py` | `infer_pool2d` for `nn.MaxPool2d` and `nn.AvgPool2d`. |
31
+ | `rules/indexing.py` | `infer_subscript` for tensor subscript and shape-tuple indexing. |
32
+ | `rules/common.py` | Shared AST helpers: `int_from_ast`, `qualified_name`, `dim_from_value`, `tuple_index`, `spatial_output_dim`. |
33
+ | `utils/paths.py` | `collect_python_files` — recursive `.py` file discovery. |
34
+
35
+ ## Dim type hierarchy
36
+
37
+ ```
38
+ Dim (TypeAlias)
39
+ ├── ConstantDim(value: int) — a fixed integer size, e.g. 32
40
+ ├── SymbolicDim(name: str) — a named unknown size, e.g. "B"
41
+ ├── ExpressionDim(expr: str) — a derived expression, e.g. "4*B" or "(B*C)/4"
42
+ └── UnknownDim(token: str) — explicitly unresolvable
43
+ ```
44
+
45
+ Shape arithmetic returns `ConstantDim` when all operands are constant and `ExpressionDim` otherwise. Expressions are stored as strings and compared structurally.
46
+
47
+ ## Shape environment
48
+
49
+ The environment `env: dict[str, Value]` maps variable names to their inferred `Value`:
50
+
51
+ ```
52
+ Value (TypeAlias)
53
+ ├── TensorValue(shape: TensorShape, origin: str | None)
54
+ ├── ShapeTupleValue(dims: tuple[Dim, ...]) — result of x.shape or x.size()
55
+ ├── IntegerValue(value: int | None) — result of x.ndim or x.size(i)
56
+ ├── TensorTupleValue(tensors: tuple[TensorValue, ...]) — result of chunk/split/MHA
57
+
58
+ │ ModuleSpec (TypeAlias — stored in module_specs and env)
59
+ ├── LinearSpec(in_features, out_features) — nn.Linear
60
+ ├── Conv2dSpec(in_channels, out_channels, kernel_size, stride, padding, dilation) — nn.Conv2d
61
+ ├── PassthroughSpec() — shape-preserving modules (BatchNorm, ReLU, …)
62
+ ├── EmbeddingSpec(embedding_dim) — nn.Embedding
63
+ ├── Pool2dSpec(kernel_size, stride, padding, dilation) — nn.MaxPool2d / nn.AvgPool2d
64
+ ├── SequentialSpec(specs: tuple[ModuleSpec, ...]) — nn.Sequential
65
+ └── MultiheadAttentionSpec(embed_dim, num_heads, batch_first) — nn.MultiheadAttention
66
+ ```
67
+
68
+ Spec values are stored in `module_specs` (keyed by attribute name) when their constructor is parsed from `__init__`. When `self.linear(x)` is called, the analyzer looks up `"linear"` in `module_specs`, retrieves the spec, and calls the appropriate inference function. Module aliases (`m = self.linear; m(x)`) are also supported: spec values stored in `env` are looked up before falling through to `func_sigs`.
69
+
70
+ ## Diagnostic codes
71
+
72
+ | Code | Trigger |
73
+ |---|---|
74
+ | `TSF1001` | Annotation parse error (malformed `Annotated` or `Shape`) |
75
+ | `TSF1003` | Incompatible `matmul` / `bmm` shapes |
76
+ | `TSF1004` | Invalid `reshape` or `flatten` dimensions |
77
+ | `TSF1005` | Invalid `cat` or `stack` dimensions or mismatched shapes |
78
+ | `TSF1006` | Broadcasting incompatibility |
79
+ | `TSF1007` | `nn.Linear`, `nn.Conv2d`, or `nn.MaxPool2d`/`AvgPool2d` input shape mismatch |
80
+ | `TSF1008` | Invalid `permute`, `transpose`, `squeeze`, or `unsqueeze` dimensions |
81
+ | `TSF1009` | Return shape does not match the declared return type annotation |
82
+
83
+ ## Adding a new operator
84
+
85
+ See [Development → Adding a new operator](development.md#adding-a-new-operator).