multipers 2.3.1__cp312-cp312-win_amd64.whl → 2.3.2__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (49) hide show
  1. multipers/_signed_measure_meta.py +71 -65
  2. multipers/array_api/__init__.py +39 -0
  3. multipers/array_api/numpy.py +34 -0
  4. multipers/array_api/torch.py +35 -0
  5. multipers/distances.py +6 -2
  6. multipers/filtrations/density.py +23 -12
  7. multipers/filtrations/filtrations.py +74 -15
  8. multipers/function_rips.cp312-win_amd64.pyd +0 -0
  9. multipers/grids.cp312-win_amd64.pyd +0 -0
  10. multipers/grids.pyx +144 -61
  11. multipers/gudhi/Simplex_tree_multi_interface.h +35 -0
  12. multipers/gudhi/gudhi/Multi_persistence/Box.h +3 -0
  13. multipers/gudhi/gudhi/One_critical_filtration.h +17 -9
  14. multipers/gudhi/mma_interface_matrix.h +5 -3
  15. multipers/gudhi/truc.h +488 -42
  16. multipers/io.cp312-win_amd64.pyd +0 -0
  17. multipers/io.pyx +16 -86
  18. multipers/ml/mma.py +4 -4
  19. multipers/ml/signed_measures.py +60 -62
  20. multipers/mma_structures.cp312-win_amd64.pyd +0 -0
  21. multipers/mma_structures.pxd +2 -1
  22. multipers/mma_structures.pyx +56 -12
  23. multipers/mma_structures.pyx.tp +14 -3
  24. multipers/multiparameter_module_approximation/approximation.h +45 -13
  25. multipers/multiparameter_module_approximation.cp312-win_amd64.pyd +0 -0
  26. multipers/multiparameter_module_approximation.pyx +24 -7
  27. multipers/plots.py +1 -0
  28. multipers/point_measure.cp312-win_amd64.pyd +0 -0
  29. multipers/point_measure.pyx +6 -2
  30. multipers/simplex_tree_multi.cp312-win_amd64.pyd +0 -0
  31. multipers/simplex_tree_multi.pxd +1 -0
  32. multipers/simplex_tree_multi.pyx +535 -113
  33. multipers/simplex_tree_multi.pyx.tp +79 -19
  34. multipers/slicer.cp312-win_amd64.pyd +0 -0
  35. multipers/slicer.pxd +699 -217
  36. multipers/slicer.pxd.tp +22 -6
  37. multipers/slicer.pyx +5315 -1365
  38. multipers/slicer.pyx.tp +202 -46
  39. multipers/tbb12.dll +0 -0
  40. multipers/tbbbind_2_5.dll +0 -0
  41. multipers/tbbmalloc.dll +0 -0
  42. multipers/tbbmalloc_proxy.dll +0 -0
  43. multipers/tests/__init__.py +9 -4
  44. multipers/torch/diff_grids.py +30 -7
  45. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info}/METADATA +4 -25
  46. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info}/RECORD +49 -46
  47. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info}/WHEEL +1 -1
  48. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info/licenses}/LICENSE +0 -0
  49. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info}/top_level.txt +0 -0
multipers/gudhi/truc.h CHANGED
@@ -1,22 +1,27 @@
1
1
  #pragma once
2
+ #include "gudhi/Matrix.h"
3
+ #include "gudhi/mma_interface_matrix.h"
2
4
  #include "gudhi/Multi_persistence/Line.h"
3
5
  #include "multiparameter_module_approximation/format_python-cpp.h"
4
6
  #include <gudhi/One_critical_filtration.h>
5
7
  #include <gudhi/Multi_critical_filtration.h>
6
8
  #include <algorithm>
7
- #include <boost/mpl/aux_/na_fwd.hpp>
8
9
  #include <cassert>
9
10
  #include <csignal>
10
11
  #include <cstddef>
11
12
  #include <cstdint>
12
- // #include <gudhi/Simplex_tree/multi_filtrations/Finitely_critical_filtrations.h>
13
13
  #include <iostream>
14
14
  #include <limits>
15
15
  #include <numeric>
16
16
  #include <oneapi/tbb/enumerable_thread_specific.h>
17
17
  #include <oneapi/tbb/parallel_for.h>
18
+ #include <oneapi/tbb/parallel_sort.h>
18
19
  #include <oneapi/tbb/task_group.h>
20
+ #include <oneapi/tbb/mutex.h>
21
+ #include <ostream>
22
+ #include <ranges>
19
23
  #include <sstream>
24
+ #include <stdexcept>
20
25
  #include <string>
21
26
  #include <utility>
22
27
  #include <vector>
