cotengrust 0.1.5__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,7 +37,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
37
37
 
38
38
  [[package]]
39
39
  name = "cotengrust"
40
- version = "0.1.5"
40
+ version = "0.2.0"
41
41
  dependencies = [
42
42
  "bit-set",
43
43
  "num-traits",
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "cotengrust"
3
- version = "0.1.5"
3
+ version = "0.2.0"
4
4
  edition = "2021"
5
5
  readme = "README.md"
6
6
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cotengrust
3
- Version: 0.1.5
3
+ Version: 0.2.0
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cotengrust"
3
- version = "0.1.5"
3
+ version = "0.2.0"
4
4
  description = "Fast contraction ordering primitives for tensor networks."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.8"
@@ -12,7 +12,7 @@ use std::hash::Hash;
12
12
  use FxHashMap as Dict;
13
13
 
14
14
  // n.b. this constrains the maximum number index appearances < 256
15
- type Count = u8;
15
+ type Count = u16;
16
16
  type Score = f32;
17
17
  type GreedyScore = OrderedFloat<Score>;
18
18
  type SSAPath = Vec<Vec<u32>>;
@@ -245,6 +245,24 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
245
245
  js
246
246
  }
247
247
 
248
+ /// like neighbors but skip edges with too many neighbors, for greedy
249
+ fn neighbors_limit(&self, i: Node, max_neighbors: usize) -> BTreeSet<Node> {
250
+ let mut js = BTreeSet::default();
251
+ for (ix, _) in self.nodes[&i].iter() {
252
+ if max_neighbors != 0 && self.edges[&ix].len() > max_neighbors {
253
+ // basically a batch index with too many combinations -> skip
254
+ continue;
255
+ }
256
+
257
+ self.edges[&ix].iter().for_each(|&j| {
258
+ if j != i {
259
+ js.insert(j);
260
+ };
261
+ });
262
+ }
263
+ js
264
+ }
265
+
248
266
  /// remove an index from the graph, updating all legs
