cotengrust 0.1.5__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cotengrust-0.1.5 → cotengrust-0.2.0}/Cargo.lock +1 -1
- {cotengrust-0.1.5 → cotengrust-0.2.0}/Cargo.toml +1 -1
- {cotengrust-0.1.5 → cotengrust-0.2.0}/PKG-INFO +1 -1
- {cotengrust-0.1.5 → cotengrust-0.2.0}/pyproject.toml +1 -1
- {cotengrust-0.1.5 → cotengrust-0.2.0}/src/lib.rs +68 -6
- {cotengrust-0.1.5 → cotengrust-0.2.0}/.github/dependabot.yml +0 -0
- {cotengrust-0.1.5 → cotengrust-0.2.0}/.github/workflows/CI.yml +0 -0
- {cotengrust-0.1.5 → cotengrust-0.2.0}/.gitignore +0 -0
- {cotengrust-0.1.5 → cotengrust-0.2.0}/LICENSE +0 -0
- {cotengrust-0.1.5 → cotengrust-0.2.0}/README.md +0 -0
- {cotengrust-0.1.5 → cotengrust-0.2.0}/cotengrust.pyi +0 -0
- {cotengrust-0.1.5 → cotengrust-0.2.0}/tests/test_cotengrust.py +0 -0
|
@@ -12,7 +12,7 @@ use std::hash::Hash;
|
|
|
12
12
|
use FxHashMap as Dict;
|
|
13
13
|
|
|
14
14
|
// n.b. this constrains the maximum number index appearances < 256
|
|
15
|
-
type Count =
|
|
15
|
+
type Count = u16;
|
|
16
16
|
type Score = f32;
|
|
17
17
|
type GreedyScore = OrderedFloat<Score>;
|
|
18
18
|
type SSAPath = Vec<Vec<u32>>;
|
|
@@ -245,6 +245,24 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
|
|
|
245
245
|
js
|
|
246
246
|
}
|
|
247
247
|
|
|
248
|
+
/// like neighbors but skip edges with too many neighbors, for greedy
|
|
249
|
+
fn neighbors_limit(&self, i: Node, max_neighbors: usize) -> BTreeSet<Node> {
|
|
250
|
+
let mut js = BTreeSet::default();
|
|
251
|
+
for (ix, _) in self.nodes[&i].iter() {
|
|
252
|
+
if max_neighbors != 0 && self.edges[&ix].len() > max_neighbors {
|
|
253
|
+
// basically a batch index with too many combinations -> skip
|
|
254
|
+
continue;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
self.edges[&ix].iter().for_each(|&j| {
|
|
258
|
+
if j != i {
|
|
259
|
+
js.insert(j);
|
|
260
|
+
};
|
|
261
|
+
});
|
|
262
|
+
}
|
|
263
|
+
js
|
|
264
|
+
}
|
|
265
|
+
|
|
248
266
|
/// remove an index from the graph, updating all legs
|
|
249
267
|
fn remove_ix(&mut self, ix: Ix) {
|
|
250
268
|
for j in self.edges.remove(&ix).unwrap() {
|
|
@@ -439,10 +457,12 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
|
|
|
439
457
|
&mut self,
|
|
440
458
|
costmod: Option<f32>,
|
|
441
459
|
temperature: Option<f32>,
|
|
460
|
+
max_neighbors: Option<usize>,
|
|
442
461
|
seed: Option<u64>,
|
|
443
462
|
) -> bool {
|
|
444
463
|
let coeff_t = temperature.unwrap_or(0.0);
|
|
445
464
|
let log_coeff_a = f32::ln(costmod.unwrap_or(1.0));
|
|
465
|
+
let max_neighbors = max_neighbors.unwrap_or(16);
|
|
446
466
|
|
|
447
467
|
let mut rng = if coeff_t != 0.0 {
|
|
448
468
|
Some(match seed {
|
|
@@ -479,6 +499,11 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
|
|
|
479
499
|
|
|
480
500
|
// get the initial candidate contractions
|
|
481
501
|
for ix_nodes in self.edges.values() {
|
|
502
|
+
if max_neighbors != 0 && ix_nodes.len() > max_neighbors {
|
|
503
|
+
// basically a batch index with too many combinations -> skip
|
|
504
|
+
continue;
|
|
505
|
+
}
|
|
506
|
+
|
|
482
507
|
// convert to vector for combinational indexing
|
|
483
508
|
let ix_nodes: Vec<Node> = ix_nodes.iter().cloned().collect();
|
|
484
509
|
// for all combinations of nodes with a connected edge
|
|
@@ -516,7 +541,7 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
|
|
|
516
541
|
|
|
517
542
|
node_sizes.insert(k, ksize);
|
|
518
543
|
|
|
519
|
-
for l in self.
|
|
544
|
+
for l in self.neighbors_limit(k, max_neighbors) {
|
|
520
545
|
// assess all neighboring contractions of new node
|
|
521
546
|
let llegs = &self.nodes[&l];
|
|
522
547
|
let lsize = node_sizes[&l];
|
|
@@ -528,6 +553,22 @@ impl<Ix: IndexType, Node: NodeType> ContractionProcessor<Ix, Node> {
|
|
|
528
553
|
contractions.insert(c, (k, l, msize, mlegs));
|
|
529
554
|
c -= 1;
|
|
530
555
|
}
|
|
556
|
+
|
|
557
|
+
// // potential queue pruning?
|
|
558
|
+
// if queue.len() > 100_000 {
|
|
559
|
+
// let mut valid_contractions = Vec::new();
|
|
560
|
+
// for (score, cid) in queue.drain() {
|
|
561
|
+
// if let Some((i, j, _, _)) = contractions.get(&cid) {
|
|
562
|
+
// if self.nodes.contains_key(&i) && self.nodes.contains_key(&j) {
|
|
563
|
+
// valid_contractions.push((score, cid));
|
|
564
|
+
// } else {
|
|
565
|
+
// // Remove stale contraction from map
|
|
566
|
+
// contractions.remove(&cid);
|
|
567
|
+
// }
|
|
568
|
+
// }
|
|
569
|
+
// }
|
|
570
|
+
// queue = BinaryHeap::from(valid_contractions);
|
|
571
|
+
// }
|
|
531
572
|
}
|
|
532
573
|
// success
|
|
533
574
|
return true;
|
|
@@ -941,6 +982,7 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
|
|
|
941
982
|
size_dict: Dict<char, f32>,
|
|
942
983
|
costmod: Option<f32>,
|
|
943
984
|
temperature: Option<f32>,
|
|
985
|
+
max_neighbors: Option<usize>,
|
|
944
986
|
seed: Option<u64>,
|
|
945
987
|
simplify: bool,
|
|
946
988
|
) -> SSAPath {
|
|
@@ -949,7 +991,7 @@ fn run_greedy<Ix: IndexType, Node: NodeType>(
|
|
|
949
991
|
if simplify {
|
|
950
992
|
cp.simplify();
|
|
951
993
|
}
|
|
952
|
-
cp.optimize_greedy(costmod, temperature, seed);
|
|
994
|
+
cp.optimize_greedy(costmod, temperature, max_neighbors, seed);
|
|
953
995
|
cp.optimize_remaining_by_size();
|
|
954
996
|
cp.ssa_path
|
|
955
997
|
}
|
|
@@ -987,6 +1029,7 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
|
|
|
987
1029
|
log_temp_min: f32,
|
|
988
1030
|
log_temp_diff: f32,
|
|
989
1031
|
is_const_temp: bool,
|
|
1032
|
+
max_neighbors: Option<usize>,
|
|
990
1033
|
rng: &mut rand::rngs::StdRng,
|
|
991
1034
|
) -> (SSAPath, Score) {
|
|
992
1035
|
let mut cp0: ContractionProcessor<Ix, Node> =
|
|
@@ -1013,7 +1056,8 @@ fn run_random_greedy_optimization<Ix: IndexType, Node: NodeType>(
|
|
|
1013
1056
|
f32::exp(log_temp_min + rng.random::<f32>() * log_temp_diff)
|
|
1014
1057
|
};
|
|
1015
1058
|
|
|
1016
|
-
let success =
|
|
1059
|
+
let success =
|
|
1060
|
+
cp.optimize_greedy(Some(costmod), Some(temperature), max_neighbors, Some(*seed));
|
|
1017
1061
|
|
|
1018
1062
|
if !success {
|
|
1019
1063
|
continue;
|
|
@@ -1135,7 +1179,7 @@ fn optimize_simplify(
|
|
|
1135
1179
|
}
|
|
1136
1180
|
|
|
1137
1181
|
#[pyfunction]
|
|
1138
|
-
#[pyo3(signature = (inputs, output, size_dict, costmod=None, temperature=None, seed=None, simplify=None, use_ssa=None))]
|
|
1182
|
+
#[pyo3(signature = (inputs, output, size_dict, costmod=None, temperature=None, max_neighbors=None, seed=None, simplify=None, use_ssa=None))]
|
|
1139
1183
|
/// Find a contraction path using a (randomizable) greedy algorithm.
|
|
1140
1184
|
///
|
|
1141
1185
|
/// Parameters
|
|
@@ -1160,6 +1204,12 @@ fn optimize_simplify(
|
|
|
1160
1204
|
/// score -> sign(score) * log(|score|) - temperature * gumbel()
|
|
1161
1205
|
///
|
|
1162
1206
|
/// which implements boltzmann sampling.
|
|
1207
|
+
/// max_neighbors : int, optional
|
|
1208
|
+
/// If non-zero, skip any index that connects to more than this many
|
|
1209
|
+
/// nodes. This is useful to avoid combinatorial explosions when
|
|
1210
|
+
/// dealing with essentially batch indices. Default: 16.
|
|
1211
|
+
/// seed : int, optional
|
|
1212
|
+
/// The seed for the random number generator.
|
|
1163
1213
|
/// simplify : bool, optional
|
|
1164
1214
|
/// Whether to perform simplifications before optimizing. These are:
|
|
1165
1215
|
///
|
|
@@ -1190,6 +1240,7 @@ fn optimize_greedy(
|
|
|
1190
1240
|
size_dict: Dict<char, f32>,
|
|
1191
1241
|
costmod: Option<f32>,
|
|
1192
1242
|
temperature: Option<f32>,
|
|
1243
|
+
max_neighbors: Option<usize>,
|
|
1193
1244
|
seed: Option<u64>,
|
|
1194
1245
|
simplify: Option<bool>,
|
|
1195
1246
|
use_ssa: Option<bool>,
|
|
@@ -1208,6 +1259,7 @@ fn optimize_greedy(
|
|
|
1208
1259
|
size_dict,
|
|
1209
1260
|
costmod,
|
|
1210
1261
|
temperature,
|
|
1262
|
+
max_neighbors,
|
|
1211
1263
|
seed,
|
|
1212
1264
|
simplify,
|
|
1213
1265
|
)
|
|
@@ -1219,6 +1271,7 @@ fn optimize_greedy(
|
|
|
1219
1271
|
size_dict,
|
|
1220
1272
|
costmod,
|
|
1221
1273
|
temperature,
|
|
1274
|
+
max_neighbors,
|
|
1222
1275
|
seed,
|
|
1223
1276
|
simplify,
|
|
1224
1277
|
)
|
|
@@ -1229,6 +1282,7 @@ fn optimize_greedy(
|
|
|
1229
1282
|
size_dict,
|
|
1230
1283
|
costmod,
|
|
1231
1284
|
temperature,
|
|
1285
|
+
max_neighbors,
|
|
1232
1286
|
seed,
|
|
1233
1287
|
simplify,
|
|
1234
1288
|
),
|
|
@@ -1243,7 +1297,7 @@ fn optimize_greedy(
|
|
|
1243
1297
|
}
|
|
1244
1298
|
|
|
1245
1299
|
#[pyfunction]
|
|
1246
|
-
#[pyo3(signature = (inputs, output, size_dict, ntrials, costmod=None, temperature=None, seed=None, simplify=None, use_ssa=None))]
|
|
1300
|
+
#[pyo3(signature = (inputs, output, size_dict, ntrials, costmod=None, temperature=None, max_neighbors=None, seed=None, simplify=None, use_ssa=None))]
|
|
1247
1301
|
/// Perform a batch of random greedy optimizations, simulteneously tracking
|
|
1248
1302
|
/// the best contraction path in terms of flops, so as to avoid constructing a
|
|
1249
1303
|
/// separate contraction tree.
|
|
@@ -1273,6 +1327,10 @@ fn optimize_greedy(
|
|
|
1273
1327
|
///
|
|
1274
1328
|
/// which implements boltzmann sampling. It is sampled log-uniformly from
|
|
1275
1329
|
/// the given range.
|
|
1330
|
+
/// max_neighbors : int, optional
|
|
1331
|
+
/// If non-zero, skip any index that connects to more than this many
|
|
1332
|
+
/// nodes. This is useful to avoid combinatorial explosions when
|
|
1333
|
+
/// dealing with essentially batch indices. Default: 16.
|
|
1276
1334
|
/// seed : int, optional
|
|
1277
1335
|
/// The seed for the random number generator.
|
|
1278
1336
|
/// simplify : bool, optional
|
|
@@ -1309,6 +1367,7 @@ fn optimize_random_greedy_track_flops(
|
|
|
1309
1367
|
ntrials: usize,
|
|
1310
1368
|
costmod: Option<(f32, f32)>,
|
|
1311
1369
|
temperature: Option<(f32, f32)>,
|
|
1370
|
+
max_neighbors: Option<usize>,
|
|
1312
1371
|
seed: Option<u64>,
|
|
1313
1372
|
simplify: Option<bool>,
|
|
1314
1373
|
use_ssa: Option<bool>,
|
|
@@ -1350,6 +1409,7 @@ fn optimize_random_greedy_track_flops(
|
|
|
1350
1409
|
log_temp_min,
|
|
1351
1410
|
log_temp_diff,
|
|
1352
1411
|
is_const_temp,
|
|
1412
|
+
max_neighbors,
|
|
1353
1413
|
&mut rng,
|
|
1354
1414
|
)
|
|
1355
1415
|
}
|
|
@@ -1367,6 +1427,7 @@ fn optimize_random_greedy_track_flops(
|
|
|
1367
1427
|
log_temp_min,
|
|
1368
1428
|
log_temp_diff,
|
|
1369
1429
|
is_const_temp,
|
|
1430
|
+
max_neighbors,
|
|
1370
1431
|
&mut rng,
|
|
1371
1432
|
)
|
|
1372
1433
|
}
|
|
@@ -1383,6 +1444,7 @@ fn optimize_random_greedy_track_flops(
|
|
|
1383
1444
|
log_temp_min,
|
|
1384
1445
|
log_temp_diff,
|
|
1385
1446
|
is_const_temp,
|
|
1447
|
+
max_neighbors,
|
|
1386
1448
|
&mut rng,
|
|
1387
1449
|
),
|
|
1388
1450
|
};
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|