cotengrust 0.1.1__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,7 +37,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
37
37
 
38
38
  [[package]]
39
39
  name = "cotengrust"
40
- version = "0.1.1"
40
+ version = "0.1.2"
41
41
  dependencies = [
42
42
  "bit-set",
43
43
  "ordered-float",
@@ -57,11 +57,17 @@ dependencies = [
57
57
  "wasi",
58
58
  ]
59
59
 
60
+ [[package]]
61
+ name = "heck"
62
+ version = "0.4.1"
63
+ source = "registry+https://github.com/rust-lang/crates.io-index"
64
+ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
65
+
60
66
  [[package]]
61
67
  name = "indoc"
62
- version = "1.0.9"
68
+ version = "2.0.4"
63
69
  source = "registry+https://github.com/rust-lang/crates.io-index"
64
- checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
70
+ checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8"
65
71
 
66
72
  [[package]]
67
73
  name = "libc"
@@ -105,9 +111,9 @@ checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
105
111
 
106
112
  [[package]]
107
113
  name = "ordered-float"
108
- version = "3.9.1"
114
+ version = "4.2.0"
109
115
  source = "registry+https://github.com/rust-lang/crates.io-index"
110
- checksum = "2a54938017eacd63036332b4ae5c8a49fc8c0c1d6d629893057e4f13609edd06"
116
+ checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e"
111
117
  dependencies = [
112
118
  "num-traits",
113
119
  ]
@@ -135,6 +141,12 @@ dependencies = [
135
141
  "windows-targets",
136
142
  ]
137
143
 
144
+ [[package]]
145
+ name = "portable-atomic"
146
+ version = "1.6.0"
147
+ source = "registry+https://github.com/rust-lang/crates.io-index"
148
+ checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
149
+
138
150
  [[package]]
139
151
  name = "ppv-lite86"
140
152
  version = "0.2.17"
@@ -152,15 +164,16 @@ dependencies = [
152
164
 
153
165
  [[package]]
154
166
  name = "pyo3"
155
- version = "0.19.2"
167
+ version = "0.21.1"
156
168
  source = "registry+https://github.com/rust-lang/crates.io-index"
157
- checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38"
169
+ checksum = "a7a8b1990bd018761768d5e608a13df8bd1ac5f678456e0f301bb93e5f3ea16b"
158
170
  dependencies = [
159
171
  "cfg-if",
160
172
  "indoc",
161
173
  "libc",
162
174
  "memoffset",
163
175
  "parking_lot",
176
+ "portable-atomic",
164
177
  "pyo3-build-config",
165
178
  "pyo3-ffi",
166
179
  "pyo3-macros",
@@ -169,9 +182,9 @@ dependencies = [
169
182
 
170
183
  [[package]]
171
184
  name = "pyo3-build-config"
172
- version = "0.19.2"
185
+ version = "0.21.1"
173
186
  source = "registry+https://github.com/rust-lang/crates.io-index"
174
- checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5"
187
+ checksum = "650dca34d463b6cdbdb02b1d71bfd6eb6b6816afc708faebb3bac1380ff4aef7"
175
188
  dependencies = [
176
189
  "once_cell",
177
190
  "target-lexicon",
@@ -179,9 +192,9 @@ dependencies = [
179
192
 
180
193
  [[package]]
181
194
  name = "pyo3-ffi"
182
- version = "0.19.2"
195
+ version = "0.21.1"
183
196
  source = "registry+https://github.com/rust-lang/crates.io-index"
184
- checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9"
197
+ checksum = "09a7da8fc04a8a2084909b59f29e1b8474decac98b951d77b80b26dc45f046ad"
185
198
  dependencies = [
186
199
  "libc",
187
200
  "pyo3-build-config",
@@ -189,9 +202,9 @@ dependencies = [
189
202
 
190
203
  [[package]]
191
204
  name = "pyo3-macros"
192
- version = "0.19.2"
205
+ version = "0.21.1"
193
206
  source = "registry+https://github.com/rust-lang/crates.io-index"
194
- checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1"
207
+ checksum = "4b8a199fce11ebb28e3569387228836ea98110e43a804a530a9fd83ade36d513"
195
208
  dependencies = [
196
209
  "proc-macro2",
197
210
  "pyo3-macros-backend",
@@ -201,11 +214,13 @@ dependencies = [
201
214
 
202
215
  [[package]]
203
216
  name = "pyo3-macros-backend"
204
- version = "0.19.2"
217
+ version = "0.21.1"
205
218
  source = "registry+https://github.com/rust-lang/crates.io-index"
206
- checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536"
219
+ checksum = "93fbbfd7eb553d10036513cb122b888dcd362a945a00b06c165f2ab480d4cc3b"
207
220
  dependencies = [
221
+ "heck",
208
222
  "proc-macro2",
223
+ "pyo3-build-config",
209
224
  "quote",
210
225
  "syn",
211
226
  ]
@@ -278,9 +293,9 @@ checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
278
293
 
279
294
  [[package]]
280
295
  name = "syn"
281
- version = "1.0.109"
296
+ version = "2.0.32"
282
297
  source = "registry+https://github.com/rust-lang/crates.io-index"
283
- checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
298
+ checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2"
284
299
  dependencies = [
285
300
  "proc-macro2",
286
301
  "quote",
@@ -301,9 +316,9 @@ checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c"
301
316
 
302
317
  [[package]]
303
318
  name = "unindent"
304
- version = "0.1.11"
319
+ version = "0.2.3"
305
320
  source = "registry+https://github.com/rust-lang/crates.io-index"
306
- checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c"
321
+ checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
307
322
 
308
323
  [[package]]
309
324
  name = "wasi"
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "cotengrust"
3
- version = "0.1.1"
3
+ version = "0.1.2"
4
4
  edition = "2021"
5
5
 
6
6
  # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -10,8 +10,8 @@ crate-type = ["cdylib"]
10
10
 
11
11
  [dependencies]
12
12
  bit-set = "0.5"
13
- pyo3 = "0.19"
14
- ordered-float = "3.9"
13
+ ordered-float = "4.2"
14
+ pyo3 = "0.21"
15
15
  rand = "0.8"
16
16
  rustc-hash = "1.1"
17
17
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: cotengrust
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cotengrust"
3
- version = "0.1.1"
3
+ version = "0.1.2"
4
4
  description = "Fast contraction ordering primitives for tensor networks."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.8"
@@ -15,7 +15,7 @@ authors = [
15
15
  ]
16
16
 
17
17
  [build-system]
18
- requires = ["maturin>=0.15,<0.16"]
18
+ requires = ["maturin>=1.0,<2.0"]
19
19
  build-backend = "maturin"
20
20
 
21
21
  [tool.maturin]
@@ -2,8 +2,9 @@ use bit_set::BitSet;
2
2
  use ordered_float::OrderedFloat;
3
3
  use pyo3::prelude::*;
4
4
  use rand::Rng;
5
+ use rand::SeedableRng;
5
6
  use rustc_hash::FxHashMap;
6
- use std::collections::{BTreeSet, BinaryHeap};
7
+ use std::collections::{BTreeSet, BinaryHeap, HashSet};
7
8
  use std::f32;
8
9
 
9
10
  use FxHashMap as Dict;
@@ -23,6 +24,7 @@ type BitPath = Vec<(Subgraph, Subgraph)>;
23
24
  type SubContraction = (Legs, Score, BitPath);
24
25
 
25
26
  /// helper struct to build contractions from bottom up
27
+ #[derive(Clone)]
26
28
  struct ContractionProcessor {
27
29
  nodes: Dict<Node, Legs>,
28
30
  edges: Dict<Ix, BTreeSet<Node>>,
@@ -30,6 +32,8 @@ struct ContractionProcessor {
30
32
  sizes: Vec<Score>,
31
33
  ssa: Node,
32
34
  ssa_path: SSAPath,
35
+ track_flops: bool,
36
+ flops: Score,
33
37
  }
34
38
 
35
39
  /// given log(x) and log(y) compute log(x + y), without exponentiating both
@@ -94,6 +98,21 @@ fn compute_size(legs: &Legs, sizes: &Vec<Score>) -> Score {
94
98
  legs.iter().map(|&(ix, _)| sizes[ix as usize]).sum()
95
99
  }
96
100
 
101
+ fn compute_flops(ilegs: &Legs, jlegs: &Legs, sizes: &Vec<Score>) -> Score {
102
+ let mut flops: Score = 0.0;
103
+ let mut seen: HashSet<Ix> = HashSet::with_capacity(ilegs.len());
104
+ for &(ix, _) in ilegs {
105
+ seen.insert(ix);
106
+ flops += sizes[ix as usize];
107
+ }
108
+ for (ix, _) in jlegs {
109
+ if !seen.contains(ix) {
110
+ flops += sizes[*ix as usize];
111
+ }
112
+ }
113
+ flops
114
+ }
115
+
97
116
  fn is_simplifiable(legs: &Legs, appearances: &Vec<Count>) -> bool {
98
117
  let mut prev_ix = Node::MAX;
99
118
  for &(ix, ix_count) in legs {
@@ -131,7 +150,12 @@ impl ContractionProcessor {
131
150
  inputs: Vec<Vec<char>>,
132
151
  output: Vec<char>,
133
152
  size_dict: Dict<char, f32>,
153
+ track_flops: bool,
134
154
  ) -> ContractionProcessor {
155
+ if size_dict.len() > Ix::MAX as usize {
156
+ panic!("cotengrust: too many indices, maximum is {}", Ix::MAX);
157
+ }
158
+
135
159
  let mut nodes: Dict<Node, Legs> = Dict::default();
136
160
  let mut edges: Dict<Ix, BTreeSet<Node>> = Dict::default();
137
161
  let mut indmap: Dict<char, Ix> = Dict::default();
@@ -149,7 +173,7 @@ impl ContractionProcessor {
149
173
  indmap.insert(ind, c);
150
174
  edges.insert(c, std::iter::once(i as Node).collect());
151
175
  appearances.push(1);
152
- sizes.push(f32::log(size_dict[&ind] as f32, 2.0));
176
+ sizes.push(f32::ln(size_dict[&ind] as f32));
153
177
  legs.push((c, 1));
154
178
  c += 1;
155
179
  }
@@ -170,6 +194,7 @@ impl ContractionProcessor {
170
194
 
171
195
  let ssa = nodes.len() as Node;
172
196
  let ssa_path: SSAPath = Vec::with_capacity(2 * ssa as usize - 1);
197
+ let flops: Score = 0.0;
173
198
 
174
199
  ContractionProcessor {
175
200
  nodes,
@@ -178,6 +203,8 @@ impl ContractionProcessor {
178
203
  sizes,
179
204
  ssa,
180
205
  ssa_path,
206
+ track_flops,
207
+ flops,
181
208
  }
182
209
  }
183
210
 
@@ -225,7 +252,9 @@ impl ContractionProcessor {
225
252
  for (ix, _) in &legs {
226
253
  self.edges
227
254
  .entry(*ix)
228
- .and_modify(|nodes| {nodes.insert(i);})
255
+ .and_modify(|nodes| {
256
+ nodes.insert(i);
257
+ })
229
258
  .or_insert(std::iter::once(i as Node).collect());
230
259
  }
231
260
  self.nodes.insert(i, legs);
@@ -236,12 +265,27 @@ impl ContractionProcessor {
236
265
  fn contract_nodes(&mut self, i: Node, j: Node) -> Node {
237
266
  let ilegs = self.pop_node(i);
238
267
  let jlegs = self.pop_node(j);
268
+ if self.track_flops {
269
+ self.flops = logadd(self.flops, compute_flops(&ilegs, &jlegs, &self.sizes));
270
+ }
239
271
  let new_legs = compute_legs(&ilegs, &jlegs, &self.appearances);
240
272
  let k = self.add_node(new_legs);
241
273
  self.ssa_path.push(vec![i, j]);
242
274
  k
243
275
  }
244
276
 
277
+ /// contract two nodes (which we already know the legs for), return the new node id
278
+ fn contract_nodes_given_legs(&mut self, i: Node, j: Node, new_legs: Legs) -> Node {
279
+ let ilegs = self.pop_node(i);
280
+ let jlegs = self.pop_node(j);
281
+ if self.track_flops {
282
+ self.flops = logadd(self.flops, compute_flops(&ilegs, &jlegs, &self.sizes));
283
+ }
284
+ let k = self.add_node(new_legs);
285
+ self.ssa_path.push(vec![i, j]);
286
+ k
287
+ }
288
+
245
289
  /// find any indices that appear in all terms and just remove/ignore them
246
290
  fn simplify_batch(&mut self) {
247
291
  let mut ix_to_remove = Vec::new();
@@ -366,13 +410,27 @@ impl ContractionProcessor {
366
410
  }
367
411
 
368
412
  /// greedily optimize the contraction order of all terms
369
- fn optimize_greedy(&mut self, costmod: Option<f32>, temperature: Option<f32>) {
370
- let mut rng = rand::thread_rng();
413
+ fn optimize_greedy(
414
+ &mut self,
415
+ costmod: Option<f32>,
416
+ temperature: Option<f32>,
417
+ seed: Option<u64>,
418
+ ) {
371
419
  let coeff_t = temperature.unwrap_or(0.0);
372
420
  let log_coeff_a = f32::ln(costmod.unwrap_or(1.0));
373
421
 
422
+ let mut rng = if coeff_t != 0.0 {
423
+ Some(match seed {
424
+ Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
425
+ None => rand::rngs::StdRng::from_entropy(),
426
+ })
427
+ } else {
428
+ // zero temp - no need for rng
429
+ None
430
+ };
431
+
374
432
  let mut local_score = |sa: Score, sb: Score, sab: Score| -> Score {
375
- let gumbel = if coeff_t != 0.0 {
433
+ let gumbel = if let Some(rng) = &mut rng {
376
434
  coeff_t * -f32::ln(-f32::ln(rng.gen()))
377
435
  } else {
378
436
  0.0 as f32
@@ -424,11 +482,7 @@ impl ContractionProcessor {
424
482
  }
425
483
 
426
484
  // perform contraction:
427
- // we already have the legs, so don't call contract_nodes
428
- self.pop_node(i);
429
- self.pop_node(j);
430
- let k = self.add_node(klegs.clone());
431
- self.ssa_path.push(vec![i, j]);
485
+ let k = self.contract_nodes_given_legs(i, j, klegs.clone());
432
486
  node_sizes.insert(k, ksize);
433
487
 
434
488
  for l in self.neighbors(k) {
@@ -800,7 +854,6 @@ impl ContractionProcessor {
800
854
  // --------------------------- PYTHON FUNCTIONS ---------------------------- //
801
855
 
802
856
  #[pyfunction]
803
- #[pyo3()]
804
857
  fn ssa_to_linear(ssa_path: SSAPath, n: Option<usize>) -> SSAPath {
805
858
  let n = match n {
806
859
  Some(n) => n,
@@ -828,18 +881,16 @@ fn ssa_to_linear(ssa_path: SSAPath, n: Option<usize>) -> SSAPath {
828
881
  }
829
882
 
830
883
  #[pyfunction]
831
- #[pyo3()]
832
884
  fn find_subgraphs(
833
885
  inputs: Vec<Vec<char>>,
834
886
  output: Vec<char>,
835
887
  size_dict: Dict<char, f32>,
836
888
  ) -> Vec<Vec<Node>> {
837
- let cp = ContractionProcessor::new(inputs, output, size_dict);
889
+ let cp = ContractionProcessor::new(inputs, output, size_dict, false);
838
890
  cp.subgraphs()
839
891
  }
840
892
 
841
893
  #[pyfunction]
842
- #[pyo3()]
843
894
  fn optimize_simplify(
844
895
  inputs: Vec<Vec<char>>,
845
896
  output: Vec<char>,
@@ -847,7 +898,7 @@ fn optimize_simplify(
847
898
  use_ssa: Option<bool>,
848
899
  ) -> SSAPath {
849
900
  let n = inputs.len();
850
- let mut cp = ContractionProcessor::new(inputs, output, size_dict);
901
+ let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
851
902
  cp.simplify();
852
903
  if use_ssa.unwrap_or(false) {
853
904
  cp.ssa_path
@@ -857,36 +908,94 @@ fn optimize_simplify(
857
908
  }
858
909
 
859
910
  #[pyfunction]
860
- #[pyo3()]
861
911
  fn optimize_greedy(
912
+ py: Python,
862
913
  inputs: Vec<Vec<char>>,
863
914
  output: Vec<char>,
864
915
  size_dict: Dict<char, f32>,
865
916
  costmod: Option<f32>,
866
917
  temperature: Option<f32>,
918
+ seed: Option<u64>,
867
919
  simplify: Option<bool>,
868
920
  use_ssa: Option<bool>,
869
921
  ) -> Vec<Vec<Node>> {
870
- let n = inputs.len();
871
- let mut cp = ContractionProcessor::new(inputs, output, size_dict);
872
- if simplify.unwrap_or(true) {
873
- // perform simplifications
874
- cp.simplify();
875
- }
876
- // greddily contract each connected subgraph
877
- cp.optimize_greedy(costmod, temperature);
878
- // optimize any remaining disconnected terms
879
- cp.optimize_remaining_by_size();
880
- if use_ssa.unwrap_or(false) {
881
- cp.ssa_path
882
- } else {
883
- ssa_to_linear(cp.ssa_path, Some(n))
884
- }
922
+ py.allow_threads(|| {
923
+ let n = inputs.len();
924
+ let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
925
+ if simplify.unwrap_or(true) {
926
+ // perform simplifications
927
+ cp.simplify();
928
+ }
929
+ // greedily contract each connected subgraph
930
+ cp.optimize_greedy(costmod, temperature, seed);
931
+ // optimize any remaining disconnected terms
932
+ cp.optimize_remaining_by_size();
933
+ if use_ssa.unwrap_or(false) {
934
+ cp.ssa_path
935
+ } else {
936
+ ssa_to_linear(cp.ssa_path, Some(n))
937
+ }
938
+ })
939
+ }
940
+
941
+ #[pyfunction]
942
+ fn optimize_random_greedy_track_flops(
943
+ py: Python,
944
+ inputs: Vec<Vec<char>>,
945
+ output: Vec<char>,
946
+ size_dict: Dict<char, f32>,
947
+ ntrials: usize,
948
+ costmod: Option<f32>,
949
+ temperature: Option<f32>,
950
+ seed: Option<u64>,
951
+ simplify: Option<bool>,
952
+ use_ssa: Option<bool>,
953
+ ) -> (Vec<Vec<Node>>, Score) {
954
+ py.allow_threads(|| {
955
+ let temperature = temperature.unwrap_or(0.01);
956
+ let mut rng = match seed {
957
+ Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
958
+ None => rand::rngs::StdRng::from_entropy(),
959
+ };
960
+ let seeds = (0..ntrials).map(|_| rng.gen()).collect::<Vec<u64>>();
961
+
962
+ let n: usize = inputs.len();
963
+ // construct processor and perform simplifications once
964
+ let mut cp0 = ContractionProcessor::new(inputs, output, size_dict, true);
965
+ if simplify.unwrap_or(true) {
966
+ cp0.simplify();
967
+ }
968
+
969
+ let mut best_path = None;
970
+ let mut best_flops = f32::INFINITY;
971
+
972
+ for seed in seeds {
973
+ let mut cp = cp0.clone();
974
+ // greedily contract each connected subgraph
975
+ cp.optimize_greedy(costmod, Some(temperature), Some(seed));
976
+ // optimize any remaining disconnected terms
977
+ cp.optimize_remaining_by_size();
978
+
979
+ if cp.flops < best_flops {
980
+ best_flops = cp.flops;
981
+ best_path = Some(cp.ssa_path);
982
+ }
983
+ }
984
+
985
+ // convert to base 10 for easier comparison
986
+ best_flops *= f32::consts::LOG10_E;
987
+
988
+ if use_ssa.unwrap_or(false) {
989
+ (best_path.unwrap(), best_flops)
990
+ } else {
991
+ (ssa_to_linear(best_path.unwrap(), Some(n)), best_flops)
992
+ }
993
+ })
885
994
  }
886
995
 
887
996
  #[pyfunction]
888
- #[pyo3()]
889
997
  fn optimize_optimal(
998
+ py: Python,
890
999
  inputs: Vec<Vec<char>>,
891
1000
  output: Vec<char>,
892
1001
  size_dict: Dict<char, f32>,
@@ -896,30 +1005,33 @@ fn optimize_optimal(
896
1005
  simplify: Option<bool>,
897
1006
  use_ssa: Option<bool>,
898
1007
  ) -> Vec<Vec<Node>> {
899
- let n = inputs.len();
900
- let mut cp = ContractionProcessor::new(inputs, output, size_dict);
901
- if simplify.unwrap_or(true) {
902
- // perform simplifications
903
- cp.simplify();
904
- }
905
- // optimally contract each connected subgraph
906
- cp.optimize_optimal(minimize, cost_cap, search_outer);
907
- // optimize any remaining disconnected terms
908
- cp.optimize_remaining_by_size();
909
- if use_ssa.unwrap_or(false) {
910
- cp.ssa_path
911
- } else {
912
- ssa_to_linear(cp.ssa_path, Some(n))
913
- }
1008
+ py.allow_threads(|| {
1009
+ let n = inputs.len();
1010
+ let mut cp = ContractionProcessor::new(inputs, output, size_dict, false);
1011
+ if simplify.unwrap_or(true) {
1012
+ // perform simplifications
1013
+ cp.simplify();
1014
+ }
1015
+ // optimally contract each connected subgraph
1016
+ cp.optimize_optimal(minimize, cost_cap, search_outer);
1017
+ // optimize any remaining disconnected terms
1018
+ cp.optimize_remaining_by_size();
1019
+ if use_ssa.unwrap_or(false) {
1020
+ cp.ssa_path
1021
+ } else {
1022
+ ssa_to_linear(cp.ssa_path, Some(n))
1023
+ }
1024
+ })
914
1025
  }
915
1026
 
916
1027
  /// A Python module implemented in Rust.
917
1028
  #[pymodule]
918
- fn cotengrust(_py: Python, m: &PyModule) -> PyResult<()> {
1029
+ fn cotengrust(m: &Bound<'_, PyModule>) -> PyResult<()> {
919
1030
  m.add_function(wrap_pyfunction!(ssa_to_linear, m)?)?;
920
1031
  m.add_function(wrap_pyfunction!(find_subgraphs, m)?)?;
921
1032
  m.add_function(wrap_pyfunction!(optimize_simplify, m)?)?;
922
1033
  m.add_function(wrap_pyfunction!(optimize_greedy, m)?)?;
1034
+ m.add_function(wrap_pyfunction!(optimize_random_greedy_track_flops, m)?)?;
923
1035
  m.add_function(wrap_pyfunction!(optimize_optimal, m)?)?;
924
1036
  Ok(())
925
1037
  }
@@ -56,6 +56,15 @@ def get_rand_size_dict(inputs, d_min=2, d_max=3):
56
56
 
57
57
  # these are taken from opt_einsum
58
58
  test_case_eqs = [
59
+ # Test single-term equations
60
+ "->",
61
+ "a->a",
62
+ "ab->ab",
63
+ "ab->ba",
64
+ "abc->bca",
65
+ "abc->b",
66
+ "baa->ba",
67
+ "aba->b",
59
68
  # Test scalar-like operations
60
69
  "a,->a",
61
70
  "ab,->ab",
@@ -188,18 +197,39 @@ def test_basic_rand(seed, which):
188
197
  @requires_cotengra
189
198
  def test_optimal_lattice_eq():
190
199
  inputs, output, _, size_dict = ctg.utils.lattice_equation(
191
- [4, 5], d_max=3, seed=42
200
+ [4, 5], d_max=2, seed=42
192
201
  )
193
202
 
194
203
  path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='flops')
195
204
  tree = ctg.ContractionTree.from_path(
196
205
  inputs, output, size_dict, path=path
197
206
  )
198
- assert tree.contraction_cost() == 3628
207
+ assert tree.is_complete()
208
+ assert tree.contraction_cost() == 964
199
209
 
200
210
  path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='size')
201
211
  assert all(len(con) <= 2 for con in path)
202
212
  tree = ctg.ContractionTree.from_path(
203
213
  inputs, output, size_dict, path=path
204
214
  )
205
- assert tree.contraction_width() == pytest.approx(6.754887502163468)
215
+ assert tree.contraction_width() == pytest.approx(5)
216
+
217
+
218
+ @requires_cotengra
219
+ def test_optimize_random_greedy_log_flops():
220
+ inputs, output, _, size_dict = ctg.utils.lattice_equation(
221
+ [10, 10], d_max=3, seed=42
222
+ )
223
+
224
+ path, cost1 = ctgr.optimize_random_greedy_track_flops(
225
+ inputs, output, size_dict, ntrials=4, seed=42
226
+ )
227
+ _, cost2 = ctgr.optimize_random_greedy_track_flops(
228
+ inputs, output, size_dict, ntrials=4, seed=42
229
+ )
230
+ assert cost1 == cost2
231
+ tree = ctg.ContractionTree.from_path(
232
+ inputs, output, size_dict, path=path
233
+ )
234
+ assert tree.is_complete()
235
+ assert tree.contraction_cost(log=10) == pytest.approx(cost1)
File without changes
File without changes
File without changes