249
267
  fn remove_ix(&mut self, ix: Ix) {
250
268
  for j in self.edges.remove(&ix).unwrap() {
@@ -439,10 +457,12 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
439
457
  &mut self,
440
458
  costmod: Option<f32>,
441
459
  temperature: Option<f32>,
460
+ max_neighbors: Option<usize>,
442
461
  seed: Option<u64>,
443
462
  ) -> bool {
444
463
  let coeff_t = temperature.unwrap_or(0.0);
445
464
  let log_coeff_a = f32::ln(costmod.unwrap_or(1.0));
465
+ let max_neighbors = max_neighbors.unwrap_or(16);
446
466
 
447
467
  let mut rng = if coeff_t != 0.0 {
448
468
  Some(match seed {
@@ -479,6 +499,11 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
479
499
 
480
500
  // get the initial candidate contractions
481
501
  for ix_nodes in self.edges.values() {
502
+ if max_neighbors != 0 && ix_nodes.len() > max_neighbors {
503
+ // basically a batch index with too many combinations -> skip
504
+ continue;
505
+ }
506
+
482
507
  // convert to vector for combinational indexing
483
508
  let ix_nodes: Vec<Node> = ix_nodes.iter().cloned().collect();
484
509
  // for all combinations of nodes with a connected edge
@@ -516,7 +541,7 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
516
541
 
517
542
  node_sizes.insert(k, ksize);
518
543
 
519
- for l in self.neighbors(k) {
544
+ for l in self.neighbors_limit(k, max_neighbors) {
520
545
  // assess all neighboring contractions of new node
521
546
  let llegs = &self.nodes[&l];
522
547
  let lsize = node_sizes[&l];
@@ -528,6 +553,22 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
528
553
  contractions.insert(c, (k, l, msize, mlegs));
529
554
  c -= 1;
530
555
  }
556
+
557
+ // // potential queue pruning?
558
+ // if queue.len() > 100_000 {
559
+ // let mut valid_contractions = Vec::new();
560
+ // for (score, cid) in queue.drain() {
561
+ // if let Some((i, j, _, _)) = contractions.get(&cid) {
562
+ // if self.nodes.contains_key(&i) && self.nodes.contains_key(&j) {
563
+ // valid_contractions.push((score, cid));
564
+ // } else {
565
+ // // Remove stale contraction from map
566
+ // contractions.remove(&cid);
567
+ // }
568
+ // }
569
+ // }
570
+ // queue = BinaryHeap::from(valid_contractions);
571
+ // }
531
572
  }
532
573
  // success
533
574
  return true;
@@ -941,6 +982,7 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
941
982
  size_dict: Dict<char, f32>,
942
983
  costmod: Option<f32>,
943
984
  temperature: Option<f32>,
985
+ max_neighbors: Option<usize>,
944
986
  seed: Option<u64>,
945
987
  simplify: bool,
946
988
  ) -> SSAPath {
@@ -949,7 +991,7 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
949
991
  if simplify {
950
992
  cp.simplify();
951
993
  }
952
- cp.optimize_greedy(costmod, temperature, seed);
994
+ cp.optimize_greedy(costmod, temperature, max_neighbors, seed);
953
995
  cp.optimize_remaining_by_size();
954
996
  cp.ssa_path
955
997
  }
@@ -987,6 +1029,7 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
987
1029
  log_temp_min: f32,
988
1030
  log_temp_diff: f32,
989
1031
  is_const_temp: bool,
1032
+ max_neighbors: Option<usize>,
990
1033
  rng: &mut rand::rngs::StdRng,
991
1034
  ) -> (SSAPath, Score) {
992
1035
  let mut cp0: ContractionProcessor<Ix, Node> =
@@ -1013,7 +1056,8 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
1013
1056
  f32::exp(log_temp_min + rng.random::<f32>() * log_temp_diff)
1014
1057
  };
1015
1058
 
1016
- let success = cp.optimize_greedy(Some(costmod), Some(temperature), Some(*seed));
1059
+ let success =
1060
+ cp.optimize_greedy(Some(costmod), Some(temperature), max_neighbors, Some(*seed));
1017
1061
 
1018
1062
  if !success {
1019
1063
  continue;
@@ -1135,7 +1179,7 @@ fn optimize_simplify(
1135
1179
  }
1136
1180
 
1137
1181
  #[pyfunction]
1138
- #[pyo3(signature = (inputs, output, size_dict, costmod=None, temperature=None, seed=None, simplify=None, use_ssa=None))]
1182
+ #[pyo3(signature = (inputs, output, size_dict, costmod=None, temperature=None, max_neighbors=None, seed=None, simplify=None, use_ssa=None))]
1139
1183
  /// Find a contraction path using a (randomizable) greedy algorithm.
1140
1184
  ///
1141
1185
  /// Parameters
@@ -1160,6 +1204,12 @@ fn optimize_simplify(
1160
1204
  /// score -> sign(score) * log(|score|) - temperature * gumbel()
1161
1205
  ///
1162
1206
  /// which implements boltzmann sampling.
1207
+ /// max_neighbors : int, optional
1208
+ /// If non-zero, skip any index that connects to more than this many
1209
+ /// nodes. This is useful to avoid combinatorial explosions when
1210
+ /// dealing with essentially batch indices. Default: 16.
1211
+ /// seed : int, optional
1212
+ /// The seed for the random number generator.
1163
1213
  /// simplify : bool, optional
1164
1214
  /// Whether to perform simplifications before optimizing. These are:
1165
1215
  ///
@@ -1190,6 +1240,7 @@ fn optimize_greedy(
1190
1240
  size_dict: Dict<char, f32>,
1191
1241
  costmod: Option<f32>,
1192
1242
  temperature: Option<f32>,
1243
+ max_neighbors: Option<usize>,
1193
1244
  seed: Option<u64>,
1194
1245
  simplify: Option<bool>,
1195
1246
  use_ssa: Option<bool>,
@@ -1208,6 +1259,7 @@ fn optimize_greedy(
1208
1259
  size_dict,
1209
1260
  costmod,
1210
1261
  temperature,
1262
+ max_neighbors,
1211
1263
  seed,
1212
1264
  simplify,
1213
1265
  )
@@ -1219,6 +1271,7 @@ fn optimize_greedy(
1219
1271
  size_dict,
1220
1272
  costmod,
1221
1273
  temperature,
1274
+ max_neighbors,
1222
1275
  seed,
1223
1276
  simplify,
1224
1277
  )
@@ -1229,6 +1282,7 @@ fn optimize_greedy(
1229
1282
  size_dict,
1230
1283
  costmod,
1231
1284
  temperature,
1285
+ max_neighbors,
1232
1286
  seed,
1233
1287
  simplify,
1234
1288
  ),
@@ -1243,7 +1297,7 @@ fn optimize_greedy(
1243
1297
  }
1244
1298
 
1245
1299
  #[pyfunction]
1246
- #[pyo3(signature = (inputs, output, size_dict, ntrials, costmod=None, temperature=None, seed=None, simplify=None, use_ssa=None))]
1300
+ #[pyo3(signature = (inputs, output, size_dict, ntrials, costmod=None, temperature=None, max_neighbors=None, seed=None, simplify=None, use_ssa=None))]
1247
1301
  /// Perform a batch of random greedy optimizations, simulteneously tracking
1248
1302
  /// the best contraction path in terms of flops, so as to avoid constructing a
1249
1303
  /// separate contraction tree.
@@ -1273,6 +1327,10 @@ fn optimize_greedy(
1273
1327
  ///
1274
1328
  /// which implements boltzmann sampling. It is sampled log-uniformly from
1275
1329
  /// the given range.
1330
+ /// max_neighbors : int, optional
1331
+ /// If non-zero, skip any index that connects to more than this many
1332
+ /// nodes. This is useful to avoid combinatorial explosions when
1333
+ /// dealing with essentially batch indices. Default: 16.
1276
1334
  /// seed : int, optional
1277
1335
  /// The seed for the random number generator.
1278
1336
  /// simplify : bool, optional
@@ -1309,6 +1367,7 @@ fn optimize_random_greedy_track_flops(
1309
1367
  ntrials: usize,
1310
1368
  costmod: Option<(f32, f32)>,
1311
1369
  temperature: Option<(f32, f32)>,
1370
+ max_neighbors: Option<usize>,
1312
1371
  seed: Option<u64>,
1313
1372
  simplify: Option<bool>,
1314
1373
  use_ssa: Option<bool>,
@@ -1350,6 +1409,7 @@ fn optimize_random_greedy_track_flops(
1350
1409
  log_temp_min,
1351
1410
  log_temp_diff,
1352
1411
  is_const_temp,
1412
+ max_neighbors,
1353
1413
  &mut rng,
1354
1414
  )
1355
1415
  }
@@ -1367,6 +1427,7 @@ fn optimize_random_greedy_track_flops(
1367
1427
  log_temp_min,
1368
1428
  log_temp_diff,
1369
1429
  is_const_temp,
1430
+ max_neighbors,
1370
1431
  &mut rng,
1371
1432
  )
1372
1433
  }
@@ -1383,6 +1444,7 @@ fn optimize_random_greedy_track_flops(
1383
1444
  log_temp_min,
1384
1445
  log_temp_diff,
1385
1446
  is_const_temp,
1447
+ max_neighbors,
1386
1448
  &mut rng,
1387
1449
  ),
1388
1450
  };
File without changes
File without changes
File without changes
File without changes