multipers 2.0.0__cp310-cp310-macosx_13_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (78) hide show
  1. multipers/.dylibs/libc++.1.0.dylib +0 -0
  2. multipers/.dylibs/libtbb.12.12.dylib +0 -0
  3. multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
  4. multipers/__init__.py +11 -0
  5. multipers/_signed_measure_meta.py +268 -0
  6. multipers/_slicer_meta.py +171 -0
  7. multipers/data/MOL2.py +350 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +68 -0
  16. multipers/distances.py +198 -0
  17. multipers/euler_characteristic.pyx +132 -0
  18. multipers/filtration_conversions.pxd +229 -0
  19. multipers/filtrations.pxd +225 -0
  20. multipers/function_rips.cpython-310-darwin.so +0 -0
  21. multipers/function_rips.pyx +105 -0
  22. multipers/grids.cpython-310-darwin.so +0 -0
  23. multipers/grids.pyx +281 -0
  24. multipers/hilbert_function.pyi +46 -0
  25. multipers/hilbert_function.pyx +153 -0
  26. multipers/io.cpython-310-darwin.so +0 -0
  27. multipers/io.pyx +571 -0
  28. multipers/ml/__init__.py +0 -0
  29. multipers/ml/accuracies.py +90 -0
  30. multipers/ml/convolutions.py +532 -0
  31. multipers/ml/invariants_with_persistable.py +79 -0
  32. multipers/ml/kernels.py +176 -0
  33. multipers/ml/mma.py +659 -0
  34. multipers/ml/one.py +472 -0
  35. multipers/ml/point_clouds.py +238 -0
  36. multipers/ml/signed_betti.py +50 -0
  37. multipers/ml/signed_measures.py +1542 -0
  38. multipers/ml/sliced_wasserstein.py +461 -0
  39. multipers/ml/tools.py +113 -0
  40. multipers/mma_structures.cpython-310-darwin.so +0 -0
  41. multipers/mma_structures.pxd +127 -0
  42. multipers/mma_structures.pyx +2433 -0
  43. multipers/multiparameter_edge_collapse.py +41 -0
  44. multipers/multiparameter_module_approximation.cpython-310-darwin.so +0 -0
  45. multipers/multiparameter_module_approximation.pyx +211 -0
  46. multipers/pickle.py +53 -0
  47. multipers/plots.py +326 -0
  48. multipers/point_measure_integration.cpython-310-darwin.so +0 -0
  49. multipers/point_measure_integration.pyx +139 -0
  50. multipers/rank_invariant.cpython-310-darwin.so +0 -0
  51. multipers/rank_invariant.pyx +229 -0
  52. multipers/simplex_tree_multi.cpython-310-darwin.so +0 -0
  53. multipers/simplex_tree_multi.pxd +129 -0
  54. multipers/simplex_tree_multi.pyi +715 -0
  55. multipers/simplex_tree_multi.pyx +4655 -0
  56. multipers/slicer.cpython-310-darwin.so +0 -0
  57. multipers/slicer.pxd +781 -0
  58. multipers/slicer.pyx +3393 -0
  59. multipers/tensor.pxd +13 -0
  60. multipers/test.pyx +44 -0
  61. multipers/tests/__init__.py +40 -0
  62. multipers/tests/old_test_rank_invariant.py +91 -0
  63. multipers/tests/test_diff_helper.py +74 -0
  64. multipers/tests/test_hilbert_function.py +82 -0
  65. multipers/tests/test_mma.py +51 -0
  66. multipers/tests/test_point_clouds.py +59 -0
  67. multipers/tests/test_python-cpp_conversion.py +82 -0
  68. multipers/tests/test_signed_betti.py +181 -0
  69. multipers/tests/test_simplextreemulti.py +98 -0
  70. multipers/tests/test_slicer.py +63 -0
  71. multipers/torch/__init__.py +1 -0
  72. multipers/torch/diff_grids.py +217 -0
  73. multipers/torch/rips_density.py +257 -0
  74. multipers-2.0.0.dist-info/LICENSE +21 -0
  75. multipers-2.0.0.dist-info/METADATA +29 -0
  76. multipers-2.0.0.dist-info/RECORD +78 -0
  77. multipers-2.0.0.dist-info/WHEEL +5 -0
  78. multipers-2.0.0.dist-info/top_level.txt +1 -0
