mlpack 4.6.1__cp312-cp312-win_amd64.whl → 4.7.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp312-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp312-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp312-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp312-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp312-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp312-win_amd64.pyd +0 -0
- mlpack/cf.cp312-win_amd64.pyd +0 -0
- mlpack/dbscan.cp312-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp312-win_amd64.pyd +0 -0
- mlpack/det.cp312-win_amd64.pyd +0 -0
- mlpack/emst.cp312-win_amd64.pyd +0 -0
- mlpack/fastmks.cp312-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp312-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp312-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp312-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp312-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp312-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp312-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp312-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp312-win_amd64.pyd +0 -0
- mlpack/image_converter.cp312-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +25 -16
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +53 -43
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +194 -57
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +130 -315
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/ccov.hpp +1 -0
- mlpack/include/mlpack/core/math/ccov_impl.hpp +4 -5
- mlpack/include/mlpack/core/math/make_alias.hpp +100 -3
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -21
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/sfinae_utility.hpp +24 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +2 -3
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/network_init.hpp +5 -5
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +68 -68
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +29 -32
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +24 -23
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +28 -27
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +27 -26
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +30 -31
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +32 -27
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +185 -89
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +29 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +38 -39
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +22 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +45 -32
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +16 -2
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +8 -7
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +145 -50
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +245 -53
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_impl.hpp +3 -8
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/fitness_functions/gini_gain.hpp +5 -8
- mlpack/include/mlpack/methods/decision_tree/fitness_functions/information_gain.hpp +5 -8
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/gmm/diagonal_gmm_impl.hpp +2 -1
- mlpack/include/mlpack/methods/gmm/eigenvalue_ratio_constraint.hpp +3 -3
- mlpack/include/mlpack/methods/gmm/gmm_impl.hpp +2 -1
- mlpack/include/mlpack/methods/hmm/hmm_impl.hpp +10 -5
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +61 -41
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +77 -67
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp312-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp312-win_amd64.pyd +0 -0
- mlpack/kfn.cp312-win_amd64.pyd +0 -0
- mlpack/kmeans.cp312-win_amd64.pyd +0 -0
- mlpack/knn.cp312-win_amd64.pyd +0 -0
- mlpack/krann.cp312-win_amd64.pyd +0 -0
- mlpack/lars.cp312-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp312-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp312-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp312-win_amd64.pyd +0 -0
- mlpack/lmnn.cp312-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp312-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp312-win_amd64.pyd +0 -0
- mlpack/lsh.cp312-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp312-win_amd64.pyd +0 -0
- mlpack/nbc.cp312-win_amd64.pyd +0 -0
- mlpack/nca.cp312-win_amd64.pyd +0 -0
- mlpack/nmf.cp312-win_amd64.pyd +0 -0
- mlpack/pca.cp312-win_amd64.pyd +0 -0
- mlpack/perceptron.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp312-win_amd64.pyd +0 -0
- mlpack/radical.cp312-win_amd64.pyd +0 -0
- mlpack/random_forest.cp312-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp312-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp312-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.1.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.1.dist-info → mlpack-4.7.0.dist-info}/RECORD +407 -388
- {mlpack-4.6.1.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.1.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.1.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
mlpack/__init__.py
CHANGED
|
@@ -11,14 +11,14 @@ http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
# start delvewheel patch
|
|
14
|
-
def
|
|
14
|
+
def _delvewheel_patch_1_12_0():
|
|
15
15
|
import os
|
|
16
16
|
if os.path.isdir(libs_dir := os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'mlpack.libs'))):
|
|
17
17
|
os.add_dll_directory(libs_dir)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
del
|
|
20
|
+
_delvewheel_patch_1_12_0()
|
|
21
|
+
del _delvewheel_patch_1_12_0
|
|
22
22
|
# end delvewheel patch
|
|
23
23
|
|
|
24
24
|
import warnings
|
|
@@ -74,4 +74,4 @@ from .adaboost import *
|
|
|
74
74
|
from .linear_regression_train import linear_regression_train
|
|
75
75
|
from .linear_regression_predict import linear_regression_predict
|
|
76
76
|
from .linear_regression import *
|
|
77
|
-
__version__='4.
|
|
77
|
+
__version__='4.7.0'
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mlpack/cf.cp312-win_amd64.pyd
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
mlpack/det.cp312-win_amd64.pyd
CHANGED
|
Binary file
|
mlpack/emst.cp312-win_amd64.pyd
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mlpack/include/mlpack/base.hpp
CHANGED
|
@@ -86,6 +86,7 @@
|
|
|
86
86
|
#include <armadillo>
|
|
87
87
|
#include <mlpack/core/util/arma_traits.hpp>
|
|
88
88
|
#include <mlpack/core/util/omp_reductions.hpp>
|
|
89
|
+
#include <mlpack/core/arma_extend/find_nan.hpp>
|
|
89
90
|
|
|
90
91
|
// On Visual Studio, disable C4519 (default arguments for function templates)
|
|
91
92
|
// since it's by default an error, which doesn't even make any sense because
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file find_nan.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
*
|
|
5
|
+
* When find_nan() is not available (Armadillo < 11.4), provide an internal
|
|
6
|
+
* mlpack implementation that operates the same way. It is slower.
|
|
7
|
+
*/
|
|
8
|
+
#ifndef MLPACK_CORE_ARMA_EXTEND_FIND_NAN_HPP
|
|
9
|
+
#define MLPACK_CORE_ARMA_EXTEND_FIND_NAN_HPP
|
|
10
|
+
|
|
11
|
+
namespace mlpack {
|
|
12
|
+
|
|
13
|
+
#if ARMA_VERSION_MAJOR < 11 || \
|
|
14
|
+
(ARMA_VERSION_MAJOR == 11 && ARMA_VERSION_MINOR < 4)
|
|
15
|
+
|
|
16
|
+
template<typename T>
|
|
17
|
+
arma::uvec find_nan(const T& m,
|
|
18
|
+
const std::enable_if_t<arma::is_arma_type<T>::value>* = 0)
|
|
19
|
+
{
|
|
20
|
+
typedef typename T::elem_type ElemType;
|
|
21
|
+
|
|
22
|
+
if (!std::numeric_limits<ElemType>::has_quiet_NaN)
|
|
23
|
+
return arma::uvec(); // There can't be any NaNs.
|
|
24
|
+
|
|
25
|
+
// find_nonfinite() exists on older Armadillo, and we can also search for +Inf
|
|
26
|
+
// and -Inf.
|
|
27
|
+
arma::uvec nonfiniteIndices = arma::find_nonfinite(m);
|
|
28
|
+
if (nonfiniteIndices.n_elem == 0)
|
|
29
|
+
return arma::uvec();
|
|
30
|
+
|
|
31
|
+
arma::uvec infIndices = arma::find(
|
|
32
|
+
m == std::numeric_limits<ElemType>::infinity());
|
|
33
|
+
arma::uvec neginfIndices = arma::find(
|
|
34
|
+
m == -std::numeric_limits<ElemType>::infinity());
|
|
35
|
+
|
|
36
|
+
arma::uvec result(nonfiniteIndices.n_elem -
|
|
37
|
+
(infIndices.n_elem + neginfIndices.n_elem));
|
|
38
|
+
if (result.n_elem == 0)
|
|
39
|
+
return result;
|
|
40
|
+
|
|
41
|
+
size_t infIndex = 0;
|
|
42
|
+
size_t neginfIndex = 0;
|
|
43
|
+
size_t outputIndex = 0;
|
|
44
|
+
for (size_t i = 0; i < nonfiniteIndices.n_elem; ++i)
|
|
45
|
+
{
|
|
46
|
+
if (infIndex < infIndices.n_elem &&
|
|
47
|
+
nonfiniteIndices[i] == infIndices[infIndex])
|
|
48
|
+
++infIndex;
|
|
49
|
+
else if (neginfIndex < neginfIndices.n_elem &&
|
|
50
|
+
nonfiniteIndices[i] == neginfIndices[neginfIndex])
|
|
51
|
+
++neginfIndex;
|
|
52
|
+
else
|
|
53
|
+
result[outputIndex++] = nonfiniteIndices[i];
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
return result;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
#endif
|
|
60
|
+
|
|
61
|
+
} // namespace mlpack
|
|
62
|
+
|
|
63
|
+
#endif
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/cereal/low_precision.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
*
|
|
5
|
+
* Extra shims necessary for cereal to serialize to JSON for low-precision types
|
|
6
|
+
* (e.g. FP16, BF16, etc.).
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_CEREAL_LOW_PRECISION_HPP
|
|
14
|
+
#define MLPACK_CORE_CEREAL_LOW_PRECISION_HPP
|
|
15
|
+
|
|
16
|
+
namespace cereal {
|
|
17
|
+
|
|
18
|
+
// Because our serialization is always done with name-value pairs, we can catch
|
|
19
|
+
// any FP16 serialization at the NVP level with a specialized implementation of
|
|
20
|
+
// the load and save functions for the JSON archive (the only one that does not
|
|
21
|
+
// serialize low-precision correctly).
|
|
22
|
+
|
|
23
|
+
#if defined(ARMA_HAVE_FP16)
|
|
24
|
+
|
|
25
|
+
inline void CEREAL_SAVE_FUNCTION_NAME(JSONOutputArchive &ar,
|
|
26
|
+
NameValuePair<arma::fp16&> const& t)
|
|
27
|
+
{
|
|
28
|
+
ar.setNextName(t.name);
|
|
29
|
+
std::ostringstream oss;
|
|
30
|
+
oss.precision(std::numeric_limits<arma::fp16>::max_digits10);
|
|
31
|
+
oss << t.value;
|
|
32
|
+
ar(oss.str());
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
inline void CEREAL_LOAD_FUNCTION_NAME(JSONInputArchive& ar,
|
|
36
|
+
NameValuePair<arma::fp16&>& t)
|
|
37
|
+
{
|
|
38
|
+
ar.setNextName(t.name);
|
|
39
|
+
std::string encoded;
|
|
40
|
+
ar.loadValue(encoded);
|
|
41
|
+
t.value = arma::fp16(std::stof(encoded));
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
} // namespace cereal
|
|
47
|
+
|
|
48
|
+
#endif
|
|
@@ -57,12 +57,12 @@ class CVBase
|
|
|
57
57
|
|
|
58
58
|
/**
|
|
59
59
|
* Assert that MLAlgorithm takes the numClasses parameter and a
|
|
60
|
-
*
|
|
60
|
+
* DatasetInfo parameter and store them.
|
|
61
61
|
*
|
|
62
62
|
* @param datasetInfo Type information for each dimension of the dataset.
|
|
63
63
|
* @param numClasses Number of classes in the dataset.
|
|
64
64
|
*/
|
|
65
|
-
CVBase(const
|
|
65
|
+
CVBase(const DatasetInfo& datasetInfo,
|
|
66
66
|
const size_t numClasses);
|
|
67
67
|
|
|
68
68
|
/**
|
|
@@ -101,9 +101,9 @@ class CVBase
|
|
|
101
101
|
static_assert(MIE::IsSupported,
|
|
102
102
|
"The given MLAlgorithm is not supported by MetaInfoExtractor");
|
|
103
103
|
|
|
104
|
-
//! A variable for storing a
|
|
105
|
-
const
|
|
106
|
-
//! An indicator whether a
|
|
104
|
+
//! A variable for storing a DatasetInfo parameter if it is passed.
|
|
105
|
+
const DatasetInfo datasetInfo;
|
|
106
|
+
//! An indicator whether a DatasetInfo parameter has been passed.
|
|
107
107
|
const bool isDatasetInfoPassed;
|
|
108
108
|
//! A variable for storing the numClasses parameter if it is passed.
|
|
109
109
|
size_t numClasses;
|
|
@@ -145,7 +145,7 @@ class CVBase
|
|
|
145
145
|
|
|
146
146
|
/**
|
|
147
147
|
* Construct a trained MLAlgorithm model if MLAlgorithm takes the
|
|
148
|
-
* numClasses parameter and a
|
|
148
|
+
* numClasses parameter and a DatasetInfo parameter.
|
|
149
149
|
*/
|
|
150
150
|
template<typename... MLAlgorithmArgs,
|
|
151
151
|
bool Enabled = MIE::TakesNumClasses & MIE::TakesDatasetInfo,
|
|
@@ -183,7 +183,7 @@ class CVBase
|
|
|
183
183
|
|
|
184
184
|
/**
|
|
185
185
|
* Construct a trained MLAlgorithm model if MLAlgorithm takes the
|
|
186
|
-
* numClasses parameter and a
|
|
186
|
+
* numClasses parameter and a DatasetInfo parameter.
|
|
187
187
|
*/
|
|
188
188
|
template<typename... MLAlgorithmArgs,
|
|
189
189
|
bool Enabled = MIE::TakesNumClasses & MIE::TakesDatasetInfo,
|
|
@@ -196,13 +196,13 @@ class CVBase
|
|
|
196
196
|
const MLAlgorithmArgs&... args);
|
|
197
197
|
|
|
198
198
|
/**
|
|
199
|
-
* When MLAlgorithm supports a
|
|
199
|
+
* When MLAlgorithm supports a DatasetInfo parameter, training should be
|
|
200
200
|
* treated separately - there are models that can be constructed with and
|
|
201
201
|
* without a data:DatasetInfo parameter and models that can be constructed
|
|
202
|
-
* only with a
|
|
202
|
+
* only with a DatasetInfo parameter.
|
|
203
203
|
*
|
|
204
204
|
* Construct a trained MLAlgorithm model when it can be constructed without a
|
|
205
|
-
*
|
|
205
|
+
* DatasetInfo parameter.
|
|
206
206
|
*/
|
|
207
207
|
template<bool ConstructableWithoutDatasetInfo,
|
|
208
208
|
typename... MLAlgorithmArgs,
|
|
@@ -213,7 +213,7 @@ class CVBase
|
|
|
213
213
|
|
|
214
214
|
/**
|
|
215
215
|
* Construct a trained MLAlgorithm model when it can't be constructed without
|
|
216
|
-
* a
|
|
216
|
+
* a DatasetInfo parameter.
|
|
217
217
|
*/
|
|
218
218
|
template<bool ConstructableWithoutDatasetInfo,
|
|
219
219
|
typename... MLAlgorithmArgs,
|
|
@@ -54,7 +54,7 @@ template<typename MLAlgorithm,
|
|
|
54
54
|
CVBase<MLAlgorithm,
|
|
55
55
|
MatType,
|
|
56
56
|
PredictionsType,
|
|
57
|
-
WeightsType>::CVBase(const
|
|
57
|
+
WeightsType>::CVBase(const DatasetInfo& datasetInfo,
|
|
58
58
|
const size_t numClasses) :
|
|
59
59
|
datasetInfo(datasetInfo),
|
|
60
60
|
isDatasetInfoPassed(true),
|
|
@@ -63,7 +63,7 @@ CVBase<MLAlgorithm,
|
|
|
63
63
|
static_assert(MIE::TakesNumClasses,
|
|
64
64
|
"The given MLAlgorithm does not take the numClasses parameter");
|
|
65
65
|
static_assert(MIE::TakesDatasetInfo,
|
|
66
|
-
"The given MLAlgorithm does not accept a
|
|
66
|
+
"The given MLAlgorithm does not accept a DatasetInfo parameter");
|
|
67
67
|
}
|
|
68
68
|
|
|
69
69
|
template<typename MLAlgorithm,
|
|
@@ -184,9 +184,9 @@ MLAlgorithm CVBase<MLAlgorithm,
|
|
|
184
184
|
{
|
|
185
185
|
static_assert(
|
|
186
186
|
std::is_constructible_v<MLAlgorithm, const MatType&,
|
|
187
|
-
const
|
|
187
|
+
const DatasetInfo, const PredictionsType&, const size_t,
|
|
188
188
|
MLAlgorithmArgs...>,
|
|
189
|
-
"The given MLAlgorithm is not constructible with a
|
|
189
|
+
"The given MLAlgorithm is not constructible with a DatasetInfo "
|
|
190
190
|
"parameter and the passed arguments");
|
|
191
191
|
|
|
192
192
|
static const bool constructableWithoutDatasetInfo =
|
|
@@ -256,9 +256,9 @@ MLAlgorithm CVBase<MLAlgorithm,
|
|
|
256
256
|
{
|
|
257
257
|
static_assert(
|
|
258
258
|
std::is_constructible_v<MLAlgorithm, const MatType&,
|
|
259
|
-
const
|
|
259
|
+
const DatasetInfo, const PredictionsType&, const size_t,
|
|
260
260
|
const WeightsType&, MLAlgorithmArgs...>,
|
|
261
|
-
"The given MLAlgorithm is not constructible with a
|
|
261
|
+
"The given MLAlgorithm is not constructible with a DatasetInfo "
|
|
262
262
|
"parameter and the passed arguments");
|
|
263
263
|
|
|
264
264
|
static const bool constructableWithoutDatasetInfo =
|
|
@@ -302,7 +302,7 @@ MLAlgorithm CVBase<MLAlgorithm,
|
|
|
302
302
|
{
|
|
303
303
|
if (!isDatasetInfoPassed)
|
|
304
304
|
throw std::invalid_argument(
|
|
305
|
-
"The given MLAlgorithm requires a
|
|
305
|
+
"The given MLAlgorithm requires a DatasetInfo parameter");
|
|
306
306
|
|
|
307
307
|
return MLAlgorithm(xs, datasetInfo, ys, numClasses, args...);
|
|
308
308
|
}
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
#include <mlpack/core/cv/meta_info_extractor.hpp>
|
|
16
16
|
#include <mlpack/core/cv/cv_base.hpp>
|
|
17
|
+
#include <mlpack/core/util/arma_traits.hpp>
|
|
17
18
|
|
|
18
19
|
namespace mlpack {
|
|
19
20
|
|
|
@@ -96,7 +97,7 @@ class KFoldCV
|
|
|
96
97
|
|
|
97
98
|
/**
|
|
98
99
|
* This constructor can be used for multiclass classification algorithms that
|
|
99
|
-
* can take a
|
|
100
|
+
* can take a DatasetInfo parameter.
|
|
100
101
|
*
|
|
101
102
|
* @param k Number of folds (should be at least 2).
|
|
102
103
|
* @param xs Data points to cross-validate on.
|
|
@@ -107,7 +108,7 @@ class KFoldCV
|
|
|
107
108
|
*/
|
|
108
109
|
KFoldCV(const size_t k,
|
|
109
110
|
const MatType& xs,
|
|
110
|
-
const
|
|
111
|
+
const DatasetInfo& datasetInfo,
|
|
111
112
|
const PredictionsType& ys,
|
|
112
113
|
const size_t numClasses,
|
|
113
114
|
const bool shuffle = true);
|
|
@@ -149,7 +150,7 @@ class KFoldCV
|
|
|
149
150
|
|
|
150
151
|
/**
|
|
151
152
|
* This constructor can be used for multiclass classification algorithms that
|
|
152
|
-
* can take a
|
|
153
|
+
* can take a DatasetInfo parameter and support weighted learning.
|
|
153
154
|
*
|
|
154
155
|
* @param k Number of folds (should be at least 2).
|
|
155
156
|
* @param xs Data points to cross-validate on.
|
|
@@ -161,7 +162,7 @@ class KFoldCV
|
|
|
161
162
|
*/
|
|
162
163
|
KFoldCV(const size_t k,
|
|
163
164
|
const MatType& xs,
|
|
164
|
-
const
|
|
165
|
+
const DatasetInfo& datasetInfo,
|
|
165
166
|
const PredictionsType& ys,
|
|
166
167
|
const size_t numClasses,
|
|
167
168
|
const WeightsType& weights,
|
|
@@ -280,30 +281,38 @@ class KFoldCV
|
|
|
280
281
|
/**
|
|
281
282
|
* Get the ith training subset from a variable of a matrix type.
|
|
282
283
|
*/
|
|
283
|
-
template<typename
|
|
284
|
-
inline
|
|
285
|
-
|
|
284
|
+
template<typename SubsetMatType>
|
|
285
|
+
inline SubsetMatType GetTrainingSubset(
|
|
286
|
+
SubsetMatType& m,
|
|
287
|
+
const size_t i,
|
|
288
|
+
const typename std::enable_if_t<IsMatrix<SubsetMatType>::value>* = 0);
|
|
286
289
|
|
|
287
290
|
/**
|
|
288
291
|
* Get the ith training subset from a variable of a row type.
|
|
289
292
|
*/
|
|
290
|
-
template<typename
|
|
291
|
-
inline
|
|
292
|
-
|
|
293
|
+
template<typename SubsetRowType>
|
|
294
|
+
inline SubsetRowType GetTrainingSubset(
|
|
295
|
+
SubsetRowType& r,
|
|
296
|
+
const size_t i,
|
|
297
|
+
const typename std::enable_if_t<IsRow<SubsetRowType>::value>* = 0);
|
|
293
298
|
|
|
294
299
|
/**
|
|
295
300
|
* Get the ith validation subset from a variable of a matrix type.
|
|
296
301
|
*/
|
|
297
|
-
template<typename
|
|
298
|
-
inline
|
|
299
|
-
|
|
302
|
+
template<typename SubsetMatType>
|
|
303
|
+
inline SubsetMatType GetValidationSubset(
|
|
304
|
+
SubsetMatType& m,
|
|
305
|
+
const size_t i,
|
|
306
|
+
const typename std::enable_if_t<IsMatrix<SubsetMatType>::value>* = 0);
|
|
300
307
|
|
|
301
308
|
/**
|
|
302
309
|
* Get the ith validation subset from a variable of a row type.
|
|
303
310
|
*/
|
|
304
|
-
template<typename
|
|
305
|
-
inline
|
|
306
|
-
|
|
311
|
+
template<typename SubsetRowType>
|
|
312
|
+
inline SubsetRowType GetValidationSubset(
|
|
313
|
+
SubsetRowType& r,
|
|
314
|
+
const size_t i,
|
|
315
|
+
const typename std::enable_if_t<IsRow<SubsetRowType>::value>* = 0);
|
|
307
316
|
};
|
|
308
317
|
|
|
309
318
|
} // namespace mlpack
|
|
@@ -58,7 +58,7 @@ KFoldCV<MLAlgorithm,
|
|
|
58
58
|
PredictionsType,
|
|
59
59
|
WeightsType>::KFoldCV(const size_t k,
|
|
60
60
|
const MatType& xs,
|
|
61
|
-
const
|
|
61
|
+
const DatasetInfo& datasetInfo,
|
|
62
62
|
const PredictionsType& ys,
|
|
63
63
|
const size_t numClasses,
|
|
64
64
|
const bool shuffle) :
|
|
@@ -111,7 +111,7 @@ KFoldCV<MLAlgorithm,
|
|
|
111
111
|
PredictionsType,
|
|
112
112
|
WeightsType>::KFoldCV(const size_t k,
|
|
113
113
|
const MatType& xs,
|
|
114
|
-
const
|
|
114
|
+
const DatasetInfo& datasetInfo,
|
|
115
115
|
const PredictionsType& ys,
|
|
116
116
|
const size_t numClasses,
|
|
117
117
|
const WeightsType& weights,
|
|
@@ -270,7 +270,7 @@ double KFoldCV<MLAlgorithm,
|
|
|
270
270
|
return 0.0;
|
|
271
271
|
}
|
|
272
272
|
|
|
273
|
-
return
|
|
273
|
+
return mean(evaluations.elem(find_finite(evaluations)));
|
|
274
274
|
}
|
|
275
275
|
|
|
276
276
|
template<typename MLAlgorithm,
|
|
@@ -300,7 +300,7 @@ double KFoldCV<MLAlgorithm,
|
|
|
300
300
|
modelPtr.reset(new MLAlgorithm(std::move(model)));
|
|
301
301
|
}
|
|
302
302
|
|
|
303
|
-
return
|
|
303
|
+
return mean(evaluations);
|
|
304
304
|
}
|
|
305
305
|
|
|
306
306
|
template<typename MLAlgorithm,
|
|
@@ -375,14 +375,15 @@ template<typename MLAlgorithm,
|
|
|
375
375
|
typename MatType,
|
|
376
376
|
typename PredictionsType,
|
|
377
377
|
typename WeightsType>
|
|
378
|
-
template<typename
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
const size_t i
|
|
378
|
+
template<typename SubsetMatType>
|
|
379
|
+
SubsetMatType KFoldCV<MLAlgorithm,
|
|
380
|
+
Metric,
|
|
381
|
+
MatType,
|
|
382
|
+
PredictionsType,
|
|
383
|
+
WeightsType>::GetTrainingSubset(
|
|
384
|
+
SubsetMatType& m,
|
|
385
|
+
const size_t i,
|
|
386
|
+
const typename std::enable_if_t<IsMatrix<SubsetMatType>::value>*)
|
|
386
387
|
{
|
|
387
388
|
// If this is not the first fold, we have to handle it a little bit
|
|
388
389
|
// differently, since the last fold may contain slightly more than 'binSize'
|
|
@@ -390,8 +391,9 @@ arma::Mat<ElementType> KFoldCV<MLAlgorithm,
|
|
|
390
391
|
const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
|
|
391
392
|
(k - 1) * binSize;
|
|
392
393
|
|
|
393
|
-
|
|
394
|
-
|
|
394
|
+
SubsetMatType alias;
|
|
395
|
+
MakeAlias(alias, m, m.n_rows, subsetSize, m.n_rows * binSize * i);
|
|
396
|
+
return alias;
|
|
395
397
|
}
|
|
396
398
|
|
|
397
399
|
template<typename MLAlgorithm,
|
|
@@ -399,14 +401,15 @@ template<typename MLAlgorithm,
|
|
|
399
401
|
typename MatType,
|
|
400
402
|
typename PredictionsType,
|
|
401
403
|
typename WeightsType>
|
|
402
|
-
template<typename
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
const size_t i
|
|
404
|
+
template<typename SubsetRowType>
|
|
405
|
+
SubsetRowType KFoldCV<MLAlgorithm,
|
|
406
|
+
Metric,
|
|
407
|
+
MatType,
|
|
408
|
+
PredictionsType,
|
|
409
|
+
WeightsType>::GetTrainingSubset(
|
|
410
|
+
SubsetRowType& r,
|
|
411
|
+
const size_t i,
|
|
412
|
+
const typename std::enable_if_t<IsRow<SubsetRowType>::value>*)
|
|
410
413
|
{
|
|
411
414
|
// If this is not the first fold, we have to handle it a little bit
|
|
412
415
|
// differently, since the last fold may contain slightly more than 'binSize'
|
|
@@ -414,7 +417,9 @@ arma::Row<ElementType> KFoldCV<MLAlgorithm,
|
|
|
414
417
|
const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
|
|
415
418
|
(k - 1) * binSize;
|
|
416
419
|
|
|
417
|
-
|
|
420
|
+
SubsetRowType alias;
|
|
421
|
+
MakeAlias(alias, r, subsetSize, r.n_rows * binSize * i);
|
|
422
|
+
return alias;
|
|
418
423
|
}
|
|
419
424
|
|
|
420
425
|
template<typename MLAlgorithm,
|
|
@@ -422,18 +427,21 @@ template<typename MLAlgorithm,
|
|
|
422
427
|
typename MatType,
|
|
423
428
|
typename PredictionsType,
|
|
424
429
|
typename WeightsType>
|
|
425
|
-
template<typename
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
const size_t i
|
|
430
|
+
template<typename SubsetMatType>
|
|
431
|
+
SubsetMatType KFoldCV<MLAlgorithm,
|
|
432
|
+
Metric,
|
|
433
|
+
MatType,
|
|
434
|
+
PredictionsType,
|
|
435
|
+
WeightsType>::GetValidationSubset(
|
|
436
|
+
SubsetMatType& m,
|
|
437
|
+
const size_t i,
|
|
438
|
+
const typename std::enable_if_t<IsMatrix<SubsetMatType>::value>*)
|
|
433
439
|
{
|
|
434
440
|
const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
|
|
435
|
-
|
|
436
|
-
|
|
441
|
+
SubsetMatType alias;
|
|
442
|
+
MakeAlias(alias, m, m.n_rows, subsetSize,
|
|
443
|
+
m.n_rows * ValidationSubsetFirstCol(i));
|
|
444
|
+
return alias;
|
|
437
445
|
}
|
|
438
446
|
|
|
439
447
|
template<typename MLAlgorithm,
|
|
@@ -441,18 +449,20 @@ template<typename MLAlgorithm,
|
|
|
441
449
|
typename MatType,
|
|
442
450
|
typename PredictionsType,
|
|
443
451
|
typename WeightsType>
|
|
444
|
-
template<typename
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
const size_t i
|
|
452
|
+
template<typename SubsetRowType>
|
|
453
|
+
SubsetRowType KFoldCV<MLAlgorithm,
|
|
454
|
+
Metric,
|
|
455
|
+
MatType,
|
|
456
|
+
PredictionsType,
|
|
457
|
+
WeightsType>::GetValidationSubset(
|
|
458
|
+
SubsetRowType& r,
|
|
459
|
+
const size_t i,
|
|
460
|
+
const typename std::enable_if_t<IsRow<SubsetRowType>::value>*)
|
|
452
461
|
{
|
|
453
462
|
const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
|
|
454
|
-
|
|
455
|
-
|
|
463
|
+
SubsetRowType alias;
|
|
464
|
+
MakeAlias(alias, r, subsetSize, r.n_rows * ValidationSubsetFirstCol(i));
|
|
465
|
+
return alias;
|
|
456
466
|
}
|
|
457
467
|
|
|
458
468
|
} // namespace mlpack
|