mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -24,6 +24,7 @@
|
|
|
24
24
|
#include <mlpack/methods/ann/activation_functions/mish_function.hpp>
|
|
25
25
|
#include <mlpack/methods/ann/activation_functions/lisht_function.hpp>
|
|
26
26
|
#include <mlpack/methods/ann/activation_functions/gelu_function.hpp>
|
|
27
|
+
#include <mlpack/methods/ann/activation_functions/gelu_exact_function.hpp>
|
|
27
28
|
#include <mlpack/methods/ann/activation_functions/elliot_function.hpp>
|
|
28
29
|
#include <mlpack/methods/ann/activation_functions/elish_function.hpp>
|
|
29
30
|
#include <mlpack/methods/ann/activation_functions/gaussian_function.hpp>
|
|
@@ -51,6 +52,7 @@ namespace mlpack {
|
|
|
51
52
|
* - Mish
|
|
52
53
|
* - LiSHT
|
|
53
54
|
* - GELU
|
|
55
|
+
* - GELUExact
|
|
54
56
|
* - ELiSH
|
|
55
57
|
* - Elliot
|
|
56
58
|
* - Gaussian
|
|
@@ -68,6 +70,9 @@ template <
|
|
|
68
70
|
class BaseLayer : public Layer<MatType>
|
|
69
71
|
{
|
|
70
72
|
public:
|
|
73
|
+
// Convenience typedef to access the element type of the weights and data.
|
|
74
|
+
using ElemType = typename MatType::elem_type;
|
|
75
|
+
|
|
71
76
|
/**
|
|
72
77
|
* Create the BaseLayer object.
|
|
73
78
|
*/
|
|
@@ -83,7 +88,7 @@ class BaseLayer : public Layer<MatType>
|
|
|
83
88
|
// members.
|
|
84
89
|
|
|
85
90
|
//! Clone the BaseLayer object. This handles polymorphism correctly.
|
|
86
|
-
BaseLayer* Clone() const { return new BaseLayer(*this); }
|
|
91
|
+
virtual BaseLayer* Clone() const { return new BaseLayer(*this); }
|
|
87
92
|
|
|
88
93
|
/**
|
|
89
94
|
* Forward pass: apply the activation to the inputs.
|
|
@@ -131,138 +136,110 @@ class BaseLayer : public Layer<MatType>
|
|
|
131
136
|
/**
|
|
132
137
|
* Standard Sigmoid-Layer using the logistic activation function.
|
|
133
138
|
*/
|
|
134
|
-
using Sigmoid = BaseLayer<LogisticFunction, arma::mat>;
|
|
135
|
-
|
|
136
139
|
template<typename MatType = arma::mat>
|
|
137
|
-
using
|
|
140
|
+
using Sigmoid = BaseLayer<LogisticFunction, MatType>;
|
|
138
141
|
|
|
139
142
|
/**
|
|
140
143
|
* Standard rectified linear unit non-linearity layer.
|
|
141
144
|
*/
|
|
142
|
-
using ReLU = BaseLayer<RectifierFunction, arma::mat>;
|
|
143
|
-
|
|
144
145
|
template<typename MatType = arma::mat>
|
|
145
|
-
using
|
|
146
|
+
using ReLU = BaseLayer<RectifierFunction, MatType>;
|
|
146
147
|
|
|
147
148
|
/**
|
|
148
149
|
* Standard hyperbolic tangent layer.
|
|
149
150
|
*/
|
|
150
|
-
using TanH = BaseLayer<TanhFunction, arma::mat>;
|
|
151
|
-
|
|
152
151
|
template<typename MatType = arma::mat>
|
|
153
|
-
using
|
|
152
|
+
using TanH = BaseLayer<TanhFunction, MatType>;
|
|
154
153
|
|
|
155
154
|
/**
|
|
156
155
|
* Standard Softplus-Layer using the Softplus activation function.
|
|
157
156
|
*/
|
|
158
|
-
using SoftPlus = BaseLayer<SoftplusFunction, arma::mat>;
|
|
159
|
-
|
|
160
157
|
template<typename MatType = arma::mat>
|
|
161
|
-
using
|
|
158
|
+
using SoftPlus = BaseLayer<SoftplusFunction, MatType>;
|
|
162
159
|
|
|
163
160
|
/**
|
|
164
161
|
* Standard HardSigmoid-Layer using the HardSigmoid activation function.
|
|
165
162
|
*/
|
|
166
|
-
using HardSigmoid = BaseLayer<HardSigmoidFunction, arma::mat>;
|
|
167
|
-
|
|
168
163
|
template<typename MatType = arma::mat>
|
|
169
|
-
using
|
|
164
|
+
using HardSigmoid = BaseLayer<HardSigmoidFunction, MatType>;
|
|
170
165
|
|
|
171
166
|
/**
|
|
172
167
|
* Standard Swish-Layer using the Swish activation function.
|
|
173
168
|
*/
|
|
174
|
-
using Swish = BaseLayer<SwishFunction, arma::mat>;
|
|
175
|
-
|
|
176
169
|
template<typename MatType = arma::mat>
|
|
177
|
-
using
|
|
170
|
+
using Swish = BaseLayer<SwishFunction, MatType>;
|
|
178
171
|
|
|
179
172
|
/**
|
|
180
173
|
* Standard Mish-Layer using the Mish activation function.
|
|
181
174
|
*/
|
|
182
|
-
using Mish = BaseLayer<MishFunction, arma::mat>;
|
|
183
|
-
|
|
184
175
|
template<typename MatType = arma::mat>
|
|
185
|
-
using
|
|
176
|
+
using Mish = BaseLayer<MishFunction, MatType>;
|
|
186
177
|
|
|
187
178
|
/**
|
|
188
179
|
* Standard LiSHT-Layer using the LiSHT activation function.
|
|
189
180
|
*/
|
|
190
|
-
using LiSHT = BaseLayer<LiSHTFunction, arma::mat>;
|
|
191
|
-
|
|
192
181
|
template<typename MatType = arma::mat>
|
|
193
|
-
using
|
|
182
|
+
using LiSHT = BaseLayer<LiSHTFunction, MatType>;
|
|
194
183
|
|
|
195
184
|
/**
|
|
196
185
|
* Standard GELU-Layer using the GELU activation function.
|
|
197
186
|
*/
|
|
198
|
-
|
|
187
|
+
template<typename MatType = arma::mat>
|
|
188
|
+
using GELU = BaseLayer<GELUFunction, MatType>;
|
|
199
189
|
|
|
190
|
+
/**
|
|
191
|
+
* Standard GELUExact-Layer using the GELUExact activation function.
|
|
192
|
+
*/
|
|
200
193
|
template<typename MatType = arma::mat>
|
|
201
|
-
using
|
|
194
|
+
using GELUExact = BaseLayer<GELUExactFunction, MatType>;
|
|
202
195
|
|
|
203
196
|
/**
|
|
204
197
|
* Standard Elliot-Layer using the Elliot activation function.
|
|
205
198
|
*/
|
|
206
|
-
using Elliot = BaseLayer<ElliotFunction, arma::mat>;
|
|
207
|
-
|
|
208
199
|
template<typename MatType = arma::mat>
|
|
209
|
-
using
|
|
200
|
+
using Elliot = BaseLayer<ElliotFunction, MatType>;
|
|
210
201
|
|
|
211
202
|
/**
|
|
212
203
|
* Standard ELiSH-Layer using the ELiSH activation function.
|
|
213
204
|
*/
|
|
214
|
-
using Elish = BaseLayer<ElishFunction, arma::mat>;
|
|
215
|
-
|
|
216
205
|
template<typename MatType = arma::mat>
|
|
217
|
-
using
|
|
206
|
+
using Elish = BaseLayer<ElishFunction, MatType>;
|
|
218
207
|
|
|
219
208
|
/**
|
|
220
209
|
* Standard Gaussian-Layer using the Gaussian activation function.
|
|
221
210
|
*/
|
|
222
|
-
using Gaussian = BaseLayer<GaussianFunction, arma::mat>;
|
|
223
|
-
|
|
224
211
|
template<typename MatType = arma::mat>
|
|
225
|
-
using
|
|
212
|
+
using Gaussian = BaseLayer<GaussianFunction, MatType>;
|
|
226
213
|
|
|
227
214
|
/**
|
|
228
215
|
* Standard HardSwish-Layer using the HardSwish activation function.
|
|
229
216
|
*/
|
|
230
|
-
using HardSwish = BaseLayer<HardSwishFunction, arma::mat>;
|
|
231
|
-
|
|
232
217
|
template <typename MatType = arma::mat>
|
|
233
|
-
using
|
|
218
|
+
using HardSwish = BaseLayer<HardSwishFunction, MatType>;
|
|
234
219
|
|
|
235
220
|
/**
|
|
236
221
|
* Standard TanhExp-Layer using the TanhExp activation function.
|
|
237
222
|
*/
|
|
238
|
-
using TanhExp = BaseLayer<TanhExpFunction, arma::mat>;
|
|
239
|
-
|
|
240
223
|
template<typename MatType = arma::mat>
|
|
241
|
-
using
|
|
224
|
+
using TanhExp = BaseLayer<TanhExpFunction, MatType>;
|
|
242
225
|
|
|
243
226
|
/**
|
|
244
227
|
* Standard SILU-Layer using the SILU activation function.
|
|
245
228
|
*/
|
|
246
|
-
using SILU = BaseLayer<SILUFunction, arma::mat>;
|
|
247
|
-
|
|
248
229
|
template<typename MatType = arma::mat>
|
|
249
|
-
using
|
|
230
|
+
using SILU = BaseLayer<SILUFunction, MatType>;
|
|
250
231
|
|
|
251
232
|
/**
|
|
252
233
|
* Standard Hyper Sinh layer.
|
|
253
234
|
*/
|
|
254
|
-
using HyperSinh = BaseLayer<HyperSinhFunction, arma::mat>;
|
|
255
|
-
|
|
256
235
|
template<typename MatType = arma::mat>
|
|
257
|
-
using
|
|
236
|
+
using HyperSinh = BaseLayer<HyperSinhFunction, MatType>;
|
|
258
237
|
|
|
259
238
|
/**
|
|
260
239
|
* Standard Bipolar Sigmoid layer.
|
|
261
240
|
*/
|
|
262
|
-
using BipolarSigmoid = BaseLayer<BipolarSigmoidFunction, arma::mat>;
|
|
263
|
-
|
|
264
241
|
template<typename MatType = arma::mat>
|
|
265
|
-
using
|
|
242
|
+
using BipolarSigmoid = BaseLayer<BipolarSigmoidFunction, MatType>;
|
|
266
243
|
|
|
267
244
|
} // namespace mlpack
|
|
268
245
|
|
|
@@ -50,10 +50,13 @@ namespace mlpack {
|
|
|
50
50
|
* computation.
|
|
51
51
|
*/
|
|
52
52
|
template <typename MatType = arma::mat>
|
|
53
|
-
class
|
|
53
|
+
class BatchNorm : public Layer<MatType>
|
|
54
54
|
{
|
|
55
55
|
public:
|
|
56
|
+
// Convenience typedefs to access the element type of the weights and data.
|
|
57
|
+
using ElemType = typename MatType::elem_type;
|
|
56
58
|
using CubeType = typename GetCubeType<MatType>::type;
|
|
59
|
+
|
|
57
60
|
/**
|
|
58
61
|
* Create the BatchNorm object.
|
|
59
62
|
*
|
|
@@ -72,7 +75,7 @@ class BatchNormType : public Layer<MatType>
|
|
|
72
75
|
* three dimensions rows, columns and slices), and `minAxis` & `maxAxis` is
|
|
73
76
|
* 2, then we apply the same normalization across different slices.
|
|
74
77
|
*/
|
|
75
|
-
|
|
78
|
+
BatchNorm();
|
|
76
79
|
|
|
77
80
|
/**
|
|
78
81
|
* Create the BatchNorm layer object for a specified axis of input units as
|
|
@@ -93,30 +96,30 @@ class BatchNormType : public Layer<MatType>
|
|
|
93
96
|
* updating the parameters or momentum is used.
|
|
94
97
|
* @param momentum Parameter used to to update the running mean and variance.
|
|
95
98
|
*/
|
|
96
|
-
|
|
99
|
+
BatchNorm(const size_t minAxis,
|
|
97
100
|
const size_t maxAxis,
|
|
98
101
|
const double eps = 1e-8,
|
|
99
102
|
const bool average = true,
|
|
100
103
|
const double momentum = 0.1);
|
|
101
104
|
|
|
102
|
-
virtual ~
|
|
105
|
+
virtual ~BatchNorm() { }
|
|
103
106
|
|
|
104
|
-
//! Clone the
|
|
105
|
-
|
|
107
|
+
//! Clone the BatchNorm object. This handles polymorphism correctly.
|
|
108
|
+
BatchNorm* Clone() const { return new BatchNorm(*this); }
|
|
106
109
|
|
|
107
110
|
//! Copy the other BatchNorm layer (but not weights).
|
|
108
|
-
|
|
111
|
+
BatchNorm(const BatchNorm& layer);
|
|
109
112
|
|
|
110
113
|
//! Take ownership of the members of the other BatchNorm layer (but not
|
|
111
114
|
//! weights).
|
|
112
|
-
|
|
115
|
+
BatchNorm(BatchNorm&& layer);
|
|
113
116
|
|
|
114
117
|
//! Copy the other BatchNorm layer (but not weights).
|
|
115
|
-
|
|
118
|
+
BatchNorm& operator=(const BatchNorm& layer);
|
|
116
119
|
|
|
117
120
|
//! Take ownership of the members of the other BatchNorm layer (but not
|
|
118
121
|
//! weights).
|
|
119
|
-
|
|
122
|
+
BatchNorm& operator=(BatchNorm&& layer);
|
|
120
123
|
|
|
121
124
|
/**
|
|
122
125
|
* Reset the layer parameters.
|
|
@@ -189,7 +192,7 @@ class BatchNormType : public Layer<MatType>
|
|
|
189
192
|
MatType& TrainingVariance() { return runningVariance; }
|
|
190
193
|
|
|
191
194
|
//! Get the number of input units / channels.
|
|
192
|
-
size_t InputSize() const { return
|
|
195
|
+
size_t InputSize() const { return inputUnits; }
|
|
193
196
|
|
|
194
197
|
//! Get the epsilon value.
|
|
195
198
|
const double &Epsilon() const { return eps; }
|
|
@@ -203,7 +206,7 @@ class BatchNormType : public Layer<MatType>
|
|
|
203
206
|
bool Average() const { return average; }
|
|
204
207
|
|
|
205
208
|
//! Get size of weights.
|
|
206
|
-
size_t WeightSize() const { return 2 *
|
|
209
|
+
size_t WeightSize() const { return 2 * inputUnits; }
|
|
207
210
|
|
|
208
211
|
//! Compute the output dimensions of the layer given `InputDimensions()`.
|
|
209
212
|
void ComputeOutputDimensions();
|
|
@@ -253,7 +256,7 @@ class BatchNormType : public Layer<MatType>
|
|
|
253
256
|
|
|
254
257
|
//! Locally-stored number of input units. (This is the product of all
|
|
255
258
|
//! dimensions between minAxis and maxAxis, inclusive.)
|
|
256
|
-
size_t
|
|
259
|
+
size_t inputUnits;
|
|
257
260
|
|
|
258
261
|
//! Locally-stored number of higher dimension we are not applying
|
|
259
262
|
//! batch normalization to. This is the product of this->inputDimensions
|
|
@@ -273,11 +276,6 @@ class BatchNormType : public Layer<MatType>
|
|
|
273
276
|
CubeType inputMean;
|
|
274
277
|
}; // class BatchNorm
|
|
275
278
|
|
|
276
|
-
// Convenience typedefs.
|
|
277
|
-
|
|
278
|
-
// Standard Adaptive max pooling layer.
|
|
279
|
-
using BatchNorm = BatchNormType<arma::mat>;
|
|
280
|
-
|
|
281
279
|
} // namespace mlpack
|
|
282
280
|
|
|
283
281
|
// Include the implementation.
|
|
@@ -22,7 +22,7 @@
|
|
|
22
22
|
namespace mlpack {
|
|
23
23
|
|
|
24
24
|
template<typename MatType>
|
|
25
|
-
|
|
25
|
+
BatchNorm<MatType>::BatchNorm() :
|
|
26
26
|
Layer<MatType>(),
|
|
27
27
|
minAxis(2),
|
|
28
28
|
maxAxis(2),
|
|
@@ -31,14 +31,14 @@ BatchNormType<MatType>::BatchNormType() :
|
|
|
31
31
|
momentum(0.0),
|
|
32
32
|
count(0),
|
|
33
33
|
inputDimension(1),
|
|
34
|
-
|
|
34
|
+
inputUnits(0),
|
|
35
35
|
higherDimension(1)
|
|
36
36
|
{
|
|
37
37
|
// Nothing to do here.
|
|
38
38
|
}
|
|
39
39
|
|
|
40
40
|
template <typename MatType>
|
|
41
|
-
|
|
41
|
+
BatchNorm<MatType>::BatchNorm(
|
|
42
42
|
const size_t minAxis,
|
|
43
43
|
const size_t maxAxis,
|
|
44
44
|
const double eps,
|
|
@@ -52,7 +52,7 @@ BatchNormType<MatType>::BatchNormType(
|
|
|
52
52
|
momentum(momentum),
|
|
53
53
|
count(0),
|
|
54
54
|
inputDimension(1),
|
|
55
|
-
|
|
55
|
+
inputUnits(0),
|
|
56
56
|
higherDimension(1)
|
|
57
57
|
{
|
|
58
58
|
// Nothing to do here.
|
|
@@ -60,7 +60,7 @@ BatchNormType<MatType>::BatchNormType(
|
|
|
60
60
|
|
|
61
61
|
// Copy constructor.
|
|
62
62
|
template<typename MatType>
|
|
63
|
-
|
|
63
|
+
BatchNorm<MatType>::BatchNorm(const BatchNorm& layer) :
|
|
64
64
|
Layer<MatType>(layer),
|
|
65
65
|
minAxis(layer.minAxis),
|
|
66
66
|
maxAxis(layer.maxAxis),
|
|
@@ -70,7 +70,7 @@ BatchNormType<MatType>::BatchNormType(const BatchNormType& layer) :
|
|
|
70
70
|
variance(layer.variance),
|
|
71
71
|
count(layer.count),
|
|
72
72
|
inputDimension(layer.inputDimension),
|
|
73
|
-
|
|
73
|
+
inputUnits(layer.inputUnits),
|
|
74
74
|
higherDimension(layer.higherDimension),
|
|
75
75
|
runningMean(layer.runningMean),
|
|
76
76
|
runningVariance(layer.runningVariance)
|
|
@@ -80,7 +80,7 @@ BatchNormType<MatType>::BatchNormType(const BatchNormType& layer) :
|
|
|
80
80
|
|
|
81
81
|
// Move constructor.
|
|
82
82
|
template<typename MatType>
|
|
83
|
-
|
|
83
|
+
BatchNorm<MatType>::BatchNorm(BatchNorm&& layer) :
|
|
84
84
|
Layer<MatType>(std::move(layer)),
|
|
85
85
|
minAxis(std::move(layer.minAxis)),
|
|
86
86
|
maxAxis(std::move(layer.maxAxis)),
|
|
@@ -90,7 +90,7 @@ BatchNormType<MatType>::BatchNormType(BatchNormType&& layer) :
|
|
|
90
90
|
variance(std::move(layer.variance)),
|
|
91
91
|
count(std::move(layer.count)),
|
|
92
92
|
inputDimension(std::move(layer.inputDimension)),
|
|
93
|
-
|
|
93
|
+
inputUnits(std::move(layer.inputUnits)),
|
|
94
94
|
higherDimension(std::move(layer.higherDimension)),
|
|
95
95
|
runningMean(std::move(layer.runningMean)),
|
|
96
96
|
runningVariance(std::move(layer.runningVariance))
|
|
@@ -99,8 +99,8 @@ BatchNormType<MatType>::BatchNormType(BatchNormType&& layer) :
|
|
|
99
99
|
}
|
|
100
100
|
|
|
101
101
|
template<typename MatType>
|
|
102
|
-
|
|
103
|
-
|
|
102
|
+
BatchNorm<MatType>&
|
|
103
|
+
BatchNorm<MatType>::operator=(const BatchNorm& layer)
|
|
104
104
|
{
|
|
105
105
|
if (&layer != this)
|
|
106
106
|
{
|
|
@@ -113,7 +113,7 @@ BatchNormType<MatType>::operator=(const BatchNormType& layer)
|
|
|
113
113
|
variance = layer.variance;
|
|
114
114
|
count = layer.count;
|
|
115
115
|
inputDimension = layer.inputDimension;
|
|
116
|
-
|
|
116
|
+
inputUnits = layer.inputUnits;
|
|
117
117
|
higherDimension = layer.higherDimension;
|
|
118
118
|
runningMean = layer.runningMean;
|
|
119
119
|
runningVariance = layer.runningVariance;
|
|
@@ -123,9 +123,9 @@ BatchNormType<MatType>::operator=(const BatchNormType& layer)
|
|
|
123
123
|
}
|
|
124
124
|
|
|
125
125
|
template<typename MatType>
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
126
|
+
BatchNorm<MatType>&
|
|
127
|
+
BatchNorm<MatType>::operator=(
|
|
128
|
+
BatchNorm&& layer)
|
|
129
129
|
{
|
|
130
130
|
if (&layer != this)
|
|
131
131
|
{
|
|
@@ -138,7 +138,7 @@ BatchNormType<MatType>::operator=(
|
|
|
138
138
|
variance = std::move(layer.variance);
|
|
139
139
|
count = std::move(layer.count);
|
|
140
140
|
inputDimension = std::move(layer.inputDimension);
|
|
141
|
-
|
|
141
|
+
inputUnits = std::move(layer.inputUnits);
|
|
142
142
|
higherDimension = std::move(layer.higherDimension);
|
|
143
143
|
runningMean = std::move(layer.runningMean);
|
|
144
144
|
runningVariance = std::move(layer.runningVariance);
|
|
@@ -148,40 +148,40 @@ BatchNormType<MatType>::operator=(
|
|
|
148
148
|
}
|
|
149
149
|
|
|
150
150
|
template<typename MatType>
|
|
151
|
-
void
|
|
151
|
+
void BatchNorm<MatType>::SetWeights(const MatType& weightsIn)
|
|
152
152
|
{
|
|
153
153
|
MakeAlias(weights, weightsIn, WeightSize(), 1);
|
|
154
154
|
// Gamma acts as the scaling parameters for the normalized output.
|
|
155
|
-
MakeAlias(gamma, weightsIn,
|
|
155
|
+
MakeAlias(gamma, weightsIn, inputUnits, 1);
|
|
156
156
|
// Beta acts as the shifting parameters for the normalized output.
|
|
157
|
-
MakeAlias(beta, weightsIn,
|
|
157
|
+
MakeAlias(beta, weightsIn, inputUnits, 1, gamma.n_elem);
|
|
158
158
|
}
|
|
159
159
|
|
|
160
160
|
template<typename MatType>
|
|
161
|
-
void
|
|
161
|
+
void BatchNorm<MatType>::CustomInitialize(
|
|
162
162
|
MatType& W,
|
|
163
163
|
const size_t elements)
|
|
164
164
|
{
|
|
165
|
-
if (elements != 2 *
|
|
166
|
-
throw std::invalid_argument("
|
|
165
|
+
if (elements != 2 * inputUnits) {
|
|
166
|
+
throw std::invalid_argument("BatchNorm::CustomInitialize(): wrong "
|
|
167
167
|
"elements size!");
|
|
168
168
|
}
|
|
169
169
|
MatType gammaTemp;
|
|
170
170
|
MatType betaTemp;
|
|
171
171
|
// Gamma acts as the scaling parameters for the normalized output.
|
|
172
|
-
MakeAlias(gammaTemp, W,
|
|
172
|
+
MakeAlias(gammaTemp, W, inputUnits, 1);
|
|
173
173
|
// Beta acts as the shifting parameters for the normalized output.
|
|
174
|
-
MakeAlias(betaTemp, W,
|
|
174
|
+
MakeAlias(betaTemp, W, inputUnits, 1, gammaTemp.n_elem);
|
|
175
175
|
|
|
176
|
-
gammaTemp.
|
|
177
|
-
betaTemp.
|
|
176
|
+
gammaTemp.ones();
|
|
177
|
+
betaTemp.zeros();
|
|
178
178
|
|
|
179
|
-
runningMean.zeros(
|
|
180
|
-
runningVariance.ones(
|
|
179
|
+
runningMean.zeros(inputUnits, 1);
|
|
180
|
+
runningVariance.ones(inputUnits, 1);
|
|
181
181
|
}
|
|
182
182
|
|
|
183
183
|
template<typename MatType>
|
|
184
|
-
void
|
|
184
|
+
void BatchNorm<MatType>::Forward(
|
|
185
185
|
const MatType& input,
|
|
186
186
|
MatType& output)
|
|
187
187
|
{
|
|
@@ -203,31 +203,32 @@ void BatchNormType<MatType>::Forward(
|
|
|
203
203
|
// Input corresponds to output from previous layer.
|
|
204
204
|
// Used a cube for simplicity.
|
|
205
205
|
CubeType inputTemp;
|
|
206
|
-
MakeAlias(inputTemp, input, inputSize,
|
|
206
|
+
MakeAlias(inputTemp, input, inputSize, inputUnits,
|
|
207
207
|
batchSize * higherDimension, 0, false);
|
|
208
208
|
|
|
209
209
|
// Initialize output to same size and values for convenience.
|
|
210
210
|
CubeType outputTemp;
|
|
211
|
-
MakeAlias(outputTemp, output, inputSize,
|
|
211
|
+
MakeAlias(outputTemp, output, inputSize, inputUnits,
|
|
212
212
|
batchSize * higherDimension, 0, false);
|
|
213
213
|
outputTemp = inputTemp;
|
|
214
214
|
|
|
215
215
|
// Calculate mean and variance over all channels.
|
|
216
216
|
MatType mean = sum(sum(inputTemp, 2), 0) / m;
|
|
217
|
-
variance = sum(sum(
|
|
218
|
-
inputTemp.each_slice() - repmat(mean, inputSize, 1)
|
|
217
|
+
variance = sum(sum(square(
|
|
218
|
+
inputTemp.each_slice() - repmat(mean, inputSize, 1)), 2), 0) / m;
|
|
219
219
|
|
|
220
220
|
outputTemp.each_slice() -= repmat(mean, inputSize, 1);
|
|
221
221
|
|
|
222
222
|
// Used in backward propagation.
|
|
223
|
-
inputMean.set_size(
|
|
223
|
+
inputMean.set_size(size(inputTemp));
|
|
224
224
|
inputMean = outputTemp;
|
|
225
225
|
|
|
226
226
|
// Normalize output.
|
|
227
|
-
outputTemp.each_slice() /= sqrt(repmat(variance, inputSize, 1) +
|
|
227
|
+
outputTemp.each_slice() /= sqrt(repmat(variance, inputSize, 1) +
|
|
228
|
+
ElemType(eps));
|
|
228
229
|
|
|
229
230
|
// Re-used in backward propagation.
|
|
230
|
-
normalized.set_size(
|
|
231
|
+
normalized.set_size(size(inputTemp));
|
|
231
232
|
normalized = outputTemp;
|
|
232
233
|
|
|
233
234
|
outputTemp.each_slice() %= repmat(gamma.t(), inputSize, 1);
|
|
@@ -235,11 +236,11 @@ void BatchNormType<MatType>::Forward(
|
|
|
235
236
|
|
|
236
237
|
count += 1;
|
|
237
238
|
// Value for average factor which used to update running parameters.
|
|
238
|
-
|
|
239
|
+
ElemType averageFactor = ElemType(average ? 1.0 / count : momentum);
|
|
239
240
|
|
|
240
|
-
|
|
241
|
+
ElemType nElements = 0;
|
|
241
242
|
if (m - 1 != 0)
|
|
242
|
-
nElements = m * (1
|
|
243
|
+
nElements = m * (ElemType(1) / (m - 1));
|
|
243
244
|
|
|
244
245
|
// Update running mean and running variance.
|
|
245
246
|
runningMean = (1 - averageFactor) * runningMean + averageFactor *
|
|
@@ -252,35 +253,35 @@ void BatchNormType<MatType>::Forward(
|
|
|
252
253
|
// Normalize the input and scale and shift the output.
|
|
253
254
|
output = input;
|
|
254
255
|
CubeType outputTemp;
|
|
255
|
-
MakeAlias(outputTemp, output, inputSize,
|
|
256
|
+
MakeAlias(outputTemp, output, inputSize, inputUnits,
|
|
256
257
|
batchSize * higherDimension, 0, false);
|
|
257
258
|
|
|
258
259
|
outputTemp.each_slice() -= repmat(runningMean.t(), inputSize, 1);
|
|
259
260
|
outputTemp.each_slice() /= sqrt(repmat(runningVariance.t(),
|
|
260
|
-
inputSize, 1) + eps);
|
|
261
|
+
inputSize, 1) + ElemType(eps));
|
|
261
262
|
outputTemp.each_slice() %= repmat(gamma.t(), inputSize, 1);
|
|
262
263
|
outputTemp.each_slice() += repmat(beta.t(), inputSize, 1);
|
|
263
264
|
}
|
|
264
265
|
}
|
|
265
266
|
|
|
266
267
|
template<typename MatType>
|
|
267
|
-
void
|
|
268
|
+
void BatchNorm<MatType>::Backward(
|
|
268
269
|
const MatType& /* input */,
|
|
269
270
|
const MatType& /* output */,
|
|
270
271
|
const MatType& gy,
|
|
271
272
|
MatType& g)
|
|
272
273
|
{
|
|
273
|
-
const MatType stdInv = 1
|
|
274
|
+
const MatType stdInv = 1 / sqrt(variance + ElemType(eps));
|
|
274
275
|
|
|
275
276
|
const size_t batchSize = gy.n_cols;
|
|
276
277
|
const size_t inputSize = inputDimension;
|
|
277
278
|
const size_t m = inputSize * batchSize * higherDimension;
|
|
278
279
|
|
|
279
280
|
CubeType gyTemp;
|
|
280
|
-
MakeAlias(gyTemp, gy, inputSize,
|
|
281
|
+
MakeAlias(gyTemp, gy, inputSize, inputUnits,
|
|
281
282
|
batchSize * higherDimension, 0, false);
|
|
282
283
|
CubeType gTemp;
|
|
283
|
-
MakeAlias(gTemp, g, inputSize,
|
|
284
|
+
MakeAlias(gTemp, g, inputSize, inputUnits,
|
|
284
285
|
batchSize * higherDimension, 0, false);
|
|
285
286
|
|
|
286
287
|
// Step 1: dl / dxhat.
|
|
@@ -288,24 +289,24 @@ void BatchNormType<MatType>::Backward(
|
|
|
288
289
|
|
|
289
290
|
// Step 2: sum dl / dxhat * (x - mu) * -0.5 * stdInv^3.
|
|
290
291
|
MatType temp = sum(sum(norm % inputMean, 2), 0);
|
|
291
|
-
MatType vars = temp % pow(stdInv, 3)
|
|
292
|
+
MatType vars = -temp % pow(stdInv, 3) / 2;
|
|
292
293
|
|
|
293
294
|
// Step 3: dl / dxhat * 1 / stdInv + variance * 2 * (x - mu) / m +
|
|
294
295
|
// dl / dmu * 1 / m.
|
|
295
296
|
gTemp = (norm.each_slice() % repmat(stdInv, inputSize, 1)) +
|
|
296
|
-
((inputMean.each_slice() % repmat(vars, inputSize, 1) * 2
|
|
297
|
+
((inputMean.each_slice() % repmat(vars, inputSize, 1) * 2) / m);
|
|
297
298
|
|
|
298
299
|
// Step 4: sum (dl / dxhat * -1 / stdInv) + variance *
|
|
299
300
|
// sum (-2 * (x - mu)) / m.
|
|
300
301
|
MatType normTemp = sum(sum((norm.each_slice() %
|
|
301
302
|
repmat(-stdInv, inputSize, 1)) +
|
|
302
|
-
(inputMean.each_slice() % repmat(vars, inputSize, 1)
|
|
303
|
+
-2 * (inputMean.each_slice() % repmat(vars, inputSize, 1) / m),
|
|
303
304
|
2), 0) / m;
|
|
304
305
|
gTemp.each_slice() += repmat(normTemp, inputSize, 1);
|
|
305
306
|
}
|
|
306
307
|
|
|
307
308
|
template<typename MatType>
|
|
308
|
-
void
|
|
309
|
+
void BatchNorm<MatType>::Gradient(
|
|
309
310
|
const MatType& /* input */,
|
|
310
311
|
const MatType& error,
|
|
311
312
|
MatType& gradient)
|
|
@@ -313,7 +314,7 @@ void BatchNormType<MatType>::Gradient(
|
|
|
313
314
|
const size_t inputSize = inputDimension;
|
|
314
315
|
|
|
315
316
|
CubeType errorTemp;
|
|
316
|
-
MakeAlias(errorTemp, error, inputSize,
|
|
317
|
+
MakeAlias(errorTemp, error, inputSize, inputUnits,
|
|
317
318
|
error.n_cols * higherDimension, 0, false);
|
|
318
319
|
|
|
319
320
|
// Step 5: dl / dy * xhat.
|
|
@@ -326,7 +327,7 @@ void BatchNormType<MatType>::Gradient(
|
|
|
326
327
|
}
|
|
327
328
|
|
|
328
329
|
template<typename MatType>
|
|
329
|
-
void
|
|
330
|
+
void BatchNorm<MatType>::ComputeOutputDimensions()
|
|
330
331
|
{
|
|
331
332
|
if (minAxis > maxAxis)
|
|
332
333
|
{
|
|
@@ -354,9 +355,9 @@ void BatchNormType<MatType>::ComputeOutputDimensions()
|
|
|
354
355
|
for (size_t i = 0; i < mainMinAxis; i++)
|
|
355
356
|
inputDimension *= this->inputDimensions[i];
|
|
356
357
|
|
|
357
|
-
|
|
358
|
+
inputUnits = this->inputDimensions[mainMinAxis];
|
|
358
359
|
for (size_t i = mainMinAxis + 1; i <= mainMaxAxis; i++)
|
|
359
|
-
|
|
360
|
+
inputUnits *= this->inputDimensions[i];
|
|
360
361
|
|
|
361
362
|
higherDimension = 1;
|
|
362
363
|
for (size_t i = mainMaxAxis + 1; i < this->inputDimensions.size(); i++)
|
|
@@ -365,7 +366,7 @@ void BatchNormType<MatType>::ComputeOutputDimensions()
|
|
|
365
366
|
|
|
366
367
|
template<typename MatType>
|
|
367
368
|
template<typename Archive>
|
|
368
|
-
void
|
|
369
|
+
void BatchNorm<MatType>::serialize(
|
|
369
370
|
Archive& ar, const uint32_t /* version */)
|
|
370
371
|
{
|
|
371
372
|
ar(cereal::base_class<Layer<MatType>>(this));
|
|
@@ -380,7 +381,7 @@ void BatchNormType<MatType>::serialize(
|
|
|
380
381
|
ar(CEREAL_NVP(runningVariance));
|
|
381
382
|
ar(CEREAL_NVP(inputMean));
|
|
382
383
|
ar(CEREAL_NVP(inputDimension));
|
|
383
|
-
ar(CEREAL_NVP(
|
|
384
|
+
ar(CEREAL_NVP(inputUnits));
|
|
384
385
|
ar(CEREAL_NVP(higherDimension));
|
|
385
386
|
}
|
|
386
387
|
|