mlpack 4.6.2__cp312-cp312-win_amd64.whl → 4.7.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp312-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp312-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp312-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp312-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp312-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp312-win_amd64.pyd +0 -0
- mlpack/cf.cp312-win_amd64.pyd +0 -0
- mlpack/dbscan.cp312-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp312-win_amd64.pyd +0 -0
- mlpack/det.cp312-win_amd64.pyd +0 -0
- mlpack/emst.cp312-win_amd64.pyd +0 -0
- mlpack/fastmks.cp312-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp312-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp312-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp312-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp312-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp312-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp312-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp312-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp312-win_amd64.pyd +0 -0
- mlpack/image_converter.cp312-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp312-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp312-win_amd64.pyd +0 -0
- mlpack/kfn.cp312-win_amd64.pyd +0 -0
- mlpack/kmeans.cp312-win_amd64.pyd +0 -0
- mlpack/knn.cp312-win_amd64.pyd +0 -0
- mlpack/krann.cp312-win_amd64.pyd +0 -0
- mlpack/lars.cp312-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp312-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp312-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp312-win_amd64.pyd +0 -0
- mlpack/lmnn.cp312-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp312-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp312-win_amd64.pyd +0 -0
- mlpack/lsh.cp312-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp312-win_amd64.pyd +0 -0
- mlpack/nbc.cp312-win_amd64.pyd +0 -0
- mlpack/nca.cp312-win_amd64.pyd +0 -0
- mlpack/nmf.cp312-win_amd64.pyd +0 -0
- mlpack/pca.cp312-win_amd64.pyd +0 -0
- mlpack/perceptron.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp312-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp312-win_amd64.pyd +0 -0
- mlpack/radical.cp312-win_amd64.pyd +0 -0
- mlpack/random_forest.cp312-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp312-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp312-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
mlpack/__init__.py
CHANGED
|
@@ -11,14 +11,14 @@ http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
# start delvewheel patch
|
|
14
|
-
def
|
|
14
|
+
def _delvewheel_patch_1_12_0():
|
|
15
15
|
import os
|
|
16
16
|
if os.path.isdir(libs_dir := os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'mlpack.libs'))):
|
|
17
17
|
os.add_dll_directory(libs_dir)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
del
|
|
20
|
+
_delvewheel_patch_1_12_0()
|
|
21
|
+
del _delvewheel_patch_1_12_0
|
|
22
22
|
# end delvewheel patch
|
|
23
23
|
|
|
24
24
|
import warnings
|
|
@@ -74,4 +74,4 @@ from .adaboost import *
|
|
|
74
74
|
from .linear_regression_train import linear_regression_train
|
|
75
75
|
from .linear_regression_predict import linear_regression_predict
|
|
76
76
|
from .linear_regression import *
|
|
77
|
-
__version__='4.
|
|
77
|
+
__version__='4.7.0'
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mlpack/cf.cp312-win_amd64.pyd
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
mlpack/det.cp312-win_amd64.pyd
CHANGED
|
Binary file
|
mlpack/emst.cp312-win_amd64.pyd
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mlpack/include/mlpack/base.hpp
CHANGED
|
@@ -86,6 +86,7 @@
|
|
|
86
86
|
#include <armadillo>
|
|
87
87
|
#include <mlpack/core/util/arma_traits.hpp>
|
|
88
88
|
#include <mlpack/core/util/omp_reductions.hpp>
|
|
89
|
+
#include <mlpack/core/arma_extend/find_nan.hpp>
|
|
89
90
|
|
|
90
91
|
// On Visual Studio, disable C4519 (default arguments for function templates)
|
|
91
92
|
// since it's by default an error, which doesn't even make any sense because
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file find_nan.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
*
|
|
5
|
+
* When find_nan() is not available (Armadillo < 11.4), provide an internal
|
|
6
|
+
* mlpack implementation that operates the same way. It is slower.
|
|
7
|
+
*/
|
|
8
|
+
#ifndef MLPACK_CORE_ARMA_EXTEND_FIND_NAN_HPP
|
|
9
|
+
#define MLPACK_CORE_ARMA_EXTEND_FIND_NAN_HPP
|
|
10
|
+
|
|
11
|
+
namespace mlpack {
|
|
12
|
+
|
|
13
|
+
#if ARMA_VERSION_MAJOR < 11 || \
|
|
14
|
+
(ARMA_VERSION_MAJOR == 11 && ARMA_VERSION_MINOR < 4)
|
|
15
|
+
|
|
16
|
+
template<typename T>
|
|
17
|
+
arma::uvec find_nan(const T& m,
|
|
18
|
+
const std::enable_if_t<arma::is_arma_type<T>::value>* = 0)
|
|
19
|
+
{
|
|
20
|
+
typedef typename T::elem_type ElemType;
|
|
21
|
+
|
|
22
|
+
if (!std::numeric_limits<ElemType>::has_quiet_NaN)
|
|
23
|
+
return arma::uvec(); // There can't be any NaNs.
|
|
24
|
+
|
|
25
|
+
// find_nonfinite() exists on older Armadillo, and we can also search for +Inf
|
|
26
|
+
// and -Inf.
|
|
27
|
+
arma::uvec nonfiniteIndices = arma::find_nonfinite(m);
|
|
28
|
+
if (nonfiniteIndices.n_elem == 0)
|
|
29
|
+
return arma::uvec();
|
|
30
|
+
|
|
31
|
+
arma::uvec infIndices = arma::find(
|
|
32
|
+
m == std::numeric_limits<ElemType>::infinity());
|
|
33
|
+
arma::uvec neginfIndices = arma::find(
|
|
34
|
+
m == -std::numeric_limits<ElemType>::infinity());
|
|
35
|
+
|
|
36
|
+
arma::uvec result(nonfiniteIndices.n_elem -
|
|
37
|
+
(infIndices.n_elem + neginfIndices.n_elem));
|
|
38
|
+
if (result.n_elem == 0)
|
|
39
|
+
return result;
|
|
40
|
+
|
|
41
|
+
size_t infIndex = 0;
|
|
42
|
+
size_t neginfIndex = 0;
|
|
43
|
+
size_t outputIndex = 0;
|
|
44
|
+
for (size_t i = 0; i < nonfiniteIndices.n_elem; ++i)
|
|
45
|
+
{
|
|
46
|
+
if (infIndex < infIndices.n_elem &&
|
|
47
|
+
nonfiniteIndices[i] == infIndices[infIndex])
|
|
48
|
+
++infIndex;
|
|
49
|
+
else if (neginfIndex < neginfIndices.n_elem &&
|
|
50
|
+
nonfiniteIndices[i] == neginfIndices[neginfIndex])
|
|
51
|
+
++neginfIndex;
|
|
52
|
+
else
|
|
53
|
+
result[outputIndex++] = nonfiniteIndices[i];
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
return result;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
#endif
|
|
60
|
+
|
|
61
|
+
} // namespace mlpack
|
|
62
|
+
|
|
63
|
+
#endif
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/cereal/low_precision.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
*
|
|
5
|
+
* Extra shims necessary for cereal to serialize to JSON for low-precision types
|
|
6
|
+
* (e.g. FP16, BF16, etc.).
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_CEREAL_LOW_PRECISION_HPP
|
|
14
|
+
#define MLPACK_CORE_CEREAL_LOW_PRECISION_HPP
|
|
15
|
+
|
|
16
|
+
namespace cereal {
|
|
17
|
+
|
|
18
|
+
// Because our serialization is always done with name-value pairs, we can catch
|
|
19
|
+
// any FP16 serialization at the NVP level with a specialized implementation of
|
|
20
|
+
// the load and save functions for the JSON archive (the only one that does not
|
|
21
|
+
// serialize low-precision correctly).
|
|
22
|
+
|
|
23
|
+
#if defined(ARMA_HAVE_FP16)
|
|
24
|
+
|
|
25
|
+
inline void CEREAL_SAVE_FUNCTION_NAME(JSONOutputArchive &ar,
|
|
26
|
+
NameValuePair<arma::fp16&> const& t)
|
|
27
|
+
{
|
|
28
|
+
ar.setNextName(t.name);
|
|
29
|
+
std::ostringstream oss;
|
|
30
|
+
oss.precision(std::numeric_limits<arma::fp16>::max_digits10);
|
|
31
|
+
oss << t.value;
|
|
32
|
+
ar(oss.str());
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
inline void CEREAL_LOAD_FUNCTION_NAME(JSONInputArchive& ar,
|
|
36
|
+
NameValuePair<arma::fp16&>& t)
|
|
37
|
+
{
|
|
38
|
+
ar.setNextName(t.name);
|
|
39
|
+
std::string encoded;
|
|
40
|
+
ar.loadValue(encoded);
|
|
41
|
+
t.value = arma::fp16(std::stof(encoded));
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
} // namespace cereal
|
|
47
|
+
|
|
48
|
+
#endif
|
|
@@ -57,12 +57,12 @@ class CVBase
|
|
|
57
57
|
|
|
58
58
|
/**
|
|
59
59
|
* Assert that MLAlgorithm takes the numClasses parameter and a
|
|
60
|
-
*
|
|
60
|
+
* DatasetInfo parameter and store them.
|
|
61
61
|
*
|
|
62
62
|
* @param datasetInfo Type information for each dimension of the dataset.
|
|
63
63
|
* @param numClasses Number of classes in the dataset.
|
|
64
64
|
*/
|
|
65
|
-
CVBase(const
|
|
65
|
+
CVBase(const DatasetInfo& datasetInfo,
|
|
66
66
|
const size_t numClasses);
|
|
67
67
|
|
|
68
68
|
/**
|
|
@@ -101,9 +101,9 @@ class CVBase
|
|
|
101
101
|
static_assert(MIE::IsSupported,
|
|
102
102
|
"The given MLAlgorithm is not supported by MetaInfoExtractor");
|
|
103
103
|
|
|
104
|
-
//! A variable for storing a
|
|
105
|
-
const
|
|
106
|
-
//! An indicator whether a
|
|
104
|
+
//! A variable for storing a DatasetInfo parameter if it is passed.
|
|
105
|
+
const DatasetInfo datasetInfo;
|
|
106
|
+
//! An indicator whether a DatasetInfo parameter has been passed.
|
|
107
107
|
const bool isDatasetInfoPassed;
|
|
108
108
|
//! A variable for storing the numClasses parameter if it is passed.
|
|
109
109
|
size_t numClasses;
|
|
@@ -145,7 +145,7 @@ class CVBase
|
|
|
145
145
|
|
|
146
146
|
/**
|
|
147
147
|
* Construct a trained MLAlgorithm model if MLAlgorithm takes the
|
|
148
|
-
* numClasses parameter and a
|
|
148
|
+
* numClasses parameter and a DatasetInfo parameter.
|
|
149
149
|
*/
|
|
150
150
|
template<typename... MLAlgorithmArgs,
|
|
151
151
|
bool Enabled = MIE::TakesNumClasses & MIE::TakesDatasetInfo,
|
|
@@ -183,7 +183,7 @@ class CVBase
|
|
|
183
183
|
|
|
184
184
|
/**
|
|
185
185
|
* Construct a trained MLAlgorithm model if MLAlgorithm takes the
|
|
186
|
-
* numClasses parameter and a
|
|
186
|
+
* numClasses parameter and a DatasetInfo parameter.
|
|
187
187
|
*/
|
|
188
188
|
template<typename... MLAlgorithmArgs,
|
|
189
189
|
bool Enabled = MIE::TakesNumClasses & MIE::TakesDatasetInfo,
|
|
@@ -196,13 +196,13 @@ class CVBase
|
|
|
196
196
|
const MLAlgorithmArgs&... args);
|
|
197
197
|
|
|
198
198
|
/**
|
|
199
|
-
* When MLAlgorithm supports a
|
|
199
|
+
* When MLAlgorithm supports a DatasetInfo parameter, training should be
|
|
200
200
|
* treated separately - there are models that can be constructed with and
|
|
201
201
|
* without a data:DatasetInfo parameter and models that can be constructed
|
|
202
|
-
* only with a
|
|
202
|
+
* only with a DatasetInfo parameter.
|
|
203
203
|
*
|
|
204
204
|
* Construct a trained MLAlgorithm model when it can be constructed without a
|
|
205
|
-
*
|
|
205
|
+
* DatasetInfo parameter.
|
|
206
206
|
*/
|
|
207
207
|
template<bool ConstructableWithoutDatasetInfo,
|
|
208
208
|
typename... MLAlgorithmArgs,
|
|
@@ -213,7 +213,7 @@ class CVBase
|
|
|
213
213
|
|
|
214
214
|
/**
|
|
215
215
|
* Construct a trained MLAlgorithm model when it can't be constructed without
|
|
216
|
-
* a
|
|
216
|
+
* a DatasetInfo parameter.
|
|
217
217
|
*/
|
|
218
218
|
template<bool ConstructableWithoutDatasetInfo,
|
|
219
219
|
typename... MLAlgorithmArgs,
|
|
@@ -54,7 +54,7 @@ template<typename MLAlgorithm,
|
|
|
54
54
|
CVBase<MLAlgorithm,
|
|
55
55
|
MatType,
|
|
56
56
|
PredictionsType,
|
|
57
|
-
WeightsType>::CVBase(const
|
|
57
|
+
WeightsType>::CVBase(const DatasetInfo& datasetInfo,
|
|
58
58
|
const size_t numClasses) :
|
|
59
59
|
datasetInfo(datasetInfo),
|
|
60
60
|
isDatasetInfoPassed(true),
|
|
@@ -63,7 +63,7 @@ CVBase<MLAlgorithm,
|
|
|
63
63
|
static_assert(MIE::TakesNumClasses,
|
|
64
64
|
"The given MLAlgorithm does not take the numClasses parameter");
|
|
65
65
|
static_assert(MIE::TakesDatasetInfo,
|
|
66
|
-
"The given MLAlgorithm does not accept a
|
|
66
|
+
"The given MLAlgorithm does not accept a DatasetInfo parameter");
|
|
67
67
|
}
|
|
68
68
|
|
|
69
69
|
template<typename MLAlgorithm,
|
|
@@ -184,9 +184,9 @@ MLAlgorithm CVBase<MLAlgorithm,
|
|
|
184
184
|
{
|
|
185
185
|
static_assert(
|
|
186
186
|
std::is_constructible_v<MLAlgorithm, const MatType&,
|
|
187
|
-
const
|
|
187
|
+
const DatasetInfo, const PredictionsType&, const size_t,
|
|
188
188
|
MLAlgorithmArgs...>,
|
|
189
|
-
"The given MLAlgorithm is not constructible with a
|
|
189
|
+
"The given MLAlgorithm is not constructible with a DatasetInfo "
|
|
190
190
|
"parameter and the passed arguments");
|
|
191
191
|
|
|
192
192
|
static const bool constructableWithoutDatasetInfo =
|
|
@@ -256,9 +256,9 @@ MLAlgorithm CVBase<MLAlgorithm,
|
|
|
256
256
|
{
|
|
257
257
|
static_assert(
|
|
258
258
|
std::is_constructible_v<MLAlgorithm, const MatType&,
|
|
259
|
-
const
|
|
259
|
+
const DatasetInfo, const PredictionsType&, const size_t,
|
|
260
260
|
const WeightsType&, MLAlgorithmArgs...>,
|
|
261
|
-
"The given MLAlgorithm is not constructible with a
|
|
261
|
+
"The given MLAlgorithm is not constructible with a DatasetInfo "
|
|
262
262
|
"parameter and the passed arguments");
|
|
263
263
|
|
|
264
264
|
static const bool constructableWithoutDatasetInfo =
|
|
@@ -302,7 +302,7 @@ MLAlgorithm CVBase<MLAlgorithm,
|
|
|
302
302
|
{
|
|
303
303
|
if (!isDatasetInfoPassed)
|
|
304
304
|
throw std::invalid_argument(
|
|
305
|
-
"The given MLAlgorithm requires a
|
|
305
|
+
"The given MLAlgorithm requires a DatasetInfo parameter");
|
|
306
306
|
|
|
307
307
|
return MLAlgorithm(xs, datasetInfo, ys, numClasses, args...);
|
|
308
308
|
}
|
|
@@ -97,7 +97,7 @@ class KFoldCV
|
|
|
97
97
|
|
|
98
98
|
/**
|
|
99
99
|
* This constructor can be used for multiclass classification algorithms that
|
|
100
|
-
* can take a
|
|
100
|
+
* can take a DatasetInfo parameter.
|
|
101
101
|
*
|
|
102
102
|
* @param k Number of folds (should be at least 2).
|
|
103
103
|
* @param xs Data points to cross-validate on.
|
|
@@ -108,7 +108,7 @@ class KFoldCV
|
|
|
108
108
|
*/
|
|
109
109
|
KFoldCV(const size_t k,
|
|
110
110
|
const MatType& xs,
|
|
111
|
-
const
|
|
111
|
+
const DatasetInfo& datasetInfo,
|
|
112
112
|
const PredictionsType& ys,
|
|
113
113
|
const size_t numClasses,
|
|
114
114
|
const bool shuffle = true);
|
|
@@ -150,7 +150,7 @@ class KFoldCV
|
|
|
150
150
|
|
|
151
151
|
/**
|
|
152
152
|
* This constructor can be used for multiclass classification algorithms that
|
|
153
|
-
* can take a
|
|
153
|
+
* can take a DatasetInfo parameter and support weighted learning.
|
|
154
154
|
*
|
|
155
155
|
* @param k Number of folds (should be at least 2).
|
|
156
156
|
* @param xs Data points to cross-validate on.
|
|
@@ -162,7 +162,7 @@ class KFoldCV
|
|
|
162
162
|
*/
|
|
163
163
|
KFoldCV(const size_t k,
|
|
164
164
|
const MatType& xs,
|
|
165
|
-
const
|
|
165
|
+
const DatasetInfo& datasetInfo,
|
|
166
166
|
const PredictionsType& ys,
|
|
167
167
|
const size_t numClasses,
|
|
168
168
|
const WeightsType& weights,
|
|
@@ -58,7 +58,7 @@ KFoldCV<MLAlgorithm,
|
|
|
58
58
|
PredictionsType,
|
|
59
59
|
WeightsType>::KFoldCV(const size_t k,
|
|
60
60
|
const MatType& xs,
|
|
61
|
-
const
|
|
61
|
+
const DatasetInfo& datasetInfo,
|
|
62
62
|
const PredictionsType& ys,
|
|
63
63
|
const size_t numClasses,
|
|
64
64
|
const bool shuffle) :
|
|
@@ -111,7 +111,7 @@ KFoldCV<MLAlgorithm,
|
|
|
111
111
|
PredictionsType,
|
|
112
112
|
WeightsType>::KFoldCV(const size_t k,
|
|
113
113
|
const MatType& xs,
|
|
114
|
-
const
|
|
114
|
+
const DatasetInfo& datasetInfo,
|
|
115
115
|
const PredictionsType& ys,
|
|
116
116
|
const size_t numClasses,
|
|
117
117
|
const WeightsType& weights,
|
|
@@ -270,7 +270,7 @@ double KFoldCV<MLAlgorithm,
|
|
|
270
270
|
return 0.0;
|
|
271
271
|
}
|
|
272
272
|
|
|
273
|
-
return
|
|
273
|
+
return mean(evaluations.elem(find_finite(evaluations)));
|
|
274
274
|
}
|
|
275
275
|
|
|
276
276
|
template<typename MLAlgorithm,
|
|
@@ -300,7 +300,7 @@ double KFoldCV<MLAlgorithm,
|
|
|
300
300
|
modelPtr.reset(new MLAlgorithm(std::move(model)));
|
|
301
301
|
}
|
|
302
302
|
|
|
303
|
-
return
|
|
303
|
+
return mean(evaluations);
|
|
304
304
|
}
|
|
305
305
|
|
|
306
306
|
template<typename MLAlgorithm,
|
|
@@ -25,7 +25,7 @@ namespace mlpack {
|
|
|
25
25
|
* @tparam MatType The type of data.
|
|
26
26
|
* @tparam PredictionsType The type of predictions.
|
|
27
27
|
* @tparam WeightsType The type of weights.
|
|
28
|
-
* @tparam DatasetInfo An indicator whether a
|
|
28
|
+
* @tparam DatasetInfo An indicator whether a DatasetInfo parameter should
|
|
29
29
|
* be present.
|
|
30
30
|
* @tparam NumClasses An indicator whether the numClasses parameter should be
|
|
31
31
|
* present.
|
|
@@ -101,7 +101,7 @@ struct TrainForm<MT, PT, void, false, false> : public TrainFormBase4<PT, void,
|
|
|
101
101
|
|
|
102
102
|
template<typename MT, typename PT>
|
|
103
103
|
struct TrainForm<MT, PT, void, true, false> : public TrainFormBase5<PT, void,
|
|
104
|
-
const MT&, const
|
|
104
|
+
const MT&, const DatasetInfo&, const PT&> {};
|
|
105
105
|
|
|
106
106
|
template<typename MT, typename PT, typename WT>
|
|
107
107
|
struct TrainForm<MT, PT, WT, false, false> : public TrainFormBase5<PT, WT,
|
|
@@ -109,7 +109,7 @@ struct TrainForm<MT, PT, WT, false, false> : public TrainFormBase5<PT, WT,
|
|
|
109
109
|
|
|
110
110
|
template<typename MT, typename PT, typename WT>
|
|
111
111
|
struct TrainForm<MT, PT, WT, true, false> : public TrainFormBase6<PT, WT,
|
|
112
|
-
const MT&, const
|
|
112
|
+
const MT&, const DatasetInfo&, const PT&, const WT&> {};
|
|
113
113
|
|
|
114
114
|
template<typename MT, typename PT>
|
|
115
115
|
struct TrainForm<MT, PT, void, false, true> : public TrainFormBase5<PT, void,
|
|
@@ -117,7 +117,7 @@ struct TrainForm<MT, PT, void, false, true> : public TrainFormBase5<PT, void,
|
|
|
117
117
|
|
|
118
118
|
template<typename MT, typename PT>
|
|
119
119
|
struct TrainForm<MT, PT, void, true, true> : public TrainFormBase6<PT, void,
|
|
120
|
-
const MT&, const
|
|
120
|
+
const MT&, const DatasetInfo&, const PT&, const size_t> {};
|
|
121
121
|
|
|
122
122
|
template<typename MT, typename PT, typename WT>
|
|
123
123
|
struct TrainForm<MT, PT, WT, false, true> : public TrainFormBase6<PT, WT,
|
|
@@ -125,7 +125,7 @@ struct TrainForm<MT, PT, WT, false, true> : public TrainFormBase6<PT, WT,
|
|
|
125
125
|
|
|
126
126
|
template<typename MT, typename PT, typename WT>
|
|
127
127
|
struct TrainForm<MT, PT, WT, true, true> : public TrainFormBase7<PT, WT,
|
|
128
|
-
const MT&, const
|
|
128
|
+
const MT&, const DatasetInfo&, const PT&,
|
|
129
129
|
const size_t, const WT&> {};
|
|
130
130
|
#else
|
|
131
131
|
template<typename PT, typename WT, typename... SignatureParams>
|
|
@@ -147,7 +147,7 @@ struct TrainForm<MT, PT, void, false, false> : public TrainFormBase<PT, void,
|
|
|
147
147
|
|
|
148
148
|
template<typename MT, typename PT>
|
|
149
149
|
struct TrainForm<MT, PT, void, true, false> : public TrainFormBase<PT, void,
|
|
150
|
-
const MT&, const
|
|
150
|
+
const MT&, const DatasetInfo&, const PT&> {};
|
|
151
151
|
|
|
152
152
|
template<typename MT, typename PT, typename WT>
|
|
153
153
|
struct TrainForm<MT, PT, WT, false, false> : public TrainFormBase<PT, WT,
|
|
@@ -155,7 +155,7 @@ struct TrainForm<MT, PT, WT, false, false> : public TrainFormBase<PT, WT,
|
|
|
155
155
|
|
|
156
156
|
template<typename MT, typename PT, typename WT>
|
|
157
157
|
struct TrainForm<MT, PT, WT, true, false> : public TrainFormBase<PT, WT,
|
|
158
|
-
const MT&, const
|
|
158
|
+
const MT&, const DatasetInfo&, const PT&, const WT&> {};
|
|
159
159
|
|
|
160
160
|
template<typename MT, typename PT>
|
|
161
161
|
struct TrainForm<MT, PT, void, false, true> : public TrainFormBase<PT, void,
|
|
@@ -163,7 +163,7 @@ struct TrainForm<MT, PT, void, false, true> : public TrainFormBase<PT, void,
|
|
|
163
163
|
|
|
164
164
|
template<typename MT, typename PT>
|
|
165
165
|
struct TrainForm<MT, PT, void, true, true> : public TrainFormBase<PT, void,
|
|
166
|
-
const MT&, const
|
|
166
|
+
const MT&, const DatasetInfo&, const PT&, const size_t> {};
|
|
167
167
|
|
|
168
168
|
template<typename MT, typename PT, typename WT>
|
|
169
169
|
struct TrainForm<MT, PT, WT, false, true> : public TrainFormBase<PT, WT,
|
|
@@ -171,7 +171,7 @@ struct TrainForm<MT, PT, WT, false, true> : public TrainFormBase<PT, WT,
|
|
|
171
171
|
|
|
172
172
|
template<typename MT, typename PT, typename WT>
|
|
173
173
|
struct TrainForm<MT, PT, WT, true, true> : public TrainFormBase<PT, WT,
|
|
174
|
-
const MT&, const
|
|
174
|
+
const MT&, const DatasetInfo&, const PT&,
|
|
175
175
|
const size_t, const WT&> {};
|
|
176
176
|
#endif
|
|
177
177
|
|
|
@@ -336,7 +336,7 @@ class MetaInfoExtractor
|
|
|
336
336
|
static const bool SupportsWeights = !std::is_same_v<WeightsType, void*>;
|
|
337
337
|
|
|
338
338
|
/**
|
|
339
|
-
* An indication whether MLAlgorithm takes a
|
|
339
|
+
* An indication whether MLAlgorithm takes a DatasetInfo parameter.
|
|
340
340
|
*/
|
|
341
341
|
static const bool TakesDatasetInfo = Selects<TF5>::value;
|
|
342
342
|
|
|
@@ -28,7 +28,8 @@ template<typename DataType, typename DistanceType>
|
|
|
28
28
|
DataType PairwiseDistances(const DataType& data,
|
|
29
29
|
const DistanceType& distance)
|
|
30
30
|
{
|
|
31
|
-
DataType distances = DataType(data.n_cols, data.n_cols,
|
|
31
|
+
DataType distances = DataType(data.n_cols, data.n_cols,
|
|
32
|
+
GetFillType<DataType>::none);
|
|
32
33
|
for (size_t i = 0; i < data.n_cols; i++)
|
|
33
34
|
{
|
|
34
35
|
for (size_t j = 0; j < i; j++)
|
|
@@ -27,7 +27,7 @@ double R2Score<AdjustedR2>::Evaluate(MLAlgorithm& model,
|
|
|
27
27
|
// Taking Predicted Output from the model.
|
|
28
28
|
model.Predict(data, predictedResponses);
|
|
29
29
|
// Mean value of response.
|
|
30
|
-
double meanResponses =
|
|
30
|
+
double meanResponses = mean(responses);
|
|
31
31
|
|
|
32
32
|
// Calculate the numerator i.e. residual sum of squares.
|
|
33
33
|
double residualSumSquared = accu(arma::square(responses -
|
|
@@ -22,7 +22,7 @@ double SilhouetteScore::Overall(const DataType& X,
|
|
|
22
22
|
const Metric& metric)
|
|
23
23
|
{
|
|
24
24
|
util::CheckSameSizes(X, labels, "SilhouetteScore::Overall()");
|
|
25
|
-
return
|
|
25
|
+
return mean(SamplesScore(X, labels, metric));
|
|
26
26
|
}
|
|
27
27
|
|
|
28
28
|
template<typename DataType>
|
|
@@ -105,7 +105,7 @@ class SimpleCV
|
|
|
105
105
|
|
|
106
106
|
/**
|
|
107
107
|
* This constructor can be used for multiclass classification algorithms that
|
|
108
|
-
* can take a
|
|
108
|
+
* can take a DatasetInfo parameter.
|
|
109
109
|
*
|
|
110
110
|
* @param validationSize A proportion (between 0 and 1) of data used as a
|
|
111
111
|
* validation set.
|
|
@@ -120,7 +120,7 @@ class SimpleCV
|
|
|
120
120
|
template<typename MatInType, typename PredictionsInType>
|
|
121
121
|
SimpleCV(const double validationSize,
|
|
122
122
|
MatInType&& xs,
|
|
123
|
-
const
|
|
123
|
+
const DatasetInfo& datasetInfo,
|
|
124
124
|
PredictionsInType&& ys,
|
|
125
125
|
const size_t numClasses);
|
|
126
126
|
|
|
@@ -173,7 +173,7 @@ class SimpleCV
|
|
|
173
173
|
|
|
174
174
|
/**
|
|
175
175
|
* This constructor can be used for multiclass classification algorithms that
|
|
176
|
-
* can take a
|
|
176
|
+
* can take a DatasetInfo parameter and support weighted learning.
|
|
177
177
|
*
|
|
178
178
|
* @param validationSize A proportion (between 0 and 1) of data used as a
|
|
179
179
|
* validation set.
|
|
@@ -192,7 +192,7 @@ class SimpleCV
|
|
|
192
192
|
typename WeightsInType>
|
|
193
193
|
SimpleCV(const double validationSize,
|
|
194
194
|
MatInType&& xs,
|
|
195
|
-
const
|
|
195
|
+
const DatasetInfo& datasetInfo,
|
|
196
196
|
PredictionsInType&& ys,
|
|
197
197
|
const size_t numClasses,
|
|
198
198
|
WeightsInType&& weights);
|
|
@@ -61,7 +61,7 @@ SimpleCV<MLAlgorithm,
|
|
|
61
61
|
PredictionsType,
|
|
62
62
|
WeightsType>::SimpleCV(const double validationSize,
|
|
63
63
|
MIT&& xs,
|
|
64
|
-
const
|
|
64
|
+
const DatasetInfo& datasetInfo,
|
|
65
65
|
PIT&& ys,
|
|
66
66
|
const size_t numClasses) :
|
|
67
67
|
SimpleCV(Base(datasetInfo, numClasses), validationSize,
|
|
@@ -117,7 +117,7 @@ SimpleCV<MLAlgorithm,
|
|
|
117
117
|
PredictionsType,
|
|
118
118
|
WeightsType>::SimpleCV(const double validationSize,
|
|
119
119
|
MIT&& xs,
|
|
120
|
-
const
|
|
120
|
+
const DatasetInfo& datasetInfo,
|
|
121
121
|
PIT&& ys,
|
|
122
122
|
const size_t numClasses,
|
|
123
123
|
WIT&& weights) :
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
#include <mlpack/prereqs.hpp>
|
|
17
17
|
|
|
18
18
|
namespace mlpack {
|
|
19
|
-
namespace data {
|
|
20
19
|
|
|
21
20
|
/**
|
|
22
21
|
* Given an input dataset and threshold, set values greater than threshold to
|
|
@@ -86,7 +85,6 @@ void Binarize(const arma::Mat<T>& input,
|
|
|
86
85
|
output(dimension, i) = input(dimension, i) > threshold;
|
|
87
86
|
}
|
|
88
87
|
|
|
89
|
-
} // namespace data
|
|
90
88
|
} // namespace mlpack
|
|
91
89
|
|
|
92
90
|
#endif
|
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
#include <mlpack/core/util/params.hpp>
|
|
18
18
|
|
|
19
19
|
namespace mlpack {
|
|
20
|
-
namespace data {
|
|
21
20
|
|
|
22
21
|
inline void CheckCategoricalParam(util::Params& params,
|
|
23
22
|
const std::string& paramName)
|
|
@@ -35,7 +34,6 @@ inline void CheckCategoricalParam(util::Params& params,
|
|
|
35
34
|
Log::Fatal << errMsg2 << std::endl;
|
|
36
35
|
}
|
|
37
36
|
|
|
38
|
-
} // namespace data
|
|
39
37
|
} // namespace mlpack
|
|
40
38
|
|
|
41
39
|
#endif
|