multipers 2.4.0b1__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- multipers/.dylibs/libboost_timer.dylib +0 -0
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.17.dylib +0 -0
- multipers/__init__.py +33 -0
- multipers/_signed_measure_meta.py +426 -0
- multipers/_slicer_meta.py +231 -0
- multipers/array_api/__init__.py +62 -0
- multipers/array_api/numpy.py +124 -0
- multipers/array_api/torch.py +133 -0
- multipers/data/MOL2.py +458 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +113 -0
- multipers/distances.py +202 -0
- multipers/filtration_conversions.pxd +736 -0
- multipers/filtration_conversions.pxd.tp +226 -0
- multipers/filtrations/__init__.py +21 -0
- multipers/filtrations/density.py +529 -0
- multipers/filtrations/filtrations.py +480 -0
- multipers/filtrations.pxd +534 -0
- multipers/filtrations.pxd.tp +332 -0
- multipers/function_rips.cpython-312-darwin.so +0 -0
- multipers/function_rips.pyx +104 -0
- multipers/grids.cpython-312-darwin.so +0 -0
- multipers/grids.pyx +538 -0
- multipers/gudhi/Persistence_slices_interface.h +213 -0
- multipers/gudhi/Simplex_tree_interface.h +274 -0
- multipers/gudhi/Simplex_tree_multi_interface.h +648 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
- multipers/gudhi/gudhi/Debug_utils.h +52 -0
- multipers/gudhi/gudhi/Degree_rips_bifiltration.h +2307 -0
- multipers/gudhi/gudhi/Dynamic_multi_parameter_filtration.h +2524 -0
- multipers/gudhi/gudhi/Fields/Multi_field.h +453 -0
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +460 -0
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +444 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +584 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +490 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +580 -0
- multipers/gudhi/gudhi/Fields/Z2_field.h +391 -0
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +389 -0
- multipers/gudhi/gudhi/Fields/Zp_field.h +493 -0
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +384 -0
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +492 -0
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
- multipers/gudhi/gudhi/Matrix.h +2200 -0
- multipers/gudhi/gudhi/Multi_filtration/Multi_parameter_generator.h +1712 -0
- multipers/gudhi/gudhi/Multi_filtration/multi_filtration_conversions.h +237 -0
- multipers/gudhi/gudhi/Multi_filtration/multi_filtration_utils.h +225 -0
- multipers/gudhi/gudhi/Multi_parameter_filtered_complex.h +485 -0
- multipers/gudhi/gudhi/Multi_parameter_filtration.h +2643 -0
- multipers/gudhi/gudhi/Multi_persistence/Box.h +233 -0
- multipers/gudhi/gudhi/Multi_persistence/Line.h +309 -0
- multipers/gudhi/gudhi/Multi_persistence/Multi_parameter_filtered_complex_pcoh_interface.h +268 -0
- multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_cohomology.h +159 -0
- multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_matrix.h +463 -0
- multipers/gudhi/gudhi/Multi_persistence/Point.h +853 -0
- multipers/gudhi/gudhi/Off_reader.h +173 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +834 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +838 -0
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +833 -0
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1367 -0
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1157 -0
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +869 -0
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +905 -0
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +122 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +260 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +288 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +170 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +247 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +571 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +182 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +130 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +235 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +312 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1092 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +923 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +914 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +930 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1071 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +203 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +886 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +984 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1213 -0
- multipers/gudhi/gudhi/Persistence_matrix/index_mapper.h +58 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +227 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +200 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +166 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +319 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +562 -0
- multipers/gudhi/gudhi/Persistence_on_a_line.h +152 -0
- multipers/gudhi/gudhi/Persistence_on_rectangle.h +617 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
- multipers/gudhi/gudhi/Persistent_cohomology.h +769 -0
- multipers/gudhi/gudhi/Points_off_io.h +171 -0
- multipers/gudhi/gudhi/Projective_cover_kernel.h +379 -0
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +559 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +121 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
- multipers/gudhi/gudhi/Simplex_tree/filtration_value_utils.h +155 -0
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +60 -0
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +105 -0
- multipers/gudhi/gudhi/Simplex_tree.h +3170 -0
- multipers/gudhi/gudhi/Slicer.h +848 -0
- multipers/gudhi/gudhi/Thread_safe_slicer.h +393 -0
- multipers/gudhi/gudhi/distance_functions.h +62 -0
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
- multipers/gudhi/gudhi/multi_simplex_tree_helpers.h +147 -0
- multipers/gudhi/gudhi/persistence_interval.h +263 -0
- multipers/gudhi/gudhi/persistence_matrix_options.h +188 -0
- multipers/gudhi/gudhi/reader_utils.h +367 -0
- multipers/gudhi/gudhi/simple_mdspan.h +484 -0
- multipers/gudhi/gudhi/slicer_helpers.h +779 -0
- multipers/gudhi/tmp_h0_pers/mma_interface_h0.h +223 -0
- multipers/gudhi/tmp_h0_pers/naive_merge_tree.h +536 -0
- multipers/io.cpython-312-darwin.so +0 -0
- multipers/io.pyx +472 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +713 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +352 -0
- multipers/ml/signed_measures.py +1667 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-312-darwin.so +0 -0
- multipers/mma_structures.pxd +134 -0
- multipers/mma_structures.pyx +1483 -0
- multipers/mma_structures.pyx.tp +1126 -0
- multipers/multi_parameter_rank_invariant/diff_helpers.h +85 -0
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +95 -0
- multipers/multi_parameter_rank_invariant/function_rips.h +317 -0
- multipers/multi_parameter_rank_invariant/hilbert_function.h +761 -0
- multipers/multi_parameter_rank_invariant/persistence_slices.h +149 -0
- multipers/multi_parameter_rank_invariant/rank_invariant.h +350 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation/approximation.h +2541 -0
- multipers/multiparameter_module_approximation/debug.h +107 -0
- multipers/multiparameter_module_approximation/format_python-cpp.h +292 -0
- multipers/multiparameter_module_approximation/utilities.h +428 -0
- multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +286 -0
- multipers/ops.cpython-312-darwin.so +0 -0
- multipers/ops.pyx +231 -0
- multipers/pickle.py +89 -0
- multipers/plots.py +550 -0
- multipers/point_measure.cpython-312-darwin.so +0 -0
- multipers/point_measure.pyx +409 -0
- multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +136 -0
- multipers/simplex_tree_multi.pyx +11719 -0
- multipers/simplex_tree_multi.pyx.tp +2102 -0
- multipers/slicer.cpython-312-darwin.so +0 -0
- multipers/slicer.pxd +2097 -0
- multipers/slicer.pxd.tp +263 -0
- multipers/slicer.pyx +13042 -0
- multipers/slicer.pyx.tp +1259 -0
- multipers/tensor/tensor.h +672 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +70 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +240 -0
- multipers/torch/rips_density.py +310 -0
- multipers/vector_interface.pxd +46 -0
- multipers-2.4.0b1.dist-info/METADATA +131 -0
- multipers-2.4.0b1.dist-info/RECORD +184 -0
- multipers-2.4.0b1.dist-info/WHEEL +6 -0
- multipers-2.4.0b1.dist-info/licenses/LICENSE +21 -0
- multipers-2.4.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import multipers.array_api.numpy as npapi
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def api_from_tensor(x, *, verbose: bool = False, strict=False):
|
|
5
|
+
if strict:
|
|
6
|
+
if npapi.is_tensor(x):
|
|
7
|
+
return npapi
|
|
8
|
+
import multipers.array_api.torch as torchapi
|
|
9
|
+
|
|
10
|
+
if torchapi.is_tensor(x):
|
|
11
|
+
return torchapi
|
|
12
|
+
raise ValueError(f"Unsupported (strict) type {type(x)=}")
|
|
13
|
+
if npapi.is_promotable(x):
|
|
14
|
+
if verbose:
|
|
15
|
+
print("using numpy backend")
|
|
16
|
+
return npapi
|
|
17
|
+
import multipers.array_api.torch as torchapi
|
|
18
|
+
|
|
19
|
+
if torchapi.is_promotable(x):
|
|
20
|
+
if verbose:
|
|
21
|
+
print("using torch backend")
|
|
22
|
+
return torchapi
|
|
23
|
+
raise ValueError(f"Unsupported type {type(x)=}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def api_from_tensors(*args):
|
|
27
|
+
assert len(args) > 0, "no tensor given"
|
|
28
|
+
import multipers.array_api.numpy as npapi
|
|
29
|
+
|
|
30
|
+
is_numpy = True
|
|
31
|
+
for x in args:
|
|
32
|
+
if not npapi.is_promotable(x):
|
|
33
|
+
is_numpy = False
|
|
34
|
+
break
|
|
35
|
+
if is_numpy:
|
|
36
|
+
return npapi
|
|
37
|
+
|
|
38
|
+
# only torch for now
|
|
39
|
+
import multipers.array_api.torch as torchapi
|
|
40
|
+
|
|
41
|
+
is_torch = True
|
|
42
|
+
for x in args:
|
|
43
|
+
if not torchapi.is_promotable(x):
|
|
44
|
+
is_torch = False
|
|
45
|
+
break
|
|
46
|
+
if is_torch:
|
|
47
|
+
return torchapi
|
|
48
|
+
raise ValueError(f"Incompatible types got {[type(x) for x in args]=}.")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def to_numpy(x):
|
|
52
|
+
api = api_from_tensor(x)
|
|
53
|
+
return api.asnumpy(x)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def check_keops():
|
|
57
|
+
import os
|
|
58
|
+
|
|
59
|
+
if os.name == "nt":
|
|
60
|
+
# see https://github.com/getkeops/keops/pull/421
|
|
61
|
+
return False
|
|
62
|
+
return npapi.check_keops()
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
|
|
3
|
+
import numpy as _np
|
|
4
|
+
from scipy.spatial.distance import cdist
|
|
5
|
+
|
|
6
|
+
backend = _np
|
|
7
|
+
cat = _np.concatenate
|
|
8
|
+
norm = _np.linalg.norm
|
|
9
|
+
astensor = _np.asarray
|
|
10
|
+
asnumpy = _np.asarray
|
|
11
|
+
tensor = _np.array
|
|
12
|
+
stack = _np.stack
|
|
13
|
+
empty = _np.empty
|
|
14
|
+
where = _np.where
|
|
15
|
+
no_grad = nullcontext
|
|
16
|
+
zeros = _np.zeros
|
|
17
|
+
min = _np.min
|
|
18
|
+
max = _np.max
|
|
19
|
+
repeat_interleave = _np.repeat
|
|
20
|
+
cdist = cdist # type: ignore[no-redef]
|
|
21
|
+
unique = _np.unique
|
|
22
|
+
inf = _np.inf
|
|
23
|
+
searchsorted = _np.searchsorted
|
|
24
|
+
LazyTensor = None
|
|
25
|
+
abs = _np.abs
|
|
26
|
+
exp = _np.exp
|
|
27
|
+
sin = _np.sin
|
|
28
|
+
cos = _np.cos
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def clip(x, min=None, max=None):
|
|
32
|
+
return _np.clip(x, a_min=min, a_max=max)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def relu(x):
|
|
36
|
+
return _np.where(x >= 0, x, 0)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def split_with_sizes(arr, sizes):
|
|
40
|
+
indices = _np.cumsum(sizes)[:-1]
|
|
41
|
+
return _np.split(arr, indices)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Test keops
|
|
45
|
+
_is_keops_available = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def check_keops():
|
|
49
|
+
global _is_keops_available, LazyTensor
|
|
50
|
+
if _is_keops_available is not None:
|
|
51
|
+
return _is_keops_available
|
|
52
|
+
try:
|
|
53
|
+
if _is_keops_available is not None:
|
|
54
|
+
return _is_keops_available
|
|
55
|
+
import pykeops.numpy as pknp
|
|
56
|
+
from pykeops.numpy import LazyTensor as LT
|
|
57
|
+
|
|
58
|
+
formula = "SqNorm2(x - y)"
|
|
59
|
+
var = ["x = Vi(3)", "y = Vj(3)"]
|
|
60
|
+
expected_res = _np.array([63.0, 90.0])
|
|
61
|
+
x = _np.arange(1, 10).reshape(-1, 3).astype("float32")
|
|
62
|
+
y = _np.arange(3, 9).reshape(-1, 3).astype("float32")
|
|
63
|
+
|
|
64
|
+
my_conv = pknp.Genred(formula, var)
|
|
65
|
+
_is_keops_available = _np.allclose(my_conv(x, y).flatten(), expected_res)
|
|
66
|
+
LazyTensor = LT
|
|
67
|
+
except:
|
|
68
|
+
from warnings import warn
|
|
69
|
+
|
|
70
|
+
warn("Could not initialize keops (numpy). using workarounds")
|
|
71
|
+
_is_keops_available = False
|
|
72
|
+
|
|
73
|
+
return _is_keops_available
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def from_numpy(x):
|
|
77
|
+
return _np.asarray(x)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def ascontiguous(x):
|
|
81
|
+
return _np.ascontiguousarray(x)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def sort(x, axis=-1):
|
|
85
|
+
return _np.sort(x, axis=axis)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def device(x): # type: ignore[no-unused-arg]
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# type: ignore[no-unused-arg]
|
|
93
|
+
def linspace(low, high, r, device=None, dtype=None):
|
|
94
|
+
return _np.linspace(low, high, r, dtype=dtype)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def cartesian_product(*arrays, dtype=None):
|
|
98
|
+
mesh = _np.meshgrid(*arrays, indexing="ij")
|
|
99
|
+
coordinates = _np.stack(mesh, axis=-1).reshape(-1, len(arrays)).astype(dtype)
|
|
100
|
+
return coordinates
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def quantile_closest(x, q, axis=None):
|
|
104
|
+
return _np.quantile(x, q, axis=axis, method="closest_observation")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def minvalues(x: _np.ndarray, **kwargs):
|
|
108
|
+
return _np.min(x, **kwargs)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def maxvalues(x: _np.ndarray, **kwargs):
|
|
112
|
+
return _np.max(x, **kwargs)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def is_tensor(x):
|
|
116
|
+
return isinstance(x, _np.ndarray)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def is_promotable(x):
|
|
120
|
+
return isinstance(x, _np.ndarray | list | tuple)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def has_grad(_):
|
|
124
|
+
return False
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import numpy as _np
|
|
2
|
+
import torch as _t
|
|
3
|
+
|
|
4
|
+
backend = _t
|
|
5
|
+
cat = _t.cat
|
|
6
|
+
norm = _t.norm
|
|
7
|
+
astensor = _t.as_tensor
|
|
8
|
+
tensor = _t.tensor
|
|
9
|
+
stack = _t.stack
|
|
10
|
+
empty = _t.empty
|
|
11
|
+
where = _t.where
|
|
12
|
+
no_grad = _t.no_grad
|
|
13
|
+
cdist = _t.cdist
|
|
14
|
+
zeros = _t.zeros
|
|
15
|
+
min = _t.min
|
|
16
|
+
max = _t.max
|
|
17
|
+
repeat_interleave = _t.repeat_interleave
|
|
18
|
+
linspace = _t.linspace
|
|
19
|
+
cartesian_product = _t.cartesian_prod
|
|
20
|
+
inf = _t.inf
|
|
21
|
+
searchsorted = _t.searchsorted
|
|
22
|
+
LazyTensor = None
|
|
23
|
+
relu = _t.relu
|
|
24
|
+
abs = _t.abs
|
|
25
|
+
exp = _t.exp
|
|
26
|
+
sin = _t.sin
|
|
27
|
+
cos = _t.cos
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_is_keops_available = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def clip(x, min=None, max=None):
|
|
34
|
+
return _t.clamp(x, min, max)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def split_with_sizes(arr, sizes):
|
|
39
|
+
return arr.split_with_sizes(sizes)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def check_keops():
|
|
43
|
+
global _is_keops_available, LazyTensor
|
|
44
|
+
if _is_keops_available is not None:
|
|
45
|
+
return _is_keops_available
|
|
46
|
+
try:
|
|
47
|
+
import pykeops.torch as pknp
|
|
48
|
+
from pykeops.torch import LazyTensor as LT
|
|
49
|
+
|
|
50
|
+
formula = "SqNorm2(x - y)"
|
|
51
|
+
var = ["x = Vi(3)", "y = Vj(3)"]
|
|
52
|
+
expected_res = _t.tensor([63.0, 90.0])
|
|
53
|
+
x = _t.arange(1, 10, dtype=_t.float32).view(-1, 3)
|
|
54
|
+
y = _t.arange(3, 9, dtype=_t.float32).view(-1, 3)
|
|
55
|
+
|
|
56
|
+
my_conv = pknp.Genred(formula, var)
|
|
57
|
+
_is_keops_available = _t.allclose(
|
|
58
|
+
my_conv(x, y).view(-1), expected_res.type(_t.float32)
|
|
59
|
+
)
|
|
60
|
+
LazyTensor = LT
|
|
61
|
+
|
|
62
|
+
except:
|
|
63
|
+
from warnings import warn
|
|
64
|
+
|
|
65
|
+
warn("Could not initialize keops (torch). using workarounds")
|
|
66
|
+
|
|
67
|
+
_is_keops_available = False
|
|
68
|
+
|
|
69
|
+
return _is_keops_available
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
check_keops()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def from_numpy(x):
|
|
76
|
+
return _t.from_numpy(x)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def ascontiguous(x):
|
|
80
|
+
return _t.as_tensor(x).contiguous()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def device(x):
|
|
84
|
+
return x.device
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def sort(x, axis=-1):
|
|
88
|
+
return _t.sort(x, dim=axis).values
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# in our context, this allows to get a correct gradient.
|
|
92
|
+
def unique(x, assume_sorted=False, _mean=True):
|
|
93
|
+
if not x.requires_grad:
|
|
94
|
+
return x.unique(sorted=assume_sorted)
|
|
95
|
+
if x.ndim != 1:
|
|
96
|
+
raise ValueError(f"Got ndim!=1. {x=}")
|
|
97
|
+
if not assume_sorted:
|
|
98
|
+
x = x.sort().values
|
|
99
|
+
_, c = _t.unique(x, sorted=True, return_counts=True)
|
|
100
|
+
if _mean:
|
|
101
|
+
x = _t.segment_reduce(data=x, reduce="mean", lengths=c, unsafe=True, axis=0)
|
|
102
|
+
else:
|
|
103
|
+
c = _np.concatenate([[0], _np.cumsum(c[:-1])])
|
|
104
|
+
x = x[c]
|
|
105
|
+
return x
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def quantile_closest(x, q, axis=None):
|
|
109
|
+
return _t.quantile(x, q, dim=axis, interpolation="nearest")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def minvalues(x: _t.Tensor, **kwargs):
|
|
113
|
+
return _t.min(x, **kwargs).values
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def maxvalues(x: _t.Tensor, **kwargs):
|
|
117
|
+
return _t.max(x, **kwargs).values
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def asnumpy(x: _t.Tensor):
|
|
121
|
+
return x.cpu().detach().numpy()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def is_tensor(x):
|
|
125
|
+
return isinstance(x, _t.Tensor)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def is_promotable(x):
|
|
129
|
+
return isinstance(x, _t.Tensor)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def has_grad(x):
|
|
133
|
+
return x.requires_grad
|