multipers/ml/mma.py ADDED
@@ -0,0 +1,659 @@
1
+ from typing import Callable, Iterable, List, Optional
2
+
3
+ import numpy as np
4
+ from joblib import Parallel, delayed
5
+ from sklearn.base import BaseEstimator, TransformerMixin
6
+ from tqdm import tqdm
7
+
8
+ import multipers as mp
9
+ import multipers.simplex_tree_multi
10
+ from multipers.grids import compute_grid as reduce_grid
11
+ from multipers.ml.tools import filtration_grid_to_coordinates
12
+ from multipers.mma_structures import PyBox_f64, PyModule_type
13
+
14
+
15
+ class SimplexTree2MMA(BaseEstimator, TransformerMixin):
16
+ """
17
+ Turns a list of simplextrees to MMA approximations
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ n_jobs: int = 1,
23
+ expand_dim: Optional[int] = None,
24
+ prune_degrees_above: Optional[int] = None,
25
+ progress=False,
26
+ **persistence_kwargs,
27
+ ) -> None:
28
+ super().__init__()
29
+ self.persistence_args = persistence_kwargs
30
+ self.n_jobs = n_jobs
31
+ self._has_axis = None
32
+ self._num_axis = None
33
+ self.prune_degrees_above = prune_degrees_above
34
+ self.progress = progress
35
+ self.expand_dim = expand_dim
36
+ self._boxes = None
37
+ return
38
+
39
+ def fit(self, X, y=None):
40
+ if len(X) == 0:
41
+ return self
42
+ self._has_axis = not mp.simplex_tree_multi.is_simplextree_multi(X[0])
43
+ if self._has_axis:
44
+ try:
45
+ X[0][0]
46
+ except IndexError:
47
+ print(f"IndexError, {X[0]=}")
48
+ if len(X[0]) == 0:
49
+ print(
50
+ "No simplextree found, maybe you forgot to give a filtration parameter to the previous pipeline"
51
+ )
52
+ raise IndexError
53
+ assert mp.simplex_tree_multi.is_simplextree_multi(X[0][0]), f"X[0] is not a simplextre, {X[0]=}, and X[0][0] neither."
54
+ self._num_axis = len(X[0])
55
+ filtration_values = np.asarray(
56
+ [
57
+ [x[axis].filtration_bounds() for x in X]
58
+ for axis in range(self._num_axis)
59
+ ]
60
+ )
61
+ num_parameters = filtration_values.shape[-1]
62
+ # Output : axis, data, min/max, num_parameters
63
+ # print("TEST : NUM PARAMETERS ", num_parameters)
64
+ m = np.asarray(
65
+ [
66
+ [
67
+ filtration_values[axis, :, 0, parameter].min()
68
+ for parameter in range(num_parameters)
69
+ ]
70
+ for axis in range(self._num_axis)
71
+ ]
72
+ )
73
+ M = np.asarray(
74
+ [
75
+ [
76
+ filtration_values[axis, :, 1, parameter].max()
77
+ for parameter in range(num_parameters)
78
+ ]
79
+ for axis in range(self._num_axis)
80
+ ]
81
+ )
82
+ # shape of m/M axis,num_parameters
83
+ self._boxes = [
84
+ np.array([m_of_axis, M_of_axis]) for m_of_axis, M_of_axis in zip(m, M)
85
+ ]
86
+ else:
87
+ filtration_values = np.asarray([x.filtration_bounds() for x in X])
88
+ num_parameters = filtration_values.shape[-1]
89
+ # print("TEST : NUM PARAMETERS ", num_parameters)
90
+ m = np.asarray(
91
+ [
92
+ filtration_values[:, 0, parameter].min()
93
+ for parameter in range(num_parameters)
94
+ ]
95
+ )
96
+ M = np.asarray(
97
+ [
98
+ filtration_values[:, 1, parameter].max()
99
+ for parameter in range(num_parameters)
100
+ ]
101
+ )
102
+ self._boxes = [m, M]
103
+ return self
104
+
105
+ def transform(self, X):
106
+ if self.prune_degrees_above is not None:
107
+ for x in X:
108
+ if self._has_axis:
109
+ for x_ in x:
110
+ x_.prune_above_dimension(
111
+ self.prune_degrees_above
112
+ ) # we only do for H0 for computational ease
113
+ else:
114
+ x.prune_above_dimension(
115
+ self.prune_degrees_above
116
+ ) # we only do for H0 for computational ease
117
+
118
+ def todo1(x: mp.simplex_tree_multi.SimplexTreeMulti_type, box):
119
+ # print(x.get_filtration_grid(resolution=3, grid_strategy="regular"))
120
+ # print("TEST BOX",box)
121
+ if self.expand_dim is not None:
122
+ x.expansion(self.expand_dim)
123
+ return x.persistence_approximation(
124
+ box=box, verbose=False, **self.persistence_args
125
+ )
126
+
127
+ def todo(sts: List[mp.simplex_tree_multi.SimplexTreeMulti_type] | mp.simplex_tree_multi.SimplexTreeMulti_type):
128
+ if self._has_axis:
129
+ assert not mp.simplex_tree_multi.is_simplextree_multi(sts)
130
+ return [todo1(st, box) for st, box in zip(sts, self._boxes)]
131
+ assert mp.simplex_tree_multi.is_simplextree_multi(sts)
132
+ return todo1(sts, self._boxes)
133
+
134
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
135
+ delayed(todo)(x)
136
+ for x in tqdm(X, desc="Computing modules", disable=not self.progress)
137
+ )
138
+
139
+
140
+ class MMAFormatter(BaseEstimator, TransformerMixin):
141
+ def __init__(
142
+ self,
143
+ degrees: list = [0, 1],
144
+ axis=None,
145
+ verbose: bool = False,
146
+ normalize: bool = False,
147
+ weights=None,
148
+ quantiles=None,
149
+ dump=False,
150
+ from_dump=False,
151
+ ):
152
+ self._module_bounds = None
153
+ self.verbose = verbose
154
+ self.axis = axis
155
+ self._axis = []
156
+ self._has_axis = None
157
+ self._num_axis = 0
158
+ self.degrees = degrees
159
+ self.normalize = normalize
160
+ self._num_parameters = None
161
+ self.weights = weights
162
+ self.quantiles = quantiles
163
+ self.dump = dump
164
+ self.from_dump = from_dump
165
+
166
+ @staticmethod
167
+ def _maybe_from_dump(X_in):
168
+ if len(X_in) == 0:
169
+ return X_in
170
+ import pickle
171
+
172
+ if isinstance(X_in[0], bytes):
173
+ X = [pickle.loads(mods) for mods in X_in]
174
+ else:
175
+ X = X_in
176
+ return X
177
+ # return [[mp.multiparameter_module_approximation.from_dump(mod) for mod in mods] for mods in dumped_modules]
178
+
179
+ @staticmethod
180
+ def _get_module_bound(x, degree):
181
+ """
182
+ Output format : (2,num_parameters)
183
+ """
184
+ # l,L = x.get_box()
185
+ filtration_values = x.get_module_of_degree(degree).get_filtration_values(
186
+ unique=True
187
+ )
188
+ out = np.array([[f[0], f[-1]] for f in filtration_values if len(f) > 0]).T
189
+ if len(out) != 2:
190
+ print(f"Missing degree {degree} here !")
191
+ m = M = [np.nan for _ in range(x.num_parameters)]
192
+ else:
193
+ m, M = out
194
+ # m = np.where(m<np.inf, m, l)
195
+ # M = np.where(M>-np.inf, M,L)
196
+ return m, M
197
+
198
+ @staticmethod
199
+ def _infer_axis(X):
200
+ has_axis = not isinstance(X[0], PyModule_type)
201
+ assert not has_axis or isinstance(X[0][0], PyModule_type)
202
+ return has_axis
203
+
204
+ @staticmethod
205
+ def _infer_num_parameters(X, ax=slice(None)):
206
+ return X[0][ax].num_parameters
207
+
208
+ @staticmethod
209
+ def _infer_bounds(X, degrees=None, axis=[slice(None)], quantiles=None):
210
+ """
211
+ Compute bounds of filtration values of a list of modules.
212
+
213
+ Output Format
214
+ -------------
215
+ m,M of shape : (num_axis,num_degrees,2,num_parameters)
216
+ """
217
+ if degrees is None:
218
+ degrees = np.arange(X[0][axis[0]].max_degree + 1)
219
+ bounds = np.array(
220
+ [
221
+ [
222
+ [
223
+ MMAFormatter._get_module_bound(x[ax], degree)
224
+ for degree in degrees
225
+ ]
226
+ for ax in axis
227
+ ]
228
+ for x in X
229
+ ]
230
+ )
231
+ if quantiles is not None:
232
+ qm, qM = quantiles
233
+ # TODO per axis, degree !!
234
+ # m = np.quantile(bounds[:,:,:,0,:], q=qm,axis=0)
235
+ # M = np.quantile(bounds[:,:,:,1,:], q=1-qM,axis=0)
236
+ num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
237
+ m = [
238
+ [
239
+ [
240
+ np.nanquantile(
241
+ bounds[:, ax, degree, 0, parameter], axis=0, q=qm
242
+ )
243
+ for parameter in range(num_parameters)
244
+ ]
245
+ for degree in range(num_degrees)
246
+ ]
247
+ for ax in range(num_axis)
248
+ ]
249
+ m = np.asarray(m)
250
+ M = [
251
+ [
252
+ [
253
+ np.nanquantile(
254
+ bounds[:, ax, degree, 1, parameter], axis=0, q=1 - qM
255
+ )
256
+ for parameter in range(num_parameters)
257
+ ]
258
+ for degree in range(num_degrees)
259
+ ]
260
+ for ax in range(num_axis)
261
+ ]
262
+ M = np.asarray(M)
263
+ else:
264
+ num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
265
+ m = [
266
+ [
267
+ [
268
+ np.nanmin(bounds[:, ax, degree, 0, parameter], axis=0)
269
+ for parameter in range(num_parameters)
270
+ ]
271
+ for degree in range(num_degrees)
272
+ ]
273
+ for ax in range(num_axis)
274
+ ]
275
+ m = np.asarray(m)
276
+ M = [
277
+ [
278
+ [
279
+ np.nanmax(bounds[:, ax, degree, 1, parameter], axis=0)
280
+ for parameter in range(num_parameters)
281
+ ]
282
+ for degree in range(num_degrees)
283
+ ]
284
+ for ax in range(num_axis)
285
+ ]
286
+ M = np.asarray(M)
287
+ # m = bounds[:,:,:,0,:].min(axis=0)
288
+ # M = bounds[:,:,:,1,:].max(axis=0)
289
+ return (m, M)
290
+
291
+ @staticmethod
292
+ def _infer_grid(X: List[PyModule_type], strategy: str, resolution: int, degrees=None):
293
+ """
294
+ Given a list of PyModules, computes a multiparameter discrete grid,
295
+ with a given strategy,
296
+ from the filtration values of the summands of the modules.
297
+ """
298
+ num_parameters = X[0].num_parameters
299
+ if degrees is None:
300
+ # Format here : ((filtration values of parameter) for parameter)
301
+ filtration_values = tuple(
302
+ mod.get_filtration_values(unique=True) for mod in X
303
+ )
304
+ else:
305
+ filtration_values = tuple(
306
+ mod.get_module_of_degrees(degrees).get_filtration_values(unique=True)
307
+ for mod in X
308
+ )
309
+
310
+ if "_mean" in strategy:
311
+ substrategy = strategy.split("_")[0]
312
+ processed_filtration_values = [
313
+ reduce_grid(f, resolution, substrategy, unique=False)
314
+ for f in filtration_values
315
+ ]
316
+ reduced_grid = np.mean(processed_filtration_values, axis=0)
317
+ # elif "_quantile" in strategy:
318
+ # substrategy = strategy.split("_")[0]
319
+ # processed_filtration_values = [reduce_grid(f, resolution, substrategy, unique=False) for f in filtration_values]
320
+ # reduced_grid = np.qu(processed_filtration_values, axis=0)
321
+ else:
322
+ filtration_values = [
323
+ np.unique(
324
+ np.concatenate([f[parameter] for f in filtration_values], axis=0)
325
+ )
326
+ for parameter in range(num_parameters)
327
+ ]
328
+ reduced_grid = reduce_grid(
329
+ filtration_values, resolution, strategy, unique=True
330
+ )
331
+
332
+ coordinates, new_resolution = filtration_grid_to_coordinates(
333
+ reduced_grid, return_resolution=True
334
+ )
335
+ return coordinates, new_resolution
336
+
337
+ def fit(self, X_in, y=None):
338
+ X = self._maybe_from_dump(X_in)
339
+ if len(X) == 0:
340
+ return self
341
+ self._has_axis = self._infer_axis(X)
342
+ # assert not self._has_axis or isinstance(X[0][0], mp.PyModule)
343
+ if self.axis is None and self._has_axis:
344
+ self.axis = -1
345
+ if self.axis is not None and not (self._has_axis):
346
+ raise Exception(f"SMF didn't find an axis, but requested axis {self.axis}")
347
+ if self._has_axis:
348
+ self._num_axis = len(X[0])
349
+ if self.verbose:
350
+ print("-----------MMAFormatter-----------")
351
+ print("---- Infered stats")
352
+ print(f"Found axis : {self._has_axis}, num : {self._num_axis}")
353
+ print(f"Number of parameters : {self._num_parameters}")
354
+ self._axis = (
355
+ [slice(None)]
356
+ if self.axis is None
357
+ else range(self._num_axis)
358
+ if self.axis == -1
359
+ else [self.axis]
360
+ )
361
+
362
+ self._num_parameters = self._infer_num_parameters(X, ax=self._axis[0])
363
+ if self.normalize:
364
+ # print(self._axis)
365
+ self._module_bounds = self._infer_bounds(
366
+ X, self.degrees, self._axis, self.quantiles
367
+ )
368
+ else:
369
+ m = np.zeros((self._num_axis, len(self.degrees), self._num_parameters))
370
+ M = m + 1
371
+ self._module_bounds = (m, M)
372
+ assert self._num_parameters == self._module_bounds[0].shape[-1]
373
+ if self.verbose:
374
+ print("---- Bounds (only computed if normalize):")
375
+ if self._has_axis and self._num_axis > 1:
376
+ print("(axis) x (degree) x (parameter)")
377
+ else:
378
+ print("(degree) x (parameter)")
379
+ m, M = self._module_bounds
380
+ print("-- Lower bound : ", m.shape)
381
+ print(m)
382
+ print("-- Upper bound :", M.shape)
383
+ print(M)
384
+ w = 1 if self.weights is None else np.asarray(self.weights)
385
+ m, M = self._module_bounds
386
+ normalizer = M - m
387
+ zero_normalizer = normalizer == 0
388
+ if np.any(zero_normalizer):
389
+ from warnings import warn
390
+
391
+ warn(f"Encountered empty bounds. Please fix me. \n M-m = {normalizer}")
392
+ normalizer[zero_normalizer] = 1
393
+ self._normalization_factors = w / normalizer
394
+ if self.verbose:
395
+ print("-- Normalization factors:", self._normalization_factors.shape)
396
+ print(self._normalization_factors)
397
+
398
+ if self.verbose:
399
+ print("---- Module size :")
400
+ for ax in self._axis:
401
+ print(f"- Axis {ax}")
402
+ for degree in self.degrees:
403
+ sizes = [len(x[ax].get_module_of_degree(degree)) for x in X]
404
+ print(
405
+ f" - Degree {degree} size \
406
+ {np.mean(sizes).round(decimals=2)}\
407
+ ±{np.std(sizes).round(decimals=2)}"
408
+ )
409
+ print("----------------------------------")
410
+ return self
411
+
412
+ @staticmethod
413
+ def copy_transform(mod, degrees, translation, rescale_factors, new_box):
414
+ copy = mod.get_module_of_degrees(
415
+ degrees
416
+ ) # and only returns the specific degrees
417
+ for j, degree in enumerate(degrees):
418
+ copy.translate(translation[j], degree=degree)
419
+ copy.rescale(rescale_factors[j], degree=degree)
420
+ copy.set_box(new_box)
421
+ return copy
422
+
423
+ def transform(self, X_in):
424
+ X = self._maybe_from_dump(X_in)
425
+ if np.any(self._normalization_factors != 1):
426
+ if self.verbose:
427
+ print("Normalizing...", end="")
428
+ w = (
429
+ [1] * self._num_parameters
430
+ if self.weights is None
431
+ else np.asarray(self.weights)
432
+ )
433
+ standard_box = PyBox_f64([0] * self._num_parameters, w)
434
+
435
+ X_copy = [
436
+ [
437
+ self.copy_transform(
438
+ mod=x[ax],
439
+ degrees=self.degrees,
440
+ translation=-self._module_bounds[0][i],
441
+ rescale_factors=self._normalization_factors[i],
442
+ new_box=standard_box,
443
+ )
444
+ for i, ax in enumerate(self._axis)
445
+ ]
446
+ for x in X
447
+ ]
448
+ if self.verbose:
449
+ print("Done.")
450
+ return X_copy
451
+ if self.axis != -1:
452
+ X = [x[self.axis] for x in X]
453
+ if self.dump:
454
+ import pickle
455
+
456
+ X = [pickle.dumps(mods) for mods in X]
457
+ return X
458
+ # return [todo(x) for x in X]
459
+
460
+
461
+ class MMA2IMG(BaseEstimator, TransformerMixin):
462
+ def __init__(
463
+ self,
464
+ degrees: list,
465
+ bandwidth: float = 0.1,
466
+ power: float = 1,
467
+ normalize: bool = False,
468
+ resolution: list | int = 50,
469
+ plot: bool = False,
470
+ box=None,
471
+ n_jobs=-1,
472
+ flatten=False,
473
+ progress=False,
474
+ grid_strategy="regular",
475
+ kernel="linear",
476
+ signed:bool=False,
477
+ ):
478
+ self.bandwidth = bandwidth
479
+ self.degrees = degrees
480
+ self.resolution = resolution
481
+ self.box = box
482
+ self.plot = plot
483
+ self._box = None
484
+ self.normalize = normalize
485
+ self.power = power
486
+ self._has_axis = None
487
+ self._num_parameters = None
488
+ self.n_jobs = n_jobs
489
+ self.flatten = flatten
490
+ self.progress = progress
491
+ self.grid_strategy = grid_strategy
492
+ self._num_axis = None
493
+ self._coords_to_compute = None
494
+ self._new_resolutions = None
495
+ self.kernel=kernel
496
+ self.signed = signed
497
+
498
+ def fit(self, X, y=None):
499
+ # TODO infer box
500
+ # TODO rescale module
501
+ self._has_axis = MMAFormatter._infer_axis(X)
502
+ if self._has_axis:
503
+ self._num_axis = len(X[0])
504
+ if self.box is None:
505
+ self._box = [[0,0], [1, 1]]
506
+ else:
507
+ self._box = self.box
508
+ if self._has_axis:
509
+ its = (tuple(x[axis] for x in X) for axis in range(self._num_axis))
510
+ crs = tuple(
511
+ MMAFormatter._infer_grid(
512
+ X_axis, self.grid_strategy, self.resolution, degrees=self.degrees
513
+ )
514
+ for X_axis in its
515
+ )
516
+ self._coords_to_compute = [
517
+ c for c, _ in crs
518
+ ] # not the same resolutions, so cannot be put in an array
519
+ self._new_resolutions = np.asarray([r for _, r in crs])
520
+ else:
521
+ coords, new_resolution = MMAFormatter._infer_grid(
522
+ X, self.grid_strategy, self.resolution, degrees=self.degrees
523
+ )
524
+ self._coords_to_compute = coords
525
+ self._new_resolutions = new_resolution
526
+ return self
527
+
528
+ def transform(self, X):
529
+ img_args = {
530
+ "bandwidth": self.bandwidth,
531
+ "p": self.power,
532
+ "normalize": self.normalize,
533
+ # "plot":self.plot,
534
+ # "cb":1, # colorbar
535
+ # "resolution" : self.resolution, # info in coordinates
536
+ "box": self.box,
537
+ "degrees": self.degrees,
538
+ # num_jobs is better for parallel over modules.
539
+ "n_jobs": self.n_jobs,
540
+ "kernel":self.kernel,
541
+ "signed":self.signed,
542
+ "flatten":True, # custom coordinates
543
+ }
544
+ if self._has_axis:
545
+
546
+ def todo1(x, c):
547
+ return x.representation(coordinates=c, **img_args)
548
+ else:
549
+
550
+ def todo1(x):
551
+ return x.representation(coordinates = self._coords_to_compute, **img_args)[
552
+ None, :
553
+ ] # shape same as has_axis
554
+
555
+ if self._has_axis:
556
+ def todo2(mods):
557
+ return tuple(todo1(mod, c) for mod, c in zip(mods, self._coords_to_compute))
558
+ else:
559
+ todo2 = todo1
560
+
561
+ if self.flatten:
562
+
563
+ def todo(mods):
564
+ return np.concatenate(todo2(mods), axis=1).flatten()
565
+ else:
566
+
567
+ def todo(mods):
568
+ return tuple(
569
+ img.reshape(len(img_args["degrees"]), *r)
570
+ for img, r in zip(todo2(mods), self._new_resolutions)
571
+ )
572
+
573
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
574
+ delayed(todo)(x)
575
+ for x in tqdm(X, desc="Computing images", disable=not self.progress)
576
+ ) # res depends on ax (infer_grid)
577
+
578
+
579
+ class MMA2Landscape(BaseEstimator, TransformerMixin):
580
+ """
581
+ Turns a list of MMA approximations into Landscapes vectorisations
582
+ """
583
+
584
+ def __init__(
585
+ self,
586
+ resolution=[100, 100],
587
+ degrees: list[int] | None = [0, 1],
588
+ ks: Iterable[int] = range(5),
589
+ phi: Callable = np.sum,
590
+ box=None,
591
+ plot: bool = False,
592
+ n_jobs=-1,
593
+ filtration_quantile: float = 0.01,
594
+ ) -> None:
595
+ super().__init__()
596
+ self.resolution: list[int] = resolution
597
+ self.degrees = degrees
598
+ self.ks = ks
599
+ self.phi = phi # Has to have a axis=0 !
600
+ self.box = box
601
+ self.plot = plot
602
+ self.n_jobs = n_jobs
603
+ self.filtration_quantile = filtration_quantile
604
+ return
605
+
606
+ def fit(self, X, y=None):
607
+ if len(X) <= 0:
608
+ return
609
+ assert (
610
+ X[0].num_parameters == 2
611
+ ), f"Number of parameters {X[0].num_parameters} has to be 2."
612
+ if self.box is None:
613
+
614
+ def _bottom(mod):
615
+ return mod.get_bottom()
616
+
617
+ def _top(mod):
618
+ return mod.get_top()
619
+
620
+ m = np.quantile(
621
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
622
+ delayed(_bottom)(mod) for mod in X
623
+ ),
624
+ q=self.filtration_quantile,
625
+ axis=0,
626
+ )
627
+ M = np.quantile(
628
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
629
+ delayed(_top)(mod) for mod in X
630
+ ),
631
+ q=1 - self.filtration_quantile,
632
+ axis=0,
633
+ )
634
+ self.box = [m, M]
635
+ return self
636
+
637
+ def transform(self, X) -> list[np.ndarray]:
638
+ if len(X) <= 0:
639
+ return
640
+
641
+ def todo(mod):
642
+ return np.concatenate(
643
+ [
644
+ self.phi(
645
+ mod.landscapes(
646
+ ks=self.ks,
647
+ resolution=self.resolution,
648
+ degree=degree,
649
+ plot=self.plot,
650
+ ),
651
+ axis=0,
652
+ ).flatten()
653
+ for degree in self.degrees
654
+ ]
655
+ ).flatten()
656
+
657
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
658
+ delayed(todo)(x) for x in X
659
+ )