multipers 1.1.3__cp310-cp310-macosx_11_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (63) hide show
  1. multipers/.dylibs/libtbb.12.12.dylib +0 -0
  2. multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
  3. multipers/__init__.py +5 -0
  4. multipers/_old_rank_invariant.pyx +328 -0
  5. multipers/_signed_measure_meta.py +193 -0
  6. multipers/data/MOL2.py +350 -0
  7. multipers/data/UCR.py +18 -0
  8. multipers/data/__init__.py +1 -0
  9. multipers/data/graphs.py +466 -0
  10. multipers/data/immuno_regions.py +27 -0
  11. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  12. multipers/data/pytorch2simplextree.py +91 -0
  13. multipers/data/shape3d.py +101 -0
  14. multipers/data/synthetic.py +68 -0
  15. multipers/distances.py +172 -0
  16. multipers/euler_characteristic.cpython-310-darwin.so +0 -0
  17. multipers/euler_characteristic.pyx +137 -0
  18. multipers/function_rips.cpython-310-darwin.so +0 -0
  19. multipers/function_rips.pyx +102 -0
  20. multipers/hilbert_function.cpython-310-darwin.so +0 -0
  21. multipers/hilbert_function.pyi +46 -0
  22. multipers/hilbert_function.pyx +151 -0
  23. multipers/io.cpython-310-darwin.so +0 -0
  24. multipers/io.pyx +176 -0
  25. multipers/ml/__init__.py +0 -0
  26. multipers/ml/accuracies.py +61 -0
  27. multipers/ml/convolutions.py +510 -0
  28. multipers/ml/invariants_with_persistable.py +79 -0
  29. multipers/ml/kernels.py +128 -0
  30. multipers/ml/mma.py +657 -0
  31. multipers/ml/one.py +472 -0
  32. multipers/ml/point_clouds.py +191 -0
  33. multipers/ml/signed_betti.py +50 -0
  34. multipers/ml/signed_measures.py +1479 -0
  35. multipers/ml/sliced_wasserstein.py +313 -0
  36. multipers/ml/tools.py +116 -0
  37. multipers/mma_structures.cpython-310-darwin.so +0 -0
  38. multipers/mma_structures.pxd +155 -0
  39. multipers/mma_structures.pyx +651 -0
  40. multipers/multiparameter_edge_collapse.py +29 -0
  41. multipers/multiparameter_module_approximation.cpython-310-darwin.so +0 -0
  42. multipers/multiparameter_module_approximation.pyi +439 -0
  43. multipers/multiparameter_module_approximation.pyx +311 -0
  44. multipers/pickle.py +53 -0
  45. multipers/plots.py +292 -0
  46. multipers/point_measure_integration.cpython-310-darwin.so +0 -0
  47. multipers/point_measure_integration.pyx +59 -0
  48. multipers/rank_invariant.cpython-310-darwin.so +0 -0
  49. multipers/rank_invariant.pyx +154 -0
  50. multipers/simplex_tree_multi.cpython-310-darwin.so +0 -0
  51. multipers/simplex_tree_multi.pxd +121 -0
  52. multipers/simplex_tree_multi.pyi +715 -0
  53. multipers/simplex_tree_multi.pyx +1417 -0
  54. multipers/slicer.cpython-310-darwin.so +0 -0
  55. multipers/slicer.pxd +94 -0
  56. multipers/slicer.pyx +276 -0
  57. multipers/tensor.pxd +13 -0
  58. multipers/test.pyx +44 -0
  59. multipers-1.1.3.dist-info/LICENSE +21 -0
  60. multipers-1.1.3.dist-info/METADATA +22 -0
  61. multipers-1.1.3.dist-info/RECORD +63 -0
  62. multipers-1.1.3.dist-info/WHEEL +5 -0
  63. multipers-1.1.3.dist-info/top_level.txt +1 -0
