mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -32,14 +32,16 @@ template<typename SortPolicy,
|
|
|
32
32
|
NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
|
|
33
33
|
DualTreeTraversalType, SingleTreeTraversalType>::
|
|
34
34
|
NeighborSearch(MatType referenceSetIn,
|
|
35
|
-
const
|
|
35
|
+
const NeighborSearchStrategy strategy,
|
|
36
36
|
const double epsilon,
|
|
37
37
|
const DistanceType distance) :
|
|
38
|
-
referenceTree(
|
|
38
|
+
referenceTree(strategy == NAIVE ? NULL :
|
|
39
39
|
BuildTree<Tree>(std::move(referenceSetIn), oldFromNewReferences)),
|
|
40
|
-
referenceSet(
|
|
41
|
-
&referenceTree->Dataset()),
|
|
42
|
-
searchMode(
|
|
40
|
+
referenceSet(strategy == NAIVE ?
|
|
41
|
+
new MatType(std::move(referenceSetIn)) : &referenceTree->Dataset()),
|
|
42
|
+
searchMode(StrategyToMode(strategy)),
|
|
43
|
+
searchModeMod(false),
|
|
44
|
+
searchStrategy(strategy),
|
|
43
45
|
epsilon(epsilon),
|
|
44
46
|
distance(distance),
|
|
45
47
|
baseCases(0),
|
|
@@ -51,6 +53,34 @@ NeighborSearch(MatType referenceSetIn,
|
|
|
51
53
|
}
|
|
52
54
|
|
|
53
55
|
// Construct the object.
|
|
56
|
+
template<typename SortPolicy,
|
|
57
|
+
typename DistanceType,
|
|
58
|
+
typename MatType,
|
|
59
|
+
template<typename TreeDistanceType,
|
|
60
|
+
typename TreeStatType,
|
|
61
|
+
typename TreeMatType> class TreeType,
|
|
62
|
+
template<typename> class DualTreeTraversalType,
|
|
63
|
+
template<typename> class SingleTreeTraversalType>
|
|
64
|
+
NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
|
|
65
|
+
DualTreeTraversalType, SingleTreeTraversalType>::
|
|
66
|
+
NeighborSearch(Tree referenceTree,
|
|
67
|
+
const NeighborSearchStrategy strategy,
|
|
68
|
+
const double epsilon) :
|
|
69
|
+
referenceTree(new Tree(std::move(referenceTree))),
|
|
70
|
+
referenceSet(&this->referenceTree->Dataset()),
|
|
71
|
+
searchMode(StrategyToMode(strategy)),
|
|
72
|
+
searchModeMod(false),
|
|
73
|
+
searchStrategy(strategy),
|
|
74
|
+
epsilon(epsilon),
|
|
75
|
+
distance(this->referenceTree->Distance()),
|
|
76
|
+
baseCases(0),
|
|
77
|
+
scores(0),
|
|
78
|
+
treeNeedsReset(false)
|
|
79
|
+
{
|
|
80
|
+
if (epsilon < 0)
|
|
81
|
+
throw std::invalid_argument("epsilon must be non-negative");
|
|
82
|
+
}
|
|
83
|
+
|
|
54
84
|
template<typename SortPolicy,
|
|
55
85
|
typename DistanceType,
|
|
56
86
|
typename MatType,
|
|
@@ -68,6 +98,8 @@ NeighborSearch(Tree referenceTree,
|
|
|
68
98
|
referenceTree(new Tree(std::move(referenceTree))),
|
|
69
99
|
referenceSet(&this->referenceTree->Dataset()),
|
|
70
100
|
searchMode(mode),
|
|
101
|
+
searchModeMod(false),
|
|
102
|
+
searchStrategy(ModeToStrategy(mode)),
|
|
71
103
|
epsilon(epsilon),
|
|
72
104
|
distance(distance),
|
|
73
105
|
baseCases(0),
|
|
@@ -89,12 +121,14 @@ template<typename SortPolicy,
|
|
|
89
121
|
template<typename> class SingleTreeTraversalType>
|
|
90
122
|
NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
|
|
91
123
|
DualTreeTraversalType, SingleTreeTraversalType>::
|
|
92
|
-
NeighborSearch(const
|
|
124
|
+
NeighborSearch(const NeighborSearchStrategy strategy,
|
|
93
125
|
const double epsilon,
|
|
94
126
|
const DistanceType distance) :
|
|
95
127
|
referenceTree(NULL),
|
|
96
|
-
referenceSet(
|
|
97
|
-
searchMode(
|
|
128
|
+
referenceSet(strategy == NAIVE ? new MatType() : NULL),
|
|
129
|
+
searchMode(StrategyToMode(strategy)),
|
|
130
|
+
searchModeMod(false),
|
|
131
|
+
searchStrategy(strategy),
|
|
98
132
|
epsilon(epsilon),
|
|
99
133
|
distance(distance),
|
|
100
134
|
baseCases(0),
|
|
@@ -105,7 +139,7 @@ NeighborSearch(const NeighborSearchMode mode,
|
|
|
105
139
|
throw std::invalid_argument("epsilon must be non-negative");
|
|
106
140
|
|
|
107
141
|
// Build the tree on the empty dataset, if necessary.
|
|
108
|
-
if (
|
|
142
|
+
if (strategy != NAIVE)
|
|
109
143
|
{
|
|
110
144
|
referenceTree = BuildTree<Tree>(std::move(MatType()),
|
|
111
145
|
oldFromNewReferences);
|
|
@@ -130,6 +164,8 @@ NeighborSearch(const NeighborSearch& other) :
|
|
|
130
164
|
referenceSet(other.referenceTree ? &referenceTree->Dataset() :
|
|
131
165
|
new MatType(*other.referenceSet)),
|
|
132
166
|
searchMode(other.searchMode),
|
|
167
|
+
searchModeMod(other.searchModeMod),
|
|
168
|
+
searchStrategy(other.searchStrategy),
|
|
133
169
|
epsilon(other.epsilon),
|
|
134
170
|
distance(other.distance),
|
|
135
171
|
baseCases(other.baseCases),
|
|
@@ -155,6 +191,8 @@ NeighborSearch(NeighborSearch&& other) :
|
|
|
155
191
|
referenceTree(other.referenceTree),
|
|
156
192
|
referenceSet(other.referenceSet),
|
|
157
193
|
searchMode(other.searchMode),
|
|
194
|
+
searchModeMod(other.searchModeMod),
|
|
195
|
+
searchStrategy(other.searchStrategy),
|
|
158
196
|
epsilon(other.epsilon),
|
|
159
197
|
distance(std::move(other.distance)),
|
|
160
198
|
baseCases(other.baseCases),
|
|
@@ -165,7 +203,9 @@ NeighborSearch(NeighborSearch&& other) :
|
|
|
165
203
|
other.referenceTree = BuildTree<Tree>(std::move(MatType()),
|
|
166
204
|
other.oldFromNewReferences);
|
|
167
205
|
other.referenceSet = &other.referenceTree->Dataset();
|
|
168
|
-
other.searchMode = DUAL_TREE_MODE
|
|
206
|
+
other.searchMode = DUAL_TREE_MODE;
|
|
207
|
+
other.searchModeMod = false;
|
|
208
|
+
other.searchStrategy = DUAL_TREE;
|
|
169
209
|
other.epsilon = 0.0;
|
|
170
210
|
other.baseCases = 0;
|
|
171
211
|
other.scores = 0;
|
|
@@ -208,6 +248,8 @@ NeighborSearch<SortPolicy,
|
|
|
208
248
|
referenceSet = other.referenceTree ? &referenceTree->Dataset() :
|
|
209
249
|
new MatType(*other.referenceSet);
|
|
210
250
|
searchMode = other.searchMode;
|
|
251
|
+
searchModeMod = other.searchModeMod;
|
|
252
|
+
searchStrategy = other.searchStrategy;
|
|
211
253
|
epsilon = other.epsilon;
|
|
212
254
|
distance = other.distance;
|
|
213
255
|
baseCases = other.baseCases;
|
|
@@ -250,6 +292,8 @@ NeighborSearch<SortPolicy,
|
|
|
250
292
|
referenceTree = other.referenceTree;
|
|
251
293
|
referenceSet = other.referenceSet;
|
|
252
294
|
searchMode = other.searchMode;
|
|
295
|
+
searchModeMod = other.searchModeMod;
|
|
296
|
+
searchStrategy = other.searchStrategy;
|
|
253
297
|
epsilon = other.epsilon;
|
|
254
298
|
distance = other.distance;
|
|
255
299
|
baseCases = other.baseCases;
|
|
@@ -263,7 +307,9 @@ NeighborSearch<SortPolicy,
|
|
|
263
307
|
other.referenceTree = BuildTree<Tree>(std::move(MatType()),
|
|
264
308
|
other.oldFromNewReferences);
|
|
265
309
|
other.referenceSet = &other.referenceTree->Dataset();
|
|
266
|
-
other.searchMode = DUAL_TREE_MODE
|
|
310
|
+
other.searchMode = DUAL_TREE_MODE;
|
|
311
|
+
other.searchModeMod = false;
|
|
312
|
+
other.searchStrategy = DUAL_TREE;
|
|
267
313
|
other.epsilon = 0.0;
|
|
268
314
|
other.baseCases = 0;
|
|
269
315
|
other.scores = 0;
|
|
@@ -300,6 +346,17 @@ void NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
|
|
|
300
346
|
DualTreeTraversalType, SingleTreeTraversalType>::
|
|
301
347
|
Train(MatType referenceSetIn)
|
|
302
348
|
{
|
|
349
|
+
// For reverse compatibility; can be removed in mlpack 5.0.0.
|
|
350
|
+
if (searchModeMod)
|
|
351
|
+
{
|
|
352
|
+
searchStrategy = ModeToStrategy(searchMode);
|
|
353
|
+
searchModeMod = false;
|
|
354
|
+
}
|
|
355
|
+
else
|
|
356
|
+
{
|
|
357
|
+
searchMode = StrategyToMode(searchStrategy);
|
|
358
|
+
}
|
|
359
|
+
|
|
303
360
|
// Clean up the old tree, if we built one.
|
|
304
361
|
if (referenceTree)
|
|
305
362
|
{
|
|
@@ -313,7 +370,7 @@ Train(MatType referenceSetIn)
|
|
|
313
370
|
}
|
|
314
371
|
|
|
315
372
|
// We may need to rebuild the tree.
|
|
316
|
-
if (
|
|
373
|
+
if (searchStrategy != NAIVE)
|
|
317
374
|
{
|
|
318
375
|
referenceTree = BuildTree<Tree>(std::move(referenceSetIn),
|
|
319
376
|
oldFromNewReferences);
|
|
@@ -336,7 +393,18 @@ template<typename SortPolicy,
|
|
|
336
393
|
void NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
|
|
337
394
|
DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree referenceTree)
|
|
338
395
|
{
|
|
339
|
-
|
|
396
|
+
// For reverse compatibility; can be removed in mlpack 5.0.0.
|
|
397
|
+
if (searchModeMod)
|
|
398
|
+
{
|
|
399
|
+
searchStrategy = ModeToStrategy(searchMode);
|
|
400
|
+
searchModeMod = false;
|
|
401
|
+
}
|
|
402
|
+
else
|
|
403
|
+
{
|
|
404
|
+
searchMode = StrategyToMode(searchStrategy);
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
if (searchStrategy == NAIVE)
|
|
340
408
|
throw std::invalid_argument("cannot train on given reference tree when "
|
|
341
409
|
"naive search (without trees) is desired");
|
|
342
410
|
|
|
@@ -374,6 +442,17 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
374
442
|
arma::Mat<IndexType>& neighbors,
|
|
375
443
|
arma::Mat<ElemType>& distances)
|
|
376
444
|
{
|
|
445
|
+
// For reverse compatibility; can be removed in mlpack 5.0.0.
|
|
446
|
+
if (searchModeMod)
|
|
447
|
+
{
|
|
448
|
+
searchStrategy = ModeToStrategy(searchMode);
|
|
449
|
+
searchModeMod = false;
|
|
450
|
+
}
|
|
451
|
+
else
|
|
452
|
+
{
|
|
453
|
+
searchMode = StrategyToMode(searchStrategy);
|
|
454
|
+
}
|
|
455
|
+
|
|
377
456
|
if (k > referenceSet->n_cols)
|
|
378
457
|
{
|
|
379
458
|
std::stringstream ss;
|
|
@@ -398,7 +477,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
398
477
|
// Mapping is only necessary if the tree rearranges points.
|
|
399
478
|
if (TreeTraits<Tree>::RearrangesDataset)
|
|
400
479
|
{
|
|
401
|
-
if (
|
|
480
|
+
if (searchStrategy == DUAL_TREE)
|
|
402
481
|
{
|
|
403
482
|
distancePtr = new arma::Mat<ElemType>; // Query indices need to be mapped.
|
|
404
483
|
neighborPtr = new arma::Mat<IndexType>;
|
|
@@ -413,9 +492,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
413
492
|
|
|
414
493
|
using RuleType = NeighborSearchRules<SortPolicy, DistanceType, Tree>;
|
|
415
494
|
|
|
416
|
-
switch (
|
|
495
|
+
switch (searchStrategy)
|
|
417
496
|
{
|
|
418
|
-
case
|
|
497
|
+
case NAIVE:
|
|
419
498
|
{
|
|
420
499
|
// Create the helper object for the tree traversal.
|
|
421
500
|
RuleType rules(*referenceSet, querySet, k, distance, epsilon);
|
|
@@ -430,7 +509,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
430
509
|
rules.GetResults(*neighborPtr, *distancePtr);
|
|
431
510
|
break;
|
|
432
511
|
}
|
|
433
|
-
case
|
|
512
|
+
case SINGLE_TREE:
|
|
434
513
|
{
|
|
435
514
|
// Create the helper object for the tree traversal.
|
|
436
515
|
RuleType rules(*referenceSet, querySet, k, distance, epsilon);
|
|
@@ -453,7 +532,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
453
532
|
rules.GetResults(*neighborPtr, *distancePtr);
|
|
454
533
|
break;
|
|
455
534
|
}
|
|
456
|
-
case
|
|
535
|
+
case DUAL_TREE:
|
|
457
536
|
{
|
|
458
537
|
// Build the query tree.
|
|
459
538
|
Tree* queryTree = BuildTree<Tree>(querySet, oldFromNewQueries);
|
|
@@ -479,7 +558,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
479
558
|
delete queryTree;
|
|
480
559
|
break;
|
|
481
560
|
}
|
|
482
|
-
case
|
|
561
|
+
case GREEDY_SINGLE_TREE:
|
|
483
562
|
{
|
|
484
563
|
// Create the helper object for the tree traversal.
|
|
485
564
|
RuleType rules(*referenceSet, querySet, k, distance);
|
|
@@ -507,7 +586,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
507
586
|
// Map points back to original indices, if necessary.
|
|
508
587
|
if (TreeTraits<Tree>::RearrangesDataset)
|
|
509
588
|
{
|
|
510
|
-
if (
|
|
589
|
+
if (searchStrategy == DUAL_TREE && !oldFromNewReferences.empty())
|
|
511
590
|
{
|
|
512
591
|
// We must map both query and reference indices.
|
|
513
592
|
neighbors.set_size(k, querySet.n_cols);
|
|
@@ -530,7 +609,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
530
609
|
delete neighborPtr;
|
|
531
610
|
delete distancePtr;
|
|
532
611
|
}
|
|
533
|
-
else if (
|
|
612
|
+
else if (searchStrategy == DUAL_TREE)
|
|
534
613
|
{
|
|
535
614
|
// We must map query indices only.
|
|
536
615
|
neighbors.set_size(k, querySet.n_cols);
|
|
@@ -581,6 +660,17 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
581
660
|
arma::Mat<ElemType>& distances,
|
|
582
661
|
bool sameSet)
|
|
583
662
|
{
|
|
663
|
+
// For reverse compatibility; can be removed in mlpack 5.0.0.
|
|
664
|
+
if (searchModeMod)
|
|
665
|
+
{
|
|
666
|
+
searchStrategy = ModeToStrategy(searchMode);
|
|
667
|
+
searchModeMod = false;
|
|
668
|
+
}
|
|
669
|
+
else
|
|
670
|
+
{
|
|
671
|
+
searchMode = StrategyToMode(searchStrategy);
|
|
672
|
+
}
|
|
673
|
+
|
|
584
674
|
if (k > referenceSet->n_cols)
|
|
585
675
|
{
|
|
586
676
|
std::stringstream ss;
|
|
@@ -590,9 +680,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
590
680
|
}
|
|
591
681
|
|
|
592
682
|
// Make sure we are in dual-tree mode.
|
|
593
|
-
if (
|
|
594
|
-
throw std::invalid_argument("
|
|
595
|
-
"query tree when
|
|
683
|
+
if (searchStrategy != DUAL_TREE)
|
|
684
|
+
throw std::invalid_argument("Cannot call NeighborSearch::Search() with a "
|
|
685
|
+
"query tree when search strategy is not DUAL_TREE!");
|
|
596
686
|
|
|
597
687
|
baseCases = 0;
|
|
598
688
|
scores = 0;
|
|
@@ -659,6 +749,17 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
659
749
|
arma::Mat<IndexType>& neighbors,
|
|
660
750
|
arma::Mat<ElemType>& distances)
|
|
661
751
|
{
|
|
752
|
+
// For reverse compatibility; can be removed in mlpack 5.0.0.
|
|
753
|
+
if (searchModeMod)
|
|
754
|
+
{
|
|
755
|
+
searchStrategy = ModeToStrategy(searchMode);
|
|
756
|
+
searchModeMod = false;
|
|
757
|
+
}
|
|
758
|
+
else
|
|
759
|
+
{
|
|
760
|
+
searchMode = StrategyToMode(searchStrategy);
|
|
761
|
+
}
|
|
762
|
+
|
|
662
763
|
if (k > referenceSet->n_cols)
|
|
663
764
|
{
|
|
664
765
|
std::stringstream ss;
|
|
@@ -697,9 +798,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
697
798
|
RuleType rules(*referenceSet, *referenceSet, k, distance, epsilon,
|
|
698
799
|
true /* don't return the same point as nearest neighbor */);
|
|
699
800
|
|
|
700
|
-
switch (
|
|
801
|
+
switch (searchStrategy)
|
|
701
802
|
{
|
|
702
|
-
case
|
|
803
|
+
case NAIVE:
|
|
703
804
|
{
|
|
704
805
|
// The naive brute-force solution.
|
|
705
806
|
for (size_t i = 0; i < referenceSet->n_cols; ++i)
|
|
@@ -709,7 +810,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
709
810
|
baseCases += referenceSet->n_cols * referenceSet->n_cols;
|
|
710
811
|
break;
|
|
711
812
|
}
|
|
712
|
-
case
|
|
813
|
+
case SINGLE_TREE:
|
|
713
814
|
{
|
|
714
815
|
// Create the traverser.
|
|
715
816
|
SingleTreeTraversalType<RuleType> traverser(rules);
|
|
@@ -727,27 +828,12 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
727
828
|
<< std::endl;
|
|
728
829
|
break;
|
|
729
830
|
}
|
|
730
|
-
case
|
|
831
|
+
case DUAL_TREE:
|
|
731
832
|
{
|
|
732
833
|
// The dual-tree monochromatic search case may require resetting the
|
|
733
834
|
// bounds in the tree.
|
|
734
835
|
if (treeNeedsReset)
|
|
735
|
-
|
|
736
|
-
std::stack<Tree*> nodes;
|
|
737
|
-
nodes.push(referenceTree);
|
|
738
|
-
while (!nodes.empty())
|
|
739
|
-
{
|
|
740
|
-
Tree* node = nodes.top();
|
|
741
|
-
nodes.pop();
|
|
742
|
-
|
|
743
|
-
// Reset bounds of this node.
|
|
744
|
-
node->Stat().Reset();
|
|
745
|
-
|
|
746
|
-
// Then add the children.
|
|
747
|
-
for (size_t i = 0; i < node->NumChildren(); ++i)
|
|
748
|
-
nodes.push(&node->Child(i));
|
|
749
|
-
}
|
|
750
|
-
}
|
|
836
|
+
ResetTree(*referenceTree);
|
|
751
837
|
|
|
752
838
|
// Create the traverser.
|
|
753
839
|
DualTreeTraversalType<RuleType> traverser(rules);
|
|
@@ -762,8 +848,6 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
762
848
|
else
|
|
763
849
|
{
|
|
764
850
|
traverser.Traverse(*referenceTree, *referenceTree);
|
|
765
|
-
// Next time we perform this search, we'll need to reset the tree.
|
|
766
|
-
treeNeedsReset = true;
|
|
767
851
|
}
|
|
768
852
|
|
|
769
853
|
scores += rules.Scores();
|
|
@@ -778,7 +862,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
|
|
|
778
862
|
treeNeedsReset = true;
|
|
779
863
|
break;
|
|
780
864
|
}
|
|
781
|
-
case
|
|
865
|
+
case GREEDY_SINGLE_TREE:
|
|
782
866
|
{
|
|
783
867
|
// Create the traverser.
|
|
784
868
|
GreedySingleTreeTraverser<Tree, RuleType> traverser(rules);
|
|
@@ -855,7 +939,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::EffectiveError(
|
|
|
855
939
|
}
|
|
856
940
|
}
|
|
857
941
|
|
|
858
|
-
if (numCases)
|
|
942
|
+
if (numCases > 0)
|
|
859
943
|
effectiveError /= numCases;
|
|
860
944
|
|
|
861
945
|
return effectiveError;
|
|
@@ -893,6 +977,35 @@ DualTreeTraversalType, SingleTreeTraversalType>::Recall(
|
|
|
893
977
|
return ((double) found) / realNeighbors.n_elem;
|
|
894
978
|
}
|
|
895
979
|
|
|
980
|
+
template<typename SortPolicy,
|
|
981
|
+
typename DistanceType,
|
|
982
|
+
typename MatType,
|
|
983
|
+
template<typename TreeDistanceType,
|
|
984
|
+
typename TreeStatType,
|
|
985
|
+
typename TreeMatType> class TreeType,
|
|
986
|
+
template<typename> class DualTreeTraversalType,
|
|
987
|
+
template<typename> class SingleTreeTraversalType>
|
|
988
|
+
void NeighborSearch<
|
|
989
|
+
SortPolicy, DistanceType, MatType, TreeType,
|
|
990
|
+
DualTreeTraversalType, SingleTreeTraversalType
|
|
991
|
+
>::ResetTree(Tree& tree)
|
|
992
|
+
{
|
|
993
|
+
std::stack<Tree*> nodes;
|
|
994
|
+
nodes.push(&tree);
|
|
995
|
+
while (!nodes.empty())
|
|
996
|
+
{
|
|
997
|
+
Tree* node = nodes.top();
|
|
998
|
+
nodes.pop();
|
|
999
|
+
|
|
1000
|
+
// Reset bounds of this node.
|
|
1001
|
+
node->Stat().Reset();
|
|
1002
|
+
|
|
1003
|
+
// Then add the children.
|
|
1004
|
+
for (size_t i = 0; i < node->NumChildren(); ++i)
|
|
1005
|
+
nodes.push(&node->Child(i));
|
|
1006
|
+
}
|
|
1007
|
+
}
|
|
1008
|
+
|
|
896
1009
|
//! Serialize the NeighborSearch model.
|
|
897
1010
|
template<typename SortPolicy,
|
|
898
1011
|
typename DistanceType,
|
|
@@ -905,15 +1018,41 @@ template<typename SortPolicy,
|
|
|
905
1018
|
template<typename Archive>
|
|
906
1019
|
void NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
|
|
907
1020
|
DualTreeTraversalType, SingleTreeTraversalType>::serialize(
|
|
908
|
-
Archive& ar, const uint32_t
|
|
1021
|
+
Archive& ar, const uint32_t version)
|
|
909
1022
|
{
|
|
1023
|
+
if (cereal::is_saving<Archive>())
|
|
1024
|
+
{
|
|
1025
|
+
// For reverse compatibility; can be removed in mlpack 5.0.0.
|
|
1026
|
+
if (searchModeMod)
|
|
1027
|
+
{
|
|
1028
|
+
searchStrategy = ModeToStrategy(searchMode);
|
|
1029
|
+
searchModeMod = false;
|
|
1030
|
+
}
|
|
1031
|
+
else
|
|
1032
|
+
{
|
|
1033
|
+
searchMode = StrategyToMode(searchStrategy);
|
|
1034
|
+
}
|
|
1035
|
+
}
|
|
1036
|
+
|
|
910
1037
|
// Serialize preferences for search.
|
|
911
|
-
|
|
1038
|
+
if (version == 0)
|
|
1039
|
+
{
|
|
1040
|
+
ar(CEREAL_NVP(searchMode));
|
|
1041
|
+
searchModeMod = false;
|
|
1042
|
+
searchStrategy = ModeToStrategy(searchMode);
|
|
1043
|
+
}
|
|
1044
|
+
else
|
|
1045
|
+
{
|
|
1046
|
+
ar(CEREAL_NVP(searchStrategy));
|
|
1047
|
+
searchModeMod = false;
|
|
1048
|
+
searchMode = StrategyToMode(searchStrategy);
|
|
1049
|
+
}
|
|
1050
|
+
|
|
912
1051
|
ar(CEREAL_NVP(treeNeedsReset));
|
|
913
1052
|
|
|
914
1053
|
// If we are doing naive search, we serialize the dataset. Otherwise we
|
|
915
1054
|
// serialize the tree.
|
|
916
|
-
if (
|
|
1055
|
+
if (searchStrategy == NAIVE)
|
|
917
1056
|
{
|
|
918
1057
|
// Delete the current reference set, if necessary and if we are loading.
|
|
919
1058
|
if (cereal::is_loading<Archive>() && referenceSet)
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_STAT_HPP
|
|
15
15
|
|
|
16
16
|
#include <mlpack/prereqs.hpp>
|
|
17
|
+
#include "sort_policies/nearest_neighbor_sort.hpp"
|
|
18
|
+
#include "sort_policies/furthest_neighbor_sort.hpp"
|
|
17
19
|
|
|
18
20
|
namespace mlpack {
|
|
19
21
|
|
|
@@ -100,6 +102,14 @@ class NeighborSearchStat
|
|
|
100
102
|
}
|
|
101
103
|
};
|
|
102
104
|
|
|
105
|
+
// This is the type that must be used as the StatisticType for
|
|
106
|
+
// k-nearest-neighbor search (e.g. the KNN or KNNType<> class).
|
|
107
|
+
using NearestNeighborStat = NeighborSearchStat<NearestNeighborSort>;
|
|
108
|
+
|
|
109
|
+
// This is the type that must be used as the StatisticType for
|
|
110
|
+
// k-furthest-neighbor search (e.g. the KFN or KFNType<> class).
|
|
111
|
+
using FurthestNeighborStat = NeighborSearchStat<FurthestNeighborSort>;
|
|
112
|
+
|
|
103
113
|
} // namespace mlpack
|
|
104
114
|
|
|
105
115
|
#endif
|
|
@@ -49,9 +49,9 @@ class NSWrapperBase
|
|
|
49
49
|
virtual const arma::mat& Dataset() const = 0;
|
|
50
50
|
|
|
51
51
|
//! Get the search mode.
|
|
52
|
-
virtual
|
|
52
|
+
virtual NeighborSearchStrategy SearchStrategy() const = 0;
|
|
53
53
|
//! Modify the search modem
|
|
54
|
-
virtual
|
|
54
|
+
virtual NeighborSearchStrategy& SearchStrategy() = 0;
|
|
55
55
|
|
|
56
56
|
//! Get the approximation parameter epsilon.
|
|
57
57
|
virtual double Epsilon() const = 0;
|
|
@@ -103,9 +103,9 @@ class NSWrapper : public NSWrapperBase
|
|
|
103
103
|
public:
|
|
104
104
|
//! Construct the NSWrapper object, initializing the internally-held
|
|
105
105
|
//! NeighborSearch object.
|
|
106
|
-
NSWrapper(const
|
|
106
|
+
NSWrapper(const NeighborSearchStrategy searchStrategy,
|
|
107
107
|
const double epsilon) :
|
|
108
|
-
ns(
|
|
108
|
+
ns(searchStrategy, epsilon)
|
|
109
109
|
{
|
|
110
110
|
// Nothing else to do.
|
|
111
111
|
}
|
|
@@ -121,9 +121,9 @@ class NSWrapper : public NSWrapperBase
|
|
|
121
121
|
const arma::mat& Dataset() const { return ns.ReferenceSet(); }
|
|
122
122
|
|
|
123
123
|
//! Get the search mode.
|
|
124
|
-
|
|
124
|
+
NeighborSearchStrategy SearchStrategy() const { return ns.SearchStrategy(); }
|
|
125
125
|
//! Modify the search mode.
|
|
126
|
-
|
|
126
|
+
NeighborSearchStrategy& SearchStrategy() { return ns.SearchStrategy(); }
|
|
127
127
|
|
|
128
128
|
//! Get epsilon, the approximation parameter.
|
|
129
129
|
double Epsilon() const { return ns.Epsilon(); }
|
|
@@ -201,12 +201,12 @@ class LeafSizeNSWrapper :
|
|
|
201
201
|
public:
|
|
202
202
|
//! Construct the LeafSizeNSWrapper by delegating to the NSWrapper
|
|
203
203
|
//! constructor.
|
|
204
|
-
LeafSizeNSWrapper(const
|
|
204
|
+
LeafSizeNSWrapper(const NeighborSearchStrategy searchStrategy,
|
|
205
205
|
const double epsilon) :
|
|
206
206
|
NSWrapper<SortPolicy,
|
|
207
207
|
TreeType,
|
|
208
208
|
DualTreeTraversalType,
|
|
209
|
-
SingleTreeTraversalType>(
|
|
209
|
+
SingleTreeTraversalType>(searchStrategy, epsilon)
|
|
210
210
|
{
|
|
211
211
|
// Nothing to do.
|
|
212
212
|
}
|
|
@@ -270,7 +270,7 @@ class SpillNSWrapper :
|
|
|
270
270
|
{
|
|
271
271
|
public:
|
|
272
272
|
//! Construct the SpillNSWrapper.
|
|
273
|
-
SpillNSWrapper(const
|
|
273
|
+
SpillNSWrapper(const NeighborSearchStrategy searchStrategy,
|
|
274
274
|
const double epsilon) :
|
|
275
275
|
NSWrapper<
|
|
276
276
|
SortPolicy,
|
|
@@ -281,7 +281,7 @@ class SpillNSWrapper :
|
|
|
281
281
|
SPTree<EuclideanDistance,
|
|
282
282
|
NeighborSearchStat<SortPolicy>,
|
|
283
283
|
arma::mat>::template DefeatistSingleTreeTraverser>(
|
|
284
|
-
|
|
284
|
+
searchStrategy, epsilon)
|
|
285
285
|
{
|
|
286
286
|
// Nothing to do.
|
|
287
287
|
}
|
|
@@ -430,9 +430,9 @@ class NSModel
|
|
|
430
430
|
//! Expose the dataset.
|
|
431
431
|
const arma::mat& Dataset() const;
|
|
432
432
|
|
|
433
|
-
//! Expose
|
|
434
|
-
|
|
435
|
-
|
|
433
|
+
//! Expose search strategy..
|
|
434
|
+
NeighborSearchStrategy SearchStrategy() const;
|
|
435
|
+
NeighborSearchStrategy& SearchStrategy();
|
|
436
436
|
|
|
437
437
|
//! Expose LeafSize.
|
|
438
438
|
size_t LeafSize() const { return leafSize; }
|
|
@@ -459,13 +459,13 @@ class NSModel
|
|
|
459
459
|
bool& RandomBasis() { return randomBasis; }
|
|
460
460
|
|
|
461
461
|
//! Initialize the model type. (This does not perform any training.)
|
|
462
|
-
void InitializeModel(const
|
|
462
|
+
void InitializeModel(const NeighborSearchStrategy searchStrategy,
|
|
463
463
|
const double epsilon);
|
|
464
464
|
|
|
465
465
|
//! Build the reference tree.
|
|
466
466
|
void BuildModel(util::Timers& timers,
|
|
467
467
|
arma::mat&& referenceSet,
|
|
468
|
-
const
|
|
468
|
+
const NeighborSearchStrategy searchStrategy,
|
|
469
469
|
const double epsilon = 0);
|
|
470
470
|
|
|
471
471
|
//! Perform neighbor search. The query set will be reordered.
|