multipers 2.3.1__tar.gz → 2.3.2b1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- {multipers-2.3.1/multipers.egg-info → multipers-2.3.2b1}/PKG-INFO +4 -25
- {multipers-2.3.1 → multipers-2.3.2b1}/_tempita_grid_gen.py +2 -2
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/_signed_measure_meta.py +71 -65
- multipers-2.3.2b1/multipers/array_api/__init__.py +39 -0
- multipers-2.3.2b1/multipers/array_api/numpy.py +34 -0
- multipers-2.3.2b1/multipers/array_api/torch.py +35 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/distances.py +6 -2
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/filtrations/density.py +23 -12
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/filtrations/filtrations.py +74 -15
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/grids.pyx +144 -61
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/Simplex_tree_multi_interface.h +35 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Multi_persistence/Box.h +3 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/One_critical_filtration.h +17 -9
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/mma_interface_matrix.h +5 -3
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/truc.h +488 -42
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/io.pyx +16 -86
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/mma.py +3 -3
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/signed_measures.py +60 -62
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/mma_structures.pxd +2 -1
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/mma_structures.pyx +56 -12
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/mma_structures.pyx.tp +14 -3
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/approximation.h +45 -13
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation.pyx +22 -6
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/plots.py +1 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/point_measure.pyx +6 -2
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/simplex_tree_multi.pxd +1 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/simplex_tree_multi.pyx +487 -109
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/simplex_tree_multi.pyx.tp +67 -18
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/slicer.pxd +699 -217
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/slicer.pxd.tp +22 -6
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/slicer.pyx +5311 -1364
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/slicer.pyx.tp +199 -46
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/tests/__init__.py +9 -4
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/torch/diff_grids.py +30 -7
- {multipers-2.3.1 → multipers-2.3.2b1/multipers.egg-info}/PKG-INFO +4 -25
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers.egg-info/SOURCES.txt +3 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/pyproject.toml +3 -3
- {multipers-2.3.1 → multipers-2.3.2b1}/setup.py +2 -1
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_diff_helper.py +0 -1
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_filtrations.py +57 -29
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_hilbert_function.py +3 -6
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_mma.py +12 -1
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_point_clouds.py +17 -15
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_signed_betti.py +4 -3
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_signed_measure.py +22 -19
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_simplextreemulti.py +30 -24
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_slicer.py +127 -13
- {multipers-2.3.1 → multipers-2.3.2b1}/LICENSE +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/MANIFEST.in +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/README.md +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/__init__.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/_slicer_meta.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/MOL2.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/UCR.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/__init__.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/graphs.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/immuno_regions.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/minimal_presentation_to_st_bf.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/pytorch2simplextree.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/shape3d.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/data/synthetic.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/filtration_conversions.pxd +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/filtration_conversions.pxd.tp +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/filtrations/__init__.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/filtrations.pxd +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/function_rips.pyx +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/Persistence_slices_interface.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/Simplex_tree_interface.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/cubical_to_boundary.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Bitmap_cubical_complex.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Debug_utils.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Multi_field.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Multi_field_operators.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Multi_field_shared.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Multi_field_small.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Z2_field.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Z2_field_operators.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Zp_field.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Zp_field_operators.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Fields/Zp_field_shared.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Matrix.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Multi_critical_filtration.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Multi_persistence/Line.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Off_reader.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Persistent_cohomology.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Points_off_io.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simple_object_pool.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/Simplex_tree_multi.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/distance_functions.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/graph_simplicial_complex.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/persistence_interval.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/persistence_matrix_options.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/gudhi/reader_utils.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/mma_interface_coh.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/mma_interface_h0.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/naive_merge_tree.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/gudhi/scc_io.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/__init__.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/accuracies.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/invariants_with_persistable.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/kernels.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/one.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/point_clouds.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/sliced_wasserstein.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/ml/tools.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multi_parameter_rank_invariant/diff_helpers.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multi_parameter_rank_invariant/euler_characteristic.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multi_parameter_rank_invariant/function_rips.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multi_parameter_rank_invariant/hilbert_function.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multi_parameter_rank_invariant/persistence_slices.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multi_parameter_rank_invariant/rank_invariant.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_edge_collapse.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/combinatory.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/debug.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/euler_curves.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/format_python-cpp.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/heap_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/images.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/list_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/list_column_2.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/ru_matrix.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/set_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/unordered_set_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/utilities.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/vector_column.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/vector_matrix.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/vineyards.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/multiparameter_module_approximation/vineyards_trajectories.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/pickle.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/tensor/tensor.h +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/tensor.pxd +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/test.pyx +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/torch/__init__.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers/torch/rips_density.py +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers.egg-info/dependency_links.txt +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers.egg-info/requires.txt +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/multipers.egg-info/top_level.txt +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/setup.cfg +0 -0
- {multipers-2.3.1 → multipers-2.3.2b1}/tests/test_python-cpp_conversion.py +2 -2
|
@@ -1,31 +1,10 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: multipers
|
|
3
|
-
Version: 2.3.
|
|
3
|
+
Version: 2.3.2b1
|
|
4
4
|
Summary: Multiparameter Topological Persistence for Machine Learning
|
|
5
5
|
Author-email: David Loiseaux <david.lapous@proton.me>, Hannah Schreiber <hannah.schreiber@inria.fr>
|
|
6
6
|
Maintainer-email: David Loiseaux <david.lapous@proton.me>
|
|
7
|
-
License: MIT
|
|
8
|
-
|
|
9
|
-
Copyright (c) 2023 David Loiseaux
|
|
10
|
-
|
|
11
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
12
|
-
of this software and associated documentation files (the "Software"), to deal
|
|
13
|
-
in the Software without restriction, including without limitation the rights
|
|
14
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
15
|
-
copies of the Software, and to permit persons to whom the Software is
|
|
16
|
-
furnished to do so, subject to the following conditions:
|
|
17
|
-
|
|
18
|
-
The above copyright notice and this permission notice shall be included in all
|
|
19
|
-
copies or substantial portions of the Software.
|
|
20
|
-
|
|
21
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
22
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
23
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
24
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
25
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
26
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
27
|
-
SOFTWARE.
|
|
28
|
-
|
|
7
|
+
License-Expression: MIT
|
|
29
8
|
Project-URL: source, https://github.com/DavidLapous/multipers
|
|
30
9
|
Project-URL: download, https://pypi.org/project/multipers/#files
|
|
31
10
|
Project-URL: tracker, https://github.com/DavidLapous/multipers/issues
|
|
@@ -40,7 +19,6 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
40
19
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
41
20
|
Classifier: Topic :: Scientific/Engineering :: Visualization
|
|
42
21
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
43
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
44
22
|
Requires-Python: >=3.10
|
|
45
23
|
Description-Content-Type: text/markdown
|
|
46
24
|
License-File: LICENSE
|
|
@@ -54,6 +32,7 @@ Requires-Dist: scikit-learn
|
|
|
54
32
|
Requires-Dist: filtration-domination
|
|
55
33
|
Requires-Dist: pykeops
|
|
56
34
|
Requires-Dist: pot
|
|
35
|
+
Dynamic: license-file
|
|
57
36
|
|
|
58
37
|
# multipers : Multiparameter Persistence for Machine Learning
|
|
59
38
|
[](https://doi.org/10.21105/joss.06773) [](https://davidlapous.github.io/multipers) [](https://github.com/DavidLapous/multipers/actions/workflows/python_PR.yml)
|
|
@@ -31,11 +31,10 @@ def signed_measure(
|
|
|
31
31
|
verbose: bool = False,
|
|
32
32
|
n_jobs: int = -1,
|
|
33
33
|
expand_collapse: bool = False,
|
|
34
|
-
backend: Optional[str] = None,
|
|
35
|
-
thread_id: str = "",
|
|
34
|
+
backend: Optional[str] = None, # deprecated
|
|
36
35
|
grid: Optional[Iterable] = None,
|
|
37
36
|
coordinate_measure: bool = False,
|
|
38
|
-
num_collapses: int = 0,
|
|
37
|
+
num_collapses: int = 0, # TODO : deprecate
|
|
39
38
|
clean: Optional[bool] = None,
|
|
40
39
|
vineyard: bool = False,
|
|
41
40
|
grid_conversion: Optional[Iterable] = None,
|
|
@@ -99,7 +98,13 @@ def signed_measure(
|
|
|
99
98
|
It is usually faster to use this backend if not in a parallel context.
|
|
100
99
|
- Rank: Same as Hilbert.
|
|
101
100
|
"""
|
|
101
|
+
if backend is not None:
|
|
102
|
+
raise ValueError("backend is deprecated. reduce the complex before this function.")
|
|
103
|
+
if num_collapses >0:
|
|
104
|
+
raise ValueError("num_collapses is deprecated. reduce the complex before this function.")
|
|
102
105
|
## TODO : add timings in verbose
|
|
106
|
+
if len(filtered_complex) == 0:
|
|
107
|
+
return [(np.empty((0,2), dtype=filtered_complex.dtype), np.empty(shape=(0,), dtype=int))]
|
|
103
108
|
if grid_conversion is not None:
|
|
104
109
|
grid = tuple(f for f in grid_conversion)
|
|
105
110
|
raise DeprecationWarning(
|
|
@@ -133,7 +138,7 @@ def signed_measure(
|
|
|
133
138
|
|
|
134
139
|
assert (
|
|
135
140
|
not plot or filtered_complex.num_parameters == 2
|
|
136
|
-
), "Can only plot 2d measures."
|
|
141
|
+
), f"Can only plot 2d measures. Got {filtered_complex.num_parameters=}."
|
|
137
142
|
|
|
138
143
|
if grid is None:
|
|
139
144
|
if not filtered_complex.is_squeezed:
|
|
@@ -141,7 +146,7 @@ def signed_measure(
|
|
|
141
146
|
filtered_complex, strategy=grid_strategy, **infer_grid_kwargs
|
|
142
147
|
)
|
|
143
148
|
else:
|
|
144
|
-
grid =
|
|
149
|
+
grid = filtered_complex.filtration_grid
|
|
145
150
|
|
|
146
151
|
if mass_default is None:
|
|
147
152
|
mass_default = mass_default
|
|
@@ -186,69 +191,70 @@ def signed_measure(
|
|
|
186
191
|
grid
|
|
187
192
|
), f"Number of parameter do not coincide. Got (grid) {len(grid)} and (filtered complex) {num_parameters}."
|
|
188
193
|
|
|
189
|
-
if is_simplextree_multi(filtered_complex_):
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
194
|
+
# if is_simplextree_multi(filtered_complex_):
|
|
195
|
+
# # if num_collapses != 0:
|
|
196
|
+
# # if verbose:
|
|
197
|
+
# # print("Collapsing edges...", end="")
|
|
198
|
+
# # filtered_complex_.collapse_edges(num_collapses)
|
|
199
|
+
# # if verbose:
|
|
200
|
+
# # print("Done.")
|
|
201
|
+
# # if backend is not None:
|
|
202
|
+
# # filtered_complex_ = mp.Slicer(filtered_complex_, vineyard=vineyard)
|
|
198
203
|
|
|
199
204
|
fix_mass_default = mass_default is not None
|
|
200
205
|
if is_slicer(filtered_complex_):
|
|
201
206
|
if verbose:
|
|
202
207
|
print("Input is a slicer.")
|
|
203
208
|
if backend is not None and not filtered_complex_.is_minpres:
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
209
|
+
raise ValueError("giving a backend to this function is deprecated")
|
|
210
|
+
# from multipers.slicer import minimal_presentation
|
|
211
|
+
#
|
|
212
|
+
# assert (
|
|
213
|
+
# invariant != "euler"
|
|
214
|
+
# ), "Euler Characteristic cannot be speed up by a backend"
|
|
215
|
+
# # This returns a list of reduced complexes
|
|
216
|
+
# if verbose:
|
|
217
|
+
# print("Reducing complex...", end="")
|
|
218
|
+
# reduced_complex = minimal_presentation(
|
|
219
|
+
# filtered_complex_,
|
|
220
|
+
# degrees=degrees,
|
|
221
|
+
# backend=backend,
|
|
222
|
+
# vineyard=vineyard,
|
|
223
|
+
# verbose=verbose,
|
|
224
|
+
# )
|
|
225
|
+
# if verbose:
|
|
226
|
+
# print("Done.")
|
|
227
|
+
# if invariant is not None and "rank" in invariant:
|
|
228
|
+
# if verbose:
|
|
229
|
+
# print("Computing rank...", end="")
|
|
230
|
+
# sms = [
|
|
231
|
+
# _rank_from_slicer(
|
|
232
|
+
# s,
|
|
233
|
+
# degrees=[d],
|
|
234
|
+
# n_jobs=n_jobs,
|
|
235
|
+
# # grid_shape=tuple(len(g) for g in grid),
|
|
236
|
+
# zero_pad=fix_mass_default,
|
|
237
|
+
# ignore_inf=ignore_infinite_filtration_values,
|
|
238
|
+
# )[0]
|
|
239
|
+
# for s, d in zip(reduced_complex, degrees)
|
|
240
|
+
# ]
|
|
241
|
+
# fix_mass_default = False
|
|
242
|
+
# if verbose:
|
|
243
|
+
# print("Done.")
|
|
244
|
+
# else:
|
|
245
|
+
# if verbose:
|
|
246
|
+
# print("Reduced slicer. Retrieving measure from it...", end="")
|
|
247
|
+
# sms = [
|
|
248
|
+
# _signed_measure_from_slicer(
|
|
249
|
+
# s,
|
|
250
|
+
# shift=(
|
|
251
|
+
# reduced_complex.minpres_degree & 1 if d is None else d & 1
|
|
252
|
+
# ),
|
|
253
|
+
# )[0]
|
|
254
|
+
# for s, d in zip(reduced_complex, degrees)
|
|
255
|
+
# ]
|
|
256
|
+
# if verbose:
|
|
257
|
+
# print("Done.")
|
|
252
258
|
else: # No backend
|
|
253
259
|
if invariant is not None and "rank" in invariant:
|
|
254
260
|
degrees = np.asarray(degrees, dtype=int)
|
|
@@ -272,7 +278,7 @@ def signed_measure(
|
|
|
272
278
|
_signed_measure_from_slicer(
|
|
273
279
|
filtered_complex_,
|
|
274
280
|
shift=(
|
|
275
|
-
filtered_complex_.minpres_degree
|
|
281
|
+
filtered_complex_.minpres_degree & 1 if d is None else d & 1
|
|
276
282
|
),
|
|
277
283
|
)[0]
|
|
278
284
|
for d in degrees
|
|
@@ -385,7 +391,7 @@ def signed_measure(
|
|
|
385
391
|
sms,
|
|
386
392
|
grid=grid,
|
|
387
393
|
mass_default=mass_default,
|
|
388
|
-
num_parameters=num_parameters,
|
|
394
|
+
# num_parameters=num_parameters,
|
|
389
395
|
)
|
|
390
396
|
if verbose:
|
|
391
397
|
print("Done.")
|
|
@@ -408,7 +414,7 @@ def _signed_measure_from_scc(
|
|
|
408
414
|
pts = np.concatenate([b[0] for b in minimal_presentation])
|
|
409
415
|
weights = np.concatenate(
|
|
410
416
|
[
|
|
411
|
-
(1 - 2 * (i
|
|
417
|
+
(1 - 2 * (i & 1)) * np.ones(len(b[0]))
|
|
412
418
|
for i, b in enumerate(minimal_presentation)
|
|
413
419
|
]
|
|
414
420
|
)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
def api_from_tensor(x, *, verbose: bool = False):
|
|
2
|
+
import multipers.array_api.numpy as npapi
|
|
3
|
+
|
|
4
|
+
if npapi.is_promotable(x):
|
|
5
|
+
if verbose:
|
|
6
|
+
print("using numpy backend")
|
|
7
|
+
return npapi
|
|
8
|
+
import multipers.array_api.torch as torchapi
|
|
9
|
+
|
|
10
|
+
if torchapi.is_promotable(x):
|
|
11
|
+
if verbose:
|
|
12
|
+
print("using torch backend")
|
|
13
|
+
return torchapi
|
|
14
|
+
raise ValueError(f"Unsupported type {type(x)=}")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def api_from_tensors(*args):
|
|
18
|
+
assert len(args) > 0, "no tensor given"
|
|
19
|
+
import multipers.array_api.numpy as npapi
|
|
20
|
+
|
|
21
|
+
is_numpy = True
|
|
22
|
+
for x in args:
|
|
23
|
+
if not npapi.is_promotable(x):
|
|
24
|
+
is_numpy = False
|
|
25
|
+
break
|
|
26
|
+
if is_numpy:
|
|
27
|
+
return npapi
|
|
28
|
+
|
|
29
|
+
# only torch for now
|
|
30
|
+
import multipers.array_api.torch as torchapi
|
|
31
|
+
|
|
32
|
+
is_torch = True
|
|
33
|
+
for x in args:
|
|
34
|
+
if not torchapi.is_promotable(x):
|
|
35
|
+
is_torch = False
|
|
36
|
+
break
|
|
37
|
+
if is_torch:
|
|
38
|
+
return torchapi
|
|
39
|
+
raise ValueError(f"Incompatible types got {[type(x) for x in args]=}.")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
|
|
3
|
+
import numpy as _np
|
|
4
|
+
from scipy.spatial.distance import cdist
|
|
5
|
+
|
|
6
|
+
backend = _np
|
|
7
|
+
cat = _np.concatenate
|
|
8
|
+
norm = _np.linalg.norm
|
|
9
|
+
astensor = _np.asarray
|
|
10
|
+
asnumpy = _np.asarray
|
|
11
|
+
tensor = _np.array
|
|
12
|
+
stack = _np.stack
|
|
13
|
+
empty = _np.empty
|
|
14
|
+
where = _np.where
|
|
15
|
+
no_grad = nullcontext
|
|
16
|
+
zeros = _np.zeros
|
|
17
|
+
min = _np.min
|
|
18
|
+
max = _np.max
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def minvalues(x: _np.ndarray, **kwargs):
|
|
22
|
+
return _np.min(x, **kwargs)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def maxvalues(x: _np.ndarray, **kwargs):
|
|
26
|
+
return _np.max(x, **kwargs)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_promotable(x):
|
|
30
|
+
return isinstance(x, _np.ndarray | list | tuple)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def has_grad(_):
|
|
34
|
+
return False
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch as _t
|
|
2
|
+
|
|
3
|
+
backend = _t
|
|
4
|
+
cat = _t.cat
|
|
5
|
+
norm = _t.norm
|
|
6
|
+
astensor = _t.as_tensor
|
|
7
|
+
tensor = _t.tensor
|
|
8
|
+
stack = _t.stack
|
|
9
|
+
empty = _t.empty
|
|
10
|
+
where = _t.where
|
|
11
|
+
no_grad = _t.no_grad
|
|
12
|
+
cdist = _t.cdist
|
|
13
|
+
zeros = _t.zeros
|
|
14
|
+
min = _t.min
|
|
15
|
+
max = _t.max
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def minvalues(x: _t.Tensor, **kwargs):
|
|
19
|
+
return _t.min(x, **kwargs).values
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def maxvalues(x: _t.Tensor, **kwargs):
|
|
23
|
+
return _t.max(x, **kwargs).values
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def asnumpy(x):
|
|
27
|
+
return x.detach().numpy()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def is_promotable(x):
|
|
31
|
+
return isinstance(x, _t.Tensor)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def has_grad(x):
|
|
35
|
+
return x.requires_grad
|
|
@@ -6,7 +6,7 @@ from multipers.multiparameter_module_approximation import PyModule_type
|
|
|
6
6
|
from multipers.simplex_tree_multi import SimplexTreeMulti_type
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def sm2diff(sm1, sm2):
|
|
9
|
+
def sm2diff(sm1, sm2, threshold=None):
|
|
10
10
|
pts = sm1[0]
|
|
11
11
|
dtype = pts.dtype
|
|
12
12
|
if isinstance(pts, np.ndarray):
|
|
@@ -45,6 +45,9 @@ def sm2diff(sm1, sm2):
|
|
|
45
45
|
)
|
|
46
46
|
x = backend_concatenate(pts1[pos_indices1], pts2[neg_indices2])
|
|
47
47
|
y = backend_concatenate(pts1[neg_indices1], pts2[pos_indices2])
|
|
48
|
+
if threshold is not None:
|
|
49
|
+
x[x>threshold]=threshold
|
|
50
|
+
y[y>threshold]=threshold
|
|
48
51
|
return x, y
|
|
49
52
|
|
|
50
53
|
|
|
@@ -55,6 +58,7 @@ def sm_distance(
|
|
|
55
58
|
reg_m: float = 0,
|
|
56
59
|
numItermax: int = 10000,
|
|
57
60
|
p: float = 1,
|
|
61
|
+
threshold=None,
|
|
58
62
|
):
|
|
59
63
|
"""
|
|
60
64
|
Computes the wasserstein distances between two signed measures,
|
|
@@ -68,7 +72,7 @@ def sm_distance(
|
|
|
68
72
|
- sinkhorn if reg != 0
|
|
69
73
|
- sinkhorn unbalanced if reg_m != 0
|
|
70
74
|
"""
|
|
71
|
-
x, y = sm2diff(sm1, sm2)
|
|
75
|
+
x, y = sm2diff(sm1, sm2, threshold=threshold)
|
|
72
76
|
loss = ot.dist(
|
|
73
77
|
x, y, metric="sqeuclidean", p=p
|
|
74
78
|
) # only euc + sqeuclidian are implemented in pot for the moment with torch backend # TODO : check later
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from collections.abc import Callable, Iterable
|
|
2
2
|
from typing import Any, Literal, Union
|
|
3
|
-
|
|
4
3
|
import numpy as np
|
|
5
4
|
|
|
5
|
+
|
|
6
|
+
from multipers.array_api import api_from_tensor
|
|
6
7
|
global available_kernels
|
|
7
8
|
available_kernels = Union[
|
|
8
9
|
Literal[
|
|
@@ -41,13 +42,14 @@ def convolution_signed_measures(
|
|
|
41
42
|
from multipers.grids import todense
|
|
42
43
|
|
|
43
44
|
grid_iterator = todense(filtrations, product_order=True)
|
|
45
|
+
api = api_from_tensor(iterable_of_signed_measures[0][0][0])
|
|
44
46
|
match backend:
|
|
45
47
|
case "sklearn":
|
|
46
48
|
|
|
47
49
|
def convolution_signed_measures_on_grid(
|
|
48
|
-
signed_measures
|
|
50
|
+
signed_measures,
|
|
49
51
|
):
|
|
50
|
-
return
|
|
52
|
+
return api.cat(
|
|
51
53
|
[
|
|
52
54
|
_pts_convolution_sparse_old(
|
|
53
55
|
pts=pts,
|
|
@@ -67,7 +69,7 @@ def convolution_signed_measures(
|
|
|
67
69
|
def convolution_signed_measures_on_grid(
|
|
68
70
|
signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
|
|
69
71
|
) -> np.ndarray:
|
|
70
|
-
return
|
|
72
|
+
return api.cat(
|
|
71
73
|
[
|
|
72
74
|
_pts_convolution_pykeops(
|
|
73
75
|
pts=pts,
|
|
@@ -111,7 +113,7 @@ def convolution_signed_measures(
|
|
|
111
113
|
if not flatten:
|
|
112
114
|
out_shape = [-1] + [len(f) for f in filtrations] # Degree
|
|
113
115
|
convolutions = [x.reshape(out_shape) for x in convolutions]
|
|
114
|
-
return
|
|
116
|
+
return api.cat([x[None] for x in convolutions])
|
|
115
117
|
|
|
116
118
|
|
|
117
119
|
# def _test(r=1000, b=0.5, plot=True, kernel=0):
|
|
@@ -173,10 +175,17 @@ def _pts_convolution_pykeops(
|
|
|
173
175
|
"""
|
|
174
176
|
Pykeops convolution
|
|
175
177
|
"""
|
|
178
|
+
if isinstance(pts, np.ndarray):
|
|
179
|
+
_asarray_weights = lambda x : np.asarray(x, dtype=pts.dtype)
|
|
180
|
+
_asarray_grid = _asarray_weights
|
|
181
|
+
else:
|
|
182
|
+
import torch
|
|
183
|
+
_asarray_weights = lambda x : torch.from_numpy(x).type(pts.dtype)
|
|
184
|
+
_asarray_grid = lambda x : x.type(pts.dtype)
|
|
176
185
|
kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
|
|
177
186
|
return kde.fit(
|
|
178
|
-
pts, sample_weights=
|
|
179
|
-
).score_samples(
|
|
187
|
+
pts, sample_weights=_asarray_weights(pts_weights)
|
|
188
|
+
).score_samples(_asarray_grid(grid_iterator))
|
|
180
189
|
|
|
181
190
|
|
|
182
191
|
def gaussian_kernel(x_i, y_j, bandwidth):
|
|
@@ -291,10 +300,10 @@ class KDE:
|
|
|
291
300
|
X.reshape((X.shape[0], 1, X.shape[1]))
|
|
292
301
|
) # numpts, 1, dim
|
|
293
302
|
lazy_y = LazyTensor(
|
|
294
|
-
Y.reshape((1, Y.shape[0], Y.shape[1]))
|
|
303
|
+
Y.reshape((1, Y.shape[0], Y.shape[1])).astype(X.dtype)
|
|
295
304
|
) # 1, numpts, dim
|
|
296
305
|
if x_weights is not None:
|
|
297
|
-
w = LazyTensor(x_weights[:, None], axis=0)
|
|
306
|
+
w = LazyTensor(np.asarray(x_weights, dtype=X.dtype)[:, None], axis=0)
|
|
298
307
|
return lazy_x, lazy_y, w
|
|
299
308
|
return lazy_x, lazy_y, None
|
|
300
309
|
import torch
|
|
@@ -303,9 +312,11 @@ class KDE:
|
|
|
303
312
|
from pykeops.torch import LazyTensor
|
|
304
313
|
|
|
305
314
|
lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
|
|
306
|
-
lazy_y = LazyTensor(Y.view(1, Y.shape[0], Y.shape[1]))
|
|
315
|
+
lazy_y = LazyTensor(Y.type(X.dtype).view(1, Y.shape[0], Y.shape[1]))
|
|
307
316
|
if x_weights is not None:
|
|
308
|
-
|
|
317
|
+
if isinstance(x_weights, np.ndarray):
|
|
318
|
+
x_weights = torch.from_numpy(x_weights)
|
|
319
|
+
w = LazyTensor(x_weights[:, None].type(X.dtype), axis=0)
|
|
309
320
|
return lazy_x, lazy_y, w
|
|
310
321
|
return lazy_x, lazy_y, None
|
|
311
322
|
raise Exception("Bad tensor type.")
|
|
@@ -521,7 +532,7 @@ class KNNmean:
|
|
|
521
532
|
|
|
522
533
|
# Symbolic distance matrix:
|
|
523
534
|
if self.metric == "euclidean":
|
|
524
|
-
D_ij = ((X_i - X_j) ** 2).sum(-1)
|
|
535
|
+
D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1/2)
|
|
525
536
|
elif self.metric == "manhattan":
|
|
526
537
|
D_ij = (X_i - X_j).abs().sum(-1)
|
|
527
538
|
elif self.metric == "angular":
|