multipers/ml/mma.py ADDED
@@ -0,0 +1,657 @@
1
+ from typing import Callable, Iterable, List, Optional
2
+ import multipers as mp
3
+ from multipers.ml.tools import filtration_grid_to_coordinates
4
+ import numpy as np
5
+ from joblib import Parallel, delayed
6
+ from sklearn.base import BaseEstimator, TransformerMixin
7
+ from multipers.multiparameter_module_approximation import PyModule
8
+ from tqdm import tqdm
9
+
10
+ from multipers.simplex_tree_multi import SimplexTreeMulti
11
+
12
+ reduce_grid = mp.simplex_tree_multi.SimplexTreeMulti._reduce_grid
13
+
14
+
15
+ class SimplexTree2MMA(BaseEstimator, TransformerMixin):
16
+ """
17
+ Turns a list of simplextrees to MMA approximations
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ n_jobs: int = 1,
23
+ expand_dim: Optional[int] = None,
24
+ prune_degrees_above: Optional[int] = None,
25
+ progress=False,
26
+ **persistence_kwargs,
27
+ ) -> None:
28
+ super().__init__()
29
+ self.persistence_args = persistence_kwargs
30
+ self.n_jobs = n_jobs
31
+ self._has_axis = None
32
+ self._num_axis = None
33
+ self.prune_degrees_above = prune_degrees_above
34
+ self.progress = progress
35
+ self.expand_dim = expand_dim
36
+ self._boxes = None
37
+ return
38
+
39
+ def fit(self, X, y=None):
40
+ if len(X) == 0:
41
+ return self
42
+ self._has_axis = not isinstance(X[0], mp.SimplexTreeMulti)
43
+ if self._has_axis:
44
+ try:
45
+ X[0][0]
46
+ except IndexError:
47
+ print(f"IndexError, {X[0]=}")
48
+ if len(X[0]) == 0:
49
+ print(
50
+ "No simplextree found, maybe you forgot to give a filtration parameter to the previous pipeline"
51
+ )
52
+ raise IndexError
53
+ assert isinstance(
54
+ X[0][0], mp.SimplexTreeMulti
55
+ ), f"X[0] is not a simplextre, {X[0]=}, and X[0][0] neither."
56
+ self._num_axis = len(X[0])
57
+ filtration_values = np.asarray(
58
+ [
59
+ [x[axis].filtration_bounds() for x in X]
60
+ for axis in range(self._num_axis)
61
+ ]
62
+ )
63
+ num_parameters = filtration_values.shape[-1]
64
+ # Output : axis, data, min/max, num_parameters
65
+ # print("TEST : NUM PARAMETERS ", num_parameters)
66
+ m = np.asarray(
67
+ [
68
+ [
69
+ filtration_values[axis, :, 0, parameter].min()
70
+ for parameter in range(num_parameters)
71
+ ]
72
+ for axis in range(self._num_axis)
73
+ ]
74
+ )
75
+ M = np.asarray(
76
+ [
77
+ [
78
+ filtration_values[axis, :, 1, parameter].max()
79
+ for parameter in range(num_parameters)
80
+ ]
81
+ for axis in range(self._num_axis)
82
+ ]
83
+ )
84
+ # shape of m/M axis,num_parameters
85
+ self._boxes = [
86
+ np.array([m_of_axis, M_of_axis]) for m_of_axis, M_of_axis in zip(m, M)
87
+ ]
88
+ else:
89
+ filtration_values = np.asarray([x.filtration_bounds() for x in X])
90
+ num_parameters = filtration_values.shape[-1]
91
+ # print("TEST : NUM PARAMETERS ", num_parameters)
92
+ m = np.asarray(
93
+ [
94
+ filtration_values[:, 0, parameter].min()
95
+ for parameter in range(num_parameters)
96
+ ]
97
+ )
98
+ M = np.asarray(
99
+ [
100
+ filtration_values[:, 1, parameter].max()
101
+ for parameter in range(num_parameters)
102
+ ]
103
+ )
104
+ self._boxes = [m, M]
105
+ return self
106
+
107
+ def transform(self, X):
108
+ if self.prune_degrees_above is not None:
109
+ for x in X:
110
+ if self._has_axis:
111
+ for x_ in x:
112
+ x_.prune_above_dimension(
113
+ self.prune_degrees_above
114
+ ) # we only do for H0 for computational ease
115
+ else:
116
+ x.prune_above_dimension(
117
+ self.prune_degrees_above
118
+ ) # we only do for H0 for computational ease
119
+
120
+ def todo1(x: mp.SimplexTreeMulti, box):
121
+ # print(x.get_filtration_grid(resolution=3, grid_strategy="regular"))
122
+ # print("TEST BOX",box)
123
+ if self.expand_dim is not None:
124
+ x.expansion(self.expand_dim)
125
+ return x.persistence_approximation(
126
+ box=box, verbose=False, **self.persistence_args
127
+ )
128
+
129
+ def todo(sts: List[SimplexTreeMulti] | SimplexTreeMulti):
130
+ if self._has_axis:
131
+ assert not isinstance(sts, SimplexTreeMulti)
132
+ return [todo1(st, box) for st, box in zip(sts, self._boxes)]
133
+ assert isinstance(sts, SimplexTreeMulti)
134
+ return todo1(sts, self._boxes)
135
+
136
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
137
+ delayed(todo)(x)
138
+ for x in tqdm(X, desc="Computing modules", disable=not self.progress)
139
+ )
140
+
141
+
142
+ class MMAFormatter(BaseEstimator, TransformerMixin):
143
+ def __init__(
144
+ self,
145
+ degrees: list = [0, 1],
146
+ axis=None,
147
+ verbose: bool = False,
148
+ normalize: bool = False,
149
+ weights=None,
150
+ quantiles=None,
151
+ dump=False,
152
+ from_dump=False,
153
+ ):
154
+ self._module_bounds = None
155
+ self.verbose = verbose
156
+ self.axis = axis
157
+ self._axis = []
158
+ self._has_axis = None
159
+ self._num_axis = 0
160
+ self.degrees = degrees
161
+ self.normalize = normalize
162
+ self._num_parameters = None
163
+ self.weights = weights
164
+ self.quantiles = quantiles
165
+ self.dump = dump
166
+ self.from_dump = from_dump
167
+
168
+ @staticmethod
169
+ def _maybe_from_dump(X_in):
170
+ if len(X_in) == 0:
171
+ return X_in
172
+ import pickle
173
+
174
+ if isinstance(X_in[0], bytes):
175
+ X = [pickle.loads(mods) for mods in X_in]
176
+ else:
177
+ X = X_in
178
+ return X
179
+ # return [[mp.multiparameter_module_approximation.from_dump(mod) for mod in mods] for mods in dumped_modules]
180
+
181
+ @staticmethod
182
+ def _get_module_bound(x, degree):
183
+ """
184
+ Output format : (2,num_parameters)
185
+ """
186
+ # l,L = x.get_box()
187
+ filtration_values = x.get_module_of_degree(degree).get_filtration_values(
188
+ unique=True
189
+ )
190
+ out = np.array([[f[0], f[-1]] for f in filtration_values if len(f) > 0]).T
191
+ if len(out) != 2:
192
+ print(f"Missing degree {degree} here !")
193
+ m = M = [np.nan for _ in range(x.num_parameters)]
194
+ else:
195
+ m, M = out
196
+ # m = np.where(m<np.inf, m, l)
197
+ # M = np.where(M>-np.inf, M,L)
198
+ return m, M
199
+
200
+ @staticmethod
201
+ def _infer_axis(X):
202
+ has_axis = not isinstance(X[0], PyModule)
203
+ assert not has_axis or isinstance(X[0][0], PyModule)
204
+ return has_axis
205
+
206
+ @staticmethod
207
+ def _infer_num_parameters(X, ax=slice(None)):
208
+ return X[0][ax].num_parameters
209
+
210
+ @staticmethod
211
+ def _infer_bounds(X, degrees=None, axis=[slice(None)], quantiles=None):
212
+ """
213
+ Compute bounds of filtration values of a list of modules.
214
+
215
+ Output Format
216
+ -------------
217
+ m,M of shape : (num_axis,num_degrees,2,num_parameters)
218
+ """
219
+ if degrees is None:
220
+ degrees = np.arange(X[0][axis[0]].max_degree + 1)
221
+ bounds = np.array(
222
+ [
223
+ [
224
+ [
225
+ MMAFormatter._get_module_bound(x[ax], degree)
226
+ for degree in degrees
227
+ ]
228
+ for ax in axis
229
+ ]
230
+ for x in X
231
+ ]
232
+ )
233
+ if quantiles is not None:
234
+ qm, qM = quantiles
235
+ # TODO per axis, degree !!
236
+ # m = np.quantile(bounds[:,:,:,0,:], q=qm,axis=0)
237
+ # M = np.quantile(bounds[:,:,:,1,:], q=1-qM,axis=0)
238
+ num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
239
+ m = [
240
+ [
241
+ [
242
+ np.nanquantile(
243
+ bounds[:, ax, degree, 0, parameter], axis=0, q=qm
244
+ )
245
+ for parameter in range(num_parameters)
246
+ ]
247
+ for degree in range(num_degrees)
248
+ ]
249
+ for ax in range(num_axis)
250
+ ]
251
+ m = np.asarray(m)
252
+ M = [
253
+ [
254
+ [
255
+ np.nanquantile(
256
+ bounds[:, ax, degree, 1, parameter], axis=0, q=1 - qM
257
+ )
258
+ for parameter in range(num_parameters)
259
+ ]
260
+ for degree in range(num_degrees)
261
+ ]
262
+ for ax in range(num_axis)
263
+ ]
264
+ M = np.asarray(M)
265
+ else:
266
+ num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
267
+ m = [
268
+ [
269
+ [
270
+ np.nanmin(bounds[:, ax, degree, 0, parameter], axis=0)
271
+ for parameter in range(num_parameters)
272
+ ]
273
+ for degree in range(num_degrees)
274
+ ]
275
+ for ax in range(num_axis)
276
+ ]
277
+ m = np.asarray(m)
278
+ M = [
279
+ [
280
+ [
281
+ np.nanmax(bounds[:, ax, degree, 1, parameter], axis=0)
282
+ for parameter in range(num_parameters)
283
+ ]
284
+ for degree in range(num_degrees)
285
+ ]
286
+ for ax in range(num_axis)
287
+ ]
288
+ M = np.asarray(M)
289
+ # m = bounds[:,:,:,0,:].min(axis=0)
290
+ # M = bounds[:,:,:,1,:].max(axis=0)
291
+ return (m, M)
292
+
293
+ @staticmethod
294
+ def _infer_grid(X: List[PyModule], strategy: str, resolution: int, degrees=None):
295
+ """
296
+ Given a list of PyModules, computes a multiparameter discrete grid,
297
+ with a given strategy,
298
+ from the filtration values of the summands of the modules.
299
+ """
300
+ num_parameters = X[0].num_parameters
301
+ if degrees is None:
302
+ # Format here : ((filtration values of parameter) for parameter)
303
+ filtration_values = tuple(
304
+ mod.get_filtration_values(unique=True) for mod in X
305
+ )
306
+ else:
307
+ filtration_values = tuple(
308
+ mod.get_module_of_degrees(degrees).get_filtration_values(unique=True)
309
+ for mod in X
310
+ )
311
+
312
+ if "_mean" in strategy:
313
+ substrategy = strategy.split("_")[0]
314
+ processed_filtration_values = [
315
+ reduce_grid(f, resolution, substrategy, unique=False)
316
+ for f in filtration_values
317
+ ]
318
+ reduced_grid = np.mean(processed_filtration_values, axis=0)
319
+ # elif "_quantile" in strategy:
320
+ # substrategy = strategy.split("_")[0]
321
+ # processed_filtration_values = [reduce_grid(f, resolution, substrategy, unique=False) for f in filtration_values]
322
+ # reduced_grid = np.qu(processed_filtration_values, axis=0)
323
+ else:
324
+ filtration_values = [
325
+ np.unique(
326
+ np.concatenate([f[parameter] for f in filtration_values], axis=0)
327
+ )
328
+ for parameter in range(num_parameters)
329
+ ]
330
+ reduced_grid = reduce_grid(
331
+ filtration_values, resolution, strategy, unique=True
332
+ )
333
+
334
+ coordinates, new_resolution = filtration_grid_to_coordinates(
335
+ reduced_grid, return_resolution=True
336
+ )
337
+ return coordinates, new_resolution
338
+
339
+ def fit(self, X_in, y=None):
340
+ X = self._maybe_from_dump(X_in)
341
+ if len(X) == 0:
342
+ return self
343
+ self._has_axis = self._infer_axis(X)
344
+ # assert not self._has_axis or isinstance(X[0][0], mp.PyModule)
345
+ if self.axis is None and self._has_axis:
346
+ self.axis = -1
347
+ if self.axis is not None and not (self._has_axis):
348
+ raise Exception(f"SMF didn't find an axis, but requested axis {self.axis}")
349
+ if self._has_axis:
350
+ self._num_axis = len(X[0])
351
+ if self.verbose:
352
+ print("-----------MMAFormatter-----------")
353
+ print("---- Infered stats")
354
+ print(f"Found axis : {self._has_axis}, num : {self._num_axis}")
355
+ print(f"Number of parameters : {self._num_parameters}")
356
+ self._axis = (
357
+ [slice(None)]
358
+ if self.axis is None
359
+ else range(self._num_axis)
360
+ if self.axis == -1
361
+ else [self.axis]
362
+ )
363
+
364
+ self._num_parameters = self._infer_num_parameters(X, ax=self._axis[0])
365
+ if self.normalize:
366
+ # print(self._axis)
367
+ self._module_bounds = self._infer_bounds(
368
+ X, self.degrees, self._axis, self.quantiles
369
+ )
370
+ else:
371
+ m = np.zeros((self._num_axis, len(self.degrees), self._num_parameters))
372
+ M = m + 1
373
+ self._module_bounds = (m, M)
374
+ assert self._num_parameters == self._module_bounds[0].shape[-1]
375
+ if self.verbose:
376
+ print("---- Bounds (only computed if normalize):")
377
+ if self._has_axis and self._num_axis > 1:
378
+ print("(axis) x (degree) x (parameter)")
379
+ else:
380
+ print("(degree) x (parameter)")
381
+ m, M = self._module_bounds
382
+ print("-- Lower bound : ", m.shape)
383
+ print(m)
384
+ print("-- Upper bound :", M.shape)
385
+ print(M)
386
+ w = 1 if self.weights is None else np.asarray(self.weights)
387
+ m, M = self._module_bounds
388
+ normalizer = M - m
389
+ zero_normalizer = normalizer == 0
390
+ if np.any(zero_normalizer):
391
+ from warnings import warn
392
+
393
+ warn(f"Encountered empty bounds. Please fix me. \n M-m = {normalizer}")
394
+ normalizer[zero_normalizer] = 1
395
+ self._normalization_factors = w / normalizer
396
+ if self.verbose:
397
+ print("-- Normalization factors:", self._normalization_factors.shape)
398
+ print(self._normalization_factors)
399
+
400
+ if self.verbose:
401
+ print("---- Module size :")
402
+ for ax in self._axis:
403
+ print(f"- Axis {ax}")
404
+ for degree in self.degrees:
405
+ sizes = [len(x[ax].get_module_of_degree(degree)) for x in X]
406
+ print(
407
+ f" - Degree {degree} size \
408
+ {np.mean(sizes).round(decimals=2)}\
409
+ ±{np.std(sizes).round(decimals=2)}"
410
+ )
411
+ print("----------------------------------")
412
+ return self
413
+
414
+ @staticmethod
415
+ def copy_transform(mod, degrees, translation, rescale_factors, new_box):
416
+ copy = mod.get_module_of_degrees(
417
+ degrees
418
+ ) # and only returns the specific degrees
419
+ for j, degree in enumerate(degrees):
420
+ copy.translate(translation[j], degree=degree)
421
+ copy.rescale(rescale_factors[j], degree=degree)
422
+ copy.set_box(new_box)
423
+ return copy
424
+
425
+ def transform(self, X_in):
426
+ X = self._maybe_from_dump(X_in)
427
+ if np.any(self._normalization_factors != 1):
428
+ if self.verbose:
429
+ print("Normalizing...", end="")
430
+ w = (
431
+ [1] * self._num_parameters
432
+ if self.weights is None
433
+ else np.asarray(self.weights)
434
+ )
435
+ standard_box = mp.multiparameter_module_approximation.PyBox(
436
+ [0] * self._num_parameters, w
437
+ )
438
+
439
+ X_copy = [
440
+ [
441
+ self.copy_transform(
442
+ mod=x[ax],
443
+ degrees=self.degrees,
444
+ translation=-self._module_bounds[0][i],
445
+ rescale_factors=self._normalization_factors[i],
446
+ new_box=standard_box,
447
+ )
448
+ for i, ax in enumerate(self._axis)
449
+ ]
450
+ for x in X
451
+ ]
452
+ if self.verbose:
453
+ print("Done.")
454
+ return X_copy
455
+ if self.axis != -1:
456
+ X = [x[self.axis] for x in X]
457
+ if self.dump:
458
+ import pickle
459
+
460
+ X = [pickle.dumps(mods) for mods in X]
461
+ return X
462
+ # return [todo(x) for x in X]
463
+
464
+
465
+ class MMA2IMG(BaseEstimator, TransformerMixin):
466
+ def __init__(
467
+ self,
468
+ degrees: list,
469
+ bandwidth: float = 0.1,
470
+ power: float = 1,
471
+ normalize: bool = False,
472
+ resolution: list | int = 50,
473
+ plot: bool = False,
474
+ box=None,
475
+ n_jobs=1,
476
+ flatten=False,
477
+ progress=False,
478
+ grid_strategy="regular",
479
+ ):
480
+ self.bandwidth = bandwidth
481
+ self.degrees = degrees
482
+ self.resolution = resolution
483
+ self.box = box
484
+ self.plot = plot
485
+ self._box = None
486
+ self.normalize = normalize
487
+ self.power = power
488
+ self._has_axis = None
489
+ self._num_parameters = None
490
+ self.n_jobs = n_jobs
491
+ self.flatten = flatten
492
+ self.progress = progress
493
+ self.grid_strategy = grid_strategy
494
+ self._num_axis = None
495
+ self._coords_to_compute = None
496
+ self._new_resolutions = None
497
+
498
+ def fit(self, X, y=None):
499
+ # TODO infer box
500
+ # TODO rescale module
501
+ self._has_axis = MMAFormatter._infer_axis(X)
502
+ if self._has_axis:
503
+ self._num_axis = len(X[0])
504
+ if self.box is None:
505
+ self._box = [[0], [1, 1]]
506
+ else:
507
+ self._box = self.box
508
+ if self._has_axis:
509
+ its = (tuple(x[axis] for x in X) for axis in range(self._num_axis))
510
+ crs = tuple(
511
+ MMAFormatter._infer_grid(
512
+ X_axis, self.grid_strategy, self.resolution, degrees=self.degrees
513
+ )
514
+ for X_axis in its
515
+ )
516
+ self._coords_to_compute = [
517
+ c for c, _ in crs
518
+ ] # not the same resolutions, so cannot be put in an array
519
+ self._new_resolutions = np.asarray([r for _, r in crs])
520
+ else:
521
+ coords, new_resolution = MMAFormatter._infer_grid(
522
+ X, self.grid_strategy, self.resolution, degrees=self.degrees
523
+ )
524
+ self._coords_to_compute = coords
525
+ self._new_resolutions = new_resolution
526
+ return self
527
+
528
+ def transform(self, X):
529
+ img_args = {
530
+ "delta": self.bandwidth,
531
+ "p": self.power,
532
+ "normalize": self.normalize,
533
+ # "plot":self.plot,
534
+ # "cb":1, # colorbar
535
+ # "resolution" : self.resolution, # info in coordinates
536
+ "box": self.box,
537
+ "degrees": self.degrees,
538
+ # num_jobs is better for parallel over modules.
539
+ "n_jobs": self.n_jobs,
540
+ }
541
+ if self._has_axis:
542
+
543
+ def todo1(x, c):
544
+ return x._compute_pixels(c, **img_args)
545
+ else:
546
+
547
+ def todo1(x):
548
+ return x._compute_pixels(self._coords_to_compute, **img_args)[
549
+ None, :
550
+ ] # shape same as has_axis
551
+
552
+ if self._has_axis:
553
+
554
+ def todo2(mods):
555
+ return [todo1(mod, c) for mod, c in zip(mods, self._coords_to_compute)]
556
+ else:
557
+ todo2 = todo1
558
+
559
+ if self.flatten:
560
+
561
+ def todo(mods):
562
+ return np.concatenate(todo2(mods), axis=1).flatten()
563
+ else:
564
+
565
+ def todo(mods):
566
+ return [
567
+ img.reshape(len(img_args["degrees"]), *r)
568
+ for img, r in zip(todo2(mods), self._new_resolutions)
569
+ ]
570
+
571
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
572
+ delayed(todo)(x)
573
+ for x in tqdm(X, desc="Computing images", disable=not self.progress)
574
+ ) # res depends on ax (infer_grid)
575
+
576
+
577
+ class MMA2Landscape(BaseEstimator, TransformerMixin):
578
+ """
579
+ Turns a list of MMA approximations into Landscapes vectorisations
580
+ """
581
+
582
+ def __init__(
583
+ self,
584
+ resolution=[100, 100],
585
+ degrees: list[int] | None = [0, 1],
586
+ ks: Iterable[int] = range(5),
587
+ phi: Callable = np.sum,
588
+ box=None,
589
+ plot: bool = False,
590
+ n_jobs=-1,
591
+ filtration_quantile: float = 0.01,
592
+ ) -> None:
593
+ super().__init__()
594
+ self.resolution: list[int] = resolution
595
+ self.degrees = degrees
596
+ self.ks = ks
597
+ self.phi = phi # Has to have a axis=0 !
598
+ self.box = box
599
+ self.plot = plot
600
+ self.n_jobs = n_jobs
601
+ self.filtration_quantile = filtration_quantile
602
+ return
603
+
604
+ def fit(self, X, y=None):
605
+ if len(X) <= 0:
606
+ return
607
+ assert (
608
+ X[0].num_parameters == 2
609
+ ), f"Number of parameters {X[0].num_parameters} has to be 2."
610
+ if self.box is None:
611
+
612
+ def _bottom(mod):
613
+ return mod.get_bottom()
614
+
615
+ def _top(mod):
616
+ return mod.get_top()
617
+
618
+ m = np.quantile(
619
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
620
+ delayed(_bottom)(mod) for mod in X
621
+ ),
622
+ q=self.filtration_quantile,
623
+ axis=0,
624
+ )
625
+ M = np.quantile(
626
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
627
+ delayed(_top)(mod) for mod in X
628
+ ),
629
+ q=1 - self.filtration_quantile,
630
+ axis=0,
631
+ )
632
+ self.box = [m, M]
633
+ return self
634
+
635
+ def transform(self, X) -> list[np.ndarray]:
636
+ if len(X) <= 0:
637
+ return
638
+
639
+ def todo(mod):
640
+ return np.concatenate(
641
+ [
642
+ self.phi(
643
+ mod.landscapes(
644
+ ks=self.ks,
645
+ resolution=self.resolution,
646
+ degree=degree,
647
+ plot=self.plot,
648
+ ),
649
+ axis=0,
650
+ ).flatten()
651
+ for degree in self.degrees
652
+ ]
653
+ ).flatten()
654
+
655
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
656
+ delayed(todo)(x) for x in X
657
+ )