torch-mpo 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_mpo-0.1.0/.github/workflows/python-publish.yml +97 -0
- torch_mpo-0.1.0/.github/workflows/tests.yml +57 -0
- torch_mpo-0.1.0/.gitignore +5 -0
- torch_mpo-0.1.0/.isort.cfg +2 -0
- torch_mpo-0.1.0/LICENSE +21 -0
- torch_mpo-0.1.0/PKG-INFO +307 -0
- torch_mpo-0.1.0/README.md +264 -0
- torch_mpo-0.1.0/benchmarks/benchmark_layers.py +336 -0
- torch_mpo-0.1.0/docs/tutorial.md +425 -0
- torch_mpo-0.1.0/examples/cifar10_vgg16_mpo.py +236 -0
- torch_mpo-0.1.0/examples/compress_vgg.py +116 -0
- torch_mpo-0.1.0/examples/imagenet_resnet50_mpo.py +443 -0
- torch_mpo-0.1.0/examples/mnist_lenet5_mpo.py +219 -0
- torch_mpo-0.1.0/pyproject.toml +77 -0
- torch_mpo-0.1.0/src/torch_mpo/__init__.py +7 -0
- torch_mpo-0.1.0/src/torch_mpo/decomposition/__init__.py +5 -0
- torch_mpo-0.1.0/src/torch_mpo/decomposition/tt_svd.py +219 -0
- torch_mpo-0.1.0/src/torch_mpo/layers/__init__.py +6 -0
- torch_mpo-0.1.0/src/torch_mpo/layers/tt_conv.py +409 -0
- torch_mpo-0.1.0/src/torch_mpo/layers/tt_linear.py +395 -0
- torch_mpo-0.1.0/src/torch_mpo/models/__init__.py +24 -0
- torch_mpo-0.1.0/src/torch_mpo/models/resnet_mpo.py +472 -0
- torch_mpo-0.1.0/src/torch_mpo/models/vgg_mpo.py +334 -0
- torch_mpo-0.1.0/src/torch_mpo/utils/__init__.py +5 -0
- torch_mpo-0.1.0/src/torch_mpo/utils/compression.py +311 -0
- torch_mpo-0.1.0/tests/__init__.py +1 -0
- torch_mpo-0.1.0/tests/test_gradient_flow.py +157 -0
- torch_mpo-0.1.0/tests/test_integration.py +166 -0
- torch_mpo-0.1.0/tests/test_tt_conv.py +296 -0
- torch_mpo-0.1.0/tests/test_tt_linear.py +253 -0
- torch_mpo-0.1.0/uv.lock +2029 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
name: Build and Publish
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags: ["v*"]
|
|
6
|
+
workflow_dispatch:
|
|
7
|
+
|
|
8
|
+
permissions:
|
|
9
|
+
contents: read
|
|
10
|
+
|
|
11
|
+
jobs:
|
|
12
|
+
build:
|
|
13
|
+
name: Build distribution
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- name: Set up Python
|
|
20
|
+
uses: actions/setup-python@v5
|
|
21
|
+
with:
|
|
22
|
+
python-version: "3.12"
|
|
23
|
+
|
|
24
|
+
- name: Install dependencies
|
|
25
|
+
run: |
|
|
26
|
+
python -m pip install --upgrade pip
|
|
27
|
+
python -m pip install build
|
|
28
|
+
|
|
29
|
+
- name: Build package
|
|
30
|
+
run: python -m build
|
|
31
|
+
|
|
32
|
+
- name: Store distribution packages
|
|
33
|
+
uses: actions/upload-artifact@v4
|
|
34
|
+
with:
|
|
35
|
+
name: python-package-distributions
|
|
36
|
+
path: dist/
|
|
37
|
+
|
|
38
|
+
publish-to-pypi:
|
|
39
|
+
name: Publish to PyPI
|
|
40
|
+
needs:
|
|
41
|
+
- build
|
|
42
|
+
runs-on: ubuntu-latest
|
|
43
|
+
environment:
|
|
44
|
+
name: pypi
|
|
45
|
+
url: https://pypi.org/p/torch-mpo
|
|
46
|
+
permissions:
|
|
47
|
+
id-token: write # IMPORTANT: mandatory for trusted publishing
|
|
48
|
+
|
|
49
|
+
steps:
|
|
50
|
+
- name: Download distribution packages
|
|
51
|
+
uses: actions/download-artifact@v4
|
|
52
|
+
with:
|
|
53
|
+
name: python-package-distributions
|
|
54
|
+
path: dist/
|
|
55
|
+
|
|
56
|
+
- name: Publish to PyPI
|
|
57
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
58
|
+
|
|
59
|
+
github-release:
|
|
60
|
+
name: Create GitHub Release
|
|
61
|
+
needs:
|
|
62
|
+
- publish-to-pypi
|
|
63
|
+
runs-on: ubuntu-latest
|
|
64
|
+
permissions:
|
|
65
|
+
contents: write # IMPORTANT: mandatory for creating releases
|
|
66
|
+
id-token: write
|
|
67
|
+
|
|
68
|
+
steps:
|
|
69
|
+
- name: Download distribution packages
|
|
70
|
+
uses: actions/download-artifact@v4
|
|
71
|
+
with:
|
|
72
|
+
name: python-package-distributions
|
|
73
|
+
path: dist/
|
|
74
|
+
|
|
75
|
+
- name: Create GitHub Release
|
|
76
|
+
env:
|
|
77
|
+
GITHUB_TOKEN: ${{ github.token }}
|
|
78
|
+
run: >-
|
|
79
|
+
gh release create
|
|
80
|
+
'${{ github.ref_name }}'
|
|
81
|
+
--repo '${{ github.repository }}'
|
|
82
|
+
--notes "Release ${{ github.ref_name }}
|
|
83
|
+
|
|
84
|
+
Install from PyPI:
|
|
85
|
+
```bash
|
|
86
|
+
pip install torch-mpo
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
See [CHANGELOG.md](https://github.com/${{ github.repository }}/blob/main/CHANGELOG.md) for details."
|
|
90
|
+
|
|
91
|
+
- name: Upload artifacts to GitHub Release
|
|
92
|
+
env:
|
|
93
|
+
GITHUB_TOKEN: ${{ github.token }}
|
|
94
|
+
run: >-
|
|
95
|
+
gh release upload
|
|
96
|
+
'${{ github.ref_name }}' dist/**
|
|
97
|
+
--repo '${{ github.repository }}'
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
name: Tests
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
test:
|
|
11
|
+
runs-on: ${{ matrix.os }}
|
|
12
|
+
strategy:
|
|
13
|
+
fail-fast: false
|
|
14
|
+
matrix:
|
|
15
|
+
os: [ubuntu-latest, macos-latest, windows-latest]
|
|
16
|
+
python-version: ["3.12"]
|
|
17
|
+
|
|
18
|
+
steps:
|
|
19
|
+
- uses: actions/checkout@v4
|
|
20
|
+
|
|
21
|
+
- name: Set up Python ${{ matrix.python-version }}
|
|
22
|
+
uses: actions/setup-python@v5
|
|
23
|
+
with:
|
|
24
|
+
python-version: ${{ matrix.python-version }}
|
|
25
|
+
|
|
26
|
+
- name: Install dependencies
|
|
27
|
+
run: |
|
|
28
|
+
python -m pip install --upgrade pip
|
|
29
|
+
python -m pip install -e ".[dev]"
|
|
30
|
+
|
|
31
|
+
- name: Lint with flake8
|
|
32
|
+
run: |
|
|
33
|
+
# Stop the build if there are Python syntax errors or undefined names
|
|
34
|
+
python -m flake8 src tests --count --select=E9,F63,F7,F82 --show-source --statistics
|
|
35
|
+
# Exit-zero treats all errors as warnings
|
|
36
|
+
python -m flake8 src tests --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
|
37
|
+
|
|
38
|
+
- name: Format check with black
|
|
39
|
+
run: |
|
|
40
|
+
python -m black --check src tests
|
|
41
|
+
|
|
42
|
+
- name: Import sort check with isort
|
|
43
|
+
run: |
|
|
44
|
+
python -m isort --check-only src tests
|
|
45
|
+
|
|
46
|
+
- name: Type check with mypy
|
|
47
|
+
run: |
|
|
48
|
+
python -m mypy src/torch_mpo --ignore-missing-imports
|
|
49
|
+
|
|
50
|
+
- name: Test with pytest
|
|
51
|
+
run: |
|
|
52
|
+
python -m pytest tests --cov=torch_mpo --cov-report=term-missing
|
|
53
|
+
|
|
54
|
+
- name: Test package installation
|
|
55
|
+
run: |
|
|
56
|
+
python -m pip install .
|
|
57
|
+
python -c "from torch_mpo import TTLinear, TTConv2d, compress_model; print('Import successful')"
|
torch_mpo-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 torch-mpo contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
torch_mpo-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-mpo
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: PyTorch implementation of Matrix Product Operators for neural network compression
|
|
5
|
+
Project-URL: Homepage, https://github.com/krzysztofwos/torch-mpo
|
|
6
|
+
Project-URL: Repository, https://github.com/krzysztofwos/torch-mpo
|
|
7
|
+
Project-URL: Issues, https://github.com/krzysztofwos/torch-mpo/issues
|
|
8
|
+
Project-URL: Documentation, https://github.com/krzysztofwos/torch-mpo/blob/main/docs/tutorial.md
|
|
9
|
+
Author: Krzysztof Woś
|
|
10
|
+
Maintainer: Krzysztof Woś
|
|
11
|
+
License: MIT
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Keywords: deep-learning,model-compression,mpo,neural-networks,pytorch,tensor-decomposition,tensor-train
|
|
14
|
+
Classifier: Development Status :: 4 - Beta
|
|
15
|
+
Classifier: Intended Audience :: Science/Research
|
|
16
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
17
|
+
Classifier: Programming Language :: Python :: 3
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.12
|
|
21
|
+
Requires-Dist: matplotlib>=3.7.0
|
|
22
|
+
Requires-Dist: numpy>=1.24.0
|
|
23
|
+
Requires-Dist: pytorch-lightning>=2.0.0
|
|
24
|
+
Requires-Dist: tensorboard>=2.13.0
|
|
25
|
+
Requires-Dist: torch>=2.0.0
|
|
26
|
+
Requires-Dist: torchvision>=0.15.0
|
|
27
|
+
Requires-Dist: tqdm>=4.65.0
|
|
28
|
+
Requires-Dist: wandb>=0.15.0
|
|
29
|
+
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: black>=23.3.0; extra == 'dev'
|
|
31
|
+
Requires-Dist: flake8>=6.0.0; extra == 'dev'
|
|
32
|
+
Requires-Dist: isort>=5.12.0; extra == 'dev'
|
|
33
|
+
Requires-Dist: mypy>=1.15.0; extra == 'dev'
|
|
34
|
+
Requires-Dist: pytest-asyncio>=0.26.0; extra == 'dev'
|
|
35
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
|
|
36
|
+
Requires-Dist: pytest-mock>=3.14.0; extra == 'dev'
|
|
37
|
+
Requires-Dist: pytest>=8.3.5; extra == 'dev'
|
|
38
|
+
Provides-Extra: docs
|
|
39
|
+
Requires-Dist: sphinx-autodoc-typehints>=1.23.0; extra == 'docs'
|
|
40
|
+
Requires-Dist: sphinx-rtd-theme>=1.2.0; extra == 'docs'
|
|
41
|
+
Requires-Dist: sphinx>=6.2.0; extra == 'docs'
|
|
42
|
+
Description-Content-Type: text/markdown
|
|
43
|
+
|
|
44
|
+
# PyTorch Matrix Product Operators
|
|
45
|
+
|
|
46
|
+
A modern PyTorch implementation of Matrix Product Operators (MPO) for neural network compression, based on the paper "Compressing deep neural networks by matrix product operators" by Ze-Feng Gao et al.
|
|
47
|
+
|
|
48
|
+
## Overview
|
|
49
|
+
|
|
50
|
+
This library provides PyTorch implementations of tensor-train decomposed neural network layers that can significantly reduce the number of parameters in deep neural networks while maintaining accuracy.
|
|
51
|
+
|
|
52
|
+
## Features
|
|
53
|
+
|
|
54
|
+
- **TT-decomposed layers**: `TTLinear` and `TTConv2d` for compressed fully-connected and convolutional layers
|
|
55
|
+
- **Modern PyTorch**: Full compatibility with PyTorch 2.0+, type hints, device-agnostic
|
|
56
|
+
- **Pretrained model compression**: Convert existing PyTorch models to MPO format
|
|
57
|
+
- **Multiple architectures**: VGG-16/19, ResNet-18/34/50/101/152, and custom models
|
|
58
|
+
- **Automatic factorization**: Smart dimension factorization for optimal compression
|
|
59
|
+
- **Comprehensive examples**: MNIST, CIFAR-10, ImageNet training scripts
|
|
60
|
+
- **Analysis tools**: Compression ratio calculation, performance benchmarks
|
|
61
|
+
|
|
62
|
+
## Installation
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
# Clone the repository
|
|
66
|
+
git clone https://github.com/krzysztofwos/torch-mpo
|
|
67
|
+
cd torch-mpo
|
|
68
|
+
|
|
69
|
+
# Install with uv (recommended)
|
|
70
|
+
uv sync # Install base dependencies
|
|
71
|
+
uv sync --all-extras # Install with all extras (dev, docs)
|
|
72
|
+
|
|
73
|
+
# Or install with pip (development mode)
|
|
74
|
+
pip install -e .
|
|
75
|
+
pip install -e ".[dev]" # With development dependencies
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
## Quick Start
|
|
79
|
+
|
|
80
|
+
### Basic Usage
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
import torch
|
|
84
|
+
from torch_mpo import TTLinear, TTConv2d
|
|
85
|
+
|
|
86
|
+
# Create a TT-decomposed linear layer
|
|
87
|
+
linear = TTLinear(
|
|
88
|
+
in_features=1024,
|
|
89
|
+
out_features=512,
|
|
90
|
+
tt_ranks=8, # Higher rank = better accuracy, more parameters
|
|
91
|
+
bias=True
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Create a TT-decomposed convolutional layer
|
|
95
|
+
conv = TTConv2d(
|
|
96
|
+
in_channels=128,
|
|
97
|
+
out_channels=256,
|
|
98
|
+
kernel_size=3,
|
|
99
|
+
padding=1,
|
|
100
|
+
tt_ranks=8
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Use them like standard PyTorch layers
|
|
104
|
+
x = torch.randn(32, 1024)
|
|
105
|
+
y = linear(x) # Shape: [32, 512]
|
|
106
|
+
|
|
107
|
+
x = torch.randn(32, 128, 32, 32)
|
|
108
|
+
y = conv(x) # Shape: [32, 256, 32, 32]
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
### Compress Existing Models
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
from torch_mpo import compress_model
|
|
115
|
+
import torchvision.models as models
|
|
116
|
+
|
|
117
|
+
# Load a pretrained model
|
|
118
|
+
model = models.vgg16(pretrained=True)
|
|
119
|
+
|
|
120
|
+
# Compress it with MPO
|
|
121
|
+
compressed_model = compress_model(
|
|
122
|
+
model,
|
|
123
|
+
compression_ratio=0.1, # Target 10x compression
|
|
124
|
+
compress_linear=True, # Compress Linear layers
|
|
125
|
+
compress_conv=True, # Compress Conv2d layers
|
|
126
|
+
verbose=True
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Fine-tune the compressed model
|
|
130
|
+
optimizer = torch.optim.Adam(compressed_model.parameters(), lr=1e-4)
|
|
131
|
+
# ... continue with training
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
### Use Pre-built Architectures
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
from torch_mpo.models import vgg16_mpo, resnet50_mpo
|
|
138
|
+
|
|
139
|
+
# VGG-16 with MPO compression
|
|
140
|
+
model = vgg16_mpo(
|
|
141
|
+
num_classes=10,
|
|
142
|
+
tt_ranks_conv=8, # TT-rank for conv layers
|
|
143
|
+
tt_ranks_fc=16, # TT-rank for FC layers
|
|
144
|
+
compress_conv=True,
|
|
145
|
+
compress_fc=True
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# ResNet-50 with MPO compression
|
|
149
|
+
model = resnet50_mpo(
|
|
150
|
+
num_classes=1000,
|
|
151
|
+
tt_ranks_conv=16,
|
|
152
|
+
tt_ranks_fc=32,
|
|
153
|
+
use_mpo_conv=True,
|
|
154
|
+
use_mpo_fc=True
|
|
155
|
+
)
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
## Examples
|
|
159
|
+
|
|
160
|
+
The `examples/` directory contains complete training scripts:
|
|
161
|
+
|
|
162
|
+
### MNIST with LeNet-5 MPO
|
|
163
|
+
|
|
164
|
+
```bash
|
|
165
|
+
python examples/mnist_lenet5_mpo.py --tt-rank 8 --epochs 10
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
### CIFAR-10 with VGG-16 MPO
|
|
169
|
+
|
|
170
|
+
```bash
|
|
171
|
+
python examples/cifar10_vgg16_mpo.py --tt-rank-conv 8 --tt-rank-fc 16 --epochs 20
|
|
172
|
+
```
|
|
173
|
+
|
|
174
|
+
### ImageNet with ResNet-50 MPO
|
|
175
|
+
|
|
176
|
+
```bash
|
|
177
|
+
python examples/imagenet_resnet50_mpo.py /path/to/imagenet \
|
|
178
|
+
--tt-rank-conv 16 --tt-rank-fc 32 --epochs 90
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
### Compress Pretrained VGG
|
|
182
|
+
|
|
183
|
+
```bash
|
|
184
|
+
python examples/compress_vgg.py --model vgg16 --compression-ratio 0.1
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
## Performance Benchmarks
|
|
188
|
+
|
|
189
|
+
Run benchmarks to compare MPO layers with standard layers:
|
|
190
|
+
|
|
191
|
+
```bash
|
|
192
|
+
python benchmarks/benchmark_layers.py
|
|
193
|
+
```
|
|
194
|
+
|
|
195
|
+
### Typical Results
|
|
196
|
+
|
|
197
|
+
| Layer | Original Params | MPO Params (rank=8) | Compression | Speedup |
|
|
198
|
+
| ---------------------- | --------------- | ------------------- | ----------- | ------- |
|
|
199
|
+
| Linear(4096, 4096) | 16.8M | 655K | 25.6x | 0.8x |
|
|
200
|
+
| Conv2d(256, 512, 3) | 1.2M | 123K | 9.7x | 0.9x |
|
|
201
|
+
| VGG-16 (full model) | 138M | 15M | 9.2x | 0.85x |
|
|
202
|
+
| ResNet-50 (full model) | 25.6M | 8.2M | 3.1x | 0.95x |
|
|
203
|
+
|
|
204
|
+
## Documentation
|
|
205
|
+
|
|
206
|
+
See the comprehensive tutorial in `docs/mpo_tutorial.md` covering:
|
|
207
|
+
|
|
208
|
+
- Mathematical foundations of TT decomposition
|
|
209
|
+
- How MPO compression works
|
|
210
|
+
- Implementation details
|
|
211
|
+
- Best practices and tips
|
|
212
|
+
- Advanced topics
|
|
213
|
+
|
|
214
|
+
## Key Concepts
|
|
215
|
+
|
|
216
|
+
### TT-Ranks
|
|
217
|
+
|
|
218
|
+
The `tt_ranks` parameter controls the trade-off between compression and accuracy:
|
|
219
|
+
|
|
220
|
+
- **Lower ranks** (4-8): High compression, some accuracy loss
|
|
221
|
+
- **Medium ranks** (8-16): Good balance
|
|
222
|
+
- **Higher ranks** (16-32): Less compression, minimal accuracy loss
|
|
223
|
+
|
|
224
|
+
### Automatic Factorization
|
|
225
|
+
|
|
226
|
+
The library automatically factorizes dimensions for optimal compression:
|
|
227
|
+
|
|
228
|
+
```python
|
|
229
|
+
# 1024 = 4 × 16 × 16 (automatic factorization)
|
|
230
|
+
layer = TTLinear(1024, 512, tt_ranks=8)
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
### Custom Factorization
|
|
234
|
+
|
|
235
|
+
You can also specify custom factorizations:
|
|
236
|
+
|
|
237
|
+
```python
|
|
238
|
+
layer = TTLinear(
|
|
239
|
+
in_features=784, # 28×28 MNIST
|
|
240
|
+
out_features=256,
|
|
241
|
+
inp_modes=[7, 4, 7, 4], # 7×4×7×4 = 784
|
|
242
|
+
out_modes=[4, 4, 4, 4], # 4×4×4×4 = 256
|
|
243
|
+
tt_ranks=[1, 8, 8, 8, 1]
|
|
244
|
+
)
|
|
245
|
+
```
|
|
246
|
+
|
|
247
|
+
### Initialization and Numerical Stability
|
|
248
|
+
|
|
249
|
+
Proper initialization is crucial for TT-decomposed layers to maintain stable gradients during training:
|
|
250
|
+
|
|
251
|
+
#### TTLinear Initialization
|
|
252
|
+
|
|
253
|
+
- Uses standard Xavier/Kaiming initialization for each core
|
|
254
|
+
- No additional scaling needed as the decomposition naturally regularizes
|
|
255
|
+
|
|
256
|
+
#### TTConv2d Initialization
|
|
257
|
+
|
|
258
|
+
- More complex due to spatial convolution followed by TT cores
|
|
259
|
+
- **Key insight**: Variance accumulates through both spatial conv and TT cores
|
|
260
|
+
- **Solution**: TT cores are scaled by `1/d^0.25` where `d` is the number of cores
|
|
261
|
+
- This empirically maintains output variance similar to standard Conv2d layers
|
|
262
|
+
|
|
263
|
+
Without proper initialization scaling, deep networks can experience:
|
|
264
|
+
|
|
265
|
+
- **Exploding activations**: Outputs growing exponentially through layers
|
|
266
|
+
- **Vanishing gradients**: Making training impossible
|
|
267
|
+
- **Poor convergence**: Model stuck at random performance
|
|
268
|
+
|
|
269
|
+
The library handles this automatically, but when implementing custom layers, careful attention to initialization is essential.
|
|
270
|
+
|
|
271
|
+
## Contributing
|
|
272
|
+
|
|
273
|
+
Contributions are welcome. Please feel free to submit a Pull Request.
|
|
274
|
+
|
|
275
|
+
## Citation
|
|
276
|
+
|
|
277
|
+
If you use this code in your research, please cite both the original paper and this implementation:
|
|
278
|
+
|
|
279
|
+
### Original Paper
|
|
280
|
+
|
|
281
|
+
```bibtex
|
|
282
|
+
@article{gao2020compressing,
|
|
283
|
+
title={Compressing deep neural networks by matrix product operators},
|
|
284
|
+
author={Gao, Ze-Feng and Song, Chao and Wang, Lei and others},
|
|
285
|
+
journal={Physical Review Research},
|
|
286
|
+
volume={2},
|
|
287
|
+
number={2},
|
|
288
|
+
pages={023300},
|
|
289
|
+
year={2020}
|
|
290
|
+
}
|
|
291
|
+
```
|
|
292
|
+
|
|
293
|
+
### This Implementation
|
|
294
|
+
|
|
295
|
+
```bibtex
|
|
296
|
+
@software{torch-mpo2024,
|
|
297
|
+
title={torch-mpo: PyTorch Matrix Product Operators},
|
|
298
|
+
author={Woś, Krzysztof},
|
|
299
|
+
year={2024},
|
|
300
|
+
url={https://github.com/krzysztofwos/torch-mpo},
|
|
301
|
+
version={0.1.0}
|
|
302
|
+
}
|
|
303
|
+
```
|
|
304
|
+
|
|
305
|
+
## License
|
|
306
|
+
|
|
307
|
+
MIT License
|