cotengrust 0.1.1__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cotengrust-0.1.1 → cotengrust-0.1.2}/Cargo.lock +34 -19
- {cotengrust-0.1.1 → cotengrust-0.1.2}/Cargo.toml +3 -3
- {cotengrust-0.1.1 → cotengrust-0.1.2}/PKG-INFO +2 -2
- {cotengrust-0.1.1 → cotengrust-0.1.2}/pyproject.toml +2 -2
- {cotengrust-0.1.1 → cotengrust-0.1.2}/src/lib.rs +161 -49
- {cotengrust-0.1.1 → cotengrust-0.1.2}/tests/test_cotengrust.py +33 -3
- {cotengrust-0.1.1 → cotengrust-0.1.2}/.github/workflows/CI.yml +0 -0
- {cotengrust-0.1.1 → cotengrust-0.1.2}/.gitignore +0 -0
- {cotengrust-0.1.1 → cotengrust-0.1.2}/LICENSE +0 -0
- {cotengrust-0.1.1 → cotengrust-0.1.2}/README.md +0 -0
|
@@ -37,7 +37,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
|
|
37
37
|
|
|
38
38
|
[[package]]
|
|
39
39
|
name = "cotengrust"
|
|
40
|
-
version = "0.1.
|
|
40
|
+
version = "0.1.2"
|
|
41
41
|
dependencies = [
|
|
42
42
|
"bit-set",
|
|
43
43
|
"ordered-float",
|
|
@@ -57,11 +57,17 @@ dependencies = [
|
|
|
57
57
|
"wasi",
|
|
58
58
|
]
|
|
59
59
|
|
|
60
|
+
[[package]]
|
|
61
|
+
name = "heck"
|
|
62
|
+
version = "0.4.1"
|
|
63
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
64
|
+
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
|
65
|
+
|
|
60
66
|
[[package]]
|
|
61
67
|
name = "indoc"
|
|
62
|
-
version = "
|
|
68
|
+
version = "2.0.4"
|
|
63
69
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
64
|
-
checksum = "
|
|
70
|
+
checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8"
|
|
65
71
|
|
|
66
72
|
[[package]]
|
|
67
73
|
name = "libc"
|
|
@@ -105,9 +111,9 @@ checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
|
|
|
105
111
|
|
|
106
112
|
[[package]]
|
|
107
113
|
name = "ordered-float"
|
|
108
|
-
version = "
|
|
114
|
+
version = "4.2.0"
|
|
109
115
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
110
|
-
checksum = "
|
|
116
|
+
checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e"
|
|
111
117
|
dependencies = [
|
|
112
118
|
"num-traits",
|
|
113
119
|
]
|
|
@@ -135,6 +141,12 @@ dependencies = [
|
|
|
135
141
|
"windows-targets",
|
|
136
142
|
]
|
|
137
143
|
|
|
144
|
+
[[package]]
|
|
145
|
+
name = "portable-atomic"
|
|
146
|
+
version = "1.6.0"
|
|
147
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
148
|
+
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
|
|
149
|
+
|
|
138
150
|
[[package]]
|
|
139
151
|
name = "ppv-lite86"
|
|
140
152
|
version = "0.2.17"
|
|
@@ -152,15 +164,16 @@ dependencies = [
|
|
|
152
164
|
|
|
153
165
|
[[package]]
|
|
154
166
|
name = "pyo3"
|
|
155
|
-
version = "0.
|
|
167
|
+
version = "0.21.1"
|
|
156
168
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
157
|
-
checksum = "
|
|
169
|
+
checksum = "a7a8b1990bd018761768d5e608a13df8bd1ac5f678456e0f301bb93e5f3ea16b"
|
|
158
170
|
dependencies = [
|
|
159
171
|
"cfg-if",
|
|
160
172
|
"indoc",
|
|
161
173
|
"libc",
|
|
162
174
|
"memoffset",
|
|
163
175
|
"parking_lot",
|
|
176
|
+
"portable-atomic",
|
|
164
177
|
"pyo3-build-config",
|
|
165
178
|
"pyo3-ffi",
|
|
166
179
|
"pyo3-macros",
|
|
@@ -169,9 +182,9 @@ dependencies = [
|
|
|
169
182
|
|
|
170
183
|
[[package]]
|
|
171
184
|
name = "pyo3-build-config"
|
|
172
|
-
version = "0.
|
|
185
|
+
version = "0.21.1"
|
|
173
186
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
174
|
-
checksum = "
|
|
187
|
+
checksum = "650dca34d463b6cdbdb02b1d71bfd6eb6b6816afc708faebb3bac1380ff4aef7"
|
|
175
188
|
dependencies = [
|
|
176
189
|
"once_cell",
|
|
177
190
|
"target-lexicon",
|
|
@@ -179,9 +192,9 @@ dependencies = [
|
|
|
179
192
|
|
|
180
193
|
[[package]]
|
|
181
194
|
name = "pyo3-ffi"
|
|
182
|
-
version = "0.
|
|
195
|
+
version = "0.21.1"
|
|
183
196
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
184
|
-
checksum = "
|
|
197
|
+
checksum = "09a7da8fc04a8a2084909b59f29e1b8474decac98b951d77b80b26dc45f046ad"
|
|
185
198
|
dependencies = [
|
|
186
199
|
"libc",
|
|
187
200
|
"pyo3-build-config",
|
|
@@ -189,9 +202,9 @@ dependencies = [
|
|
|
189
202
|
|
|
190
203
|
[[package]]
|
|
191
204
|
name = "pyo3-macros"
|
|
192
|
-
version = "0.
|
|
205
|
+
version = "0.21.1"
|
|
193
206
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
194
|
-
checksum = "
|
|
207
|
+
checksum = "4b8a199fce11ebb28e3569387228836ea98110e43a804a530a9fd83ade36d513"
|
|
195
208
|
dependencies = [
|
|
196
209
|
"proc-macro2",
|
|
197
210
|
"pyo3-macros-backend",
|
|
@@ -201,11 +214,13 @@ dependencies = [
|
|
|
201
214
|
|
|
202
215
|
[[package]]
|
|
203
216
|
name = "pyo3-macros-backend"
|
|
204
|
-
version = "0.
|
|
217
|
+
version = "0.21.1"
|
|
205
218
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
206
|
-
checksum = "
|
|
219
|
+
checksum = "93fbbfd7eb553d10036513cb122b888dcd362a945a00b06c165f2ab480d4cc3b"
|
|
207
220
|
dependencies = [
|
|
221
|
+
"heck",
|
|
208
222
|
"proc-macro2",
|
|
223
|
+
"pyo3-build-config",
|
|
209
224
|
"quote",
|
|
210
225
|
"syn",
|
|
211
226
|
]
|
|
@@ -278,9 +293,9 @@ checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
|
|
278
293
|
|
|
279
294
|
[[package]]
|
|
280
295
|
name = "syn"
|
|
281
|
-
version = "
|
|
296
|
+
version = "2.0.32"
|
|
282
297
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
283
|
-
checksum = "
|
|
298
|
+
checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2"
|
|
284
299
|
dependencies = [
|
|
285
300
|
"proc-macro2",
|
|
286
301
|
"quote",
|
|
@@ -301,9 +316,9 @@ checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c"
|
|
|
301
316
|
|
|
302
317
|
[[package]]
|
|
303
318
|
name = "unindent"
|
|
304
|
-
version = "0.
|
|
319
|
+
version = "0.2.3"
|
|
305
320
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
306
|
-
checksum = "
|
|
321
|
+
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
|
|
307
322
|
|
|
308
323
|
[[package]]
|
|
309
324
|
name = "wasi"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "cotengrust"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.2"
|
|
4
4
|
edition = "2021"
|
|
5
5
|
|
|
6
6
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
@@ -10,8 +10,8 @@ crate-type = ["cdylib"]
|
|
|
10
10
|
|
|
11
11
|
[dependencies]
|
|
12
12
|
bit-set = "0.5"
|
|
13
|
-
|
|
14
|
-
|
|
13
|
+
ordered-float = "4.2"
|
|
14
|
+
pyo3 = "0.21"
|
|
15
15
|
rand = "0.8"
|
|
16
16
|
rustc-hash = "1.1"
|
|
17
17
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
2
|
Name: cotengrust
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Classifier: Programming Language :: Rust
|
|
5
5
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
6
6
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "cotengrust"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.2"
|
|
4
4
|
description = "Fast contraction ordering primitives for tensor networks."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.8"
|
|
@@ -15,7 +15,7 @@ authors = [
|
|
|
15
15
|
]
|
|
16
16
|
|
|
17
17
|
[build-system]
|
|
18
|
-
requires = ["maturin>=0.
|
|
18
|
+
requires = ["maturin>=1.0,<2.0"]
|
|
19
19
|
build-backend = "maturin"
|
|
20
20
|
|
|
21
21
|
[tool.maturin]
|
|
@@ -2,8 +2,9 @@ use bit_set::BitSet;
|
|
|
2
2
|
use ordered_float::OrderedFloat;
|
|
3
3
|
use pyo3::prelude::*;
|
|
4
4
|
use rand::Rng;
|
|
5
|
+
use rand::SeedableRng;
|
|
5
6
|
use rustc_hash::FxHashMap;
|
|
6
|
-
use std::collections::{BTreeSet, BinaryHeap};
|
|
7
|
+
use std::collections::{BTreeSet, BinaryHeap, HashSet};
|
|
7
8
|
use std::f32;
|
|
8
9
|
|
|
9
10
|
use FxHashMap as Dict;
|
|
@@ -23,6 +24,7 @@ type BitPath = Vec<(Subgraph, Subgraph)>;
|
|
|
23
24
|
type SubContraction = (Legs, Score, BitPath);
|
|
24
25
|
|
|
25
26
|
/// helper struct to build contractions from bottom up
|
|
27
|
+
#[derive(Clone)]
|
|
26
28
|
struct ContractionProcessor {
|
|
27
29
|
nodes: Dict<Node, Legs>,
|
|
28
30
|
edges: Dict<Ix, BTreeSet<Node>>,
|
|
@@ -30,6 +32,8 @@ struct ContractionProcessor {
|
|
|
30
32
|
sizes: Vec<Score>,
|
|
31
33
|
ssa: Node,
|
|
32
34
|
ssa_path: SSAPath,
|
|
35
|
+
track_flops: bool,
|
|
36
|
+
flops: Score,
|
|
33
37
|
}
|
|
34
38
|
|
|
35
39
|
/// given log(x) and log(y) compute log(x + y), without exponentiating both
|
|
@@ -94,6 +98,21 @@ fn compute_size(legs: &Legs, sizes: &Vec<Score>) -> Score {
|
|
|
94
98
|
legs.iter().map(|&(ix, _)| sizes[ix as usize]).sum()
|
|
95
99
|
}
|
|
96
100
|
|
|
101
|
+
fn compute_flops(ilegs: &Legs, jlegs: &Legs, sizes: &Vec<Score>) -> Score {
|
|
102
|
+
let mut flops: Score = 0.0;
|
|
103
|
+
let mut seen: HashSet<Ix> = HashSet::with_capacity(ilegs.len());
|
|
104
|
+
for &(ix, _) in ilegs {
|
|
105
|
+
seen.insert(ix);
|
|
106
|
+
flops += sizes[ix as usize];
|
|
107
|
+
}
|
|
108
|
+
for (ix, _) in jlegs {
|
|
109
|
+
if !seen.contains(ix) {
|
|
110
|
+
flops += sizes[*ix as usize];
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
flops
|
|
114
|
+
}
|
|
115
|
+
|
|
97
116
|
fn is_simplifiable(legs: &Legs, appearances: &Vec<Count>) -> bool {
|
|
98
117
|
let mut prev_ix = Node::MAX;
|
|
99
118
|
for &(ix, ix_count) in legs {
|
|
@@ -131,7 +150,12 @@ impl ContractionProcessor {
|
|
|
131
150
|
inputs: Vec<Vec<char>>,
|
|
132
151
|
output: Vec<char>,
|
|
133
152
|
size_dict: Dict<char, f32>,
|
|
153
|
+
track_flops: bool,
|
|
134
154
|
) -> ContractionProcessor {
|
|
155
|
+
if size_dict.len() > Ix::MAX as usize {
|
|
156
|
+
panic!("cotengrust: too many indices, maximum is {}", Ix::MAX);
|
|
157
|
+
}
|
|
158
|
+
|
|
135
159
|
let mut nodes: Dict<Node, Legs> = Dict::default();
|
|
136
160
|
let mut edges: Dict<Ix, BTreeSet<Node>> = Dict::default();
|
|
137
161
|
let mut indmap: Dict<char, Ix> = Dict::default();
|
|
@@ -149,7 +173,7 @@ impl ContractionProcessor {
|
|
|
149
173
|
indmap.insert(ind, c);
|
|
150
174
|
edges.insert(c, std::iter::once(i as Node).collect());
|
|
151
175
|
appearances.push(1);
|
|
152
|
-
sizes.push(f32::
|
|
176
|
+
sizes.push(f32::ln(size_dict[&ind] as f32));
|
|
153
177
|
legs.push((c, 1));
|
|
154
178
|
c += 1;
|
|
155
179
|
}
|
|
@@ -170,6 +194,7 @@ impl ContractionProcessor {
|
|
|
170
194
|
|
|
171
195
|
let ssa = nodes.len() as Node;
|
|
172
196
|
let ssa_path: SSAPath = Vec::with_capacity(2 * ssa as usize - 1);
|
|
197
|
+
let flops: Score = 0.0;
|
|
173
198
|
|
|
174
199
|
ContractionProcessor {
|
|
175
200
|
nodes,
|
|
@@ -178,6 +203,8 @@ impl ContractionProcessor {
|
|
|
178
203
|
sizes,
|
|
179
204
|
ssa,
|
|
180
205
|
ssa_path,
|
|
206
|
+
track_flops,
|
|
207
|
+
flops,
|
|
181
208
|
}
|
|
182
209
|
}
|
|
183
210
|
|
|
@@ -225,7 +252,9 @@ impl ContractionProcessor {
|
|
|
225
252
|
for (ix, _) in &legs {
|
|
226
253
|
self.edges
|
|
227
254
|
.entry(*ix)
|
|
228
|
-
.and_modify(|nodes| {
|
|
255
|
+
.and_modify(|nodes| {
|
|
256
|
+
nodes.insert(i);
|
|
257
|
+
})
|
|
229
258
|
.or_insert(std::iter::once(i as Node).collect());
|
|
230
259
|
}
|
|
231
260
|
self.nodes.insert(i, legs);
|
|
@@ -236,12 +265,27 @@ impl ContractionProcessor {
|
|
|
236
265
|
fn contract_nodes(&mut self, i: Node, j: Node) -> Node {
|
|
237
266
|
let ilegs = self.pop_node(i);
|
|
238
267
|
let jlegs = self.pop_node(j);
|
|
268
|
+
if self.track_flops {
|
|
269
|
+
self.flops = logadd(self.flops, compute_flops(&ilegs, &jlegs, &self.sizes));
|
|
270
|
+
}
|
|
239
271
|
let new_legs = compute_legs(&ilegs, &jlegs, &self.appearances);
|
|
240
272
|
let k = self.add_node(new_legs);
|
|
241
273
|
self.ssa_path.push(vec![i, j]);
|
|
242
274
|
k
|
|
243
275
|
}
|
|
244
276
|
|
|
277
|
+
/// contract two nodes (which we already know the legs for), return the new node id
|
|
278
|
+
fn contract_nodes_given_legs(&mut self, i: Node, j: Node, new_legs: Legs) -> Node {
|
|
279
|
+
let ilegs = self.pop_node(i);
|
|
280
|
+
let jlegs = self.pop_node(j);
|
|
281
|
+
if self.track_flops {
|
|
282
|
+
self.flops = logadd(self.flops, compute_flops(&ilegs, &jlegs, &self.sizes));
|
|
283
|
+
}
|
|
284
|
+
let k = self.add_node(new_legs);
|
|
285
|
+
self.ssa_path.push(vec![i, j]);
|
|
286
|
+
k
|
|
287
|
+
}
|
|
288
|
+
|
|
245
289
|
/// find any indices that appear in all terms and just remove/ignore them
|
|
246
290
|
fn simplify_batch(&mut self) {
|
|
247
291
|
let mut ix_to_remove = Vec::new();
|
|
@@ -366,13 +410,27 @@ impl ContractionProcessor {
|
|
|
366
410
|
}
|
|
367
411
|
|
|
368
412
|
/// greedily optimize the contraction order of all terms
|
|
369
|
-
fn optimize_greedy(
|
|
370
|
-
|
|
413
|
+
fn optimize_greedy(
|
|
414
|
+
&mut self,
|
|
415
|
+
costmod: Option<f32>,
|
|
416
|
+
temperature: Option<f32>,
|
|
417
|
+
seed: Option<u64>,
|
|
418
|
+
) {
|
|
371
419
|
let coeff_t = temperature.unwrap_or(0.0);
|
|
372
420
|
let log_coeff_a = f32::ln(costmod.unwrap_or(1.0));
|
|
373
421
|
|
|
422
|
+
let mut rng = if coeff_t != 0.0 {
|
|
423
|
+
Some(match seed {
|
|
424
|
+
Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
|
|
425
|
+
None => rand::rngs::StdRng::from_entropy(),
|
|
426
|
+
})
|
|
427
|
+
} else {
|
|
428
|
+
// zero temp - no need for rng
|
|
429
|
+
None
|
|
430
|
+
};
|
|
431
|
+
|
|
374
432
|
let mut local_score = |sa: Score, sb: Score, sab: Score| -> Score {
|
|
375
|
-
let gumbel = if
|
|
433
|
+
let gumbel = if let Some(rng) = &mut rng {
|
|
376
434
|
coeff_t * -f32::ln(-f32::ln(rng.gen()))
|
|
377
435
|
} else {
|
|
378
436
|
0.0 as f32
|
|
@@ -424,11 +482,7 @@ impl ContractionProcessor {
|
|
|
424
482
|
}
|
|
425
483
|
|
|
426
484
|
// perform contraction:
|
|
427
|
-
|
|
428
|
-
self.pop_node(i);
|
|
429
|
-
self.pop_node(j);
|
|
430
|
-
let k = self.add_node(klegs.clone());
|
|
431
|
-
self.ssa_path.push(vec![i, j]);
|
|
485
|
+
let k = self.contract_nodes_given_legs(i, j, klegs.clone());
|
|
432
486
|
node_sizes.insert(k, ksize);
|
|
433
487
|
|
|
434
488
|
for l in self.neighbors(k) {
|
|
@@ -800,7 +854,6 @@ impl ContractionProcessor {
|
|
|
800
854
|
// --------------------------- PYTHON FUNCTIONS ---------------------------- //
|
|
801
855
|
|
|
802
856
|
#[pyfunction]
|
|
803
|
-
#[pyo3()]
|
|
804
857
|
fn ssa_to_linear(ssa_path: SSAPath, n: Option<usize>) -> SSAPath {
|
|
805
858
|
let n = match n {
|
|
806
859
|
Some(n) => n,
|
|
@@ -828,18 +881,16 @@ fn ssa_to_linear(ssa_path: SSAPath, n: Option<usize>) -> SSAPath {
|
|
|
828
881
|
}
|
|
829
882
|
|
|
830
883
|
#[pyfunction]
|
|
831
|
-
#[pyo3()]
|
|
832
884
|
fn find_subgraphs(
|
|
833
885
|
inputs: Vec<Vec<char>>,
|
|
834
886
|
output: Vec<char>,
|
|
835
887
|
size_dict: Dict<char, f32>,
|
|
836
888
|
) -> Vec<Vec<Node>> {
|
|
837
|
-
let cp = ContractionProcessor::new(inputs, output, size_dict);
|
|
889
|
+
let cp = ContractionProcessor::new(inputs, output, size_dict, false);
|
|
838
890
|
cp.subgraphs()
|
|
839
891
|
}
|
|
840
892
|
|
|
841
893
|
#[pyfunction]
|
|
842
|
-
#[pyo3()]
|
|
843
894
|
fn optimize_simplify(
|
|
844
895
|
inputs: Vec<Vec<char>>,
|
|
845
896
|
output: Vec<char>,
|
|
@@ -847,7 +898,7 @@ fn optimize_simplify(
|
|
|
847
898
|
use_ssa: Option<bool>,
|
|
848
899
|
) -> SSAPath {
|
|
849
900
|
let n = inputs.len();
|
|
850
|
-
let mut cp = ContractionProcessor::new(inputs, output, size_dict);
|
|
901
|
+
let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
|
|
851
902
|
cp.simplify();
|
|
852
903
|
if use_ssa.unwrap_or(false) {
|
|
853
904
|
cp.ssa_path
|
|
@@ -857,36 +908,94 @@ fn optimize_simplify(
|
|
|
857
908
|
}
|
|
858
909
|
|
|
859
910
|
#[pyfunction]
|
|
860
|
-
#[pyo3()]
|
|
861
911
|
fn optimize_greedy(
|
|
912
|
+
py: Python,
|
|
862
913
|
inputs: Vec<Vec<char>>,
|
|
863
914
|
output: Vec<char>,
|
|
864
915
|
size_dict: Dict<char, f32>,
|
|
865
916
|
costmod: Option<f32>,
|
|
866
917
|
temperature: Option<f32>,
|
|
918
|
+
seed: Option<u64>,
|
|
867
919
|
simplify: Option<bool>,
|
|
868
920
|
use_ssa: Option<bool>,
|
|
869
921
|
) -> Vec<Vec<Node>> {
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
922
|
+
py.allow_threads(|| {
|
|
923
|
+
let n = inputs.len();
|
|
924
|
+
let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
|
|
925
|
+
if simplify.unwrap_or(true) {
|
|
926
|
+
// perform simplifications
|
|
927
|
+
cp.simplify();
|
|
928
|
+
}
|
|
929
|
+
// greedily contract each connected subgraph
|
|
930
|
+
cp.optimize_greedy(costmod, temperature, seed);
|
|
931
|
+
// optimize any remaining disconnected terms
|
|
932
|
+
cp.optimize_remaining_by_size();
|
|
933
|
+
if use_ssa.unwrap_or(false) {
|
|
934
|
+
cp.ssa_path
|
|
935
|
+
} else {
|
|
936
|
+
ssa_to_linear(cp.ssa_path, Some(n))
|
|
937
|
+
}
|
|
938
|
+
})
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
#[pyfunction]
|
|
942
|
+
fn optimize_random_greedy_track_flops(
|
|
943
|
+
py: Python,
|
|
944
|
+
inputs: Vec<Vec<char>>,
|
|
945
|
+
output: Vec<char>,
|
|
946
|
+
size_dict: Dict<char, f32>,
|
|
947
|
+
ntrials: usize,
|
|
948
|
+
costmod: Option<f32>,
|
|
949
|
+
temperature: Option<f32>,
|
|
950
|
+
seed: Option<u64>,
|
|
951
|
+
simplify: Option<bool>,
|
|
952
|
+
use_ssa: Option<bool>,
|
|
953
|
+
) -> (Vec<Vec<Node>>, Score) {
|
|
954
|
+
py.allow_threads(|| {
|
|
955
|
+
let temperature = temperature.unwrap_or(0.01);
|
|
956
|
+
let mut rng = match seed {
|
|
957
|
+
Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
|
|
958
|
+
None => rand::rngs::StdRng::from_entropy(),
|
|
959
|
+
};
|
|
960
|
+
let seeds = (0..ntrials).map(|_| rng.gen()).collect::<Vec<u64>>();
|
|
961
|
+
|
|
962
|
+
let n: usize = inputs.len();
|
|
963
|
+
// construct processor and perform simplifications once
|
|
964
|
+
let mut cp0 = ContractionProcessor::new(inputs, output, size_dict, true);
|
|
965
|
+
if simplify.unwrap_or(true) {
|
|
966
|
+
cp0.simplify();
|
|
967
|
+
}
|
|
968
|
+
|
|
969
|
+
let mut best_path = None;
|
|
970
|
+
let mut best_flops = f32::INFINITY;
|
|
971
|
+
|
|
972
|
+
for seed in seeds {
|
|
973
|
+
let mut cp = cp0.clone();
|
|
974
|
+
// greedily contract each connected subgraph
|
|
975
|
+
cp.optimize_greedy(costmod, Some(temperature), Some(seed));
|
|
976
|
+
// optimize any remaining disconnected terms
|
|
977
|
+
cp.optimize_remaining_by_size();
|
|
978
|
+
|
|
979
|
+
if cp.flops < best_flops {
|
|
980
|
+
best_flops = cp.flops;
|
|
981
|
+
best_path = Some(cp.ssa_path);
|
|
982
|
+
}
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
// convert to base 10 for easier comparison
|
|
986
|
+
best_flops *= f32::consts::LOG10_E;
|
|
987
|
+
|
|
988
|
+
if use_ssa.unwrap_or(false) {
|
|
989
|
+
(best_path.unwrap(), best_flops)
|
|
990
|
+
} else {
|
|
991
|
+
(ssa_to_linear(best_path.unwrap(), Some(n)), best_flops)
|
|
992
|
+
}
|
|
993
|
+
})
|
|
885
994
|
}
|
|
886
995
|
|
|
887
996
|
#[pyfunction]
|
|
888
|
-
#[pyo3()]
|
|
889
997
|
fn optimize_optimal(
|
|
998
|
+
py: Python,
|
|
890
999
|
inputs: Vec<Vec<char>>,
|
|
891
1000
|
output: Vec<char>,
|
|
892
1001
|
size_dict: Dict<char, f32>,
|
|
@@ -896,30 +1005,33 @@ fn optimize_optimal(
|
|
|
896
1005
|
simplify: Option<bool>,
|
|
897
1006
|
use_ssa: Option<bool>,
|
|
898
1007
|
) -> Vec<Vec<Node>> {
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
1008
|
+
py.allow_threads(|| {
|
|
1009
|
+
let n = inputs.len();
|
|
1010
|
+
let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
|
|
1011
|
+
if simplify.unwrap_or(true) {
|
|
1012
|
+
// perform simplifications
|
|
1013
|
+
cp.simplify();
|
|
1014
|
+
}
|
|
1015
|
+
// optimally contract each connected subgraph
|
|
1016
|
+
cp.optimize_optimal(minimize, cost_cap, search_outer);
|
|
1017
|
+
// optimize any remaining disconnected terms
|
|
1018
|
+
cp.optimize_remaining_by_size();
|
|
1019
|
+
if use_ssa.unwrap_or(false) {
|
|
1020
|
+
cp.ssa_path
|
|
1021
|
+
} else {
|
|
1022
|
+
ssa_to_linear(cp.ssa_path, Some(n))
|
|
1023
|
+
}
|
|
1024
|
+
})
|
|
914
1025
|
}
|
|
915
1026
|
|
|
916
1027
|
/// A Python module implemented in Rust.
|
|
917
1028
|
#[pymodule]
|
|
918
|
-
fn cotengrust(
|
|
1029
|
+
fn cotengrust(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
919
1030
|
m.add_function(wrap_pyfunction!(ssa_to_linear, m)?)?;
|
|
920
1031
|
m.add_function(wrap_pyfunction!(find_subgraphs, m)?)?;
|
|
921
1032
|
m.add_function(wrap_pyfunction!(optimize_simplify, m)?)?;
|
|
922
1033
|
m.add_function(wrap_pyfunction!(optimize_greedy, m)?)?;
|
|
1034
|
+
m.add_function(wrap_pyfunction!(optimize_random_greedy_track_flops, m)?)?;
|
|
923
1035
|
m.add_function(wrap_pyfunction!(optimize_optimal, m)?)?;
|
|
924
1036
|
Ok(())
|
|
925
1037
|
}
|
|
@@ -56,6 +56,15 @@ def get_rand_size_dict(inputs, d_min=2, d_max=3):
|
|
|
56
56
|
|
|
57
57
|
# these are taken from opt_einsum
|
|
58
58
|
test_case_eqs = [
|
|
59
|
+
# Test single-term equations
|
|
60
|
+
"->",
|
|
61
|
+
"a->a",
|
|
62
|
+
"ab->ab",
|
|
63
|
+
"ab->ba",
|
|
64
|
+
"abc->bca",
|
|
65
|
+
"abc->b",
|
|
66
|
+
"baa->ba",
|
|
67
|
+
"aba->b",
|
|
59
68
|
# Test scalar-like operations
|
|
60
69
|
"a,->a",
|
|
61
70
|
"ab,->ab",
|
|
@@ -188,18 +197,39 @@ def test_basic_rand(seed, which):
|
|
|
188
197
|
@requires_cotengra
|
|
189
198
|
def test_optimal_lattice_eq():
|
|
190
199
|
inputs, output, _, size_dict = ctg.utils.lattice_equation(
|
|
191
|
-
[4, 5], d_max=
|
|
200
|
+
[4, 5], d_max=2, seed=42
|
|
192
201
|
)
|
|
193
202
|
|
|
194
203
|
path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='flops')
|
|
195
204
|
tree = ctg.ContractionTree.from_path(
|
|
196
205
|
inputs, output, size_dict, path=path
|
|
197
206
|
)
|
|
198
|
-
assert tree.
|
|
207
|
+
assert tree.is_complete()
|
|
208
|
+
assert tree.contraction_cost() == 964
|
|
199
209
|
|
|
200
210
|
path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='size')
|
|
201
211
|
assert all(len(con) <= 2 for con in path)
|
|
202
212
|
tree = ctg.ContractionTree.from_path(
|
|
203
213
|
inputs, output, size_dict, path=path
|
|
204
214
|
)
|
|
205
|
-
assert tree.contraction_width() == pytest.approx(
|
|
215
|
+
assert tree.contraction_width() == pytest.approx(5)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@requires_cotengra
|
|
219
|
+
def test_optimize_random_greedy_log_flops():
|
|
220
|
+
inputs, output, _, size_dict = ctg.utils.lattice_equation(
|
|
221
|
+
[10, 10], d_max=3, seed=42
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
path, cost1 = ctgr.optimize_random_greedy_track_flops(
|
|
225
|
+
inputs, output, size_dict, ntrials=4, seed=42
|
|
226
|
+
)
|
|
227
|
+
_, cost2 = ctgr.optimize_random_greedy_track_flops(
|
|
228
|
+
inputs, output, size_dict, ntrials=4, seed=42
|
|
229
|
+
)
|
|
230
|
+
assert cost1 == cost2
|
|
231
|
+
tree = ctg.ContractionTree.from_path(
|
|
232
|
+
inputs, output, size_dict, path=path
|
|
233
|
+
)
|
|
234
|
+
assert tree.is_complete()
|
|
235
|
+
assert tree.contraction_cost(log=10) == pytest.approx(cost1)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|