mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -94,12 +94,12 @@ bool RPTreeMeanSplit<BoundType, MatType>::GetDotMedian(
|
|
|
94
94
|
for (size_t k = 0; k < samples.n_elem; ++k)
|
|
95
95
|
values[k] = dot(data.col(samples[k]), direction);
|
|
96
96
|
|
|
97
|
-
const ElemType maximum =
|
|
97
|
+
const ElemType maximum = max(values);
|
|
98
98
|
const ElemType minimum = min(values);
|
|
99
99
|
if (minimum == maximum)
|
|
100
100
|
return false;
|
|
101
101
|
|
|
102
|
-
splitVal =
|
|
102
|
+
splitVal = median(values);
|
|
103
103
|
|
|
104
104
|
if (splitVal == maximum)
|
|
105
105
|
splitVal = minimum;
|
|
@@ -111,29 +111,29 @@ template<typename BoundType, typename MatType>
|
|
|
111
111
|
bool RPTreeMeanSplit<BoundType, MatType>::GetMeanMedian(
|
|
112
112
|
const MatType& data,
|
|
113
113
|
const arma::uvec& samples,
|
|
114
|
-
arma::Col<ElemType>&
|
|
114
|
+
arma::Col<ElemType>& meanCol,
|
|
115
115
|
ElemType& splitVal)
|
|
116
116
|
{
|
|
117
117
|
arma::Col<ElemType> values(samples.n_elem);
|
|
118
118
|
|
|
119
|
-
|
|
119
|
+
meanCol = mean(data.cols(samples), 1);
|
|
120
120
|
|
|
121
121
|
arma::Col<ElemType> tmp(data.n_rows);
|
|
122
122
|
|
|
123
123
|
for (size_t k = 0; k < samples.n_elem; ++k)
|
|
124
124
|
{
|
|
125
125
|
tmp = data.col(samples[k]);
|
|
126
|
-
tmp -=
|
|
126
|
+
tmp -= meanCol;
|
|
127
127
|
|
|
128
128
|
values[k] = dot(tmp, tmp);
|
|
129
129
|
}
|
|
130
130
|
|
|
131
|
-
const ElemType maximum =
|
|
131
|
+
const ElemType maximum = max(values);
|
|
132
132
|
const ElemType minimum = min(values);
|
|
133
133
|
if (minimum == maximum)
|
|
134
134
|
return false;
|
|
135
135
|
|
|
136
|
-
splitVal =
|
|
136
|
+
splitVal = median(values);
|
|
137
137
|
|
|
138
138
|
if (splitVal == maximum)
|
|
139
139
|
splitVal = minimum;
|
|
@@ -143,10 +143,10 @@ class CellBound
|
|
|
143
143
|
ElemType& MinWidth() { return minWidth; }
|
|
144
144
|
|
|
145
145
|
//! Get the distance metric associated with this bound.
|
|
146
|
-
[[deprecated("Will be removed in 5.0.0; use Distance()")]]
|
|
146
|
+
[[deprecated("Will be removed in mlpack 5.0.0; use Distance()")]]
|
|
147
147
|
const DistanceType& Metric() const { return distance; }
|
|
148
148
|
//! Modify the distance metric associated with this bound.
|
|
149
|
-
[[deprecated("Will be removed in 5.0.0; use Distance()")]]
|
|
149
|
+
[[deprecated("Will be removed in mlpack 5.0.0; use Distance()")]]
|
|
150
150
|
DistanceType& Metric() { return distance; }
|
|
151
151
|
|
|
152
152
|
//! Get the distance metric associated with this bound.
|
|
@@ -33,7 +33,7 @@ inline CosineTree<MatType>::CosineTree(const MatType& dataset) :
|
|
|
33
33
|
for (size_t i = 0; i < numColumns; ++i)
|
|
34
34
|
{
|
|
35
35
|
indices[i] = i;
|
|
36
|
-
double l2Norm = (double)
|
|
36
|
+
double l2Norm = (double) norm(dataset.col(i), 2);
|
|
37
37
|
l2NormsSquared(i) = l2Norm * l2Norm;
|
|
38
38
|
}
|
|
39
39
|
|
|
@@ -92,7 +92,7 @@ inline CosineTree<MatType>::CosineTree(const MatType& dataset,
|
|
|
92
92
|
|
|
93
93
|
// Define root node of the tree and add it to the queue.
|
|
94
94
|
CosineTree root(dataset);
|
|
95
|
-
VecType tempVector =
|
|
95
|
+
VecType tempVector = VecType(dataset.n_rows, GetFillType<VecType>::zeros);
|
|
96
96
|
root.L2Error(-1.0); // We don't know what the error is.
|
|
97
97
|
root.BasisVector(tempVector);
|
|
98
98
|
treeQueue.push_back(&root);
|
|
@@ -412,8 +412,8 @@ inline void CosineTree<MatType>::ModifiedGramSchmidt(
|
|
|
412
412
|
}
|
|
413
413
|
|
|
414
414
|
// Normalize the modified centroid vector.
|
|
415
|
-
if (
|
|
416
|
-
newBasisVector /=
|
|
415
|
+
if (norm(newBasisVector, 2))
|
|
416
|
+
newBasisVector /= norm(newBasisVector, 2);
|
|
417
417
|
}
|
|
418
418
|
|
|
419
419
|
template<typename MatType>
|
|
@@ -475,7 +475,7 @@ inline double CosineTree<MatType>::MonteCarloError(
|
|
|
475
475
|
}
|
|
476
476
|
|
|
477
477
|
// Calculate the Frobenius norm squared of the projected vector.
|
|
478
|
-
double frobProjection =
|
|
478
|
+
double frobProjection = norm(projection, "frob");
|
|
479
479
|
double frobProjectionSquared = frobProjection * frobProjection;
|
|
480
480
|
|
|
481
481
|
// Calculate the weighted projection magnitude.
|
|
@@ -483,8 +483,8 @@ inline double CosineTree<MatType>::MonteCarloError(
|
|
|
483
483
|
}
|
|
484
484
|
|
|
485
485
|
// Compute mean and standard deviation of the weighted samples.
|
|
486
|
-
double mu =
|
|
487
|
-
double sigma =
|
|
486
|
+
double mu = mean(weightedMagnitudes);
|
|
487
|
+
double sigma = stddev(weightedMagnitudes);
|
|
488
488
|
|
|
489
489
|
if (!sigma)
|
|
490
490
|
{
|
|
@@ -536,7 +536,7 @@ inline void CosineTree<MatType>::CosineNodeSplit()
|
|
|
536
536
|
|
|
537
537
|
// Compute maximum and minimum cosine values.
|
|
538
538
|
double cosineMax, cosineMin;
|
|
539
|
-
cosineMax =
|
|
539
|
+
cosineMax = max(cosines % (cosines < 1));
|
|
540
540
|
cosineMin = min(cosines);
|
|
541
541
|
|
|
542
542
|
std::vector<size_t> leftIndices, rightIndices;
|
|
@@ -670,8 +670,8 @@ inline void CosineTree<MatType>::CalculateCosines(
|
|
|
670
670
|
else
|
|
671
671
|
{
|
|
672
672
|
cosines(i) =
|
|
673
|
-
std::abs(
|
|
674
|
-
|
|
673
|
+
std::abs(norm_dot(dataset->col(indices[splitPointIndex]),
|
|
674
|
+
dataset->col(indices[i])));
|
|
675
675
|
}
|
|
676
676
|
}
|
|
677
677
|
}
|
|
@@ -402,6 +402,16 @@ class Octree
|
|
|
402
402
|
const VecType& point,
|
|
403
403
|
typename std::enable_if_t<IsVector<VecType>::value>* = 0) const;
|
|
404
404
|
|
|
405
|
+
//! Return the index of the beginning point of this subset.
|
|
406
|
+
size_t Begin() const { return begin; }
|
|
407
|
+
//! Modify the index of the beginning point of this subset.
|
|
408
|
+
size_t& Begin() { return begin; }
|
|
409
|
+
|
|
410
|
+
//! Return the number of points in this subset.
|
|
411
|
+
size_t Count() const { return count; }
|
|
412
|
+
//! Modify the number of points in this subset.
|
|
413
|
+
size_t& Count() { return count; }
|
|
414
|
+
|
|
405
415
|
//! Store the center of the bounding region in the given vector.
|
|
406
416
|
template<typename VecType>
|
|
407
417
|
void Center(VecType& center) const { bound.Center(center); }
|
|
@@ -288,8 +288,13 @@ Octree<DistanceType, StatisticType, MatType>::Octree(
|
|
|
288
288
|
// Calculate empirical center of data.
|
|
289
289
|
bound |= dataset->cols(begin, begin + count - 1);
|
|
290
290
|
|
|
291
|
-
|
|
292
|
-
|
|
291
|
+
ElemType maxWidth = 0.0;
|
|
292
|
+
for (size_t i = 0; i < bound.Dim(); ++i)
|
|
293
|
+
if (bound[i].Hi() - bound[i].Lo() > maxWidth)
|
|
294
|
+
maxWidth = bound[i].Hi() - bound[i].Lo();
|
|
295
|
+
|
|
296
|
+
if (maxWidth != 0.0)
|
|
297
|
+
SplitNode(center, width, maxLeafSize);
|
|
293
298
|
|
|
294
299
|
// Calculate the distance from the empirical center of this node to the
|
|
295
300
|
// empirical center of the parent.
|
|
@@ -323,8 +328,13 @@ Octree<DistanceType, StatisticType, MatType>::Octree(
|
|
|
323
328
|
// Calculate empirical center of data.
|
|
324
329
|
bound |= dataset->cols(begin, begin + count - 1);
|
|
325
330
|
|
|
326
|
-
|
|
327
|
-
|
|
331
|
+
ElemType maxWidth = 0.0;
|
|
332
|
+
for (size_t i = 0; i < bound.Dim(); ++i)
|
|
333
|
+
if (bound[i].Hi() - bound[i].Lo() > maxWidth)
|
|
334
|
+
maxWidth = bound[i].Hi() - bound[i].Lo();
|
|
335
|
+
|
|
336
|
+
if (maxWidth != 0.0)
|
|
337
|
+
SplitNode(center, width, oldFromNew, maxLeafSize);
|
|
328
338
|
|
|
329
339
|
// Calculate the distance from the empirical center of this node to the
|
|
330
340
|
// empirical center of the parent.
|
|
@@ -12,6 +12,16 @@
|
|
|
12
12
|
#ifndef MLPACK_CORE_UTIL_ARMA_TRAITS_HPP
|
|
13
13
|
#define MLPACK_CORE_UTIL_ARMA_TRAITS_HPP
|
|
14
14
|
|
|
15
|
+
// Get whether or not the given type is any non-field Armadillo type
|
|
16
|
+
// This includes sparse, dense, and cube types
|
|
17
|
+
template<typename T>
|
|
18
|
+
struct IsArma
|
|
19
|
+
{
|
|
20
|
+
constexpr static bool value = arma::is_arma_type<T>::value ||
|
|
21
|
+
arma::is_arma_cube_type<T>::value ||
|
|
22
|
+
arma::is_arma_sparse_type<T>::value;
|
|
23
|
+
};
|
|
24
|
+
|
|
15
25
|
// Structs have public members by default (that's why they are chosen over
|
|
16
26
|
// classes).
|
|
17
27
|
|
|
@@ -154,6 +164,15 @@ struct GetRowType<arma::SpMat<eT>>
|
|
|
154
164
|
using type = arma::SpRow<eT>;
|
|
155
165
|
};
|
|
156
166
|
|
|
167
|
+
template<typename MatType, typename T = void>
|
|
168
|
+
struct GetURowType;
|
|
169
|
+
|
|
170
|
+
template<typename MatType>
|
|
171
|
+
struct GetURowType<MatType, std::enable_if_t<IsArma<MatType>::value>>
|
|
172
|
+
{
|
|
173
|
+
using type = arma::Row<arma::uword>;
|
|
174
|
+
};
|
|
175
|
+
|
|
157
176
|
// Get the column vector type corresponding to a given MatType.
|
|
158
177
|
|
|
159
178
|
template<typename MatType>
|
|
@@ -162,8 +181,11 @@ struct GetColType
|
|
|
162
181
|
using type = arma::Col<typename MatType::elem_type>;
|
|
163
182
|
};
|
|
164
183
|
|
|
184
|
+
template<typename MatType, typename T = void>
|
|
185
|
+
struct GetUColType;
|
|
186
|
+
|
|
165
187
|
template<typename MatType>
|
|
166
|
-
struct GetUColType
|
|
188
|
+
struct GetUColType<MatType, std::enable_if_t<IsArma<MatType>::value>>
|
|
167
189
|
{
|
|
168
190
|
using type = arma::Col<arma::uword>;
|
|
169
191
|
};
|
|
@@ -239,16 +261,6 @@ struct GetCubeType<arma::Mat<eT>>
|
|
|
239
261
|
using type = arma::Cube<eT>;
|
|
240
262
|
};
|
|
241
263
|
|
|
242
|
-
#if defined(MLPACK_HAS_COOT)
|
|
243
|
-
|
|
244
|
-
template<typename eT>
|
|
245
|
-
struct GetCubeType<coot::Mat<eT>>
|
|
246
|
-
{
|
|
247
|
-
using type = coot::Cube<eT>;
|
|
248
|
-
};
|
|
249
|
-
|
|
250
|
-
#endif
|
|
251
|
-
|
|
252
264
|
// Get the sparse matrix type corresponding to a given MatType.
|
|
253
265
|
|
|
254
266
|
template<typename MatType>
|
|
@@ -356,35 +368,10 @@ struct IsDense<arma::Mat<eT>>
|
|
|
356
368
|
constexpr static bool value = true;
|
|
357
369
|
};
|
|
358
370
|
|
|
359
|
-
// Get whether or not the given type is any non-field Armadillo type
|
|
360
|
-
// This includes sparse, dense, and cube types
|
|
361
371
|
template<typename T>
|
|
362
|
-
struct
|
|
372
|
+
struct IsSparse
|
|
363
373
|
{
|
|
364
|
-
constexpr static bool value = arma::
|
|
365
|
-
arma::is_arma_cube_type<T>::value ||
|
|
366
|
-
arma::is_arma_sparse_type<T>::value;
|
|
374
|
+
constexpr static bool value = arma::is_arma_sparse_type<T>::value;
|
|
367
375
|
};
|
|
368
376
|
|
|
369
|
-
#if defined(MLPACK_HAS_COOT)
|
|
370
|
-
|
|
371
|
-
// Get whether or not the given type is any Bandicoot type
|
|
372
|
-
// This includes dense and cube types
|
|
373
|
-
template<typename T>
|
|
374
|
-
struct IsCoot
|
|
375
|
-
{
|
|
376
|
-
constexpr static bool value = coot::is_coot_type<T>::value ||
|
|
377
|
-
coot::is_coot_cube_type<T>::value;
|
|
378
|
-
};
|
|
379
|
-
|
|
380
|
-
#else
|
|
381
|
-
|
|
382
|
-
template<typename T>
|
|
383
|
-
struct IsCoot
|
|
384
|
-
{
|
|
385
|
-
constexpr static bool value = false;
|
|
386
|
-
};
|
|
387
|
-
|
|
388
|
-
#endif
|
|
389
|
-
|
|
390
377
|
#endif
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/util/coot_traits.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
5
|
+
*
|
|
6
|
+
* Some traits used for template metaprogramming (SFINAE) with Bandicoot types.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_UTIL_COOT_TRAITS_HPP
|
|
14
|
+
#define MLPACK_CORE_UTIL_COOT_TRAITS_HPP
|
|
15
|
+
|
|
16
|
+
#if defined(MLPACK_HAS_COOT)
|
|
17
|
+
|
|
18
|
+
// Get whether or not the given type is any Bandicoot type
|
|
19
|
+
// This includes dense and cube types
|
|
20
|
+
template<typename T>
|
|
21
|
+
struct IsCoot
|
|
22
|
+
{
|
|
23
|
+
constexpr static bool value = coot::is_coot_type<T>::value ||
|
|
24
|
+
coot::is_coot_cube_type<T>::value;
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
template<typename eT>
|
|
28
|
+
struct GetCubeType<coot::Mat<eT>>
|
|
29
|
+
{
|
|
30
|
+
using type = coot::Cube<eT>;
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
template<typename eT>
|
|
34
|
+
struct GetDenseMatType<coot::Cube<eT>>
|
|
35
|
+
{
|
|
36
|
+
using type = coot::Mat<eT>;
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
template<typename MatType>
|
|
40
|
+
struct GetURowType<MatType, std::enable_if_t<IsCoot<MatType>::value>>
|
|
41
|
+
{
|
|
42
|
+
using type = coot::Row<coot::uword>;
|
|
43
|
+
};
|
|
44
|
+
|
|
45
|
+
template<typename MatType>
|
|
46
|
+
struct GetUColType<MatType, std::enable_if_t<IsCoot<MatType>::value>>
|
|
47
|
+
{
|
|
48
|
+
using type = coot::Col<coot::uword>;
|
|
49
|
+
};
|
|
50
|
+
|
|
51
|
+
template<typename eT>
|
|
52
|
+
struct IsVector<coot::Col<eT> >
|
|
53
|
+
{
|
|
54
|
+
static const bool value = true;
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
template<typename eT>
|
|
58
|
+
struct IsVector<coot::Row<eT> >
|
|
59
|
+
{
|
|
60
|
+
static const bool value = true;
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
template<typename eT>
|
|
64
|
+
struct IsVector<coot::subview_col<eT> >
|
|
65
|
+
{
|
|
66
|
+
static const bool value = true;
|
|
67
|
+
};
|
|
68
|
+
|
|
69
|
+
template<typename eT>
|
|
70
|
+
struct IsVector<coot::subview_row<eT> >
|
|
71
|
+
{
|
|
72
|
+
static const bool value = true;
|
|
73
|
+
};
|
|
74
|
+
|
|
75
|
+
template<typename eT>
|
|
76
|
+
struct IsMatrix<coot::Mat<eT> >
|
|
77
|
+
{
|
|
78
|
+
static const bool value = true;
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
template<typename eT>
|
|
82
|
+
struct IsCube<coot::Cube<eT> >
|
|
83
|
+
{
|
|
84
|
+
static const bool value = true;
|
|
85
|
+
};
|
|
86
|
+
|
|
87
|
+
#else
|
|
88
|
+
|
|
89
|
+
template<typename T>
|
|
90
|
+
struct IsCoot
|
|
91
|
+
{
|
|
92
|
+
constexpr static bool value = false;
|
|
93
|
+
};
|
|
94
|
+
|
|
95
|
+
#endif // defined(MLPACK_HAS_COOT)
|
|
96
|
+
|
|
97
|
+
#endif
|
|
@@ -30,7 +30,6 @@ class Timers;
|
|
|
30
30
|
#include "params.hpp"
|
|
31
31
|
|
|
32
32
|
namespace mlpack {
|
|
33
|
-
namespace data {
|
|
34
33
|
|
|
35
34
|
class IncrementPolicy;
|
|
36
35
|
|
|
@@ -44,7 +43,6 @@ using DatasetInfo = DatasetMapper<IncrementPolicy, std::string>;
|
|
|
44
43
|
// DatasetInfo.
|
|
45
44
|
void CheckCategoricalParam(util::Params& p, const std::string& paramName);
|
|
46
45
|
|
|
47
|
-
} // namespace data
|
|
48
46
|
} // namespace mlpack
|
|
49
47
|
|
|
50
48
|
#endif
|
|
@@ -757,10 +757,10 @@
|
|
|
757
757
|
* here---it will cause problems.
|
|
758
758
|
* @param ALIAS One-character string representing the alias of the parameter.
|
|
759
759
|
*/
|
|
760
|
-
#define TUPLE_TYPE std::tuple<mlpack::
|
|
760
|
+
#define TUPLE_TYPE std::tuple<mlpack::DatasetInfo, arma::mat>
|
|
761
761
|
#define PARAM_MATRIX_AND_INFO_IN(ID, DESC, ALIAS) \
|
|
762
762
|
PARAM(TUPLE_TYPE, ID, DESC, ALIAS, \
|
|
763
|
-
"std::tuple<mlpack::
|
|
763
|
+
"std::tuple<mlpack::DatasetInfo, arma::mat>", false, true, true, \
|
|
764
764
|
TUPLE_TYPE())
|
|
765
765
|
|
|
766
766
|
/**
|
|
@@ -789,10 +789,10 @@
|
|
|
789
789
|
* here---it will cause problems.
|
|
790
790
|
* @param ALIAS One-character string representing the alias of the parameter.
|
|
791
791
|
*/
|
|
792
|
-
#define TUPLE_TYPE std::tuple<mlpack::
|
|
792
|
+
#define TUPLE_TYPE std::tuple<mlpack::DatasetInfo, arma::mat>
|
|
793
793
|
#define PARAM_MATRIX_AND_INFO_IN_REQ(ID, DESC, ALIAS) \
|
|
794
794
|
PARAM(TUPLE_TYPE, ID, DESC, ALIAS, \
|
|
795
|
-
"std::tuple<mlpack::
|
|
795
|
+
"std::tuple<mlpack::DatasetInfo, arma::mat>", true, true, true, \
|
|
796
796
|
TUPLE_TYPE())
|
|
797
797
|
|
|
798
798
|
/**
|
|
@@ -286,11 +286,11 @@ inline void Params::CheckInputMatrices()
|
|
|
286
286
|
{
|
|
287
287
|
CheckInputMatrix(Get<arma::rowvec>(paramName), paramName);
|
|
288
288
|
}
|
|
289
|
-
else if (paramType == "std::tuple<mlpack::
|
|
289
|
+
else if (paramType == "std::tuple<mlpack::DatasetInfo, arma::mat>")
|
|
290
290
|
{
|
|
291
291
|
// Note that CheckCategoricalParam() is a utility function that must be
|
|
292
292
|
// defined after DatasetInfo is fully defined.
|
|
293
|
-
|
|
293
|
+
CheckCategoricalParam(*this, paramName);
|
|
294
294
|
}
|
|
295
295
|
}
|
|
296
296
|
}
|
|
@@ -18,16 +18,24 @@
|
|
|
18
18
|
#define MLPACK_CORE_UTIL_USING_HPP
|
|
19
19
|
|
|
20
20
|
#include "arma_traits.hpp"
|
|
21
|
+
#include "coot_traits.hpp"
|
|
21
22
|
|
|
22
23
|
namespace mlpack {
|
|
23
24
|
|
|
24
25
|
#ifdef MLPACK_HAS_COOT
|
|
25
26
|
|
|
26
27
|
/* using for bandicoot namespace*/
|
|
27
|
-
using coot::
|
|
28
|
+
using coot::accu;
|
|
29
|
+
using coot::all;
|
|
30
|
+
using coot::conv_to;
|
|
28
31
|
using coot::dot;
|
|
32
|
+
using coot::exp;
|
|
33
|
+
using coot::find;
|
|
34
|
+
using coot::find_nan;
|
|
35
|
+
using coot::find_nonfinite;
|
|
29
36
|
using coot::join_cols;
|
|
30
37
|
using coot::join_rows;
|
|
38
|
+
using coot::linspace;
|
|
31
39
|
using coot::log;
|
|
32
40
|
using coot::min;
|
|
33
41
|
using coot::max;
|
|
@@ -41,21 +49,38 @@ using coot::randn;
|
|
|
41
49
|
using coot::randu;
|
|
42
50
|
using coot::repmat;
|
|
43
51
|
using coot::sign;
|
|
52
|
+
using coot::size;
|
|
53
|
+
using coot::sort_index;
|
|
44
54
|
using coot::sqrt;
|
|
45
55
|
using coot::square;
|
|
46
56
|
using coot::sum;
|
|
47
57
|
using coot::trans;
|
|
48
58
|
using coot::vectorise;
|
|
49
59
|
using coot::zeros;
|
|
60
|
+
#else
|
|
61
|
+
|
|
62
|
+
// Only use arma::conv_to if Bandicoot is not available: Bandicoot's conv_to
|
|
63
|
+
// supports Armadillo types too.
|
|
64
|
+
using arma::conv_to;
|
|
50
65
|
|
|
51
66
|
#endif
|
|
52
67
|
|
|
53
68
|
/* using for armadillo namespace */
|
|
54
|
-
using arma::
|
|
69
|
+
using arma::accu;
|
|
70
|
+
using arma::all;
|
|
55
71
|
using arma::dot;
|
|
72
|
+
using arma::exp;
|
|
73
|
+
using arma::find;
|
|
74
|
+
#if ARMA_VERSION_MAJOR > 11 || \
|
|
75
|
+
(ARMA_VERSION_MAJOR == 11 && ARMA_VERSION_MINOR >= 4)
|
|
76
|
+
using arma::find_nan;
|
|
77
|
+
#endif
|
|
78
|
+
using arma::find_nonfinite;
|
|
56
79
|
using arma::join_cols;
|
|
57
80
|
using arma::join_rows;
|
|
81
|
+
using arma::linspace;
|
|
58
82
|
using arma::log;
|
|
83
|
+
using arma::linspace;
|
|
59
84
|
using arma::min;
|
|
60
85
|
using arma::max;
|
|
61
86
|
using arma::mean;
|
|
@@ -68,6 +93,8 @@ using arma::randn;
|
|
|
68
93
|
using arma::randu;
|
|
69
94
|
using arma::repmat;
|
|
70
95
|
using arma::sign;
|
|
96
|
+
using arma::size;
|
|
97
|
+
using arma::sort_index;
|
|
71
98
|
using arma::sqrt;
|
|
72
99
|
using arma::square;
|
|
73
100
|
using arma::sum;
|
|
@@ -15,10 +15,12 @@
|
|
|
15
15
|
#include <string>
|
|
16
16
|
|
|
17
17
|
// The version of mlpack. If this is a git repository, this will be a version
|
|
18
|
-
// with higher number than the most recent release
|
|
18
|
+
// with higher number than the most recent release, and the MLPACK_PRERELEASE
|
|
19
|
+
// macro will be defined.
|
|
19
20
|
#define MLPACK_VERSION_MAJOR 4
|
|
20
|
-
#define MLPACK_VERSION_MINOR
|
|
21
|
-
#define MLPACK_VERSION_PATCH
|
|
21
|
+
#define MLPACK_VERSION_MINOR 7
|
|
22
|
+
#define MLPACK_VERSION_PATCH 0
|
|
23
|
+
//#define MLPACK_PRERELEASE
|
|
22
24
|
|
|
23
25
|
// The name of the version (for use by --version).
|
|
24
26
|
namespace mlpack {
|
|
@@ -20,16 +20,13 @@ namespace util {
|
|
|
20
20
|
// name.
|
|
21
21
|
inline std::string GetVersion()
|
|
22
22
|
{
|
|
23
|
-
#ifndef MLPACK_GIT_VERSION
|
|
24
23
|
std::stringstream o;
|
|
25
24
|
o << "mlpack " << MLPACK_VERSION_MAJOR << "." << MLPACK_VERSION_MINOR
|
|
26
25
|
<< "." << MLPACK_VERSION_PATCH;
|
|
26
|
+
#if defined(MLPACK_PRERELEASE)
|
|
27
|
+
o << " (prerelease)";
|
|
28
|
+
#endif
|
|
27
29
|
return o.str();
|
|
28
|
-
#else
|
|
29
|
-
// This file is generated by CMake as necessary and contains just a return
|
|
30
|
-
// statement with the git revision in it.
|
|
31
|
-
#include "gitversion.hpp"
|
|
32
|
-
#endif
|
|
33
30
|
}
|
|
34
31
|
|
|
35
32
|
} // namespace util
|
|
@@ -84,7 +84,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
84
84
|
timers.Stop("adaboost_classification");
|
|
85
85
|
|
|
86
86
|
Row<size_t> results;
|
|
87
|
-
|
|
87
|
+
RevertLabels(predictedLabels, m->Mappings(), results);
|
|
88
88
|
|
|
89
89
|
params.Get<arma::Row<size_t>>("predictions") = std::move(results);
|
|
90
90
|
}
|
|
@@ -108,7 +108,7 @@ BINDING_EXAMPLE(
|
|
|
108
108
|
BINDING_SEE_ALSO("AdaBoost on Wikipedia", "https://en.wikipedia.org/wiki/"
|
|
109
109
|
"AdaBoost");
|
|
110
110
|
BINDING_SEE_ALSO("Improved boosting algorithms using confidence-rated "
|
|
111
|
-
"predictions (pdf)", "http://
|
|
111
|
+
"predictions (pdf)", "http://www.schapire.net/papers/SchapireSi98.pdf");
|
|
112
112
|
BINDING_SEE_ALSO("Perceptron", "#perceptron");
|
|
113
113
|
BINDING_SEE_ALSO("Decision Trees", "#decision_tree");
|
|
114
114
|
BINDING_SEE_ALSO("AdaBoost C++ class documentation",
|
|
@@ -202,7 +202,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
202
202
|
Row<size_t> labels;
|
|
203
203
|
|
|
204
204
|
// Normalize the labels.
|
|
205
|
-
|
|
205
|
+
NormalizeLabels(labelsIn, labels, m->Mappings());
|
|
206
206
|
|
|
207
207
|
// Get other training parameters.
|
|
208
208
|
const double tolerance = params.Get<double>("tolerance");
|
|
@@ -253,7 +253,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
253
253
|
}
|
|
254
254
|
|
|
255
255
|
Row<size_t> results;
|
|
256
|
-
|
|
256
|
+
RevertLabels(predictedLabels, m->Mappings(), results);
|
|
257
257
|
|
|
258
258
|
// Save the predicted labels.
|
|
259
259
|
if (params.Has("predictions"))
|
|
@@ -84,7 +84,7 @@ BINDING_EXAMPLE(
|
|
|
84
84
|
BINDING_SEE_ALSO("AdaBoost on Wikipedia", "https://en.wikipedia.org/wiki/"
|
|
85
85
|
"AdaBoost");
|
|
86
86
|
BINDING_SEE_ALSO("Improved boosting algorithms using confidence-rated "
|
|
87
|
-
"predictions (pdf)", "http://
|
|
87
|
+
"predictions (pdf)", "http://www.schapire.net/papers/SchapireSi98.pdf");
|
|
88
88
|
BINDING_SEE_ALSO("Perceptron", "#perceptron");
|
|
89
89
|
BINDING_SEE_ALSO("Decision Trees", "#decision_tree");
|
|
90
90
|
BINDING_SEE_ALSO("AdaBoost C++ class documentation",
|
|
@@ -157,7 +157,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
157
157
|
Row<size_t> labels;
|
|
158
158
|
|
|
159
159
|
// Normalize the labels.
|
|
160
|
-
|
|
160
|
+
NormalizeLabels(labelsIn, labels, m->Mappings());
|
|
161
161
|
|
|
162
162
|
// Get other training parameters.
|
|
163
163
|
const double tolerance = params.Get<double>("tolerance");
|
|
@@ -34,7 +34,8 @@ class BipolarSigmoidFunction
|
|
|
34
34
|
* @param x Input data.
|
|
35
35
|
* @return f(x).
|
|
36
36
|
*/
|
|
37
|
-
|
|
37
|
+
template<typename ElemType>
|
|
38
|
+
static ElemType Fn(const ElemType x)
|
|
38
39
|
{
|
|
39
40
|
return (1 - std::exp(-x)) / (1 + std::exp(-x));
|
|
40
41
|
}
|
|
@@ -58,9 +59,10 @@ class BipolarSigmoidFunction
|
|
|
58
59
|
* @param y Result of Fn(x).
|
|
59
60
|
* @return f'(x)
|
|
60
61
|
*/
|
|
61
|
-
|
|
62
|
+
template<typename ElemType>
|
|
63
|
+
static ElemType Deriv(const ElemType /* x */, const ElemType y)
|
|
62
64
|
{
|
|
63
|
-
return (1
|
|
65
|
+
return (1 - std::pow(y, ElemType(2))) / 2;
|
|
64
66
|
}
|
|
65
67
|
|
|
66
68
|
/**
|
|
@@ -75,7 +77,7 @@ class BipolarSigmoidFunction
|
|
|
75
77
|
const OutputVecType& y,
|
|
76
78
|
DerivVecType& dy)
|
|
77
79
|
{
|
|
78
|
-
dy = (1
|
|
80
|
+
dy = (1 - square(y)) / 2;
|
|
79
81
|
}
|
|
80
82
|
}; // class BipolarSigmoidFunction
|
|
81
83
|
|