@@ -28,6 +33,12 @@ namespace multiparameter {
28
33
  namespace truc_interface {
29
34
  using index_type = std::uint32_t;
30
35
 
36
+ template <typename T, typename = void>
37
+ struct has_columns : std::false_type {};
38
+
39
+ template <typename T>
40
+ struct has_columns<T, std::void_t<typename T::options>> : std::true_type {};
41
+
31
42
  class PresentationStructure {
32
43
  public:
33
44
  PresentationStructure() {}
@@ -72,9 +83,10 @@ class PresentationStructure {
72
83
 
73
84
  inline friend std::ostream &operator<<(std::ostream &stream, const PresentationStructure &structure) {
74
85
  stream << "Boundary:\n";
75
- stream << "{";
76
- for (const auto &stuff : structure.generators) {
77
- stream << "{";
86
+ stream << "{\n";
87
+ for (auto i : std::views::iota(0u, structure.size())) {
88
+ const auto &stuff = structure.generators[i];
89
+ stream << i << ": {";
78
90
  for (auto truc : stuff) stream << truc << ", ";
79
91
 
80
92
  if (!stuff.empty()) stream << "\b" << "\b ";
@@ -118,18 +130,31 @@ class PresentationStructure {
118
130
  }
119
131
 
120
132
  PresentationStructure permute(const std::vector<index_type> &order) const {
121
- std::vector<std::vector<index_type>> new_generators(generators.size());
122
- std::vector<int> new_generator_dimensions(generator_dimensions.size());
133
+ if (order.size() > generators.size()) {
134
+ throw std::invalid_argument("Permutation order must have the same size as the number of generators.");
135
+ }
136
+ index_type flag = -1;
137
+ std::vector<index_type> inverse_order(generators.size(), flag);
123
138
  for (std::size_t i = 0; i < order.size(); i++) {
124
- new_generators[i] = std::vector<index_type>(generators[order[i]].size());
139
+ inverse_order[order[i]] = i;
140
+ }
141
+ std::vector<std::vector<index_type>> new_generators(order.size());
142
+ std::vector<int> new_generator_dimensions(order.size());
143
+
144
+ for (auto i : std::views::iota(0u, order.size())) {
145
+ new_generators[i].reserve(generators[order[i]].size());
125
146
  for (std::size_t j = 0; j < generators[order[i]].size(); j++) {
126
- new_generators[i][j] = order[generators[order[i]][j]];
147
+ index_type stuff = inverse_order[generators[order[i]][j]];
148
+ if (stuff != flag) new_generators[i].push_back(stuff);
127
149
  }
150
+ std::sort(new_generators[i].begin(), new_generators[i].end());
128
151
  new_generator_dimensions[i] = generator_dimensions[order[i]];
129
152
  }
130
153
  return PresentationStructure(new_generators, new_generator_dimensions);
131
154
  }
132
155
 
156
+ void update_matrix(std::vector<std::vector<index_type>> &new_gens) { std::swap(generators, new_gens); }
157
+
133
158
  private:
134
159
  std::vector<std::vector<index_type>> generators;
135
160
  std::vector<int> generator_dimensions;
@@ -210,9 +235,11 @@ template <class PersBackend, class Structure, class MultiFiltration>
210
235
  class Truc {
211
236
  public:
212
237
  using Filtration_value = MultiFiltration;
238
+ using MultiFiltrations = std::vector<MultiFiltration>;
213
239
  using value_type = typename MultiFiltration::value_type;
214
240
  using split_barcode =
215
241
  std::vector<std::vector<std::pair<typename MultiFiltration::value_type, typename MultiFiltration::value_type>>>;
242
+ using split_barcode_idx = std::vector<std::vector<std::pair<int, int>>>;
216
243
  template <typename value_type = value_type>
217
244
  using flat_barcode = std::vector<std::pair<int, std::pair<value_type, value_type>>>;
218
245
 
@@ -279,7 +306,7 @@ class Truc {
279
306
  return structure.dimension(i) < structure.dimension(j);
280
307
  };
281
308
 
282
- inline bool colexical_order(const index_type &i, const index_type &j) const {
309
+ inline bool colexical_order(const index_type &i, const index_type &j) const {
283
310
  if (structure.dimension(i) > structure.dimension(j)) return false;
284
311
  if (structure.dimension(i) < structure.dimension(j)) return true;
285
312
  if constexpr (MultiFiltration::is_multicritical()) // TODO : this may not be the best
@@ -294,25 +321,347 @@ class Truc {
294
321
  return false;
295
322
  };
296
323
 
324
+ // TODO : inside of MultiFiltration
325
+ inline static bool lexical_order(const MultiFiltration &a, const MultiFiltration &b) {
326
+ if constexpr (MultiFiltration::is_multicritical()) // TODO : this may not be the best
327
+ throw "Not implemented in the multicritical case";
328
+ if (a.is_plus_inf() || a.is_nan() || b.is_minus_inf()) return false;
329
+ if (b.is_plus_inf() || b.is_nan() || a.is_minus_inf()) return true;
330
+ for (auto idx = 0u; idx < a.num_parameters(); ++idx) {
331
+ if (a[idx] < b[idx])
332
+ return true;
333
+ else if (a[idx] > b[idx])
334
+ return false;
335
+ }
336
+ return false;
337
+ };
297
338
 
298
- template <class Fun>
299
- inline Truc rearange_sort(const Fun&& fun) const {
300
- std::vector<index_type> permutation(generator_order.size());
301
- std::iota(permutation.begin(), permutation.end(), 0);
302
- std::sort(permutation.begin(), permutation.end(), [&](std::size_t i, std::size_t j) {
303
- return fun(i, j);
304
- });
305
- std::vector<MultiFiltration> new_filtration(generator_filtration_values.size());
306
- for (std::size_t i = 0; i < generator_filtration_values.size(); i++) {
339
+ inline bool lexical_order(const index_type &i, const index_type &j) const {
340
+ if (structure.dimension(i) > structure.dimension(j)) return false;
341
+ if (structure.dimension(i) < structure.dimension(j)) return true;
342
+ if constexpr (MultiFiltration::is_multicritical()) // TODO : this may not be the best
343
+ throw "Not implemented in the multicritical case";
344
+
345
+ for (int idx = 0; idx < generator_filtration_values[i].num_parameters(); ++idx) {
346
+ if (generator_filtration_values[i][idx] < generator_filtration_values[j][idx])
347
+ return true;
348
+ else if (generator_filtration_values[i][idx] > generator_filtration_values[j][idx])
349
+ return false;
350
+ }
351
+ return false;
352
+ };
353
+
354
+ inline Truc permute(const std::vector<index_type> &permutation) const {
355
+ auto num_new_gen = permutation.size();
356
+ if (permutation.size() > this->num_generators()) {
357
+ throw std::invalid_argument("Invalid permutation size. Got " + std::to_string(num_new_gen) + " expected " +
358
+ std::to_string(this->num_generators()) + ".");
359
+ }
360
+ std::vector<MultiFiltration> new_filtration(num_new_gen);
361
+ for (auto i : std::views::iota(0u, num_new_gen)) { // assumes permutation has correct indices.
307
362
  new_filtration[i] = generator_filtration_values[permutation[i]];
308
363
  }
309
364
  return Truc(structure.permute(permutation), new_filtration);
310
365
  }
311
366
 
312
- Truc colexical_rearange() const {
367
+ template <class Fun>
368
+ inline std::pair<Truc, std::vector<index_type>> rearange_sort(const Fun &&fun) const {
369
+ std::vector<index_type> permutation(generator_order.size());
370
+ std::iota(permutation.begin(), permutation.end(), 0);
371
+ tbb::parallel_sort(permutation.begin(), permutation.end(), [&](std::size_t i, std::size_t j) { return fun(i, j); });
372
+ return {permute(permutation), permutation};
373
+ }
374
+
375
+ std::pair<Truc, std::vector<index_type>> colexical_rearange() const {
313
376
  return rearange_sort([this](std::size_t i, std::size_t j) { return this->colexical_order(i, j); });
314
377
  }
315
378
 
379
+ template <bool generator_only = false>
380
+ std::conditional_t<generator_only, std::pair<std::vector<std::vector<index_type>>, MultiFiltrations>, Truc>
381
+ projective_cover_kernel(int dim) {
382
+ if constexpr (MultiFiltration::is_multicritical() || !std::is_same_v<Structure, PresentationStructure> ||
383
+ !has_columns<PersBackend>::value) // TODO : this may not be the best
384
+ {
385
+ throw std::invalid_argument("Not implemented for this Truc");
386
+ } else {
387
+ // TODO : this only works for 2 parameter modules. Optimize w.r.t. this.
388
+ const bool verbose = false;
389
+ // filtration values are assumed to be dim + colexicographically sorted
390
+ // vector seem to be good here
391
+ using SmallMatrix = Gudhi::persistence_matrix::Matrix<
392
+ Gudhi::multiparameter::truc_interface::fix_presentation_options<PersBackend::options::column_type, false>>;
393
+
394
+ int nd = 0;
395
+ int ndpp = 0;
396
+ for (auto i : std::views::iota(0u, structure.size())) {
397
+ if (structure.dimension(i) == dim) {
398
+ nd++;
399
+ } else if (structure.dimension(i) == dim + 1) {
400
+ ndpp++;
401
+ } else {
402
+ throw std::invalid_argument("This truc contains bad dimensions. Got " +
403
+ std::to_string(structure.dimension(i)) + " expected " + std::to_string(dim) +
404
+ " or " + std::to_string(dim + 1) + " in position " + std::to_string(i) + " .");
405
+ }
406
+ }
407
+ if (ndpp == 0)
408
+ throw std::invalid_argument("Given dimension+1 has no simplices. Got " + std::to_string(nd) + " " +
409
+ std::to_string(ndpp) + ".");
410
+ // lexico iterator
411
+ auto lex_cmp = [&](const MultiFiltration &a, const MultiFiltration &b) { return lexical_order(a, b); };
412
+
413
+ struct SmallQueue {
414
+ SmallQueue() {};
415
+
416
+ struct MFWrapper {
417
+ MFWrapper(const MultiFiltration &g) : g(g) {};
418
+
419
+ MFWrapper(const MultiFiltration &g, int col) : g(g) { some_cols.insert(col); }
420
+
421
+ MFWrapper(const MultiFiltration &g, std::initializer_list<int> cols)
422
+ : g(g), some_cols(cols.begin(), cols.end()) {}
423
+
424
+ inline void insert(int col) const { some_cols.insert(col); }
425
+
426
+ inline bool operator<(const MFWrapper &other) const { return lexical_order(g, other.g); }
427
+
428
+ public:
429
+ MultiFiltration g;
430
+ mutable std::set<int> some_cols;
431
+ };
432
+
433
+ inline void insert(const MultiFiltration &g, int col) {
434
+ auto it = queue.find(g);
435
+ if (it != queue.end()) {
436
+ it->insert(col);
437
+ } else {
438
+ queue.emplace(g, col);
439
+ }
440
+ };
441
+
442
+ inline void insert(const MultiFiltration &g, const std::initializer_list<int> &cols) {
443
+ auto it = queue.find(g);
444
+ if (it != queue.end()) {
445
+ for (int c : cols) it->insert(c);
446
+ } else {
447
+ queue.emplace(g, cols);
448
+ }
449
+ };
450
+
451
+ inline bool empty() const { return queue.empty(); }
452
+
453
+ inline MultiFiltration pop() {
454
+ if (queue.empty()) [[unlikely]]
455
+ throw std::runtime_error("Queue is empty");
456
+
457
+ auto out = std::move(*queue.begin());
458
+ queue.erase(queue.begin());
459
+ std::swap(last_cols, out.some_cols);
460
+ return out.g;
461
+ }
462
+
463
+ const auto &get_current_cols() const { return last_cols; }
464
+
465
+ private:
466
+ std::set<MFWrapper> queue;
467
+ std::set<int> last_cols;
468
+ };
469
+
470
+ SmallQueue lexico_it;
471
+ SmallMatrix M(nd + ndpp);
472
+ for (int i = 0; i < nd + ndpp; i++) {
473
+ const auto &b = structure[i];
474
+ M.insert_boundary(b);
475
+ }
476
+ SmallMatrix N(nd + ndpp); // slave
477
+ for (auto i : std::views::iota(0u, static_cast<unsigned int>(nd + ndpp))) {
478
+ N.insert_boundary({i});
479
+ };
480
+
481
+ auto get_fil = [&](int i) -> MultiFiltration & { return generator_filtration_values[i]; };
482
+ auto get_pivot = [&](int j) -> int {
483
+ const auto &col = M.get_column(j);
484
+ return col.size() > 0 ? (*col.rbegin()).get_row_index() : -1;
485
+ };
486
+
487
+ if constexpr (verbose) {
488
+ std::cout << "Initial matrix (" << nd << " + " << ndpp << "):" << std::endl;
489
+ for (int i = 0; i < nd + ndpp; i++) {
490
+ std::cout << "Column " << i << " : {";
491
+ for (const auto &j : M.get_column(i)) std::cout << j << " ";
492
+ std::cout << "} | " << get_fil(i) << std::endl;
493
+ }
494
+ }
495
+
496
+ // TODO : pivot caches are small : maybe use a flat container instead ?
497
+ std::vector<std::set<int>> pivot_cache(nd + ndpp); // this[pivot] = cols of given pivot (<=nd)
498
+ std::vector<bool> reduced_columns(nd + ndpp); // small cache
499
+ MultiFiltration grid_value;
500
+
501
+ std::vector<std::vector<index_type>> out_structure;
502
+ out_structure.reserve(2 * ndpp);
503
+ std::vector<MultiFiltration> out_filtration;
504
+ out_filtration.reserve(2 * ndpp);
505
+ std::vector<int> out_dimension;
506
+ out_dimension.reserve(2 * ndpp);
507
+ if constexpr (!generator_only) {
508
+ for (auto i : std::views::iota(nd, nd + ndpp)) {
509
+ out_structure.push_back({});
510
+ out_filtration.push_back(this->get_filtration_values()[i]);
511
+ out_dimension.push_back(this->structure.dimension(i));
512
+ }
513
+ }
514
+ // pivot cache
515
+ if constexpr (verbose) {
516
+ std::cout << "Initial pivot cache:\n";
517
+ }
518
+ for (int j : std::views::iota(nd, nd + ndpp)) {
519
+ int col_pivot = get_pivot(j);
520
+ if (col_pivot < 0) {
521
+ reduced_columns[j] = true;
522
+ continue;
523
+ };
524
+ auto &current_pivot_cache = pivot_cache[col_pivot];
525
+ current_pivot_cache.emplace_hint(current_pivot_cache.cend(), j); // j is increasing
526
+ }
527
+ if constexpr (verbose) {
528
+ int i = 0;
529
+ for (const auto &cols : pivot_cache) {
530
+ std::cout << " - " << i++ << " : ";
531
+ for (const auto &col : cols) {
532
+ std::cout << col << " ";
533
+ }
534
+ std::cout << std::endl;
535
+ }
536
+ }
537
+
538
+ // if constexpr (!use_grid) {
539
+ if constexpr (verbose) std::cout << "Initial grid queue:\n";
540
+ for (int j : std::views::iota(nd, nd + ndpp)) {
541
+ int col_pivot = get_pivot(j);
542
+ if (col_pivot < 0) continue;
543
+ lexico_it.insert(get_fil(j), j);
544
+ auto it = pivot_cache[col_pivot].find(j);
545
+ if (it == pivot_cache[col_pivot].end()) [[unlikely]]
546
+ throw std::runtime_error("Column " + std::to_string(j) + " not in pivot cache");
547
+ it++;
548
+ // for (int k : pivot_cache[col_pivot]) {
549
+ for (auto _k = it; _k != pivot_cache[col_pivot].end(); ++_k) {
550
+ int k = *_k;
551
+ if (k <= j) [[unlikely]]
552
+ throw std::runtime_error("Column " + std::to_string(k) + " is not a future column");
553
+ auto prev = get_fil(k);
554
+ prev.push_to_least_common_upper_bound(get_fil(j));
555
+ if constexpr (verbose) std::cout << " - (" << j << ", " << k << ") are interacting at " << prev << "\n";
556
+ lexico_it.insert(std::move(prev), k);
557
+ }
558
+ }
559
+ // TODO : check poset cache ?
560
+ if constexpr (verbose) std::cout << std::flush;
561
+ // }
562
+ auto reduce_column = [&](int j) -> bool {
563
+ int pivot = get_pivot(j);
564
+ if constexpr (verbose) std::cout << "Reducing column " << j << " with pivot " << pivot << "\n";
565
+ if (pivot < 0) {
566
+ if (!reduced_columns[j]) {
567
+ std::vector<index_type> _b(N.get_column(j).begin(), N.get_column(j).end());
568
+ for (auto &stuff : _b) stuff -= nd;
569
+ out_structure.push_back(std::move(_b));
570
+ out_filtration.emplace_back(grid_value.begin(), grid_value.end());
571
+ if constexpr (!generator_only) out_dimension.emplace_back(this->structure.dimension(j) + 1);
572
+ reduced_columns[j] = true;
573
+ }
574
+ return false;
575
+ }
576
+ if constexpr (verbose) std::cout << "Previous registered pivot : " << *pivot_cache[pivot].begin() << std::endl;
577
+ // WARN : we lazy update variables linked with col j...
578
+ if (pivot_cache[pivot].size() == 0) {
579
+ return false;
580
+ }
581
+ for (int k : pivot_cache[pivot]) {
582
+ if (k >= j) { // cannot reduce more here. this is a (local) pivot.
583
+ return false;
584
+ }
585
+ if (get_fil(k) <= grid_value) {
586
+ M.add_to(k, j);
587
+ N.add_to(k, j);
588
+ // std::cout << "Adding " << k << " to " << j << " at grid time " << grid_value << std::endl;
589
+ pivot_cache[pivot].erase(j);
590
+ // WARN : we update the pivot cache after the update loop
591
+ if (get_pivot(j) >= pivot) {
592
+ throw std::runtime_error("Addition failed ? current " + std::to_string(get_pivot(j)) + " previous " +
593
+ std::to_string(pivot));
594
+ }
595
+ return true; // pivot has changed
596
+ }
597
+ }
598
+ return false; // for loop exhausted (j may not be there because of lazy)
599
+ };
600
+ auto chores_after_new_pivot = [&](int j) {
601
+ int col_pivot = get_pivot(j);
602
+ if (col_pivot < 0) {
603
+ if (!reduced_columns[j]) throw std::runtime_error("Empty column should have been detected before");
604
+ return;
605
+ };
606
+ auto [it, was_there] = pivot_cache[col_pivot].insert(j);
607
+ it++;
608
+ // if constexpr (!use_grid) {
609
+ for (auto _k = it; _k != pivot_cache[col_pivot].end(); ++_k) {
610
+ int k = *_k;
611
+ if (k <= j) [[unlikely]]
612
+ throw std::runtime_error("(chores) col " + std::to_string(k) + " is not a future column");
613
+ if (get_fil(k) >= get_fil(j)) continue;
614
+ auto prev = get_fil(k);
615
+ prev.push_to_least_common_upper_bound(get_fil(j));
616
+ if (lex_cmp(grid_value, prev)) {
617
+ if constexpr (verbose)
618
+ std::cout << "(chores) Updating grid queue, (" << j << ", " << k << ") are interacting at " << prev
619
+ << std::endl;
620
+ lexico_it.insert(prev, k);
621
+ }
622
+ }
623
+ // }
624
+ };
625
+ if constexpr (verbose) {
626
+ std::cout << "Initially reduced columns: [";
627
+ for (int i = 0; i < nd + ndpp; i++) {
628
+ std::cout << reduced_columns[i] << ", ";
629
+ }
630
+ std::cout << "]" << std::endl;
631
+ }
632
+ while (!lexico_it.empty()) {
633
+ // if constexpr (use_grid) {
634
+ // grid_value = lexico_it.next();
635
+ // } else {
636
+ grid_value = std::move(lexico_it.pop());
637
+ // }
638
+ if constexpr (verbose) {
639
+ std::cout << "Grid value: " << grid_value << std::endl;
640
+ std::cout << " Reduced cols: ";
641
+ for (int i = 0; i < nd + ndpp; i++) {
642
+ std::cout << reduced_columns[i] << ", ";
643
+ }
644
+ std::cout << "]" << std::endl;
645
+ }
646
+
647
+ for (int i : lexico_it.get_current_cols()) {
648
+ if constexpr (false) {
649
+ if ((reduced_columns[i] || !(get_fil(i) <= grid_value))) continue;
650
+ if ((get_fil(i) > grid_value)) break;
651
+ }
652
+ while (reduce_column(i));
653
+ chores_after_new_pivot(i);
654
+ }
655
+ }
656
+ // std::cout<< grid_.str() << std::endl;
657
+ if constexpr (generator_only)
658
+ return {out_structure, out_dimension};
659
+ else {
660
+ return Truc(out_structure, out_dimension, out_filtration);
661
+ }
662
+ }
663
+ }
664
+
316
665
  template <bool ignore_inf>
317
666
  std::vector<std::pair<int, std::vector<index_type>>> get_current_boundary_matrix() {
318
667
  std::vector<index_type> permutation(generator_order.size());
@@ -324,9 +673,9 @@ class Truc {
324
673
  return filtration_container[val] == MultiFiltration::Generator::T_inf;
325
674
  }),
326
675
  permutation.end());
327
- std::sort(permutation.begin(), permutation.end());
676
+ tbb::parallel_sort(permutation.begin(), permutation.end());
328
677
  }
329
- std::sort(permutation.begin(), permutation.end(), [&](std::size_t i, std::size_t j) {
678
+ tbb::parallel_sort(permutation.begin(), permutation.end(), [&](std::size_t i, std::size_t j) {
330
679
  if (structure.dimension(i) > structure.dimension(j)) return false;
331
680
  if (structure.dimension(i) < structure.dimension(j)) return true;
332
681
  return filtration_container[i] < filtration_container[j];
@@ -383,7 +732,9 @@ class Truc {
383
732
 
384
733
  template <class array1d>
385
734
  inline void set_one_filtration(const array1d &truc) {
386
- assert(truc.size() == this->num_generators());
735
+ if (truc.size() != this->num_generators())
736
+ throw std::invalid_argument("(setting one filtration) Bad size. Got " + std::to_string(truc.size()) +
737
+ " expected " + std::to_string(this->num_generators()));
387
738
  this->filtration_container = truc;
388
739
  }
389
740
 
@@ -397,7 +748,7 @@ class Truc {
397
748
  const bool ignore_inf) const { // needed ftm as PersBackend only points there
398
749
  constexpr const bool verbose = false;
399
750
  if (one_filtration.size() != this->num_generators()) {
400
- throw;
751
+ throw std::runtime_error("The one parameter filtration doesn't have a proper size.");
401
752
  }
402
753
  out_gen_order.resize(this->num_generators());
403
754
  std::iota(out_gen_order.begin(),
@@ -482,6 +833,49 @@ class Truc {
482
833
  vineyard_update(this->persistence, this->filtration_container, this->generator_order);
483
834
  }
484
835
 
836
+ inline split_barcode_idx get_barcode_idx(
837
+ PersBackend &persistence) const {
838
+ auto barcode_indices = persistence.get_barcode();
839
+ split_barcode_idx out(this->structure.max_dimension() + 1); // TODO : This doesn't allow for negative dimensions
840
+ for (const auto &bar : barcode_indices) {
841
+ int death = bar.death == static_cast<typename PersBackend::pos_index>(-1) ? -1 : bar.death;
842
+ out[bar.dim].push_back({bar.birth, death});
843
+ }
844
+ return out;
845
+ }
846
+
847
+ // puts the degree-ordered bc starting out_ptr, and returns the "next" pointer.
848
+ // corresond to an array of shape (num_bar, 2);
849
+ template <bool return_shape = false>
850
+ inline std::conditional_t<return_shape, std::pair<std::vector<int>, int*>, int*> get_barcode_idx(
851
+ PersBackend &persistence,
852
+ int *start_ptr) const {
853
+ const auto &bc = persistence.barcode();
854
+ if (bc.size() == 0) return start_ptr;
855
+ std::vector<int> shape(this->structure.max_dimension());
856
+ for (const auto &b : bc) shape[b.dim]++;
857
+ // dim in barcode may be unsorted...
858
+ std::vector<int *> ptr_shifts(shape.size());
859
+ int shape_cumsum = 0;
860
+ for (auto i : std::views::iota(0u, bc.size())) {
861
+ if (i != 0u) shape_cumsum += shape[i - 1];
862
+ // 2 for (birth, death)
863
+ ptr_shifts[i] = 2 * shape_cumsum + start_ptr;
864
+ }
865
+ for (const auto &b : bc) {
866
+ int *current_loc = ptr_shifts[b.dim];
867
+ *(current_loc++) = b.birth;
868
+ *(current_loc++) = b.death == static_cast<typename PersBackend::pos_index>(-1) ? -1 : b.death;
869
+ }
870
+
871
+ if constexpr (return_shape)
872
+ return {shape, ptr_shifts.back()};
873
+ else
874
+ return ptr_shifts.back();
875
+ }
876
+
877
+
878
+
485
879
  inline split_barcode get_barcode(
486
880
  PersBackend &persistence,
487
881
  const std::vector<typename MultiFiltration::value_type> &filtration_container) const {
@@ -517,7 +911,7 @@ class Truc {
517
911
  << " --" << bar.death << "(" << death_filtration << ")"
518
912
  << " dim " << bar.dim << std::endl;
519
913
  }
520
- if (birth_filtration < death_filtration)
914
+ if (birth_filtration <= death_filtration)
521
915
  out[bar.dim].push_back({birth_filtration, death_filtration});
522
916
  else {
523
917
  out[bar.dim].push_back({inf, inf});
@@ -528,6 +922,8 @@ class Truc {
528
922
 
529
923
  inline split_barcode get_barcode() { return get_barcode(this->persistence, this->filtration_container); }
530
924
 
925
+ inline split_barcode_idx get_barcode_idx() { return get_barcode_idx(this->persistence); }
926
+
531
927
  template <typename value_type = value_type>
532
928
  static inline flat_nodim_barcode<value_type> get_flat_nodim_barcode(
533
929
  PersBackend &persistence,
@@ -555,7 +951,7 @@ class Truc {
555
951
  << std::endl;
556
952
  }
557
953
 
558
- if (birth_filtration < death_filtration)
954
+ if (birth_filtration <= death_filtration)
559
955
  out[idx] = {birth_filtration, death_filtration};
560
956
  else {
561
957
  out[idx] = {inf, inf};
@@ -592,7 +988,7 @@ class Truc {
592
988
  << std::endl;
593
989
  }
594
990
 
595
- if (birth_filtration < death_filtration)
991
+ if (birth_filtration <= death_filtration)
596
992
  out[idx] = {bar.dim, {birth_filtration, death_filtration}};
597
993
  else {
598
994
  out[idx] = {bar.dim, {inf, inf}};
@@ -698,6 +1094,8 @@ class Truc {
698
1094
  return out;
699
1095
  }
700
1096
 
1097
+ inline int get_dimension(int i) const { return structure.dimension(i); }
1098
+
701
1099
  inline void prune_above_dimension(int max_dim) {
702
1100
  int idx = structure.prune_above_dimension(max_dim);
703
1101
  generator_filtration_values.resize(idx);
@@ -714,12 +1112,15 @@ class Truc {
714
1112
  return out;
715
1113
  }
716
1114
 
717
- auto coarsen_on_grid(const std::vector<std::vector<typename MultiFiltration::value_type>> grid) {
1115
+ auto coarsen_on_grid(const std::vector<std::vector<typename MultiFiltration::value_type>>& grid) {
718
1116
  using return_type = decltype(std::declval<MultiFiltration>().template as_type<std::int32_t>());
719
1117
  std::vector<return_type> coords(this->num_generators());
720
- for (std::size_t gen = 0u; gen < coords.size(); ++gen) {
1118
+ // for (std::size_t gen = 0u; gen < coords.size(); ++gen) { // TODO : parallelize
1119
+ // coords[gen] = compute_coordinates_in_grid<int32_t>(generator_filtration_values[gen], grid);
1120
+ // }
1121
+ tbb::parallel_for(static_cast<std::size_t>(0u), coords.size(), [&](std::size_t gen){
721
1122
  coords[gen] = compute_coordinates_in_grid<int32_t>(generator_filtration_values[gen], grid);
722
- }
1123
+ });
723
1124
  return Truc<PersBackend, Structure, return_type>(structure, coords);
724
1125
  }
725
1126
 
@@ -832,6 +1233,10 @@ class Truc {
832
1233
 
833
1234
  inline split_barcode get_barcode() { return truc_ptr->get_barcode(this->persistence, this->filtration_container); }
834
1235
 
1236
+ inline split_barcode_idx get_barcode_idx() {
1237
+ return truc_ptr->get_barcode_idx(this->persistence);
1238
+ }
1239
+
835
1240
  inline std::size_t num_generators() const { return this->truc_ptr->structure.size(); }
836
1241
 
837
1242
  inline std::size_t num_parameters() const {
@@ -846,6 +1251,14 @@ class Truc {
846
1251
  return this->filtration_container;
847
1252
  }
848
1253
 
1254
+ template <class array1d>
1255
+ inline void set_one_filtration(const array1d &truc) {
1256
+ if (truc.size() != this->num_generators())
1257
+ throw std::invalid_argument("(setting one filtration) Bad size. Got " + std::to_string(truc.size()) +
1258
+ " expected " + std::to_string(this->num_generators()));
1259
+ this->filtration_container = truc;
1260
+ }
1261
+
849
1262
  private:
850
1263
  const Truc *truc_ptr;
851
1264
  std::vector<index_type> generator_order; // size fixed at construction time,
@@ -858,21 +1271,34 @@ class Truc {
858
1271
  * returns barcodes of the f(multipers)
859
1272
  *
860
1273
  */
861
- template <typename Fun, typename Fun_arg>
862
- inline std::vector<split_barcode> barcodes(Fun &&f, const std::vector<Fun_arg> &args, const bool ignore_inf = true) {
1274
+ template <typename Fun, typename Fun_arg, bool idx = false, bool custom = false>
1275
+ inline std::conditional_t<idx, std::vector<split_barcode_idx>, std::vector<split_barcode>>
1276
+ barcodes(Fun &&f, const std::vector<Fun_arg> &args, const bool ignore_inf = true) {
863
1277
  if (args.size() == 0) {
864
1278
  return {};
865
1279
  }
866
- std::vector<split_barcode> out(args.size());
1280
+ std::conditional_t<idx, std::vector<split_barcode_idx>, std::vector<split_barcode>> out(args.size());
867
1281
 
868
1282
  if constexpr (PersBackend::is_vine) {
869
- this->push_to(f(args[0]));
1283
+ if constexpr (custom)
1284
+ this->set_one_filtration(f(args[0]));
1285
+ else
1286
+ this->push_to(f(args[0]));
870
1287
  this->compute_persistence();
871
- out[0] = this->get_barcode();
1288
+ if constexpr (idx)
1289
+ out[0] = this->get_barcode_idx();
1290
+ else
1291
+ out[0] = this->get_barcode();
872
1292
  for (auto i = 1u; i < args.size(); ++i) {
873
- this->push_to(f(args[i]));
1293
+ if constexpr (custom)
1294
+ this->set_one_filtration(f(args[i]));
1295
+ else
1296
+ this->push_to(f(args[i]));
874
1297
  this->vineyard_update();
875
- out[i] = this->get_barcode();
1298
+ if constexpr (idx)
1299
+ out[i] = this->get_barcode_idx();
1300
+ else
1301
+ out[i] = this->get_barcode();
876
1302
  }
877
1303
 
878
1304
  } else {
@@ -880,9 +1306,15 @@ class Truc {
880
1306
  tbb::enumerable_thread_specific<ThreadSafe> thread_locals(local_template);
881
1307
  tbb::parallel_for(static_cast<std::size_t>(0), args.size(), [&](const std::size_t &i) {
882
1308
  ThreadSafe &s = thread_locals.local();
883
- s.push_to(f(args[i]));
1309
+ if constexpr (custom)
1310
+ s.set_one_filtration(f(args[i]));
1311
+ else
1312
+ s.push_to(f(args[i]));
884
1313
  s.compute_persistence(ignore_inf);
885
- out[i] = s.get_barcode();
1314
+ if constexpr (idx) {
1315
+ out[i] = s.get_barcode_idx();
1316
+ } else
1317
+ out[i] = s.get_barcode();
886
1318
  });
887
1319
  }
888
1320
  return out;
@@ -898,6 +1330,20 @@ class Truc {
898
1330
  ignore_inf);
899
1331
  }
900
1332
 
1333
+ inline std::vector<split_barcode_idx> custom_persistences(const value_type *filtrations, int size, bool ignore_inf) {
1334
+ std::vector<const value_type *> args(size);
1335
+ for (auto i = 0; i < size; ++i) args[i] = filtrations + this->num_generators() * i;
1336
+
1337
+ auto fun = [&](const value_type *one_filtration_ptr) {
1338
+ std::vector<value_type> fil(this->num_generators());
1339
+ for (auto i : std::views::iota(0u, this->num_generators())) {
1340
+ fil[i] = *(one_filtration_ptr + i);
1341
+ }
1342
+ return std::move(fil);
1343
+ };
1344
+ return barcodes<decltype(fun), const value_type *, true, true>(std::move(fun), args, ignore_inf);
1345
+ }
1346
+
901
1347
  inline std::vector<split_barcode> persistence_on_lines(
902
1348
  const std::vector<std::pair<std::vector<value_type>, std::vector<value_type>>> &bp_dirs,
903
1349
  bool ignore_inf) {
@@ -944,9 +1390,9 @@ class Truc {
944
1390
  // bool reverse);
945
1391
 
946
1392
  private:
947
- std::vector<MultiFiltration> generator_filtration_values; // defined at construction time. Const
948
- std::vector<index_type> generator_order; // size fixed at construction time, // TODO : CHANGE THAT TO UINT32
949
- Structure structure; // defined at construction time. Const
1393
+ MultiFiltrations generator_filtration_values; // defined at construction time. Const
1394
+ std::vector<index_type> generator_order; // size fixed at construction time
1395
+ Structure structure; // defined at construction time. Const
950
1396
  std::vector<typename MultiFiltration::value_type> filtration_container; // filtration of the current slice
951
1397
  PersBackend persistence; // generated by the structure, and generator_order.
952
1398