mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -313,7 +313,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
313
313
|
// Now, do the training.
|
|
314
314
|
if (params.Has("training"))
|
|
315
315
|
{
|
|
316
|
-
|
|
316
|
+
NormalizeLabels(rawLabels, labels, model->mappings);
|
|
317
317
|
numClasses = params.Get<int>("num_classes") == 0 ?
|
|
318
318
|
model->mappings.n_elem : params.Get<int>("num_classes");
|
|
319
319
|
model->svm.Lambda() = lambda;
|
|
@@ -410,7 +410,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
410
410
|
}
|
|
411
411
|
|
|
412
412
|
model->svm.Classify(testSet, predictedLabels);
|
|
413
|
-
|
|
413
|
+
RevertLabels(predictedLabels, model->mappings, predictions);
|
|
414
414
|
|
|
415
415
|
// Calculate accuracy, if desired.
|
|
416
416
|
if (params.Has("test_labels"))
|
|
@@ -419,7 +419,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
419
419
|
arma::Row<size_t> testRawLabels =
|
|
420
420
|
std::move(params.Get<arma::Row<size_t>>("test_labels"));
|
|
421
421
|
|
|
422
|
-
|
|
422
|
+
NormalizeLabels(testRawLabels, testLabels, model->mappings);
|
|
423
423
|
|
|
424
424
|
if (testSet.n_cols != testLabels.n_elem)
|
|
425
425
|
{
|
|
@@ -326,7 +326,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
326
326
|
// Now, normalize the labels.
|
|
327
327
|
arma::Col<size_t> mappings;
|
|
328
328
|
arma::Row<size_t> labels;
|
|
329
|
-
|
|
329
|
+
NormalizeLabels(rawLabels, labels, mappings);
|
|
330
330
|
|
|
331
331
|
arma::mat distance;
|
|
332
332
|
|
|
@@ -183,7 +183,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
183
183
|
if (params.Has("reference"))
|
|
184
184
|
{
|
|
185
185
|
// Workaround: this avoids printing load information twice for the CLI
|
|
186
|
-
// bindings, where GetPrintable() will trigger a call to
|
|
186
|
+
// bindings, where GetPrintable() will trigger a call to Load(),
|
|
187
187
|
// which prints loading information in the middle of the Log::Info
|
|
188
188
|
// message.
|
|
189
189
|
(void) params.Get<arma::mat>("reference");
|
|
@@ -81,7 +81,7 @@ inline void MatrixCompletion::CheckValues()
|
|
|
81
81
|
if (indices(0, i) >= m || indices(1, i) >= n)
|
|
82
82
|
Log::Fatal << "MatrixCompletion::CheckValues(): indices ("
|
|
83
83
|
<< indices(0, i) << ", " << indices(1, i)
|
|
84
|
-
<< ") are out of bounds for matrix of size " << m << " x n!"
|
|
84
|
+
<< ") are out of bounds for matrix of size " << m << " x " << n << "!"
|
|
85
85
|
<< std::endl;
|
|
86
86
|
}
|
|
87
87
|
}
|
|
@@ -364,7 +364,7 @@ void NaiveBayesClassifier<ModelMatType>::Classify(
|
|
|
364
364
|
ModelMatType logLikelihoods;
|
|
365
365
|
LogLikelihood(data, logLikelihoods);
|
|
366
366
|
|
|
367
|
-
predictionProbs.set_size(
|
|
367
|
+
predictionProbs.set_size(size(logLikelihoods));
|
|
368
368
|
double maxValue, logProbX;
|
|
369
369
|
for (size_t j = 0; j < data.n_cols; ++j)
|
|
370
370
|
{
|
|
@@ -152,14 +152,14 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
152
152
|
{
|
|
153
153
|
// Load labels.
|
|
154
154
|
Row<size_t> rawLabels = std::move(params.Get<Row<size_t>>("labels"));
|
|
155
|
-
|
|
155
|
+
NormalizeLabels(rawLabels, labels, model->mappings);
|
|
156
156
|
}
|
|
157
157
|
else
|
|
158
158
|
{
|
|
159
159
|
// Use the last row of the training data as the labels.
|
|
160
160
|
Log::Info << "Using last dimension of training data as training labels."
|
|
161
161
|
<< endl;
|
|
162
|
-
|
|
162
|
+
NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
|
|
163
163
|
model->mappings);
|
|
164
164
|
// Remove the label row.
|
|
165
165
|
trainingData.shed_row(trainingData.n_rows - 1);
|
|
@@ -200,7 +200,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
200
200
|
{
|
|
201
201
|
// Un-normalize labels to prepare output.
|
|
202
202
|
Row<size_t> rawResults;
|
|
203
|
-
|
|
203
|
+
RevertLabels(predictions, model->mappings, rawResults);
|
|
204
204
|
|
|
205
205
|
if (params.Has("predictions"))
|
|
206
206
|
params.Get<Row<size_t>>("predictions") = std::move(rawResults);
|
|
@@ -217,7 +217,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
217
217
|
// Now, normalize the labels.
|
|
218
218
|
arma::Col<size_t> mappings;
|
|
219
219
|
arma::Row<size_t> labels;
|
|
220
|
-
|
|
220
|
+
NormalizeLabels(rawLabels, labels, mappings);
|
|
221
221
|
|
|
222
222
|
arma::mat distance;
|
|
223
223
|
|
|
@@ -179,16 +179,16 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
179
179
|
const string algorithm = params.Get<string>("algorithm");
|
|
180
180
|
RequireParamInSet<string>(params, "algorithm", { "naive", "single_tree",
|
|
181
181
|
"dual_tree", "greedy" }, true, "unknown neighbor search algorithm");
|
|
182
|
-
|
|
182
|
+
NeighborSearchStrategy searchStrategy = DUAL_TREE;
|
|
183
183
|
|
|
184
184
|
if (algorithm == "naive")
|
|
185
|
-
|
|
185
|
+
searchStrategy = NAIVE;
|
|
186
186
|
else if (algorithm == "single_tree")
|
|
187
|
-
|
|
187
|
+
searchStrategy = SINGLE_TREE;
|
|
188
188
|
else if (algorithm == "dual_tree")
|
|
189
|
-
|
|
189
|
+
searchStrategy = DUAL_TREE;
|
|
190
190
|
else if (algorithm == "greedy")
|
|
191
|
-
|
|
191
|
+
searchStrategy = GREEDY_SINGLE_TREE;
|
|
192
192
|
|
|
193
193
|
if (params.Has("reference"))
|
|
194
194
|
{
|
|
@@ -240,7 +240,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
240
240
|
Log::Info << "Using reference data from "
|
|
241
241
|
<< params.GetPrintable<arma::mat>("reference") << "." << endl;
|
|
242
242
|
|
|
243
|
-
kfn->BuildModel(timers, std::move(referenceSet),
|
|
243
|
+
kfn->BuildModel(timers, std::move(referenceSet), searchStrategy, epsilon);
|
|
244
244
|
}
|
|
245
245
|
else
|
|
246
246
|
{
|
|
@@ -248,7 +248,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
248
248
|
kfn = params.Get<KFNModel*>("input_model");
|
|
249
249
|
|
|
250
250
|
// Adjust search mode.
|
|
251
|
-
kfn->
|
|
251
|
+
kfn->SearchStrategy() = searchStrategy;
|
|
252
252
|
kfn->Epsilon() = epsilon;
|
|
253
253
|
|
|
254
254
|
// If leaf_size wasn't provided, let's consider the current value in the
|
|
@@ -272,7 +272,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
272
272
|
if (params.Has("query"))
|
|
273
273
|
{
|
|
274
274
|
// Workaround: this avoids printing load information twice for the CLI
|
|
275
|
-
// bindings, where GetPrintable() will trigger a call to
|
|
275
|
+
// bindings, where GetPrintable() will trigger a call to Load(),
|
|
276
276
|
// which prints loading information in the middle of the Log::Info
|
|
277
277
|
// message.
|
|
278
278
|
(void) params.Get<arma::mat>("query");
|
|
@@ -187,16 +187,16 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
187
187
|
const string algorithm = params.Get<string>("algorithm");
|
|
188
188
|
RequireParamInSet<string>(params, "algorithm", { "naive", "single_tree",
|
|
189
189
|
"dual_tree", "greedy" }, true, "unknown neighbor search algorithm");
|
|
190
|
-
|
|
190
|
+
NeighborSearchStrategy searchStrategy = DUAL_TREE;
|
|
191
191
|
|
|
192
192
|
if (algorithm == "naive")
|
|
193
|
-
|
|
193
|
+
searchStrategy = NAIVE;
|
|
194
194
|
else if (algorithm == "single_tree")
|
|
195
|
-
|
|
195
|
+
searchStrategy = SINGLE_TREE;
|
|
196
196
|
else if (algorithm == "dual_tree")
|
|
197
|
-
|
|
197
|
+
searchStrategy = DUAL_TREE;
|
|
198
198
|
else if (algorithm == "greedy")
|
|
199
|
-
|
|
199
|
+
searchStrategy = GREEDY_SINGLE_TREE;
|
|
200
200
|
|
|
201
201
|
if (params.Has("reference"))
|
|
202
202
|
{
|
|
@@ -253,7 +253,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
253
253
|
Log::Info << "Using reference data from "
|
|
254
254
|
<< params.GetPrintable<arma::mat>("reference") << "." << endl;
|
|
255
255
|
|
|
256
|
-
knn->BuildModel(timers, std::move(referenceSet),
|
|
256
|
+
knn->BuildModel(timers, std::move(referenceSet), searchStrategy, epsilon);
|
|
257
257
|
}
|
|
258
258
|
else
|
|
259
259
|
{
|
|
@@ -261,7 +261,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
261
261
|
knn = params.Get<KNNModel*>("input_model");
|
|
262
262
|
|
|
263
263
|
// Adjust search mode.
|
|
264
|
-
knn->
|
|
264
|
+
knn->SearchStrategy() = searchStrategy;
|
|
265
265
|
knn->Epsilon() = epsilon;
|
|
266
266
|
|
|
267
267
|
// If leaf_size wasn't provided, let's consider the current value in the
|
|
@@ -285,7 +285,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
285
285
|
if (params.Has("query"))
|
|
286
286
|
{
|
|
287
287
|
// Workaround: this avoids printing load information twice for the CLI
|
|
288
|
-
// bindings, where GetPrintable() will trigger a call to
|
|
288
|
+
// bindings, where GetPrintable() will trigger a call to Load(),
|
|
289
289
|
// which prints loading information in the middle of the Log::Info
|
|
290
290
|
// message.
|
|
291
291
|
(void) params.Get<arma::mat>("query");
|
|
@@ -16,8 +16,6 @@
|
|
|
16
16
|
#include <mlpack/core.hpp>
|
|
17
17
|
|
|
18
18
|
#include "neighbor_search_stat.hpp"
|
|
19
|
-
#include "sort_policies/nearest_neighbor_sort.hpp"
|
|
20
|
-
#include "sort_policies/furthest_neighbor_sort.hpp"
|
|
21
19
|
#include "neighbor_search_rules.hpp"
|
|
22
20
|
#include "unmap.hpp"
|
|
23
21
|
|
|
@@ -32,7 +30,17 @@ template<typename SortPolicy,
|
|
|
32
30
|
template<typename RuleType> class SingleTreeTraversalType>
|
|
33
31
|
class LeafSizeNSWrapper;
|
|
34
32
|
|
|
35
|
-
|
|
33
|
+
// NeighborSearchStrategy represents the different neighbor search strategies
|
|
34
|
+
// available.
|
|
35
|
+
enum NeighborSearchStrategy
|
|
36
|
+
{
|
|
37
|
+
NAIVE,
|
|
38
|
+
SINGLE_TREE,
|
|
39
|
+
DUAL_TREE,
|
|
40
|
+
GREEDY_SINGLE_TREE
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
// This is for reverse compatibility and will be removed in mlpack 5.0.0.
|
|
36
44
|
enum NeighborSearchMode
|
|
37
45
|
{
|
|
38
46
|
NAIVE_MODE,
|
|
@@ -41,6 +49,36 @@ enum NeighborSearchMode
|
|
|
41
49
|
GREEDY_SINGLE_TREE_MODE
|
|
42
50
|
};
|
|
43
51
|
|
|
52
|
+
// This is for reverse compatibility and will be removed in mlpack 5.0.0.
|
|
53
|
+
inline NeighborSearchStrategy ModeToStrategy(const NeighborSearchMode& mode)
|
|
54
|
+
{
|
|
55
|
+
switch (mode)
|
|
56
|
+
{
|
|
57
|
+
case NAIVE_MODE: return NAIVE;
|
|
58
|
+
case SINGLE_TREE_MODE: return SINGLE_TREE;
|
|
59
|
+
case DUAL_TREE_MODE: return DUAL_TREE;
|
|
60
|
+
case GREEDY_SINGLE_TREE_MODE: return GREEDY_SINGLE_TREE;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
// Fix warning.
|
|
64
|
+
return DUAL_TREE;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
// This is for reverse compatibility and will be removed in mlpack 5.0.0.
|
|
68
|
+
inline NeighborSearchMode StrategyToMode(const NeighborSearchStrategy& strategy)
|
|
69
|
+
{
|
|
70
|
+
switch (strategy)
|
|
71
|
+
{
|
|
72
|
+
case NAIVE: return NAIVE_MODE;
|
|
73
|
+
case SINGLE_TREE: return SINGLE_TREE_MODE;
|
|
74
|
+
case DUAL_TREE: return DUAL_TREE_MODE;
|
|
75
|
+
case GREEDY_SINGLE_TREE: return GREEDY_SINGLE_TREE_MODE;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
// Fix warning.
|
|
79
|
+
return DUAL_TREE_MODE;
|
|
80
|
+
}
|
|
81
|
+
|
|
44
82
|
/**
|
|
45
83
|
* The NeighborSearch class is a template class for performing distance-based
|
|
46
84
|
* neighbor searches. It takes a query dataset and a reference dataset (or just
|
|
@@ -97,26 +135,34 @@ class NeighborSearch
|
|
|
97
135
|
* pre-constructing the trees, passing std::move(yourReferenceSet).
|
|
98
136
|
*
|
|
99
137
|
* @param referenceSet Set of reference points.
|
|
100
|
-
* @param
|
|
138
|
+
* @param strategy Neighbor search strategy.
|
|
101
139
|
* @param epsilon Relative approximate error (non-negative).
|
|
102
140
|
* @param distance An optional instance of the DistanceType class.
|
|
103
141
|
*/
|
|
104
142
|
NeighborSearch(MatType referenceSet,
|
|
105
|
-
const
|
|
143
|
+
const NeighborSearchStrategy strategy = DUAL_TREE,
|
|
106
144
|
const double epsilon = 0,
|
|
107
145
|
const DistanceType distance = DistanceType());
|
|
108
146
|
|
|
147
|
+
[[deprecated("Will be removed in mlpack 5.0.0. Instead of a "
|
|
148
|
+
"NeighborSearchMode, pass a NeighborSearchStrategy.")]]
|
|
149
|
+
NeighborSearch(MatType referenceSet,
|
|
150
|
+
const NeighborSearchMode mode,
|
|
151
|
+
const double epsilon = 0,
|
|
152
|
+
const DistanceType distance = DistanceType()) :
|
|
153
|
+
NeighborSearch(std::move(referenceSet), ModeToStrategy(mode), epsilon,
|
|
154
|
+
distance) { }
|
|
155
|
+
|
|
109
156
|
/**
|
|
110
157
|
* Initialize the NeighborSearch object with a copy of the given
|
|
111
158
|
* pre-constructed reference tree (this is the tree built on the points that
|
|
112
|
-
* will be searched). Optionally, choose to use
|
|
113
|
-
*
|
|
114
|
-
*
|
|
115
|
-
* metric holds data.
|
|
159
|
+
* will be searched). Optionally, choose to use a different search strategy.
|
|
160
|
+
* Additionally, an instantiated distance metric can be given, for cases where
|
|
161
|
+
* the distance metric holds data.
|
|
116
162
|
*
|
|
117
163
|
* This method will copy the given tree. When copies must absolutely be
|
|
118
164
|
* avoided, you can avoid this copy, while taking ownership of the given tree,
|
|
119
|
-
* by passing std::move(yourReferenceTree)
|
|
165
|
+
* by passing std::move(yourReferenceTree).
|
|
120
166
|
*
|
|
121
167
|
* @note
|
|
122
168
|
* Mapping the points of the matrix back to their original indices is not done
|
|
@@ -127,12 +173,28 @@ class NeighborSearch
|
|
|
127
173
|
* @param referenceTree Pre-built tree for reference points.
|
|
128
174
|
* @param mode Neighbor search mode.
|
|
129
175
|
* @param epsilon Relative approximate error (non-negative).
|
|
130
|
-
* @param distance Instantiated distance metric.
|
|
131
176
|
*/
|
|
132
177
|
NeighborSearch(Tree referenceTree,
|
|
133
|
-
const
|
|
134
|
-
const double epsilon = 0
|
|
135
|
-
|
|
178
|
+
const NeighborSearchStrategy strategy = DUAL_TREE,
|
|
179
|
+
const double epsilon = 0);
|
|
180
|
+
|
|
181
|
+
[[deprecated("Will be removed in mlpack 5.0.0. Instead of a "
|
|
182
|
+
"NeighborSearchMode, pass a NeighborSearchStrategy.")]]
|
|
183
|
+
NeighborSearch(Tree referenceTree,
|
|
184
|
+
const NeighborSearchMode mode,
|
|
185
|
+
const double epsilon = 0) :
|
|
186
|
+
NeighborSearch(std::move(referenceTree), ModeToStrategy(mode), epsilon) {}
|
|
187
|
+
|
|
188
|
+
// This version is kept around for reverse compatibility; but, if you are
|
|
189
|
+
// passing a distance, you should use the overload above, which will just use
|
|
190
|
+
// the distance directly from the given tree.
|
|
191
|
+
[[deprecated("Will be removed in mlpack 5.0.0. Use the version without "
|
|
192
|
+
"`distance` instead (`referenceTree.Distance()` will be used as "
|
|
193
|
+
"the distance metric).")]]
|
|
194
|
+
NeighborSearch(Tree referenceTree,
|
|
195
|
+
const NeighborSearchMode mode,
|
|
196
|
+
const double epsilon,
|
|
197
|
+
const DistanceType distance);
|
|
136
198
|
|
|
137
199
|
/**
|
|
138
200
|
* Create a NeighborSearch object without any reference data. If Search() is
|
|
@@ -143,10 +205,17 @@ class NeighborSearch
|
|
|
143
205
|
* @param epsilon Relative approximate error (non-negative).
|
|
144
206
|
* @param distance Instantiated distance metric.
|
|
145
207
|
*/
|
|
146
|
-
NeighborSearch(const
|
|
208
|
+
NeighborSearch(const NeighborSearchStrategy strategy = DUAL_TREE,
|
|
147
209
|
const double epsilon = 0,
|
|
148
210
|
const DistanceType distance = DistanceType());
|
|
149
211
|
|
|
212
|
+
[[deprecated("Will be removed in mlpack 5.0.0. Instead of a "
|
|
213
|
+
"NeighborSearchMode, pass a NeighborSearchStrategy.")]]
|
|
214
|
+
NeighborSearch(const NeighborSearchMode mode,
|
|
215
|
+
const double epsilon = 0,
|
|
216
|
+
const DistanceType distance = DistanceType()) :
|
|
217
|
+
NeighborSearch(ModeToStrategy(mode), epsilon, distance) { }
|
|
218
|
+
|
|
150
219
|
/**
|
|
151
220
|
* Construct the NeighborSearch object by copying the given NeighborSearch
|
|
152
221
|
* object.
|
|
@@ -213,8 +282,8 @@ class NeighborSearch
|
|
|
213
282
|
*
|
|
214
283
|
* If querySet contains only a few query points, the extra cost of building a
|
|
215
284
|
* tree on the points for dual-tree search may not be warranted, and it may be
|
|
216
|
-
* worthwhile to set
|
|
217
|
-
*
|
|
285
|
+
* worthwhile to set mode to SINGLE_TREE_MODE (either in the constructor or
|
|
286
|
+
* with SearchMode()).
|
|
218
287
|
*
|
|
219
288
|
* @param querySet Set of query points (can be just one point).
|
|
220
289
|
* @param k Number of neighbors to search for.
|
|
@@ -311,17 +380,31 @@ class NeighborSearch
|
|
|
311
380
|
static double Recall(arma::Mat<IndexType>& foundNeighbors,
|
|
312
381
|
arma::Mat<IndexType>& realNeighbors);
|
|
313
382
|
|
|
314
|
-
|
|
315
|
-
|
|
383
|
+
// Reset all bounding quantities in a prebuilt external tree.
|
|
384
|
+
// When calling Search() multiple times with a prebuilt query tree, this must
|
|
385
|
+
// be called between each Search() invocation!
|
|
386
|
+
static void ResetTree(Tree& tree);
|
|
387
|
+
|
|
388
|
+
// Return the total number of base case evaluations performed during the last
|
|
389
|
+
// search.
|
|
316
390
|
size_t BaseCases() const { return baseCases; }
|
|
317
391
|
|
|
318
|
-
|
|
392
|
+
// Return the number of node combination scores during the last search.
|
|
319
393
|
size_t Scores() const { return scores; }
|
|
320
394
|
|
|
321
|
-
|
|
395
|
+
// Access the search mode.
|
|
396
|
+
[[deprecated("Will be removed in mlpack 5.0.0. Use SearchStrategy() "
|
|
397
|
+
"instead.")]]
|
|
322
398
|
NeighborSearchMode SearchMode() const { return searchMode; }
|
|
323
|
-
|
|
324
|
-
|
|
399
|
+
// Modify the search mode.
|
|
400
|
+
[[deprecated("Will be removed in mlpack 5.0.0. Use SearchStrategy() "
|
|
401
|
+
"instead.")]]
|
|
402
|
+
NeighborSearchMode& SearchMode() { searchModeMod = true; return searchMode; }
|
|
403
|
+
|
|
404
|
+
// Access the search strategy.
|
|
405
|
+
NeighborSearchStrategy SearchStrategy() const { return searchStrategy; }
|
|
406
|
+
// Modify the search strategy.
|
|
407
|
+
NeighborSearchStrategy& SearchStrategy() { return searchStrategy; }
|
|
325
408
|
|
|
326
409
|
//! Access the relative error to be considered in approximate search.
|
|
327
410
|
double Epsilon() const { return epsilon; }
|
|
@@ -341,37 +424,74 @@ class NeighborSearch
|
|
|
341
424
|
void serialize(Archive& ar, const uint32_t version);
|
|
342
425
|
|
|
343
426
|
private:
|
|
344
|
-
|
|
427
|
+
// Permutations of reference points during tree building.
|
|
345
428
|
std::vector<size_t> oldFromNewReferences;
|
|
346
|
-
|
|
429
|
+
// Pointer to the root of the reference tree.
|
|
347
430
|
Tree* referenceTree;
|
|
348
|
-
|
|
431
|
+
// Reference dataset. In some situations we may be the owner of this.
|
|
349
432
|
const MatType* referenceSet;
|
|
350
433
|
|
|
351
|
-
|
|
434
|
+
// This is only kept for reverse compatibility and will be removed in mlpack
|
|
435
|
+
// 5.0.0.
|
|
352
436
|
NeighborSearchMode searchMode;
|
|
353
|
-
|
|
437
|
+
bool searchModeMod; // also for reverse compatibility
|
|
438
|
+
// Indicates the neighbor search strategy.
|
|
439
|
+
NeighborSearchStrategy searchStrategy;
|
|
440
|
+
// Indicates the relative error to be considered in approximate search.
|
|
354
441
|
double epsilon;
|
|
355
442
|
|
|
356
|
-
|
|
443
|
+
// Instantiation of distance metric.
|
|
357
444
|
DistanceType distance;
|
|
358
445
|
|
|
359
|
-
|
|
446
|
+
// The total number of base cases.
|
|
360
447
|
size_t baseCases;
|
|
361
|
-
|
|
448
|
+
// The total number of scores (applicable for non-naive search).
|
|
362
449
|
size_t scores;
|
|
363
450
|
|
|
364
|
-
|
|
365
|
-
|
|
451
|
+
// If this is true, the reference tree bounds need to be reset on a call to
|
|
452
|
+
// Search() without a query set.
|
|
366
453
|
bool treeNeedsReset;
|
|
367
454
|
|
|
368
|
-
|
|
455
|
+
// The NSModel class should have access to internal members.
|
|
369
456
|
friend class LeafSizeNSWrapper<SortPolicy, TreeType, DualTreeTraversalType,
|
|
370
457
|
SingleTreeTraversalType>;
|
|
371
458
|
}; // class NeighborSearch
|
|
372
459
|
|
|
373
460
|
} // namespace mlpack
|
|
374
461
|
|
|
462
|
+
// The CEREAL_TEMPLATE_CLASS_VERSION() macro does not work with template
|
|
463
|
+
// template parameters so we write it manually.
|
|
464
|
+
namespace cereal {
|
|
465
|
+
namespace detail {
|
|
466
|
+
|
|
467
|
+
template<typename SortPolicy,
|
|
468
|
+
typename DistanceType,
|
|
469
|
+
typename MatType,
|
|
470
|
+
template<typename TreeDistanceType,
|
|
471
|
+
typename TreeStatType,
|
|
472
|
+
typename TreeMatType> class TreeType,
|
|
473
|
+
template<typename RuleType> class DualTreeTraversalType,
|
|
474
|
+
template<typename RuleType> class SingleTreeTraversalType>
|
|
475
|
+
struct Version<mlpack::NeighborSearch<SortPolicy, DistanceType, MatType,
|
|
476
|
+
TreeType, DualTreeTraversalType, SingleTreeTraversalType>>
|
|
477
|
+
{
|
|
478
|
+
static std::uint32_t registerVersion()
|
|
479
|
+
{
|
|
480
|
+
::cereal::detail::StaticObject<Versions>::getInstance().mapping.emplace(
|
|
481
|
+
std::type_index(typeid(mlpack::NeighborSearch<SortPolicy, DistanceType,
|
|
482
|
+
MatType, TreeType, DualTreeTraversalType,
|
|
483
|
+
SingleTreeTraversalType>)).hash_code(), 1);
|
|
484
|
+
return 1;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
static inline const std::uint32_t version = registerVersion();
|
|
488
|
+
|
|
489
|
+
static void unused() { (void) version; }
|
|
490
|
+
}; /* end Version */
|
|
491
|
+
|
|
492
|
+
} // namespace detail
|
|
493
|
+
} // namespace cereal
|
|
494
|
+
|
|
375
495
|
// Include implementation.
|
|
376
496
|
#include "neighbor_search_impl.hpp"
|
|
377
497
|
|