torchshapeflow 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchshapeflow might be problematic. Click here for more details.
- torchshapeflow-0.2.0/.github/workflows/build-artifacts.yml +41 -0
- torchshapeflow-0.2.0/.github/workflows/ci.yml +43 -0
- torchshapeflow-0.2.0/.github/workflows/docs.yml +42 -0
- torchshapeflow-0.2.0/.github/workflows/release.yml +144 -0
- torchshapeflow-0.2.0/.gitignore +13 -0
- torchshapeflow-0.2.0/LICENSE +21 -0
- torchshapeflow-0.2.0/Makefile +52 -0
- torchshapeflow-0.2.0/PKG-INFO +113 -0
- torchshapeflow-0.2.0/README.md +71 -0
- torchshapeflow-0.2.0/docs/architecture.md +85 -0
- torchshapeflow-0.2.0/docs/development.md +120 -0
- torchshapeflow-0.2.0/docs/extension.md +58 -0
- torchshapeflow-0.2.0/docs/index.md +45 -0
- torchshapeflow-0.2.0/docs/limitations.md +44 -0
- torchshapeflow-0.2.0/docs/operators.md +632 -0
- torchshapeflow-0.2.0/docs/quickstart.md +81 -0
- torchshapeflow-0.2.0/docs/syntax.md +124 -0
- torchshapeflow-0.2.0/examples/attention_scores.py +13 -0
- torchshapeflow-0.2.0/examples/error_cases.py +17 -0
- torchshapeflow-0.2.0/examples/simple_cnn.py +17 -0
- torchshapeflow-0.2.0/examples/transformer_block.py +11 -0
- torchshapeflow-0.2.0/examples/vit_patch_embed.py +10 -0
- torchshapeflow-0.2.0/mkdocs.yml +19 -0
- torchshapeflow-0.2.0/pyproject.toml +77 -0
- torchshapeflow-0.2.0/scripts/bump_version.py +80 -0
- torchshapeflow-0.2.0/src/torchshapeflow/__init__.py +4 -0
- torchshapeflow-0.2.0/src/torchshapeflow/_version.py +1 -0
- torchshapeflow-0.2.0/src/torchshapeflow/analyzer.py +1489 -0
- torchshapeflow-0.2.0/src/torchshapeflow/annotations.py +34 -0
- torchshapeflow-0.2.0/src/torchshapeflow/cli.py +52 -0
- torchshapeflow-0.2.0/src/torchshapeflow/diagnostics.py +37 -0
- torchshapeflow-0.2.0/src/torchshapeflow/index.py +246 -0
- torchshapeflow-0.2.0/src/torchshapeflow/model.py +379 -0
- torchshapeflow-0.2.0/src/torchshapeflow/parser.py +112 -0
- torchshapeflow-0.2.0/src/torchshapeflow/py.typed +1 -0
- torchshapeflow-0.2.0/src/torchshapeflow/report.py +39 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/__init__.py +59 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/broadcasting.py +11 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/common.py +98 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/conv2d.py +41 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/embedding.py +17 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/indexing.py +94 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/linear.py +21 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/pool2d.py +40 -0
- torchshapeflow-0.2.0/src/torchshapeflow/rules/shape_ops.py +665 -0
- torchshapeflow-0.2.0/src/torchshapeflow/utils/__init__.py +1 -0
- torchshapeflow-0.2.0/src/torchshapeflow/utils/paths.py +9 -0
- torchshapeflow-0.2.0/tests/fixtures/attention_scores.py +13 -0
- torchshapeflow-0.2.0/tests/test_analyzer.py +873 -0
- torchshapeflow-0.2.0/tests/test_annotations.py +22 -0
- torchshapeflow-0.2.0/tests/test_cli.py +17 -0
- torchshapeflow-0.2.0/tests/test_index.py +427 -0
- torchshapeflow-0.2.0/tests/test_model.py +152 -0
- torchshapeflow-0.2.0/tests/test_parser.py +139 -0
- torchshapeflow-0.2.0/tests/test_rules.py +365 -0
- torchshapeflow-0.2.0/tests/test_shape_ops.py +631 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
name: Build Artifacts
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
pull_request:
|
|
8
|
+
workflow_dispatch:
|
|
9
|
+
|
|
10
|
+
jobs:
|
|
11
|
+
python-package:
|
|
12
|
+
runs-on: ubuntu-latest
|
|
13
|
+
steps:
|
|
14
|
+
- uses: actions/checkout@v6
|
|
15
|
+
- uses: astral-sh/setup-uv@v7
|
|
16
|
+
- uses: actions/setup-python@v6
|
|
17
|
+
with:
|
|
18
|
+
python-version: "3.12"
|
|
19
|
+
- run: uv build
|
|
20
|
+
- uses: actions/upload-artifact@v7
|
|
21
|
+
with:
|
|
22
|
+
name: python-dist
|
|
23
|
+
path: dist/*
|
|
24
|
+
|
|
25
|
+
vscode-extension:
|
|
26
|
+
runs-on: ubuntu-latest
|
|
27
|
+
steps:
|
|
28
|
+
- uses: actions/checkout@v6
|
|
29
|
+
- uses: actions/setup-node@v6
|
|
30
|
+
with:
|
|
31
|
+
node-version: "24"
|
|
32
|
+
cache: "npm"
|
|
33
|
+
cache-dependency-path: extensions/vscode/package-lock.json
|
|
34
|
+
- working-directory: extensions/vscode
|
|
35
|
+
run: npm ci
|
|
36
|
+
- working-directory: extensions/vscode
|
|
37
|
+
run: npm run package
|
|
38
|
+
- uses: actions/upload-artifact@v7
|
|
39
|
+
with:
|
|
40
|
+
name: vscode-extension
|
|
41
|
+
path: extensions/vscode/dist/*.vsix
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
pull_request:
|
|
6
|
+
|
|
7
|
+
jobs:
|
|
8
|
+
check:
|
|
9
|
+
runs-on: ubuntu-latest
|
|
10
|
+
steps:
|
|
11
|
+
- uses: actions/checkout@v6
|
|
12
|
+
- uses: astral-sh/setup-uv@v7
|
|
13
|
+
- uses: actions/setup-python@v6
|
|
14
|
+
with:
|
|
15
|
+
python-version: "3.12"
|
|
16
|
+
- uses: actions/setup-node@v6
|
|
17
|
+
with:
|
|
18
|
+
node-version: "24"
|
|
19
|
+
cache: "npm"
|
|
20
|
+
cache-dependency-path: extensions/vscode/package-lock.json
|
|
21
|
+
- run: uv sync --extra dev
|
|
22
|
+
- run: uv run ruff format . --check
|
|
23
|
+
- run: uv run ruff check .
|
|
24
|
+
- run: uv run mypy .
|
|
25
|
+
- run: uv run mkdocs build
|
|
26
|
+
- working-directory: extensions/vscode
|
|
27
|
+
run: npm ci
|
|
28
|
+
- working-directory: extensions/vscode
|
|
29
|
+
run: npm run build
|
|
30
|
+
|
|
31
|
+
test:
|
|
32
|
+
runs-on: ubuntu-latest
|
|
33
|
+
strategy:
|
|
34
|
+
matrix:
|
|
35
|
+
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
|
36
|
+
steps:
|
|
37
|
+
- uses: actions/checkout@v6
|
|
38
|
+
- uses: astral-sh/setup-uv@v7
|
|
39
|
+
- uses: actions/setup-python@v6
|
|
40
|
+
with:
|
|
41
|
+
python-version: ${{ matrix.python-version }}
|
|
42
|
+
- run: uv sync --extra dev
|
|
43
|
+
- run: uv run pytest -q
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
name: Docs
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
workflow_dispatch:
|
|
5
|
+
push:
|
|
6
|
+
branches:
|
|
7
|
+
- main
|
|
8
|
+
|
|
9
|
+
permissions:
|
|
10
|
+
contents: read
|
|
11
|
+
pages: write
|
|
12
|
+
id-token: write
|
|
13
|
+
|
|
14
|
+
concurrency:
|
|
15
|
+
group: pages
|
|
16
|
+
cancel-in-progress: true
|
|
17
|
+
|
|
18
|
+
jobs:
|
|
19
|
+
build:
|
|
20
|
+
runs-on: ubuntu-latest
|
|
21
|
+
steps:
|
|
22
|
+
- uses: actions/checkout@v6
|
|
23
|
+
- uses: astral-sh/setup-uv@v7
|
|
24
|
+
- uses: actions/setup-python@v6
|
|
25
|
+
with:
|
|
26
|
+
python-version: "3.12"
|
|
27
|
+
- run: uv sync --extra dev
|
|
28
|
+
- run: uv run mkdocs build
|
|
29
|
+
- uses: actions/upload-pages-artifact@v4
|
|
30
|
+
with:
|
|
31
|
+
path: site
|
|
32
|
+
|
|
33
|
+
deploy:
|
|
34
|
+
if: github.ref == 'refs/heads/main'
|
|
35
|
+
needs: build
|
|
36
|
+
runs-on: ubuntu-latest
|
|
37
|
+
environment:
|
|
38
|
+
name: github-pages
|
|
39
|
+
url: ${{ steps.deployment.outputs.page_url }}
|
|
40
|
+
steps:
|
|
41
|
+
- id: deployment
|
|
42
|
+
uses: actions/deploy-pages@v4
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
name: Release
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- "v*"
|
|
7
|
+
workflow_dispatch:
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
check-secrets:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
outputs:
|
|
13
|
+
has-vsce-pat: ${{ steps.check.outputs.has-vsce-pat }}
|
|
14
|
+
has-ovsx-pat: ${{ steps.check.outputs.has-ovsx-pat }}
|
|
15
|
+
steps:
|
|
16
|
+
- id: check
|
|
17
|
+
env:
|
|
18
|
+
VSCE_PAT: ${{ secrets.VSCE_PAT }}
|
|
19
|
+
OVSX_PAT: ${{ secrets.OVSX_PAT }}
|
|
20
|
+
run: |
|
|
21
|
+
echo "has-vsce-pat=${{ secrets.VSCE_PAT != '' }}" >> "$GITHUB_OUTPUT"
|
|
22
|
+
echo "has-ovsx-pat=${{ secrets.OVSX_PAT != '' }}" >> "$GITHUB_OUTPUT"
|
|
23
|
+
|
|
24
|
+
build-python:
|
|
25
|
+
runs-on: ubuntu-latest
|
|
26
|
+
steps:
|
|
27
|
+
- uses: actions/checkout@v6
|
|
28
|
+
- uses: astral-sh/setup-uv@v7
|
|
29
|
+
- uses: actions/setup-python@v6
|
|
30
|
+
with:
|
|
31
|
+
python-version: "3.12"
|
|
32
|
+
- run: uv build
|
|
33
|
+
- uses: actions/upload-artifact@v7
|
|
34
|
+
with:
|
|
35
|
+
name: python-dist
|
|
36
|
+
path: dist/*
|
|
37
|
+
|
|
38
|
+
build-extension:
|
|
39
|
+
runs-on: ubuntu-latest
|
|
40
|
+
steps:
|
|
41
|
+
- uses: actions/checkout@v6
|
|
42
|
+
- uses: actions/setup-node@v6
|
|
43
|
+
with:
|
|
44
|
+
node-version: "24"
|
|
45
|
+
cache: "npm"
|
|
46
|
+
cache-dependency-path: extensions/vscode/package-lock.json
|
|
47
|
+
- working-directory: extensions/vscode
|
|
48
|
+
run: npm ci
|
|
49
|
+
- working-directory: extensions/vscode
|
|
50
|
+
run: npm run package
|
|
51
|
+
- uses: actions/upload-artifact@v7
|
|
52
|
+
with:
|
|
53
|
+
name: vscode-extension
|
|
54
|
+
path: extensions/vscode/dist/*.vsix
|
|
55
|
+
|
|
56
|
+
publish-vscode-marketplace:
|
|
57
|
+
if: needs.check-secrets.outputs.has-vsce-pat == 'true' && !contains(github.ref, '-rc') && !contains(github.ref, '-test')
|
|
58
|
+
needs: [check-secrets, build-extension]
|
|
59
|
+
runs-on: ubuntu-latest
|
|
60
|
+
steps:
|
|
61
|
+
- uses: actions/checkout@v6
|
|
62
|
+
- uses: actions/setup-node@v6
|
|
63
|
+
with:
|
|
64
|
+
node-version: "24"
|
|
65
|
+
cache: "npm"
|
|
66
|
+
cache-dependency-path: extensions/vscode/package-lock.json
|
|
67
|
+
- working-directory: extensions/vscode
|
|
68
|
+
run: npm ci
|
|
69
|
+
- working-directory: extensions/vscode
|
|
70
|
+
env:
|
|
71
|
+
VSCE_PAT: ${{ secrets.VSCE_PAT }}
|
|
72
|
+
run: npx @vscode/vsce publish -p "$VSCE_PAT"
|
|
73
|
+
|
|
74
|
+
publish-open-vsx:
|
|
75
|
+
if: needs.check-secrets.outputs.has-ovsx-pat == 'true' && !contains(github.ref, '-rc') && !contains(github.ref, '-test')
|
|
76
|
+
needs: [check-secrets, build-extension]
|
|
77
|
+
runs-on: ubuntu-latest
|
|
78
|
+
steps:
|
|
79
|
+
- uses: actions/checkout@v6
|
|
80
|
+
- uses: actions/setup-node@v6
|
|
81
|
+
with:
|
|
82
|
+
node-version: "24"
|
|
83
|
+
cache: "npm"
|
|
84
|
+
cache-dependency-path: extensions/vscode/package-lock.json
|
|
85
|
+
- working-directory: extensions/vscode
|
|
86
|
+
run: npm ci
|
|
87
|
+
- working-directory: extensions/vscode
|
|
88
|
+
run: npm run package
|
|
89
|
+
- working-directory: extensions/vscode
|
|
90
|
+
env:
|
|
91
|
+
OVSX_PAT: ${{ secrets.OVSX_PAT }}
|
|
92
|
+
run: npx ovsx publish dist/torchshapeflow.vsix -p "$OVSX_PAT"
|
|
93
|
+
|
|
94
|
+
publish-testpypi:
|
|
95
|
+
if: contains(github.ref, '-rc') || contains(github.ref, '-test')
|
|
96
|
+
needs: build-python
|
|
97
|
+
runs-on: ubuntu-latest
|
|
98
|
+
permissions:
|
|
99
|
+
id-token: write
|
|
100
|
+
environment:
|
|
101
|
+
name: testpypi
|
|
102
|
+
steps:
|
|
103
|
+
- uses: actions/download-artifact@v8
|
|
104
|
+
with:
|
|
105
|
+
name: python-dist
|
|
106
|
+
path: dist
|
|
107
|
+
- uses: pypa/gh-action-pypi-publish@release/v1
|
|
108
|
+
with:
|
|
109
|
+
repository-url: https://test.pypi.org/legacy/
|
|
110
|
+
|
|
111
|
+
publish-pypi:
|
|
112
|
+
if: "!contains(github.ref, '-rc') && !contains(github.ref, '-test')"
|
|
113
|
+
needs: build-python
|
|
114
|
+
runs-on: ubuntu-latest
|
|
115
|
+
permissions:
|
|
116
|
+
id-token: write
|
|
117
|
+
environment:
|
|
118
|
+
name: pypi
|
|
119
|
+
steps:
|
|
120
|
+
- uses: actions/download-artifact@v8
|
|
121
|
+
with:
|
|
122
|
+
name: python-dist
|
|
123
|
+
path: dist
|
|
124
|
+
- uses: pypa/gh-action-pypi-publish@release/v1
|
|
125
|
+
|
|
126
|
+
github-release:
|
|
127
|
+
needs: [build-python, build-extension]
|
|
128
|
+
runs-on: ubuntu-latest
|
|
129
|
+
permissions:
|
|
130
|
+
contents: write
|
|
131
|
+
steps:
|
|
132
|
+
- uses: actions/download-artifact@v8
|
|
133
|
+
with:
|
|
134
|
+
name: python-dist
|
|
135
|
+
path: dist
|
|
136
|
+
- uses: actions/download-artifact@v8
|
|
137
|
+
with:
|
|
138
|
+
name: vscode-extension
|
|
139
|
+
path: extensions/vscode/dist
|
|
140
|
+
- uses: softprops/action-gh-release@v2
|
|
141
|
+
with:
|
|
142
|
+
files: |
|
|
143
|
+
dist/*
|
|
144
|
+
extensions/vscode/dist/*.vsix
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Xuesong Wang
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
PYTHON ?= python
|
|
2
|
+
|
|
3
|
+
.PHONY: install format lint typecheck test check docs docs-serve build python-dist extension-build extension-package bump-patch bump-minor bump-major clean
|
|
4
|
+
|
|
5
|
+
install:
|
|
6
|
+
uv sync --extra dev
|
|
7
|
+
|
|
8
|
+
format:
|
|
9
|
+
uv run ruff format .
|
|
10
|
+
|
|
11
|
+
lint:
|
|
12
|
+
uv run ruff check . --fix
|
|
13
|
+
|
|
14
|
+
typecheck:
|
|
15
|
+
uv run mypy .
|
|
16
|
+
|
|
17
|
+
test:
|
|
18
|
+
uv run pytest -q
|
|
19
|
+
|
|
20
|
+
check: format lint typecheck test
|
|
21
|
+
|
|
22
|
+
docs:
|
|
23
|
+
uv run mkdocs build
|
|
24
|
+
|
|
25
|
+
docs-serve:
|
|
26
|
+
uv run mkdocs serve
|
|
27
|
+
|
|
28
|
+
build: python-dist extension-package
|
|
29
|
+
|
|
30
|
+
python-dist:
|
|
31
|
+
uv build
|
|
32
|
+
|
|
33
|
+
extension-build:
|
|
34
|
+
cd extensions/vscode && npm ci && npm run build
|
|
35
|
+
|
|
36
|
+
extension-package:
|
|
37
|
+
cd extensions/vscode && npm ci && npm run package
|
|
38
|
+
|
|
39
|
+
bump-patch:
|
|
40
|
+
uv run python scripts/bump_version.py patch
|
|
41
|
+
uv lock
|
|
42
|
+
|
|
43
|
+
bump-minor:
|
|
44
|
+
uv run python scripts/bump_version.py minor
|
|
45
|
+
uv lock
|
|
46
|
+
|
|
47
|
+
bump-major:
|
|
48
|
+
uv run python scripts/bump_version.py major
|
|
49
|
+
uv lock
|
|
50
|
+
|
|
51
|
+
clean:
|
|
52
|
+
rm -rf .pytest_cache .mypy_cache .ruff_cache build dist site extensions/vscode/dist extensions/vscode/out
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchshapeflow
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Static AST-based PyTorch tensor shape analysis.
|
|
5
|
+
Project-URL: Homepage, https://github.com/Davidxswang/torchshapeflow
|
|
6
|
+
Project-URL: Repository, https://github.com/Davidxswang/torchshapeflow
|
|
7
|
+
Project-URL: Issues, https://github.com/Davidxswang/torchshapeflow/issues
|
|
8
|
+
Project-URL: Documentation, https://davidxswang.github.io/torchshapeflow/
|
|
9
|
+
Author: TorchShapeFlow Contributors
|
|
10
|
+
License: MIT License
|
|
11
|
+
|
|
12
|
+
Copyright (c) 2026 Xuesong Wang
|
|
13
|
+
|
|
14
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
15
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
16
|
+
in the Software without restriction, including without limitation the rights
|
|
17
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
18
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
19
|
+
furnished to do so, subject to the following conditions:
|
|
20
|
+
|
|
21
|
+
The above copyright notice and this permission notice shall be included in all
|
|
22
|
+
copies or substantial portions of the Software.
|
|
23
|
+
|
|
24
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
25
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
26
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
27
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
28
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
29
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
30
|
+
SOFTWARE.
|
|
31
|
+
License-File: LICENSE
|
|
32
|
+
Requires-Python: >=3.10
|
|
33
|
+
Requires-Dist: typer>=0.16.0
|
|
34
|
+
Provides-Extra: dev
|
|
35
|
+
Requires-Dist: mkdocs-material>=9.5.0; extra == 'dev'
|
|
36
|
+
Requires-Dist: mkdocs>=1.6.0; extra == 'dev'
|
|
37
|
+
Requires-Dist: mypy>=1.18.0; extra == 'dev'
|
|
38
|
+
Requires-Dist: pytest>=8.3.0; extra == 'dev'
|
|
39
|
+
Requires-Dist: ruff>=0.11.0; extra == 'dev'
|
|
40
|
+
Requires-Dist: torch>=2.10.0; extra == 'dev'
|
|
41
|
+
Description-Content-Type: text/markdown
|
|
42
|
+
|
|
43
|
+
# TorchShapeFlow
|
|
44
|
+
|
|
45
|
+
[](https://github.com/Davidxswang/torchshapeflow/actions/workflows/ci.yml)
|
|
46
|
+
[](https://pypi.org/project/torchshapeflow/)
|
|
47
|
+
[](https://pypi.org/project/torchshapeflow/)
|
|
48
|
+
[](LICENSE)
|
|
49
|
+
|
|
50
|
+
TorchShapeFlow is a static, AST-based shape analyzer for PyTorch. It reads your Python source — no execution required — infers tensor shapes through your code, and reports mismatches as structured diagnostics.
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
from typing import Annotated
|
|
54
|
+
import torch
|
|
55
|
+
import torch.nn as nn
|
|
56
|
+
from torchshapeflow import Shape
|
|
57
|
+
|
|
58
|
+
class Net(nn.Module):
|
|
59
|
+
def __init__(self):
|
|
60
|
+
self.conv = nn.Conv2d(3, 8, 3, padding=1)
|
|
61
|
+
self.linear = nn.Linear(8 * 32 * 32, 10)
|
|
62
|
+
|
|
63
|
+
def forward(self, x: Annotated[torch.Tensor, Shape("B", 3, 32, 32)]):
|
|
64
|
+
y = self.conv(x) # inferred: [B, 8, 32, 32]
|
|
65
|
+
z = y.flatten(1) # inferred: [B, 8192]
|
|
66
|
+
return self.linear(z) # inferred: [B, 10]
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
```bash
|
|
70
|
+
$ tsf check mymodel.py
|
|
71
|
+
mymodel.py: ok
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## Install
|
|
75
|
+
|
|
76
|
+
```bash
|
|
77
|
+
pip install torchshapeflow
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## Documentation
|
|
81
|
+
|
|
82
|
+
Full docs at **[davidxswang.github.io/torchshapeflow](https://davidxswang.github.io/torchshapeflow)**
|
|
83
|
+
|
|
84
|
+
- [Quickstart](https://davidxswang.github.io/torchshapeflow/quickstart/) — install and run your first check
|
|
85
|
+
- [Annotation syntax](https://davidxswang.github.io/torchshapeflow/syntax/) — how to annotate your tensors
|
|
86
|
+
- [Supported operators](https://davidxswang.github.io/torchshapeflow/operators/) — what is analyzed and what shapes are inferred
|
|
87
|
+
- [Limitations](https://davidxswang.github.io/torchshapeflow/limitations/) — what the analyzer does not handle
|
|
88
|
+
|
|
89
|
+
## Contributing
|
|
90
|
+
|
|
91
|
+
```bash
|
|
92
|
+
git clone https://github.com/Davidxswang/torchshapeflow
|
|
93
|
+
cd torchshapeflow
|
|
94
|
+
make install # uv sync --extra dev
|
|
95
|
+
make check # format + lint + typecheck + tests
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
See [docs/development.md](docs/development.md) for the full development guide: all make targets, CI workflow descriptions, and how to add new operators.
|
|
99
|
+
|
|
100
|
+
## Release
|
|
101
|
+
|
|
102
|
+
See [RELEASING.md](RELEASING.md) for the full release procedure.
|
|
103
|
+
|
|
104
|
+
Build commands:
|
|
105
|
+
|
|
106
|
+
- `make python-dist` — wheel and sdist into `dist/`
|
|
107
|
+
- `make extension-package` — VS Code extension `.vsix`
|
|
108
|
+
- `make build` — both
|
|
109
|
+
|
|
110
|
+
Marketplace publishing in the release workflow is gated on GitHub Actions secrets:
|
|
111
|
+
|
|
112
|
+
- `VSCE_PAT` for the VS Code Marketplace
|
|
113
|
+
- `OVSX_PAT` for Open VSX
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# TorchShapeFlow
|
|
2
|
+
|
|
3
|
+
[](https://github.com/Davidxswang/torchshapeflow/actions/workflows/ci.yml)
|
|
4
|
+
[](https://pypi.org/project/torchshapeflow/)
|
|
5
|
+
[](https://pypi.org/project/torchshapeflow/)
|
|
6
|
+
[](LICENSE)
|
|
7
|
+
|
|
8
|
+
TorchShapeFlow is a static, AST-based shape analyzer for PyTorch. It reads your Python source — no execution required — infers tensor shapes through your code, and reports mismatches as structured diagnostics.
|
|
9
|
+
|
|
10
|
+
```python
|
|
11
|
+
from typing import Annotated
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
from torchshapeflow import Shape
|
|
15
|
+
|
|
16
|
+
class Net(nn.Module):
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self.conv = nn.Conv2d(3, 8, 3, padding=1)
|
|
19
|
+
self.linear = nn.Linear(8 * 32 * 32, 10)
|
|
20
|
+
|
|
21
|
+
def forward(self, x: Annotated[torch.Tensor, Shape("B", 3, 32, 32)]):
|
|
22
|
+
y = self.conv(x) # inferred: [B, 8, 32, 32]
|
|
23
|
+
z = y.flatten(1) # inferred: [B, 8192]
|
|
24
|
+
return self.linear(z) # inferred: [B, 10]
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
$ tsf check mymodel.py
|
|
29
|
+
mymodel.py: ok
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Install
|
|
33
|
+
|
|
34
|
+
```bash
|
|
35
|
+
pip install torchshapeflow
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
## Documentation
|
|
39
|
+
|
|
40
|
+
Full docs at **[davidxswang.github.io/torchshapeflow](https://davidxswang.github.io/torchshapeflow)**
|
|
41
|
+
|
|
42
|
+
- [Quickstart](https://davidxswang.github.io/torchshapeflow/quickstart/) — install and run your first check
|
|
43
|
+
- [Annotation syntax](https://davidxswang.github.io/torchshapeflow/syntax/) — how to annotate your tensors
|
|
44
|
+
- [Supported operators](https://davidxswang.github.io/torchshapeflow/operators/) — what is analyzed and what shapes are inferred
|
|
45
|
+
- [Limitations](https://davidxswang.github.io/torchshapeflow/limitations/) — what the analyzer does not handle
|
|
46
|
+
|
|
47
|
+
## Contributing
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
git clone https://github.com/Davidxswang/torchshapeflow
|
|
51
|
+
cd torchshapeflow
|
|
52
|
+
make install # uv sync --extra dev
|
|
53
|
+
make check # format + lint + typecheck + tests
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
See [docs/development.md](docs/development.md) for the full development guide: all make targets, CI workflow descriptions, and how to add new operators.
|
|
57
|
+
|
|
58
|
+
## Release
|
|
59
|
+
|
|
60
|
+
See [RELEASING.md](RELEASING.md) for the full release procedure.
|
|
61
|
+
|
|
62
|
+
Build commands:
|
|
63
|
+
|
|
64
|
+
- `make python-dist` — wheel and sdist into `dist/`
|
|
65
|
+
- `make extension-package` — VS Code extension `.vsix`
|
|
66
|
+
- `make build` — both
|
|
67
|
+
|
|
68
|
+
Marketplace publishing in the release workflow is gated on GitHub Actions secrets:
|
|
69
|
+
|
|
70
|
+
- `VSCE_PAT` for the VS Code Marketplace
|
|
71
|
+
- `OVSX_PAT` for Open VSX
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Architecture
|
|
2
|
+
|
|
3
|
+
## Analysis pipeline
|
|
4
|
+
|
|
5
|
+
TorchShapeFlow analyzes Python source in a single pass per file:
|
|
6
|
+
|
|
7
|
+
1. **Parse** — `ast.parse` converts source text into an AST module (`parser.parse_source`).
|
|
8
|
+
2. **Collect module specs** — `_collect_class_specs` walks class `__init__` bodies to find `nn.Linear`, `nn.Conv2d`, `nn.Embedding`, `nn.MaxPool2d`, `nn.AvgPool2d`, `nn.Sequential`, `nn.MultiheadAttention`, and passthrough module assignments, recording their constructor arguments as spec values.
|
|
9
|
+
3. **Seed shape environment** — for each function (or `forward` method), annotated parameters are parsed via `parser.parse_tensor_annotation` and added to the environment `env: dict[str, Value]`.
|
|
10
|
+
4. **Propagate shapes** — `_analyze_statement` walks the function body statement by statement. For each assignment, `_eval_expr` evaluates the right-hand side, dispatching to the appropriate rule function. Results are stored back into `env`.
|
|
11
|
+
5. **Emit results** — diagnostics and hover facts accumulate in a `ModuleContext` and are returned as a `FileReport`.
|
|
12
|
+
|
|
13
|
+
## Module map
|
|
14
|
+
|
|
15
|
+
| Module | Responsibility |
|
|
16
|
+
|---|---|
|
|
17
|
+
| `model.py` | All core data types (`Dim` variants, `TensorShape`, `TensorValue`, `TensorTupleValue`, `LinearSpec`, `Conv2dSpec`, `PassthroughSpec`, `EmbeddingSpec`, `Pool2dSpec`, `SequentialSpec`, `MultiheadAttentionSpec`, `ModuleSpec`, `Value`). Shape arithmetic: `product_dim`, `quotient_dim`, `sum_dim`, `broadcast_shapes`, `batch_matmul_shape`, `normalize_index`. |
|
|
18
|
+
| `annotations.py` | Public `Shape` class used in `Annotated[Tensor, Shape(...)]`. |
|
|
19
|
+
| `parser.py` | Parses `Annotated[Tensor, Shape(...)]` annotation AST nodes into `TensorValue`. Raises `AnnotationParseError` on malformed annotations. |
|
|
20
|
+
| `analyzer.py` | Main AST walker. Manages the shape environment, dispatches to rule functions, emits diagnostics via `ModuleContext`. |
|
|
21
|
+
| `diagnostics.py` | `Diagnostic` dataclass and `Severity` type alias (`"error" \| "warning"`). |
|
|
22
|
+
| `report.py` | `FileReport` (list of diagnostics + hover facts per file) and `HoverFact` (inferred shape at a source location). |
|
|
23
|
+
| `cli.py` | Typer CLI. `tsf check` runs the analyzer and formats output. `tsf version` prints the package version. |
|
|
24
|
+
| `rules/__init__.py` | Re-exports all public inference functions. |
|
|
25
|
+
| `rules/shape_ops.py` | `infer_permute`, `infer_transpose`, `infer_reshape`, `infer_flatten`, `infer_squeeze`, `infer_unsqueeze`, `infer_size`, `infer_cat`, `infer_stack`, `infer_matmul`, `infer_mm`, `infer_movedim`, `infer_reduction`, `infer_chunk`, `infer_split`, `infer_einsum`. |
|
|
26
|
+
| `rules/broadcasting.py` | `infer_binary_broadcast` — wraps `broadcast_shapes` for element-wise ops. |
|
|
27
|
+
| `rules/linear.py` | `infer_linear` for `nn.Linear`. |
|
|
28
|
+
| `rules/conv2d.py` | `infer_conv2d` for `nn.Conv2d`. |
|
|
29
|
+
| `rules/embedding.py` | `infer_embedding` for `nn.Embedding`. |
|
|
30
|
+
| `rules/pool2d.py` | `infer_pool2d` for `nn.MaxPool2d` and `nn.AvgPool2d`. |
|
|
31
|
+
| `rules/indexing.py` | `infer_subscript` for tensor subscript and shape-tuple indexing. |
|
|
32
|
+
| `rules/common.py` | Shared AST helpers: `int_from_ast`, `qualified_name`, `dim_from_value`, `tuple_index`, `spatial_output_dim`. |
|
|
33
|
+
| `utils/paths.py` | `collect_python_files` — recursive `.py` file discovery. |
|
|
34
|
+
|
|
35
|
+
## Dim type hierarchy
|
|
36
|
+
|
|
37
|
+
```
|
|
38
|
+
Dim (TypeAlias)
|
|
39
|
+
├── ConstantDim(value: int) — a fixed integer size, e.g. 32
|
|
40
|
+
├── SymbolicDim(name: str) — a named unknown size, e.g. "B"
|
|
41
|
+
├── ExpressionDim(expr: str) — a derived expression, e.g. "4*B" or "(B*C)/4"
|
|
42
|
+
└── UnknownDim(token: str) — explicitly unresolvable
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
Shape arithmetic returns `ConstantDim` when all operands are constant and `ExpressionDim` otherwise. Expressions are stored as strings and compared structurally.
|
|
46
|
+
|
|
47
|
+
## Shape environment
|
|
48
|
+
|
|
49
|
+
The environment `env: dict[str, Value]` maps variable names to their inferred `Value`:
|
|
50
|
+
|
|
51
|
+
```
|
|
52
|
+
Value (TypeAlias)
|
|
53
|
+
├── TensorValue(shape: TensorShape, origin: str | None)
|
|
54
|
+
├── ShapeTupleValue(dims: tuple[Dim, ...]) — result of x.shape or x.size()
|
|
55
|
+
├── IntegerValue(value: int | None) — result of x.ndim or x.size(i)
|
|
56
|
+
├── TensorTupleValue(tensors: tuple[TensorValue, ...]) — result of chunk/split/MHA
|
|
57
|
+
│
|
|
58
|
+
│ ModuleSpec (TypeAlias — stored in module_specs and env)
|
|
59
|
+
├── LinearSpec(in_features, out_features) — nn.Linear
|
|
60
|
+
├── Conv2dSpec(in_channels, out_channels, kernel_size, stride, padding, dilation) — nn.Conv2d
|
|
61
|
+
├── PassthroughSpec() — shape-preserving modules (BatchNorm, ReLU, …)
|
|
62
|
+
├── EmbeddingSpec(embedding_dim) — nn.Embedding
|
|
63
|
+
├── Pool2dSpec(kernel_size, stride, padding, dilation) — nn.MaxPool2d / nn.AvgPool2d
|
|
64
|
+
├── SequentialSpec(specs: tuple[ModuleSpec, ...]) — nn.Sequential
|
|
65
|
+
└── MultiheadAttentionSpec(embed_dim, num_heads, batch_first) — nn.MultiheadAttention
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
Spec values are stored in `module_specs` (keyed by attribute name) when their constructor is parsed from `__init__`. When `self.linear(x)` is called, the analyzer looks up `"linear"` in `module_specs`, retrieves the spec, and calls the appropriate inference function. Module aliases (`m = self.linear; m(x)`) are also supported: spec values stored in `env` are looked up before falling through to `func_sigs`.
|
|
69
|
+
|
|
70
|
+
## Diagnostic codes
|
|
71
|
+
|
|
72
|
+
| Code | Trigger |
|
|
73
|
+
|---|---|
|
|
74
|
+
| `TSF1001` | Annotation parse error (malformed `Annotated` or `Shape`) |
|
|
75
|
+
| `TSF1003` | Incompatible `matmul` / `bmm` shapes |
|
|
76
|
+
| `TSF1004` | Invalid `reshape` or `flatten` dimensions |
|
|
77
|
+
| `TSF1005` | Invalid `cat` or `stack` dimensions or mismatched shapes |
|
|
78
|
+
| `TSF1006` | Broadcasting incompatibility |
|
|
79
|
+
| `TSF1007` | `nn.Linear`, `nn.Conv2d`, or `nn.MaxPool2d`/`AvgPool2d` input shape mismatch |
|
|
80
|
+
| `TSF1008` | Invalid `permute`, `transpose`, `squeeze`, or `unsqueeze` dimensions |
|
|
81
|
+
| `TSF1009` | Return shape does not match the declared return type annotation |
|
|
82
|
+
|
|
83
|
+
## Adding a new operator
|
|
84
|
+
|
|
85
|
+
See [Development → Adding a new operator](development.md#adding-a-new-operator).
|