cotengrust 0.1.1__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotengrust-0.1.3/.github/workflows/CI.yml +138 -0
- {cotengrust-0.1.1 → cotengrust-0.1.3}/Cargo.lock +34 -19
- {cotengrust-0.1.1 → cotengrust-0.1.3}/Cargo.toml +3 -3
- {cotengrust-0.1.1 → cotengrust-0.1.3}/PKG-INFO +83 -7
- {cotengrust-0.1.1 → cotengrust-0.1.3}/README.md +81 -5
- {cotengrust-0.1.1 → cotengrust-0.1.3}/pyproject.toml +2 -2
- {cotengrust-0.1.1 → cotengrust-0.1.3}/src/lib.rs +203 -50
- {cotengrust-0.1.1 → cotengrust-0.1.3}/tests/test_cotengrust.py +33 -3
- cotengrust-0.1.1/.github/workflows/CI.yml +0 -162
- {cotengrust-0.1.1 → cotengrust-0.1.3}/.gitignore +0 -0
- {cotengrust-0.1.1 → cotengrust-0.1.3}/LICENSE +0 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# This file is autogenerated by maturin v1.5.1
|
|
2
|
+
# To update, run
|
|
3
|
+
#
|
|
4
|
+
# maturin generate-ci github
|
|
5
|
+
#
|
|
6
|
+
name: CI
|
|
7
|
+
|
|
8
|
+
on:
|
|
9
|
+
push:
|
|
10
|
+
branches:
|
|
11
|
+
- main
|
|
12
|
+
- master
|
|
13
|
+
tags:
|
|
14
|
+
- '*'
|
|
15
|
+
pull_request:
|
|
16
|
+
workflow_dispatch:
|
|
17
|
+
|
|
18
|
+
permissions:
|
|
19
|
+
contents: read
|
|
20
|
+
|
|
21
|
+
jobs:
|
|
22
|
+
linux:
|
|
23
|
+
runs-on: ${{ matrix.platform.runner }}
|
|
24
|
+
strategy:
|
|
25
|
+
matrix:
|
|
26
|
+
platform:
|
|
27
|
+
- runner: ubuntu-latest
|
|
28
|
+
target: x86_64
|
|
29
|
+
- runner: ubuntu-latest
|
|
30
|
+
target: x86
|
|
31
|
+
- runner: ubuntu-latest
|
|
32
|
+
target: aarch64
|
|
33
|
+
- runner: ubuntu-latest
|
|
34
|
+
target: armv7
|
|
35
|
+
- runner: ubuntu-latest
|
|
36
|
+
target: s390x
|
|
37
|
+
- runner: ubuntu-latest
|
|
38
|
+
target: ppc64le
|
|
39
|
+
steps:
|
|
40
|
+
- uses: actions/checkout@v4
|
|
41
|
+
- uses: actions/setup-python@v5
|
|
42
|
+
with:
|
|
43
|
+
python-version: '3.10'
|
|
44
|
+
- name: Build wheels
|
|
45
|
+
uses: PyO3/maturin-action@v1
|
|
46
|
+
with:
|
|
47
|
+
target: ${{ matrix.platform.target }}
|
|
48
|
+
args: --release --out dist --find-interpreter
|
|
49
|
+
sccache: 'true'
|
|
50
|
+
manylinux: auto
|
|
51
|
+
- name: Upload wheels
|
|
52
|
+
uses: actions/upload-artifact@v4
|
|
53
|
+
with:
|
|
54
|
+
name: wheels-linux-${{ matrix.platform.target }}
|
|
55
|
+
path: dist
|
|
56
|
+
|
|
57
|
+
windows:
|
|
58
|
+
runs-on: ${{ matrix.platform.runner }}
|
|
59
|
+
strategy:
|
|
60
|
+
matrix:
|
|
61
|
+
platform:
|
|
62
|
+
- runner: windows-latest
|
|
63
|
+
target: x64
|
|
64
|
+
- runner: windows-latest
|
|
65
|
+
target: x86
|
|
66
|
+
steps:
|
|
67
|
+
- uses: actions/checkout@v4
|
|
68
|
+
- uses: actions/setup-python@v5
|
|
69
|
+
with:
|
|
70
|
+
python-version: '3.10'
|
|
71
|
+
architecture: ${{ matrix.platform.target }}
|
|
72
|
+
- name: Build wheels
|
|
73
|
+
uses: PyO3/maturin-action@v1
|
|
74
|
+
with:
|
|
75
|
+
target: ${{ matrix.platform.target }}
|
|
76
|
+
args: --release --out dist --find-interpreter
|
|
77
|
+
sccache: 'true'
|
|
78
|
+
- name: Upload wheels
|
|
79
|
+
uses: actions/upload-artifact@v4
|
|
80
|
+
with:
|
|
81
|
+
name: wheels-windows-${{ matrix.platform.target }}
|
|
82
|
+
path: dist
|
|
83
|
+
|
|
84
|
+
macos:
|
|
85
|
+
runs-on: ${{ matrix.platform.runner }}
|
|
86
|
+
strategy:
|
|
87
|
+
matrix:
|
|
88
|
+
platform:
|
|
89
|
+
- runner: macos-latest
|
|
90
|
+
target: x86_64
|
|
91
|
+
- runner: macos-14
|
|
92
|
+
target: aarch64
|
|
93
|
+
steps:
|
|
94
|
+
- uses: actions/checkout@v4
|
|
95
|
+
- uses: actions/setup-python@v5
|
|
96
|
+
with:
|
|
97
|
+
python-version: '3.10'
|
|
98
|
+
- name: Build wheels
|
|
99
|
+
uses: PyO3/maturin-action@v1
|
|
100
|
+
with:
|
|
101
|
+
target: ${{ matrix.platform.target }}
|
|
102
|
+
args: --release --out dist --find-interpreter
|
|
103
|
+
sccache: 'true'
|
|
104
|
+
- name: Upload wheels
|
|
105
|
+
uses: actions/upload-artifact@v4
|
|
106
|
+
with:
|
|
107
|
+
name: wheels-macos-${{ matrix.platform.target }}
|
|
108
|
+
path: dist
|
|
109
|
+
|
|
110
|
+
sdist:
|
|
111
|
+
runs-on: ubuntu-latest
|
|
112
|
+
steps:
|
|
113
|
+
- uses: actions/checkout@v4
|
|
114
|
+
- name: Build sdist
|
|
115
|
+
uses: PyO3/maturin-action@v1
|
|
116
|
+
with:
|
|
117
|
+
command: sdist
|
|
118
|
+
args: --out dist
|
|
119
|
+
- name: Upload sdist
|
|
120
|
+
uses: actions/upload-artifact@v4
|
|
121
|
+
with:
|
|
122
|
+
name: wheels-sdist
|
|
123
|
+
path: dist
|
|
124
|
+
|
|
125
|
+
release:
|
|
126
|
+
name: Release
|
|
127
|
+
runs-on: ubuntu-latest
|
|
128
|
+
if: startsWith(github.ref, 'refs/tags/')
|
|
129
|
+
needs: [linux, windows, macos, sdist]
|
|
130
|
+
steps:
|
|
131
|
+
- uses: actions/download-artifact@v4
|
|
132
|
+
- name: Publish to PyPI
|
|
133
|
+
uses: PyO3/maturin-action@v1
|
|
134
|
+
env:
|
|
135
|
+
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
|
136
|
+
with:
|
|
137
|
+
command: upload
|
|
138
|
+
args: --non-interactive --skip-existing wheels-*/*
|
|
@@ -37,7 +37,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
|
|
37
37
|
|
|
38
38
|
[[package]]
|
|
39
39
|
name = "cotengrust"
|
|
40
|
-
version = "0.1.
|
|
40
|
+
version = "0.1.3"
|
|
41
41
|
dependencies = [
|
|
42
42
|
"bit-set",
|
|
43
43
|
"ordered-float",
|
|
@@ -57,11 +57,17 @@ dependencies = [
|
|
|
57
57
|
"wasi",
|
|
58
58
|
]
|
|
59
59
|
|
|
60
|
+
[[package]]
|
|
61
|
+
name = "heck"
|
|
62
|
+
version = "0.4.1"
|
|
63
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
64
|
+
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
|
65
|
+
|
|
60
66
|
[[package]]
|
|
61
67
|
name = "indoc"
|
|
62
|
-
version = "
|
|
68
|
+
version = "2.0.4"
|
|
63
69
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
64
|
-
checksum = "
|
|
70
|
+
checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8"
|
|
65
71
|
|
|
66
72
|
[[package]]
|
|
67
73
|
name = "libc"
|
|
@@ -105,9 +111,9 @@ checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
|
|
|
105
111
|
|
|
106
112
|
[[package]]
|
|
107
113
|
name = "ordered-float"
|
|
108
|
-
version = "
|
|
114
|
+
version = "4.2.0"
|
|
109
115
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
110
|
-
checksum = "
|
|
116
|
+
checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e"
|
|
111
117
|
dependencies = [
|
|
112
118
|
"num-traits",
|
|
113
119
|
]
|
|
@@ -135,6 +141,12 @@ dependencies = [
|
|
|
135
141
|
"windows-targets",
|
|
136
142
|
]
|
|
137
143
|
|
|
144
|
+
[[package]]
|
|
145
|
+
name = "portable-atomic"
|
|
146
|
+
version = "1.6.0"
|
|
147
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
148
|
+
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
|
|
149
|
+
|
|
138
150
|
[[package]]
|
|
139
151
|
name = "ppv-lite86"
|
|
140
152
|
version = "0.2.17"
|
|
@@ -152,15 +164,16 @@ dependencies = [
|
|
|
152
164
|
|
|
153
165
|
[[package]]
|
|
154
166
|
name = "pyo3"
|
|
155
|
-
version = "0.
|
|
167
|
+
version = "0.21.1"
|
|
156
168
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
157
|
-
checksum = "
|
|
169
|
+
checksum = "a7a8b1990bd018761768d5e608a13df8bd1ac5f678456e0f301bb93e5f3ea16b"
|
|
158
170
|
dependencies = [
|
|
159
171
|
"cfg-if",
|
|
160
172
|
"indoc",
|
|
161
173
|
"libc",
|
|
162
174
|
"memoffset",
|
|
163
175
|
"parking_lot",
|
|
176
|
+
"portable-atomic",
|
|
164
177
|
"pyo3-build-config",
|
|
165
178
|
"pyo3-ffi",
|
|
166
179
|
"pyo3-macros",
|
|
@@ -169,9 +182,9 @@ dependencies = [
|
|
|
169
182
|
|
|
170
183
|
[[package]]
|
|
171
184
|
name = "pyo3-build-config"
|
|
172
|
-
version = "0.
|
|
185
|
+
version = "0.21.1"
|
|
173
186
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
174
|
-
checksum = "
|
|
187
|
+
checksum = "650dca34d463b6cdbdb02b1d71bfd6eb6b6816afc708faebb3bac1380ff4aef7"
|
|
175
188
|
dependencies = [
|
|
176
189
|
"once_cell",
|
|
177
190
|
"target-lexicon",
|
|
@@ -179,9 +192,9 @@ dependencies = [
|
|
|
179
192
|
|
|
180
193
|
[[package]]
|
|
181
194
|
name = "pyo3-ffi"
|
|
182
|
-
version = "0.
|
|
195
|
+
version = "0.21.1"
|
|
183
196
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
184
|
-
checksum = "
|
|
197
|
+
checksum = "09a7da8fc04a8a2084909b59f29e1b8474decac98b951d77b80b26dc45f046ad"
|
|
185
198
|
dependencies = [
|
|
186
199
|
"libc",
|
|
187
200
|
"pyo3-build-config",
|
|
@@ -189,9 +202,9 @@ dependencies = [
|
|
|
189
202
|
|
|
190
203
|
[[package]]
|
|
191
204
|
name = "pyo3-macros"
|
|
192
|
-
version = "0.
|
|
205
|
+
version = "0.21.1"
|
|
193
206
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
194
|
-
checksum = "
|
|
207
|
+
checksum = "4b8a199fce11ebb28e3569387228836ea98110e43a804a530a9fd83ade36d513"
|
|
195
208
|
dependencies = [
|
|
196
209
|
"proc-macro2",
|
|
197
210
|
"pyo3-macros-backend",
|
|
@@ -201,11 +214,13 @@ dependencies = [
|
|
|
201
214
|
|
|
202
215
|
[[package]]
|
|
203
216
|
name = "pyo3-macros-backend"
|
|
204
|
-
version = "0.
|
|
217
|
+
version = "0.21.1"
|
|
205
218
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
206
|
-
checksum = "
|
|
219
|
+
checksum = "93fbbfd7eb553d10036513cb122b888dcd362a945a00b06c165f2ab480d4cc3b"
|
|
207
220
|
dependencies = [
|
|
221
|
+
"heck",
|
|
208
222
|
"proc-macro2",
|
|
223
|
+
"pyo3-build-config",
|
|
209
224
|
"quote",
|
|
210
225
|
"syn",
|
|
211
226
|
]
|
|
@@ -278,9 +293,9 @@ checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
|
|
278
293
|
|
|
279
294
|
[[package]]
|
|
280
295
|
name = "syn"
|
|
281
|
-
version = "
|
|
296
|
+
version = "2.0.32"
|
|
282
297
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
283
|
-
checksum = "
|
|
298
|
+
checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2"
|
|
284
299
|
dependencies = [
|
|
285
300
|
"proc-macro2",
|
|
286
301
|
"quote",
|
|
@@ -301,9 +316,9 @@ checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c"
|
|
|
301
316
|
|
|
302
317
|
[[package]]
|
|
303
318
|
name = "unindent"
|
|
304
|
-
version = "0.
|
|
319
|
+
version = "0.2.3"
|
|
305
320
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
306
|
-
checksum = "
|
|
321
|
+
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
|
|
307
322
|
|
|
308
323
|
[[package]]
|
|
309
324
|
name = "wasi"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "cotengrust"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.3"
|
|
4
4
|
edition = "2021"
|
|
5
5
|
|
|
6
6
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
@@ -10,8 +10,8 @@ crate-type = ["cdylib"]
|
|
|
10
10
|
|
|
11
11
|
[dependencies]
|
|
12
12
|
bit-set = "0.5"
|
|
13
|
-
|
|
14
|
-
|
|
13
|
+
ordered-float = "4.2"
|
|
14
|
+
pyo3 = "0.21"
|
|
15
15
|
rand = "0.8"
|
|
16
16
|
rustc-hash = "1.1"
|
|
17
17
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
2
|
Name: cotengrust
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Classifier: Programming Language :: Rust
|
|
5
5
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
6
6
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
@@ -19,9 +19,14 @@ are:
|
|
|
19
19
|
- `optimize_optimal(inputs, output, size_dict, **kwargs)`
|
|
20
20
|
- `optimize_greedy(inputs, output, size_dict, **kwargs)`
|
|
21
21
|
|
|
22
|
-
The optimal algorithm is an optimized version of the `opt_einsum` 'dp'
|
|
22
|
+
The optimal algorithm is an optimized version of the `opt_einsum` 'dp'
|
|
23
23
|
path - itself an implementation of https://arxiv.org/abs/1304.6112.
|
|
24
24
|
|
|
25
|
+
There is also a variant of the greedy algorithm, which runs `ntrials` of greedy,
|
|
26
|
+
randomized paths and computes and reports the flops cost (log10) simultaneously:
|
|
27
|
+
|
|
28
|
+
- `optimize_random_greedy_track_flops(inputs, output, size_dict, **kwargs)`
|
|
29
|
+
|
|
25
30
|
|
|
26
31
|
## Installation
|
|
27
32
|
|
|
@@ -32,7 +37,7 @@ path - itself an implementation of https://arxiv.org/abs/1304.6112.
|
|
|
32
37
|
pip install cotengrust
|
|
33
38
|
```
|
|
34
39
|
|
|
35
|
-
or if you want to develop locally (which requires [pyo3](https://github.com/PyO3/pyo3)
|
|
40
|
+
or if you want to develop locally (which requires [pyo3](https://github.com/PyO3/pyo3)
|
|
36
41
|
and [maturin](https://github.com/PyO3/maturin)):
|
|
37
42
|
|
|
38
43
|
```bash
|
|
@@ -46,8 +51,8 @@ maturin develop --release
|
|
|
46
51
|
## Usage
|
|
47
52
|
|
|
48
53
|
If `cotengrust` is installed, then by default `cotengra` will use it for its
|
|
49
|
-
greedy and optimal subroutines, notably subtree
|
|
50
|
-
call the routines directly:
|
|
54
|
+
greedy, random-greedy, and optimal subroutines, notably subtree
|
|
55
|
+
reconfiguration. You can also call the routines directly:
|
|
51
56
|
|
|
52
57
|
```python
|
|
53
58
|
import cotengra as ctg
|
|
@@ -171,7 +176,7 @@ def optimize_greedy(
|
|
|
171
176
|
When assessing local greedy scores how much to weight the size of the
|
|
172
177
|
tensors removed compared to the size of the tensor added::
|
|
173
178
|
|
|
174
|
-
score = size_ab
|
|
179
|
+
score = size_ab / costmod - (size_a + size_b) * costmod
|
|
175
180
|
|
|
176
181
|
This can be a useful hyper-parameter to tune.
|
|
177
182
|
temperature : float, optional
|
|
@@ -237,6 +242,77 @@ def optimize_simplify(
|
|
|
237
242
|
"""
|
|
238
243
|
...
|
|
239
244
|
|
|
245
|
+
def optimize_random_greedy_track_flops(
|
|
246
|
+
inputs,
|
|
247
|
+
output,
|
|
248
|
+
size_dict,
|
|
249
|
+
ntrials=1,
|
|
250
|
+
costmod=(0.1, 4.0),
|
|
251
|
+
temperature=(0.001, 1.0),
|
|
252
|
+
seed=None,
|
|
253
|
+
simplify=True,
|
|
254
|
+
use_ssa=False,
|
|
255
|
+
):
|
|
256
|
+
"""Perform a batch of random greedy optimizations, simulteneously tracking
|
|
257
|
+
the best contraction path in terms of flops, so as to avoid constructing a
|
|
258
|
+
separate contraction tree.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
inputs : tuple[tuple[str]]
|
|
263
|
+
The indices of each input tensor.
|
|
264
|
+
output : tuple[str]
|
|
265
|
+
The indices of the output tensor.
|
|
266
|
+
size_dict : dict[str, int]
|
|
267
|
+
A dictionary mapping indices to their dimension.
|
|
268
|
+
ntrials : int, optional
|
|
269
|
+
The number of random greedy trials to perform. The default is 1.
|
|
270
|
+
costmod : (float, float), optional
|
|
271
|
+
When assessing local greedy scores how much to weight the size of the
|
|
272
|
+
tensors removed compared to the size of the tensor added::
|
|
273
|
+
|
|
274
|
+
score = size_ab / costmod - (size_a + size_b) * costmod
|
|
275
|
+
|
|
276
|
+
It is sampled uniformly from the given range.
|
|
277
|
+
temperature : (float, float), optional
|
|
278
|
+
When asessing local greedy scores, how much to randomly perturb the
|
|
279
|
+
score. This is implemented as::
|
|
280
|
+
|
|
281
|
+
score -> sign(score) * log(|score|) - temperature * gumbel()
|
|
282
|
+
|
|
283
|
+
which implements boltzmann sampling. It is sampled log-uniformly from
|
|
284
|
+
the given range.
|
|
285
|
+
seed : int, optional
|
|
286
|
+
The seed for the random number generator.
|
|
287
|
+
simplify : bool, optional
|
|
288
|
+
Whether to perform simplifications before optimizing. These are:
|
|
289
|
+
|
|
290
|
+
- ignore any indices that appear in all terms
|
|
291
|
+
- combine any repeated indices within a single term
|
|
292
|
+
- reduce any non-output indices that only appear on a single term
|
|
293
|
+
- combine any scalar terms
|
|
294
|
+
- combine any tensors with matching indices (hadamard products)
|
|
295
|
+
|
|
296
|
+
Such simpifications may be required in the general case for the proper
|
|
297
|
+
functioning of the core optimization, but may be skipped if the input
|
|
298
|
+
indices are already in a simplified form.
|
|
299
|
+
use_ssa : bool, optional
|
|
300
|
+
Whether to return the contraction path in 'single static assignment'
|
|
301
|
+
(SSA) format (i.e. as if each intermediate is appended to the list of
|
|
302
|
+
inputs, without removals). This can be quicker and easier to work with
|
|
303
|
+
than the 'linear recycled' format that `numpy` and `opt_einsum` use.
|
|
304
|
+
|
|
305
|
+
Returns
|
|
306
|
+
-------
|
|
307
|
+
path : list[list[int]]
|
|
308
|
+
The best contraction path, given as a sequence of pairs of node
|
|
309
|
+
indices.
|
|
310
|
+
flops : float
|
|
311
|
+
The flops (/ contraction cost / number of multiplications), of the best
|
|
312
|
+
contraction path, given log10.
|
|
313
|
+
"""
|
|
314
|
+
...
|
|
315
|
+
|
|
240
316
|
def ssa_to_linear(ssa_path, n=None):
|
|
241
317
|
"""Convert a SSA path to linear format."""
|
|
242
318
|
...
|
|
@@ -7,9 +7,14 @@ are:
|
|
|
7
7
|
- `optimize_optimal(inputs, output, size_dict, **kwargs)`
|
|
8
8
|
- `optimize_greedy(inputs, output, size_dict, **kwargs)`
|
|
9
9
|
|
|
10
|
-
The optimal algorithm is an optimized version of the `opt_einsum` 'dp'
|
|
10
|
+
The optimal algorithm is an optimized version of the `opt_einsum` 'dp'
|
|
11
11
|
path - itself an implementation of https://arxiv.org/abs/1304.6112.
|
|
12
12
|
|
|
13
|
+
There is also a variant of the greedy algorithm, which runs `ntrials` of greedy,
|
|
14
|
+
randomized paths and computes and reports the flops cost (log10) simultaneously:
|
|
15
|
+
|
|
16
|
+
- `optimize_random_greedy_track_flops(inputs, output, size_dict, **kwargs)`
|
|
17
|
+
|
|
13
18
|
|
|
14
19
|
## Installation
|
|
15
20
|
|
|
@@ -20,7 +25,7 @@ path - itself an implementation of https://arxiv.org/abs/1304.6112.
|
|
|
20
25
|
pip install cotengrust
|
|
21
26
|
```
|
|
22
27
|
|
|
23
|
-
or if you want to develop locally (which requires [pyo3](https://github.com/PyO3/pyo3)
|
|
28
|
+
or if you want to develop locally (which requires [pyo3](https://github.com/PyO3/pyo3)
|
|
24
29
|
and [maturin](https://github.com/PyO3/maturin)):
|
|
25
30
|
|
|
26
31
|
```bash
|
|
@@ -34,8 +39,8 @@ maturin develop --release
|
|
|
34
39
|
## Usage
|
|
35
40
|
|
|
36
41
|
If `cotengrust` is installed, then by default `cotengra` will use it for its
|
|
37
|
-
greedy and optimal subroutines, notably subtree
|
|
38
|
-
call the routines directly:
|
|
42
|
+
greedy, random-greedy, and optimal subroutines, notably subtree
|
|
43
|
+
reconfiguration. You can also call the routines directly:
|
|
39
44
|
|
|
40
45
|
```python
|
|
41
46
|
import cotengra as ctg
|
|
@@ -159,7 +164,7 @@ def optimize_greedy(
|
|
|
159
164
|
When assessing local greedy scores how much to weight the size of the
|
|
160
165
|
tensors removed compared to the size of the tensor added::
|
|
161
166
|
|
|
162
|
-
score = size_ab
|
|
167
|
+
score = size_ab / costmod - (size_a + size_b) * costmod
|
|
163
168
|
|
|
164
169
|
This can be a useful hyper-parameter to tune.
|
|
165
170
|
temperature : float, optional
|
|
@@ -225,6 +230,77 @@ def optimize_simplify(
|
|
|
225
230
|
"""
|
|
226
231
|
...
|
|
227
232
|
|
|
233
|
+
def optimize_random_greedy_track_flops(
|
|
234
|
+
inputs,
|
|
235
|
+
output,
|
|
236
|
+
size_dict,
|
|
237
|
+
ntrials=1,
|
|
238
|
+
costmod=(0.1, 4.0),
|
|
239
|
+
temperature=(0.001, 1.0),
|
|
240
|
+
seed=None,
|
|
241
|
+
simplify=True,
|
|
242
|
+
use_ssa=False,
|
|
243
|
+
):
|
|
244
|
+
"""Perform a batch of random greedy optimizations, simulteneously tracking
|
|
245
|
+
the best contraction path in terms of flops, so as to avoid constructing a
|
|
246
|
+
separate contraction tree.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
inputs : tuple[tuple[str]]
|
|
251
|
+
The indices of each input tensor.
|
|
252
|
+
output : tuple[str]
|
|
253
|
+
The indices of the output tensor.
|
|
254
|
+
size_dict : dict[str, int]
|
|
255
|
+
A dictionary mapping indices to their dimension.
|
|
256
|
+
ntrials : int, optional
|
|
257
|
+
The number of random greedy trials to perform. The default is 1.
|
|
258
|
+
costmod : (float, float), optional
|
|
259
|
+
When assessing local greedy scores how much to weight the size of the
|
|
260
|
+
tensors removed compared to the size of the tensor added::
|
|
261
|
+
|
|
262
|
+
score = size_ab / costmod - (size_a + size_b) * costmod
|
|
263
|
+
|
|
264
|
+
It is sampled uniformly from the given range.
|
|
265
|
+
temperature : (float, float), optional
|
|
266
|
+
When asessing local greedy scores, how much to randomly perturb the
|
|
267
|
+
score. This is implemented as::
|
|
268
|
+
|
|
269
|
+
score -> sign(score) * log(|score|) - temperature * gumbel()
|
|
270
|
+
|
|
271
|
+
which implements boltzmann sampling. It is sampled log-uniformly from
|
|
272
|
+
the given range.
|
|
273
|
+
seed : int, optional
|
|
274
|
+
The seed for the random number generator.
|
|
275
|
+
simplify : bool, optional
|
|
276
|
+
Whether to perform simplifications before optimizing. These are:
|
|
277
|
+
|
|
278
|
+
- ignore any indices that appear in all terms
|
|
279
|
+
- combine any repeated indices within a single term
|
|
280
|
+
- reduce any non-output indices that only appear on a single term
|
|
281
|
+
- combine any scalar terms
|
|
282
|
+
- combine any tensors with matching indices (hadamard products)
|
|
283
|
+
|
|
284
|
+
Such simpifications may be required in the general case for the proper
|
|
285
|
+
functioning of the core optimization, but may be skipped if the input
|
|
286
|
+
indices are already in a simplified form.
|
|
287
|
+
use_ssa : bool, optional
|
|
288
|
+
Whether to return the contraction path in 'single static assignment'
|
|
289
|
+
(SSA) format (i.e. as if each intermediate is appended to the list of
|
|
290
|
+
inputs, without removals). This can be quicker and easier to work with
|
|
291
|
+
than the 'linear recycled' format that `numpy` and `opt_einsum` use.
|
|
292
|
+
|
|
293
|
+
Returns
|
|
294
|
+
-------
|
|
295
|
+
path : list[list[int]]
|
|
296
|
+
The best contraction path, given as a sequence of pairs of node
|
|
297
|
+
indices.
|
|
298
|
+
flops : float
|
|
299
|
+
The flops (/ contraction cost / number of multiplications), of the best
|
|
300
|
+
contraction path, given log10.
|
|
301
|
+
"""
|
|
302
|
+
...
|
|
303
|
+
|
|
228
304
|
def ssa_to_linear(ssa_path, n=None):
|
|
229
305
|
"""Convert a SSA path to linear format."""
|
|
230
306
|
...
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "cotengrust"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.3"
|
|
4
4
|
description = "Fast contraction ordering primitives for tensor networks."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.8"
|
|
@@ -15,7 +15,7 @@ authors = [
|
|
|
15
15
|
]
|
|
16
16
|
|
|
17
17
|
[build-system]
|
|
18
|
-
requires = ["maturin>=0.
|
|
18
|
+
requires = ["maturin>=1.0,<2.0"]
|
|
19
19
|
build-backend = "maturin"
|
|
20
20
|
|
|
21
21
|
[tool.maturin]
|
|
@@ -2,8 +2,9 @@ use bit_set::BitSet;
|
|
|
2
2
|
use ordered_float::OrderedFloat;
|
|
3
3
|
use pyo3::prelude::*;
|
|
4
4
|
use rand::Rng;
|
|
5
|
+
use rand::SeedableRng;
|
|
5
6
|
use rustc_hash::FxHashMap;
|
|
6
|
-
use std::collections::{BTreeSet, BinaryHeap};
|
|
7
|
+
use std::collections::{BTreeSet, BinaryHeap, HashSet};
|
|
7
8
|
use std::f32;
|
|
8
9
|
|
|
9
10
|
use FxHashMap as Dict;
|
|
@@ -23,6 +24,7 @@ type BitPath = Vec<(Subgraph, Subgraph)>;
|
|
|
23
24
|
type SubContraction = (Legs, Score, BitPath);
|
|
24
25
|
|
|
25
26
|
/// helper struct to build contractions from bottom up
|
|
27
|
+
#[derive(Clone)]
|
|
26
28
|
struct ContractionProcessor {
|
|
27
29
|
nodes: Dict<Node, Legs>,
|
|
28
30
|
edges: Dict<Ix, BTreeSet<Node>>,
|
|
@@ -30,6 +32,9 @@ struct ContractionProcessor {
|
|
|
30
32
|
sizes: Vec<Score>,
|
|
31
33
|
ssa: Node,
|
|
32
34
|
ssa_path: SSAPath,
|
|
35
|
+
track_flops: bool,
|
|
36
|
+
flops: Score,
|
|
37
|
+
flops_limit: Score,
|
|
33
38
|
}
|
|
34
39
|
|
|
35
40
|
/// given log(x) and log(y) compute log(x + y), without exponentiating both
|
|
@@ -94,6 +99,21 @@ fn compute_size(legs: &Legs, sizes: &Vec<Score>) -> Score {
|
|
|
94
99
|
legs.iter().map(|&(ix, _)| sizes[ix as usize]).sum()
|
|
95
100
|
}
|
|
96
101
|
|
|
102
|
+
fn compute_flops(ilegs: &Legs, jlegs: &Legs, sizes: &Vec<Score>) -> Score {
|
|
103
|
+
let mut flops: Score = 0.0;
|
|
104
|
+
let mut seen: HashSet<Ix> = HashSet::with_capacity(ilegs.len());
|
|
105
|
+
for &(ix, _) in ilegs {
|
|
106
|
+
seen.insert(ix);
|
|
107
|
+
flops += sizes[ix as usize];
|
|
108
|
+
}
|
|
109
|
+
for (ix, _) in jlegs {
|
|
110
|
+
if !seen.contains(ix) {
|
|
111
|
+
flops += sizes[*ix as usize];
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
flops
|
|
115
|
+
}
|
|
116
|
+
|
|
97
117
|
fn is_simplifiable(legs: &Legs, appearances: &Vec<Count>) -> bool {
|
|
98
118
|
let mut prev_ix = Node::MAX;
|
|
99
119
|
for &(ix, ix_count) in legs {
|
|
@@ -131,7 +151,12 @@ impl ContractionProcessor {
|
|
|
131
151
|
inputs: Vec<Vec<char>>,
|
|
132
152
|
output: Vec<char>,
|
|
133
153
|
size_dict: Dict<char, f32>,
|
|
154
|
+
track_flops: bool,
|
|
134
155
|
) -> ContractionProcessor {
|
|
156
|
+
if size_dict.len() > Ix::MAX as usize {
|
|
157
|
+
panic!("cotengrust: too many indices, maximum is {}", Ix::MAX);
|
|
158
|
+
}
|
|
159
|
+
|
|
135
160
|
let mut nodes: Dict<Node, Legs> = Dict::default();
|
|
136
161
|
let mut edges: Dict<Ix, BTreeSet<Node>> = Dict::default();
|
|
137
162
|
let mut indmap: Dict<char, Ix> = Dict::default();
|
|
@@ -149,7 +174,7 @@ impl ContractionProcessor {
|
|
|
149
174
|
indmap.insert(ind, c);
|
|
150
175
|
edges.insert(c, std::iter::once(i as Node).collect());
|
|
151
176
|
appearances.push(1);
|
|
152
|
-
sizes.push(f32::
|
|
177
|
+
sizes.push(f32::ln(size_dict[&ind] as f32));
|
|
153
178
|
legs.push((c, 1));
|
|
154
179
|
c += 1;
|
|
155
180
|
}
|
|
@@ -170,6 +195,8 @@ impl ContractionProcessor {
|
|
|
170
195
|
|
|
171
196
|
let ssa = nodes.len() as Node;
|
|
172
197
|
let ssa_path: SSAPath = Vec::with_capacity(2 * ssa as usize - 1);
|
|
198
|
+
let flops: Score = 0.0;
|
|
199
|
+
let flops_limit: Score = Score::INFINITY;
|
|
173
200
|
|
|
174
201
|
ContractionProcessor {
|
|
175
202
|
nodes,
|
|
@@ -178,6 +205,9 @@ impl ContractionProcessor {
|
|
|
178
205
|
sizes,
|
|
179
206
|
ssa,
|
|
180
207
|
ssa_path,
|
|
208
|
+
track_flops,
|
|
209
|
+
flops,
|
|
210
|
+
flops_limit,
|
|
181
211
|
}
|
|
182
212
|
}
|
|
183
213
|
|
|
@@ -225,7 +255,9 @@ impl ContractionProcessor {
|
|
|
225
255
|
for (ix, _) in &legs {
|
|
226
256
|
self.edges
|
|
227
257
|
.entry(*ix)
|
|
228
|
-
.and_modify(|nodes| {
|
|
258
|
+
.and_modify(|nodes| {
|
|
259
|
+
nodes.insert(i);
|
|
260
|
+
})
|
|
229
261
|
.or_insert(std::iter::once(i as Node).collect());
|
|
230
262
|
}
|
|
231
263
|
self.nodes.insert(i, legs);
|
|
@@ -236,12 +268,27 @@ impl ContractionProcessor {
|
|
|
236
268
|
fn contract_nodes(&mut self, i: Node, j: Node) -> Node {
|
|
237
269
|
let ilegs = self.pop_node(i);
|
|
238
270
|
let jlegs = self.pop_node(j);
|
|
271
|
+
if self.track_flops {
|
|
272
|
+
self.flops = logadd(self.flops, compute_flops(&ilegs, &jlegs, &self.sizes));
|
|
273
|
+
}
|
|
239
274
|
let new_legs = compute_legs(&ilegs, &jlegs, &self.appearances);
|
|
240
275
|
let k = self.add_node(new_legs);
|
|
241
276
|
self.ssa_path.push(vec![i, j]);
|
|
242
277
|
k
|
|
243
278
|
}
|
|
244
279
|
|
|
280
|
+
/// contract two nodes (which we already know the legs for), return the new node id
|
|
281
|
+
fn contract_nodes_given_legs(&mut self, i: Node, j: Node, new_legs: Legs) -> Node {
|
|
282
|
+
let ilegs = self.pop_node(i);
|
|
283
|
+
let jlegs = self.pop_node(j);
|
|
284
|
+
if self.track_flops {
|
|
285
|
+
self.flops = logadd(self.flops, compute_flops(&ilegs, &jlegs, &self.sizes));
|
|
286
|
+
}
|
|
287
|
+
let k = self.add_node(new_legs);
|
|
288
|
+
self.ssa_path.push(vec![i, j]);
|
|
289
|
+
k
|
|
290
|
+
}
|
|
291
|
+
|
|
245
292
|
/// find any indices that appear in all terms and just remove/ignore them
|
|
246
293
|
fn simplify_batch(&mut self) {
|
|
247
294
|
let mut ix_to_remove = Vec::new();
|
|
@@ -366,18 +413,32 @@ impl ContractionProcessor {
|
|
|
366
413
|
}
|
|
367
414
|
|
|
368
415
|
/// greedily optimize the contraction order of all terms
|
|
369
|
-
fn optimize_greedy(
|
|
370
|
-
|
|
416
|
+
fn optimize_greedy(
|
|
417
|
+
&mut self,
|
|
418
|
+
costmod: Option<f32>,
|
|
419
|
+
temperature: Option<f32>,
|
|
420
|
+
seed: Option<u64>,
|
|
421
|
+
) -> bool {
|
|
371
422
|
let coeff_t = temperature.unwrap_or(0.0);
|
|
372
423
|
let log_coeff_a = f32::ln(costmod.unwrap_or(1.0));
|
|
373
424
|
|
|
425
|
+
let mut rng = if coeff_t != 0.0 {
|
|
426
|
+
Some(match seed {
|
|
427
|
+
Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
|
|
428
|
+
None => rand::rngs::StdRng::from_entropy(),
|
|
429
|
+
})
|
|
430
|
+
} else {
|
|
431
|
+
// zero temp - no need for rng
|
|
432
|
+
None
|
|
433
|
+
};
|
|
434
|
+
|
|
374
435
|
let mut local_score = |sa: Score, sb: Score, sab: Score| -> Score {
|
|
375
|
-
let gumbel = if
|
|
436
|
+
let gumbel = if let Some(rng) = &mut rng {
|
|
376
437
|
coeff_t * -f32::ln(-f32::ln(rng.gen()))
|
|
377
438
|
} else {
|
|
378
439
|
0.0 as f32
|
|
379
440
|
};
|
|
380
|
-
logsub(sab
|
|
441
|
+
logsub(sab - log_coeff_a, logadd(sa, sb) + log_coeff_a) - gumbel
|
|
381
442
|
};
|
|
382
443
|
|
|
383
444
|
// cache all current nodes sizes as we go
|
|
@@ -424,11 +485,13 @@ impl ContractionProcessor {
|
|
|
424
485
|
}
|
|
425
486
|
|
|
426
487
|
// perform contraction:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
self.
|
|
430
|
-
|
|
431
|
-
|
|
488
|
+
let k = self.contract_nodes_given_legs(i, j, klegs.clone());
|
|
489
|
+
|
|
490
|
+
if self.track_flops && self.flops >= self.flops_limit {
|
|
491
|
+
// stop if we have reached the flops limit
|
|
492
|
+
return false;
|
|
493
|
+
}
|
|
494
|
+
|
|
432
495
|
node_sizes.insert(k, ksize);
|
|
433
496
|
|
|
434
497
|
for l in self.neighbors(k) {
|
|
@@ -444,6 +507,8 @@ impl ContractionProcessor {
|
|
|
444
507
|
c -= 1;
|
|
445
508
|
}
|
|
446
509
|
}
|
|
510
|
+
// success
|
|
511
|
+
return true;
|
|
447
512
|
}
|
|
448
513
|
|
|
449
514
|
/// Optimize the contraction order of all terms using a greedy algorithm
|
|
@@ -800,7 +865,6 @@ impl ContractionProcessor {
|
|
|
800
865
|
// --------------------------- PYTHON FUNCTIONS ---------------------------- //
|
|
801
866
|
|
|
802
867
|
#[pyfunction]
|
|
803
|
-
#[pyo3()]
|
|
804
868
|
fn ssa_to_linear(ssa_path: SSAPath, n: Option<usize>) -> SSAPath {
|
|
805
869
|
let n = match n {
|
|
806
870
|
Some(n) => n,
|
|
@@ -828,18 +892,16 @@ fn ssa_to_linear(ssa_path: SSAPath, n: Option<usize>) -> SSAPath {
|
|
|
828
892
|
}
|
|
829
893
|
|
|
830
894
|
#[pyfunction]
|
|
831
|
-
#[pyo3()]
|
|
832
895
|
fn find_subgraphs(
|
|
833
896
|
inputs: Vec<Vec<char>>,
|
|
834
897
|
output: Vec<char>,
|
|
835
898
|
size_dict: Dict<char, f32>,
|
|
836
899
|
) -> Vec<Vec<Node>> {
|
|
837
|
-
let cp = ContractionProcessor::new(inputs, output, size_dict);
|
|
900
|
+
let cp = ContractionProcessor::new(inputs, output, size_dict, false);
|
|
838
901
|
cp.subgraphs()
|
|
839
902
|
}
|
|
840
903
|
|
|
841
904
|
#[pyfunction]
|
|
842
|
-
#[pyo3()]
|
|
843
905
|
fn optimize_simplify(
|
|
844
906
|
inputs: Vec<Vec<char>>,
|
|
845
907
|
output: Vec<char>,
|
|
@@ -847,7 +909,7 @@ fn optimize_simplify(
|
|
|
847
909
|
use_ssa: Option<bool>,
|
|
848
910
|
) -> SSAPath {
|
|
849
911
|
let n = inputs.len();
|
|
850
|
-
let mut cp = ContractionProcessor::new(inputs, output, size_dict);
|
|
912
|
+
let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
|
|
851
913
|
cp.simplify();
|
|
852
914
|
if use_ssa.unwrap_or(false) {
|
|
853
915
|
cp.ssa_path
|
|
@@ -857,36 +919,124 @@ fn optimize_simplify(
|
|
|
857
919
|
}
|
|
858
920
|
|
|
859
921
|
#[pyfunction]
|
|
860
|
-
#[pyo3()]
|
|
861
922
|
fn optimize_greedy(
|
|
923
|
+
py: Python,
|
|
862
924
|
inputs: Vec<Vec<char>>,
|
|
863
925
|
output: Vec<char>,
|
|
864
926
|
size_dict: Dict<char, f32>,
|
|
865
927
|
costmod: Option<f32>,
|
|
866
928
|
temperature: Option<f32>,
|
|
929
|
+
seed: Option<u64>,
|
|
867
930
|
simplify: Option<bool>,
|
|
868
931
|
use_ssa: Option<bool>,
|
|
869
932
|
) -> Vec<Vec<Node>> {
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
933
|
+
py.allow_threads(|| {
|
|
934
|
+
let n = inputs.len();
|
|
935
|
+
let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
|
|
936
|
+
if simplify.unwrap_or(true) {
|
|
937
|
+
// perform simplifications
|
|
938
|
+
cp.simplify();
|
|
939
|
+
}
|
|
940
|
+
// greedily contract each connected subgraph
|
|
941
|
+
cp.optimize_greedy(costmod, temperature, seed);
|
|
942
|
+
// optimize any remaining disconnected terms
|
|
943
|
+
cp.optimize_remaining_by_size();
|
|
944
|
+
if use_ssa.unwrap_or(false) {
|
|
945
|
+
cp.ssa_path
|
|
946
|
+
} else {
|
|
947
|
+
ssa_to_linear(cp.ssa_path, Some(n))
|
|
948
|
+
}
|
|
949
|
+
})
|
|
950
|
+
}
|
|
951
|
+
|
|
952
|
+
#[pyfunction]
|
|
953
|
+
fn optimize_random_greedy_track_flops(
|
|
954
|
+
py: Python,
|
|
955
|
+
inputs: Vec<Vec<char>>,
|
|
956
|
+
output: Vec<char>,
|
|
957
|
+
size_dict: Dict<char, f32>,
|
|
958
|
+
ntrials: usize,
|
|
959
|
+
costmod: Option<(f32, f32)>,
|
|
960
|
+
temperature: Option<(f32, f32)>,
|
|
961
|
+
seed: Option<u64>,
|
|
962
|
+
simplify: Option<bool>,
|
|
963
|
+
use_ssa: Option<bool>,
|
|
964
|
+
) -> (Vec<Vec<Node>>, Score) {
|
|
965
|
+
py.allow_threads(|| {
|
|
966
|
+
let (costmod_min, costmod_max) = costmod.unwrap_or((0.1, 4.0));
|
|
967
|
+
let costmod_diff = (costmod_max - costmod_min).abs();
|
|
968
|
+
let is_const_costmod = costmod_diff < Score::EPSILON;
|
|
969
|
+
|
|
970
|
+
let (temp_min, temp_max) = temperature.unwrap_or((0.001, 1.0));
|
|
971
|
+
let log_temp_min = Score::ln(temp_min);
|
|
972
|
+
let log_temp_max = Score::ln(temp_max);
|
|
973
|
+
let log_temp_diff = (log_temp_max - log_temp_min).abs();
|
|
974
|
+
let is_const_temp = log_temp_diff < Score::EPSILON;
|
|
975
|
+
|
|
976
|
+
let mut rng = match seed {
|
|
977
|
+
Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
|
|
978
|
+
None => rand::rngs::StdRng::from_entropy(),
|
|
979
|
+
};
|
|
980
|
+
let seeds = (0..ntrials).map(|_| rng.gen()).collect::<Vec<u64>>();
|
|
981
|
+
|
|
982
|
+
let n: usize = inputs.len();
|
|
983
|
+
// construct processor and perform simplifications once
|
|
984
|
+
let mut cp0 = ContractionProcessor::new(inputs, output, size_dict, true);
|
|
985
|
+
if simplify.unwrap_or(true) {
|
|
986
|
+
cp0.simplify();
|
|
987
|
+
}
|
|
988
|
+
|
|
989
|
+
let mut best_path = None;
|
|
990
|
+
let mut best_flops = f32::INFINITY;
|
|
991
|
+
|
|
992
|
+
for seed in seeds {
|
|
993
|
+
let mut cp = cp0.clone();
|
|
994
|
+
|
|
995
|
+
// uniform sample for costmod
|
|
996
|
+
let costmod = if is_const_costmod {
|
|
997
|
+
costmod_min
|
|
998
|
+
} else {
|
|
999
|
+
costmod_min + rng.gen::<f32>() * costmod_diff
|
|
1000
|
+
};
|
|
1001
|
+
|
|
1002
|
+
// log-uniform sample for temperature
|
|
1003
|
+
let temperature = if is_const_temp {
|
|
1004
|
+
temp_min
|
|
1005
|
+
} else {
|
|
1006
|
+
f32::exp(log_temp_min + rng.gen::<f32>() * log_temp_diff)
|
|
1007
|
+
};
|
|
1008
|
+
|
|
1009
|
+
// greedily contract each connected subgraph
|
|
1010
|
+
let success = cp.optimize_greedy(Some(costmod), Some(temperature), Some(seed));
|
|
1011
|
+
|
|
1012
|
+
if !success {
|
|
1013
|
+
continue;
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
// optimize any remaining disconnected terms
|
|
1017
|
+
cp.optimize_remaining_by_size();
|
|
1018
|
+
|
|
1019
|
+
if cp.flops < best_flops {
|
|
1020
|
+
best_path = Some(cp.ssa_path);
|
|
1021
|
+
best_flops = cp.flops;
|
|
1022
|
+
cp0.flops_limit = cp.flops;
|
|
1023
|
+
}
|
|
1024
|
+
}
|
|
1025
|
+
|
|
1026
|
+
// convert to base 10 for easier comparison
|
|
1027
|
+
best_flops *= f32::consts::LOG10_E;
|
|
1028
|
+
|
|
1029
|
+
if use_ssa.unwrap_or(false) {
|
|
1030
|
+
(best_path.unwrap(), best_flops)
|
|
1031
|
+
} else {
|
|
1032
|
+
(ssa_to_linear(best_path.unwrap(), Some(n)), best_flops)
|
|
1033
|
+
}
|
|
1034
|
+
})
|
|
885
1035
|
}
|
|
886
1036
|
|
|
887
1037
|
#[pyfunction]
|
|
888
|
-
#[pyo3()]
|
|
889
1038
|
fn optimize_optimal(
|
|
1039
|
+
py: Python,
|
|
890
1040
|
inputs: Vec<Vec<char>>,
|
|
891
1041
|
output: Vec<char>,
|
|
892
1042
|
size_dict: Dict<char, f32>,
|
|
@@ -896,30 +1046,33 @@ fn optimize_optimal(
|
|
|
896
1046
|
simplify: Option<bool>,
|
|
897
1047
|
use_ssa: Option<bool>,
|
|
898
1048
|
) -> Vec<Vec<Node>> {
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
1049
|
+
py.allow_threads(|| {
|
|
1050
|
+
let n = inputs.len();
|
|
1051
|
+
let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
|
|
1052
|
+
if simplify.unwrap_or(true) {
|
|
1053
|
+
// perform simplifications
|
|
1054
|
+
cp.simplify();
|
|
1055
|
+
}
|
|
1056
|
+
// optimally contract each connected subgraph
|
|
1057
|
+
cp.optimize_optimal(minimize, cost_cap, search_outer);
|
|
1058
|
+
// optimize any remaining disconnected terms
|
|
1059
|
+
cp.optimize_remaining_by_size();
|
|
1060
|
+
if use_ssa.unwrap_or(false) {
|
|
1061
|
+
cp.ssa_path
|
|
1062
|
+
} else {
|
|
1063
|
+
ssa_to_linear(cp.ssa_path, Some(n))
|
|
1064
|
+
}
|
|
1065
|
+
})
|
|
914
1066
|
}
|
|
915
1067
|
|
|
916
1068
|
/// A Python module implemented in Rust.
|
|
917
1069
|
#[pymodule]
|
|
918
|
-
fn cotengrust(
|
|
1070
|
+
fn cotengrust(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
919
1071
|
m.add_function(wrap_pyfunction!(ssa_to_linear, m)?)?;
|
|
920
1072
|
m.add_function(wrap_pyfunction!(find_subgraphs, m)?)?;
|
|
921
1073
|
m.add_function(wrap_pyfunction!(optimize_simplify, m)?)?;
|
|
922
1074
|
m.add_function(wrap_pyfunction!(optimize_greedy, m)?)?;
|
|
1075
|
+
m.add_function(wrap_pyfunction!(optimize_random_greedy_track_flops, m)?)?;
|
|
923
1076
|
m.add_function(wrap_pyfunction!(optimize_optimal, m)?)?;
|
|
924
1077
|
Ok(())
|
|
925
1078
|
}
|
|
@@ -56,6 +56,15 @@ def get_rand_size_dict(inputs, d_min=2, d_max=3):
|
|
|
56
56
|
|
|
57
57
|
# these are taken from opt_einsum
|
|
58
58
|
test_case_eqs = [
|
|
59
|
+
# Test single-term equations
|
|
60
|
+
"->",
|
|
61
|
+
"a->a",
|
|
62
|
+
"ab->ab",
|
|
63
|
+
"ab->ba",
|
|
64
|
+
"abc->bca",
|
|
65
|
+
"abc->b",
|
|
66
|
+
"baa->ba",
|
|
67
|
+
"aba->b",
|
|
59
68
|
# Test scalar-like operations
|
|
60
69
|
"a,->a",
|
|
61
70
|
"ab,->ab",
|
|
@@ -188,18 +197,39 @@ def test_basic_rand(seed, which):
|
|
|
188
197
|
@requires_cotengra
|
|
189
198
|
def test_optimal_lattice_eq():
|
|
190
199
|
inputs, output, _, size_dict = ctg.utils.lattice_equation(
|
|
191
|
-
[4, 5], d_max=
|
|
200
|
+
[4, 5], d_max=2, seed=42
|
|
192
201
|
)
|
|
193
202
|
|
|
194
203
|
path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='flops')
|
|
195
204
|
tree = ctg.ContractionTree.from_path(
|
|
196
205
|
inputs, output, size_dict, path=path
|
|
197
206
|
)
|
|
198
|
-
assert tree.
|
|
207
|
+
assert tree.is_complete()
|
|
208
|
+
assert tree.contraction_cost() == 964
|
|
199
209
|
|
|
200
210
|
path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='size')
|
|
201
211
|
assert all(len(con) <= 2 for con in path)
|
|
202
212
|
tree = ctg.ContractionTree.from_path(
|
|
203
213
|
inputs, output, size_dict, path=path
|
|
204
214
|
)
|
|
205
|
-
assert tree.contraction_width() == pytest.approx(
|
|
215
|
+
assert tree.contraction_width() == pytest.approx(5)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@requires_cotengra
|
|
219
|
+
def test_optimize_random_greedy_log_flops():
|
|
220
|
+
inputs, output, _, size_dict = ctg.utils.lattice_equation(
|
|
221
|
+
[10, 10], d_max=3, seed=42
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
path, cost1 = ctgr.optimize_random_greedy_track_flops(
|
|
225
|
+
inputs, output, size_dict, ntrials=4, seed=42
|
|
226
|
+
)
|
|
227
|
+
_, cost2 = ctgr.optimize_random_greedy_track_flops(
|
|
228
|
+
inputs, output, size_dict, ntrials=4, seed=42
|
|
229
|
+
)
|
|
230
|
+
assert cost1 == cost2
|
|
231
|
+
tree = ctg.ContractionTree.from_path(
|
|
232
|
+
inputs, output, size_dict, path=path
|
|
233
|
+
)
|
|
234
|
+
assert tree.is_complete()
|
|
235
|
+
assert tree.contraction_cost(log=10) == pytest.approx(cost1)
|
|
@@ -1,162 +0,0 @@
|
|
|
1
|
-
# This file is autogenerated by maturin v1.2.3
|
|
2
|
-
# To update, run
|
|
3
|
-
#
|
|
4
|
-
# maturin generate-ci github --pytest
|
|
5
|
-
#
|
|
6
|
-
name: CI
|
|
7
|
-
|
|
8
|
-
on:
|
|
9
|
-
push:
|
|
10
|
-
branches:
|
|
11
|
-
- main
|
|
12
|
-
- master
|
|
13
|
-
tags:
|
|
14
|
-
- '*'
|
|
15
|
-
pull_request:
|
|
16
|
-
workflow_dispatch:
|
|
17
|
-
|
|
18
|
-
permissions:
|
|
19
|
-
contents: read
|
|
20
|
-
|
|
21
|
-
jobs:
|
|
22
|
-
linux:
|
|
23
|
-
runs-on: ubuntu-latest
|
|
24
|
-
strategy:
|
|
25
|
-
matrix:
|
|
26
|
-
target: [x86_64, x86, aarch64, armv7, s390x, ppc64le]
|
|
27
|
-
steps:
|
|
28
|
-
- uses: actions/checkout@v3
|
|
29
|
-
- uses: actions/setup-python@v4
|
|
30
|
-
with:
|
|
31
|
-
python-version: '3.10'
|
|
32
|
-
- name: Build wheels
|
|
33
|
-
uses: PyO3/maturin-action@v1
|
|
34
|
-
with:
|
|
35
|
-
target: ${{ matrix.target }}
|
|
36
|
-
args: --release --out dist --find-interpreter
|
|
37
|
-
sccache: 'true'
|
|
38
|
-
manylinux: auto
|
|
39
|
-
- name: Upload wheels
|
|
40
|
-
uses: actions/upload-artifact@v3
|
|
41
|
-
with:
|
|
42
|
-
name: wheels
|
|
43
|
-
path: dist
|
|
44
|
-
- name: pytest
|
|
45
|
-
if: ${{ startsWith(matrix.target, 'x86_64') }}
|
|
46
|
-
shell: bash
|
|
47
|
-
run: |
|
|
48
|
-
set -e
|
|
49
|
-
ls dist/*
|
|
50
|
-
pip install cotengrust --find-links dist --force-reinstall
|
|
51
|
-
pip install pytest numpy cotengra
|
|
52
|
-
pytest --verbose
|
|
53
|
-
- name: pytest
|
|
54
|
-
if: ${{ !startsWith(matrix.target, 'x86') && matrix.target != 'ppc64' }}
|
|
55
|
-
uses: uraimo/run-on-arch-action@v2.5.0
|
|
56
|
-
with:
|
|
57
|
-
arch: ${{ matrix.target }}
|
|
58
|
-
distro: ubuntu22.04
|
|
59
|
-
githubToken: ${{ github.token }}
|
|
60
|
-
install: |
|
|
61
|
-
apt-get update
|
|
62
|
-
apt-get install -y --no-install-recommends python3 python3-pip
|
|
63
|
-
pip3 install -U pip pytest # numpy cotengra
|
|
64
|
-
run: |
|
|
65
|
-
set -e
|
|
66
|
-
pip3 install cotengrust --find-links dist --force-reinstall
|
|
67
|
-
pytest --verbose
|
|
68
|
-
|
|
69
|
-
windows:
|
|
70
|
-
runs-on: windows-latest
|
|
71
|
-
strategy:
|
|
72
|
-
matrix:
|
|
73
|
-
target: [x64, x86]
|
|
74
|
-
steps:
|
|
75
|
-
- uses: actions/checkout@v3
|
|
76
|
-
- uses: actions/setup-python@v4
|
|
77
|
-
with:
|
|
78
|
-
python-version: '3.10'
|
|
79
|
-
architecture: ${{ matrix.target }}
|
|
80
|
-
- name: Build wheels
|
|
81
|
-
uses: PyO3/maturin-action@v1
|
|
82
|
-
with:
|
|
83
|
-
target: ${{ matrix.target }}
|
|
84
|
-
args: --release --out dist --find-interpreter
|
|
85
|
-
sccache: 'true'
|
|
86
|
-
- name: Upload wheels
|
|
87
|
-
uses: actions/upload-artifact@v3
|
|
88
|
-
with:
|
|
89
|
-
name: wheels
|
|
90
|
-
path: dist
|
|
91
|
-
- name: pytest
|
|
92
|
-
if: ${{ !startsWith(matrix.target, 'aarch64') }}
|
|
93
|
-
shell: bash
|
|
94
|
-
run: |
|
|
95
|
-
set -e
|
|
96
|
-
ls dist/*
|
|
97
|
-
pip install cotengrust --find-links dist --force-reinstall
|
|
98
|
-
pip install pytest numpy cotengra
|
|
99
|
-
pytest --verbose
|
|
100
|
-
|
|
101
|
-
macos:
|
|
102
|
-
runs-on: macos-latest
|
|
103
|
-
strategy:
|
|
104
|
-
matrix:
|
|
105
|
-
target: [x86_64, aarch64]
|
|
106
|
-
steps:
|
|
107
|
-
- uses: actions/checkout@v3
|
|
108
|
-
- uses: actions/setup-python@v4
|
|
109
|
-
with:
|
|
110
|
-
python-version: '3.10'
|
|
111
|
-
- name: Build wheels
|
|
112
|
-
uses: PyO3/maturin-action@v1
|
|
113
|
-
with:
|
|
114
|
-
target: ${{ matrix.target }}
|
|
115
|
-
args: --release --out dist --find-interpreter
|
|
116
|
-
sccache: 'true'
|
|
117
|
-
- name: Upload wheels
|
|
118
|
-
uses: actions/upload-artifact@v3
|
|
119
|
-
with:
|
|
120
|
-
name: wheels
|
|
121
|
-
path: dist
|
|
122
|
-
- name: pytest
|
|
123
|
-
if: ${{ !startsWith(matrix.target, 'aarch64') }}
|
|
124
|
-
shell: bash
|
|
125
|
-
run: |
|
|
126
|
-
set -e
|
|
127
|
-
ls dist/*
|
|
128
|
-
pip install cotengrust --find-links dist --force-reinstall
|
|
129
|
-
pip install pytest numpy cotengra
|
|
130
|
-
pytest --verbose
|
|
131
|
-
|
|
132
|
-
sdist:
|
|
133
|
-
runs-on: ubuntu-latest
|
|
134
|
-
steps:
|
|
135
|
-
- uses: actions/checkout@v3
|
|
136
|
-
- name: Build sdist
|
|
137
|
-
uses: PyO3/maturin-action@v1
|
|
138
|
-
with:
|
|
139
|
-
command: sdist
|
|
140
|
-
args: --out dist
|
|
141
|
-
- name: Upload sdist
|
|
142
|
-
uses: actions/upload-artifact@v3
|
|
143
|
-
with:
|
|
144
|
-
name: wheels
|
|
145
|
-
path: dist
|
|
146
|
-
|
|
147
|
-
release:
|
|
148
|
-
name: Release
|
|
149
|
-
runs-on: ubuntu-latest
|
|
150
|
-
if: startsWith(github.ref, 'refs/tags/')
|
|
151
|
-
needs: [linux, windows, macos, sdist]
|
|
152
|
-
steps:
|
|
153
|
-
- uses: actions/download-artifact@v3
|
|
154
|
-
with:
|
|
155
|
-
name: wheels
|
|
156
|
-
- name: Publish to PyPI
|
|
157
|
-
uses: PyO3/maturin-action@v1
|
|
158
|
-
env:
|
|
159
|
-
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
|
|
160
|
-
with:
|
|
161
|
-
command: upload
|
|
162
|
-
args: --non-interactive --skip-existing *
|
|
File without changes
|
|
File without changes
|