mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
namespace mlpack {
|
|
19
19
|
|
|
20
20
|
template<typename MatType>
|
|
21
|
-
|
|
21
|
+
LSTM<MatType>::LSTM() :
|
|
22
22
|
RecurrentLayer<MatType>(),
|
|
23
23
|
inSize(0),
|
|
24
24
|
outSize(0)
|
|
@@ -27,7 +27,7 @@ LSTMType<MatType>::LSTMType() :
|
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
template<typename MatType>
|
|
30
|
-
|
|
30
|
+
LSTM<MatType>::LSTM(const size_t outSize) :
|
|
31
31
|
RecurrentLayer<MatType>(),
|
|
32
32
|
inSize(0),
|
|
33
33
|
outSize(outSize)
|
|
@@ -36,7 +36,7 @@ LSTMType<MatType>::LSTMType(const size_t outSize) :
|
|
|
36
36
|
}
|
|
37
37
|
|
|
38
38
|
template<typename MatType>
|
|
39
|
-
|
|
39
|
+
LSTM<MatType>::LSTM(const LSTM& layer) :
|
|
40
40
|
RecurrentLayer<MatType>(layer),
|
|
41
41
|
inSize(layer.inSize),
|
|
42
42
|
outSize(layer.outSize)
|
|
@@ -45,7 +45,7 @@ LSTMType<MatType>::LSTMType(const LSTMType& layer) :
|
|
|
45
45
|
}
|
|
46
46
|
|
|
47
47
|
template<typename MatType>
|
|
48
|
-
|
|
48
|
+
LSTM<MatType>::LSTM(LSTM&& layer) :
|
|
49
49
|
RecurrentLayer<MatType>(std::move(layer)),
|
|
50
50
|
inSize(layer.inSize),
|
|
51
51
|
outSize(layer.outSize)
|
|
@@ -55,7 +55,7 @@ LSTMType<MatType>::LSTMType(LSTMType&& layer) :
|
|
|
55
55
|
}
|
|
56
56
|
|
|
57
57
|
template<typename MatType>
|
|
58
|
-
|
|
58
|
+
LSTM<MatType>& LSTM<MatType>::operator=(const LSTM& layer)
|
|
59
59
|
{
|
|
60
60
|
if (this != &layer)
|
|
61
61
|
{
|
|
@@ -68,7 +68,7 @@ LSTMType<MatType>& LSTMType<MatType>::operator=(const LSTMType& layer)
|
|
|
68
68
|
}
|
|
69
69
|
|
|
70
70
|
template<typename MatType>
|
|
71
|
-
|
|
71
|
+
LSTM<MatType>& LSTM<MatType>::operator=(LSTM&& layer)
|
|
72
72
|
{
|
|
73
73
|
if (this != &layer)
|
|
74
74
|
{
|
|
@@ -84,7 +84,7 @@ LSTMType<MatType>& LSTMType<MatType>::operator=(LSTMType&& layer)
|
|
|
84
84
|
}
|
|
85
85
|
|
|
86
86
|
template<typename MatType>
|
|
87
|
-
void
|
|
87
|
+
void LSTM<MatType>::SetWeights(const MatType& weights)
|
|
88
88
|
{
|
|
89
89
|
// Set the weight parameters for the inputs.
|
|
90
90
|
const size_t inputWeightSize = outSize * inSize;
|
|
@@ -123,14 +123,10 @@ void LSTMType<MatType>::SetWeights(const MatType& weights)
|
|
|
123
123
|
|
|
124
124
|
// Forward when cellState is not needed.
|
|
125
125
|
template<typename MatType>
|
|
126
|
-
void
|
|
126
|
+
void LSTM<MatType>::Forward(const MatType& input, MatType& output)
|
|
127
127
|
{
|
|
128
128
|
// Convenience alias.
|
|
129
|
-
const size_t
|
|
130
|
-
|
|
131
|
-
// The internal quantities are stored as recurrent state; so, set aliases
|
|
132
|
-
// correctly for this time step.
|
|
133
|
-
SetInternalAliases(batchSize);
|
|
129
|
+
const size_t activeBatchSize = input.n_cols;
|
|
134
130
|
|
|
135
131
|
// Compute internal state:
|
|
136
132
|
//
|
|
@@ -142,25 +138,29 @@ void LSTMType<MatType>::Forward(const MatType& input, MatType& output)
|
|
|
142
138
|
// y_t = tanh(c_t) % o_t
|
|
143
139
|
|
|
144
140
|
// Start by computing all non-recurrent portions.
|
|
145
|
-
blockInput = blockInputWeight * input + repmat(blockInputBias, 1,
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
141
|
+
blockInput = blockInputWeight * input + repmat(blockInputBias, 1,
|
|
142
|
+
activeBatchSize);
|
|
143
|
+
inputGate = inputGateWeight * input + repmat(inputGateBias, 1,
|
|
144
|
+
activeBatchSize);
|
|
145
|
+
forgetGate = forgetGateWeight * input + repmat(forgetGateBias, 1,
|
|
146
|
+
activeBatchSize);
|
|
147
|
+
outputGate = outputGateWeight * input + repmat(outputGateBias, 1,
|
|
148
|
+
activeBatchSize);
|
|
149
149
|
|
|
150
150
|
// Now add in recurrent portions, if needed.
|
|
151
151
|
if (this->HasPreviousStep())
|
|
152
152
|
{
|
|
153
153
|
blockInput += recurrentBlockInputWeight * prevRecurrent;
|
|
154
154
|
inputGate += recurrentInputGateWeight * prevRecurrent +
|
|
155
|
-
repmat(peepholeInputGateWeight, 1,
|
|
155
|
+
repmat(peepholeInputGateWeight, 1, activeBatchSize) % prevCell;
|
|
156
156
|
forgetGate += recurrentForgetGateWeight * prevRecurrent +
|
|
157
|
-
repmat(peepholeForgetGateWeight, 1,
|
|
157
|
+
repmat(peepholeForgetGateWeight, 1, activeBatchSize) % prevCell;
|
|
158
158
|
}
|
|
159
159
|
|
|
160
160
|
// Apply nonlinearities. (TODO: fast sigmoid?)
|
|
161
161
|
blockInput = tanh(blockInput);
|
|
162
|
-
inputGate = 1
|
|
163
|
-
forgetGate = 1
|
|
162
|
+
inputGate = 1 / (1 + exp(-inputGate));
|
|
163
|
+
forgetGate = 1 / (1 + exp(-forgetGate));
|
|
164
164
|
|
|
165
165
|
// Compute the cell state.
|
|
166
166
|
if (this->HasPreviousStep())
|
|
@@ -172,17 +172,18 @@ void LSTMType<MatType>::Forward(const MatType& input, MatType& output)
|
|
|
172
172
|
if (this->HasPreviousStep())
|
|
173
173
|
{
|
|
174
174
|
outputGate += recurrentOutputGateWeight * prevRecurrent +
|
|
175
|
-
repmat(peepholeOutputGateWeight, 1,
|
|
175
|
+
repmat(peepholeOutputGateWeight, 1, activeBatchSize) % thisCell;
|
|
176
176
|
}
|
|
177
177
|
else
|
|
178
178
|
{
|
|
179
179
|
// If we don't have a previous step, we still have to consider the peephole
|
|
180
180
|
// connection.
|
|
181
|
-
outputGate += repmat(peepholeOutputGateWeight, 1,
|
|
181
|
+
outputGate += repmat(peepholeOutputGateWeight, 1, activeBatchSize) %
|
|
182
|
+
thisCell;
|
|
182
183
|
}
|
|
183
184
|
|
|
184
185
|
// Apply nonlinearity for output gate.
|
|
185
|
-
outputGate = 1
|
|
186
|
+
outputGate = 1 / (1 + exp(-outputGate));
|
|
186
187
|
|
|
187
188
|
// Finally, we can compute the output itself.
|
|
188
189
|
output = tanh(thisCell) % outputGate;
|
|
@@ -193,7 +194,7 @@ void LSTMType<MatType>::Forward(const MatType& input, MatType& output)
|
|
|
193
194
|
}
|
|
194
195
|
|
|
195
196
|
template<typename MatType>
|
|
196
|
-
void
|
|
197
|
+
void LSTM<MatType>::Backward(
|
|
197
198
|
const MatType& /* input */,
|
|
198
199
|
const MatType& output,
|
|
199
200
|
const MatType& gy,
|
|
@@ -219,12 +220,7 @@ void LSTMType<MatType>::Backward(
|
|
|
219
220
|
// dz_t = dc_t % i_t % (1 - z_t .^ 2)
|
|
220
221
|
//
|
|
221
222
|
// dx_t = W_z^T dz_t + W_i^T di_t + W_f^T df_t + W_o^T do_t
|
|
222
|
-
|
|
223
|
-
// Before we start, set all the internal aliases, which will contain this time
|
|
224
|
-
// step's values as computed in Forward().
|
|
225
|
-
const size_t batchSize = output.n_cols;
|
|
226
|
-
SetInternalAliases(batchSize);
|
|
227
|
-
SetBackwardWorkspace(batchSize);
|
|
223
|
+
const size_t activeBatchSize = output.n_cols;
|
|
228
224
|
|
|
229
225
|
// First attempt...
|
|
230
226
|
if (this->AtFinalStep())
|
|
@@ -239,35 +235,31 @@ void LSTMType<MatType>::Backward(
|
|
|
239
235
|
recurrentOutputGateWeight.t() * nextDeltaOutputGate;
|
|
240
236
|
}
|
|
241
237
|
|
|
242
|
-
deltaOutputGate = deltaY % tanh(thisCell) % (outputGate % (1
|
|
238
|
+
deltaOutputGate = deltaY % tanh(thisCell) % (outputGate % (1 - outputGate));
|
|
243
239
|
|
|
244
240
|
// Only first two terms if at final step
|
|
245
241
|
if (this->AtFinalStep())
|
|
246
242
|
{
|
|
247
|
-
deltaCell = deltaY % outputGate % (1
|
|
248
|
-
repmat(peepholeOutputGateWeight, 1,
|
|
243
|
+
deltaCell = deltaY % outputGate % (1 - square(tanh(thisCell))) +
|
|
244
|
+
repmat(peepholeOutputGateWeight, 1, activeBatchSize) % deltaOutputGate;
|
|
249
245
|
}
|
|
250
246
|
else
|
|
251
247
|
{
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
deltaCell = deltaY % outputGate % (1.0 - square(tanh(thisCell))) +
|
|
259
|
-
repmat(peepholeOutputGateWeight, 1, batchSize) % deltaOutputGate +
|
|
260
|
-
repmat(peepholeInputGateWeight, 1, batchSize) % nextDeltaInputGate +
|
|
261
|
-
repmat(peepholeForgetGateWeight, 1, batchSize) % nextDeltaForgetGate +
|
|
248
|
+
deltaCell = deltaY % outputGate % (1 - square(tanh(thisCell))) +
|
|
249
|
+
repmat(peepholeOutputGateWeight, 1, activeBatchSize) % deltaOutputGate +
|
|
250
|
+
repmat(peepholeInputGateWeight, 1, activeBatchSize) %
|
|
251
|
+
nextDeltaInputGate +
|
|
252
|
+
repmat(peepholeForgetGateWeight, 1, activeBatchSize) %
|
|
253
|
+
nextDeltaForgetGate +
|
|
262
254
|
nextDeltaCell % nextForgetGate;
|
|
263
255
|
}
|
|
264
256
|
|
|
265
257
|
if (this->HasPreviousStep())
|
|
266
|
-
deltaForgetGate = deltaCell % prevCell % (forgetGate % (1
|
|
258
|
+
deltaForgetGate = deltaCell % prevCell % (forgetGate % (1 - forgetGate));
|
|
267
259
|
else
|
|
268
260
|
deltaForgetGate.zeros();
|
|
269
|
-
deltaInputGate = deltaCell % blockInput % (inputGate % (1
|
|
270
|
-
deltaBlockInput = deltaCell % inputGate % (1
|
|
261
|
+
deltaInputGate = deltaCell % blockInput % (inputGate % (1 - inputGate));
|
|
262
|
+
deltaBlockInput = deltaCell % inputGate % (1 - square(blockInput));
|
|
271
263
|
|
|
272
264
|
// Finally, compute deltaX (which is what we wanted all along).
|
|
273
265
|
g = blockInputWeight.t() * deltaBlockInput +
|
|
@@ -280,15 +272,11 @@ void LSTMType<MatType>::Backward(
|
|
|
280
272
|
}
|
|
281
273
|
|
|
282
274
|
template<typename MatType>
|
|
283
|
-
void
|
|
275
|
+
void LSTM<MatType>::Gradient(
|
|
284
276
|
const MatType& input,
|
|
285
277
|
const MatType& /* error */,
|
|
286
278
|
MatType& gradient)
|
|
287
279
|
{
|
|
288
|
-
// This implementation depends on Gradient() being called just after
|
|
289
|
-
// Backward(), which is something we can safely assume. So, the workspace
|
|
290
|
-
// aliases are already set by SetBackwardWorkspace().
|
|
291
|
-
//
|
|
292
280
|
// In this implementation we won't use aliases; we'll just address the correct
|
|
293
281
|
// part of the gradient directly.
|
|
294
282
|
|
|
@@ -390,7 +378,7 @@ void LSTMType<MatType>::Gradient(
|
|
|
390
378
|
}
|
|
391
379
|
|
|
392
380
|
template<typename MatType>
|
|
393
|
-
size_t
|
|
381
|
+
size_t LSTM<MatType>::WeightSize() const
|
|
394
382
|
{
|
|
395
383
|
return 4 * inSize * outSize /* input weight connections */ +
|
|
396
384
|
4 * outSize /* input bias */ +
|
|
@@ -399,7 +387,7 @@ size_t LSTMType<MatType>::WeightSize() const
|
|
|
399
387
|
}
|
|
400
388
|
|
|
401
389
|
template<typename MatType>
|
|
402
|
-
size_t
|
|
390
|
+
size_t LSTM<MatType>::RecurrentSize() const
|
|
403
391
|
{
|
|
404
392
|
// We have to account for the cell, recurrent connection, and the four
|
|
405
393
|
// internal matrices: block input, input gate, forget gate, and output gate.
|
|
@@ -410,97 +398,113 @@ size_t LSTMType<MatType>::RecurrentSize() const
|
|
|
410
398
|
}
|
|
411
399
|
|
|
412
400
|
template<typename MatType>
|
|
413
|
-
void
|
|
401
|
+
void LSTM<MatType>::OnStepChanged(const size_t step,
|
|
402
|
+
const size_t batchSize,
|
|
403
|
+
const size_t activeBatchSize,
|
|
404
|
+
const bool backwards)
|
|
414
405
|
{
|
|
415
406
|
// Make all of the aliases for internal state point to the correct place.
|
|
416
|
-
MatType& state = this->RecurrentState(
|
|
407
|
+
MatType& state = this->RecurrentState(step);
|
|
417
408
|
|
|
418
409
|
// First make aliases for the recurrent connections.
|
|
419
|
-
MakeAlias(thisRecurrent, state, outSize,
|
|
420
|
-
MakeAlias(thisCell, state, outSize,
|
|
410
|
+
MakeAlias(thisRecurrent, state, outSize, activeBatchSize);
|
|
411
|
+
MakeAlias(thisCell, state, outSize, activeBatchSize, outSize * batchSize);
|
|
421
412
|
|
|
422
413
|
// Now make aliases for the internal state members that we use as scratch
|
|
423
414
|
// space for computation.
|
|
424
|
-
MakeAlias(blockInput, state, outSize,
|
|
425
|
-
|
|
426
|
-
MakeAlias(
|
|
427
|
-
|
|
415
|
+
MakeAlias(blockInput, state, outSize, activeBatchSize, 2 * outSize *
|
|
416
|
+
batchSize);
|
|
417
|
+
MakeAlias(inputGate, state, outSize, activeBatchSize, 3 * outSize *
|
|
418
|
+
batchSize);
|
|
419
|
+
MakeAlias(forgetGate, state, outSize, activeBatchSize, 4 * outSize *
|
|
420
|
+
batchSize);
|
|
421
|
+
MakeAlias(outputGate, state, outSize, activeBatchSize, 5 * outSize *
|
|
422
|
+
batchSize);
|
|
428
423
|
|
|
429
424
|
// Make aliases for the previous time step, too, if we can.
|
|
430
425
|
if (this->HasPreviousStep())
|
|
431
426
|
{
|
|
432
427
|
MatType& prevState = this->RecurrentState(this->PreviousStep());
|
|
433
428
|
|
|
434
|
-
MakeAlias(prevRecurrent, prevState, outSize,
|
|
435
|
-
MakeAlias(prevCell, prevState, outSize,
|
|
429
|
+
MakeAlias(prevRecurrent, prevState, outSize, activeBatchSize);
|
|
430
|
+
MakeAlias(prevCell, prevState, outSize, activeBatchSize, outSize *
|
|
431
|
+
batchSize);
|
|
436
432
|
}
|
|
437
|
-
}
|
|
438
|
-
|
|
439
|
-
template<typename MatType>
|
|
440
|
-
void LSTMType<MatType>::SetBackwardWorkspace(const size_t batchSize)
|
|
441
|
-
{
|
|
442
|
-
// We need to hold enough space for two time steps.
|
|
443
|
-
workspace.set_size(12 * outSize, batchSize);
|
|
444
433
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
MakeAlias(deltaY, workspace, outSize, batchSize);
|
|
448
|
-
MakeAlias(deltaBlockInput, workspace, outSize, batchSize,
|
|
449
|
-
outSize * batchSize);
|
|
450
|
-
MakeAlias(deltaInputGate, workspace, outSize, batchSize,
|
|
451
|
-
2 * outSize * batchSize);
|
|
452
|
-
MakeAlias(deltaForgetGate, workspace, outSize, batchSize,
|
|
453
|
-
3 * outSize * batchSize);
|
|
454
|
-
MakeAlias(deltaOutputGate, workspace, outSize, batchSize,
|
|
455
|
-
4 * outSize * batchSize);
|
|
456
|
-
MakeAlias(deltaCell, workspace, outSize, batchSize,
|
|
457
|
-
5 * outSize * batchSize);
|
|
458
|
-
|
|
459
|
-
MakeAlias(nextDeltaY, workspace, outSize, batchSize,
|
|
460
|
-
6 * outSize * batchSize);
|
|
461
|
-
MakeAlias(nextDeltaBlockInput, workspace, outSize, batchSize,
|
|
462
|
-
7 * outSize * batchSize);
|
|
463
|
-
MakeAlias(nextDeltaInputGate, workspace, outSize, batchSize,
|
|
464
|
-
8 * outSize * batchSize);
|
|
465
|
-
MakeAlias(nextDeltaForgetGate, workspace, outSize, batchSize,
|
|
466
|
-
9 * outSize * batchSize);
|
|
467
|
-
MakeAlias(nextDeltaOutputGate, workspace, outSize, batchSize,
|
|
468
|
-
10 * outSize * batchSize);
|
|
469
|
-
MakeAlias(nextDeltaCell, workspace, outSize, batchSize,
|
|
470
|
-
11 * outSize * batchSize);
|
|
471
|
-
}
|
|
472
|
-
else
|
|
434
|
+
// Also set the workspaces for the backwards pass, if requested.
|
|
435
|
+
if (backwards)
|
|
473
436
|
{
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
437
|
+
// We need to hold enough space for two time steps.
|
|
438
|
+
workspace.set_size(12 * outSize, batchSize);
|
|
439
|
+
|
|
440
|
+
if (step % 2 == 0)
|
|
441
|
+
{
|
|
442
|
+
MakeAlias(deltaY, workspace, outSize, activeBatchSize);
|
|
443
|
+
MakeAlias(deltaBlockInput, workspace, outSize, activeBatchSize,
|
|
444
|
+
outSize * batchSize);
|
|
445
|
+
MakeAlias(deltaInputGate, workspace, outSize, activeBatchSize,
|
|
446
|
+
2 * outSize * batchSize);
|
|
447
|
+
MakeAlias(deltaForgetGate, workspace, outSize, activeBatchSize,
|
|
448
|
+
3 * outSize * batchSize);
|
|
449
|
+
MakeAlias(deltaOutputGate, workspace, outSize, activeBatchSize,
|
|
450
|
+
4 * outSize * batchSize);
|
|
451
|
+
MakeAlias(deltaCell, workspace, outSize, activeBatchSize,
|
|
452
|
+
5 * outSize * batchSize);
|
|
453
|
+
|
|
454
|
+
MakeAlias(nextDeltaY, workspace, outSize, activeBatchSize,
|
|
455
|
+
6 * outSize * batchSize);
|
|
456
|
+
MakeAlias(nextDeltaBlockInput, workspace, outSize, activeBatchSize,
|
|
457
|
+
7 * outSize * batchSize);
|
|
458
|
+
MakeAlias(nextDeltaInputGate, workspace, outSize, activeBatchSize,
|
|
459
|
+
8 * outSize * batchSize);
|
|
460
|
+
MakeAlias(nextDeltaForgetGate, workspace, outSize, activeBatchSize,
|
|
461
|
+
9 * outSize * batchSize);
|
|
462
|
+
MakeAlias(nextDeltaOutputGate, workspace, outSize, activeBatchSize,
|
|
463
|
+
10 * outSize * batchSize);
|
|
464
|
+
MakeAlias(nextDeltaCell, workspace, outSize, activeBatchSize,
|
|
465
|
+
11 * outSize * batchSize);
|
|
466
|
+
}
|
|
467
|
+
else
|
|
468
|
+
{
|
|
469
|
+
MakeAlias(nextDeltaY, workspace, outSize, activeBatchSize);
|
|
470
|
+
MakeAlias(nextDeltaBlockInput, workspace, outSize, activeBatchSize,
|
|
471
|
+
outSize * batchSize);
|
|
472
|
+
MakeAlias(nextDeltaInputGate, workspace, outSize, activeBatchSize,
|
|
473
|
+
2 * outSize * batchSize);
|
|
474
|
+
MakeAlias(nextDeltaForgetGate, workspace, outSize, activeBatchSize,
|
|
475
|
+
3 * outSize * batchSize);
|
|
476
|
+
MakeAlias(nextDeltaOutputGate, workspace, outSize, activeBatchSize,
|
|
477
|
+
4 * outSize * batchSize);
|
|
478
|
+
MakeAlias(nextDeltaCell, workspace, outSize, activeBatchSize,
|
|
479
|
+
5 * outSize * batchSize);
|
|
480
|
+
|
|
481
|
+
MakeAlias(deltaY, workspace, outSize, activeBatchSize,
|
|
482
|
+
6 * outSize * batchSize);
|
|
483
|
+
MakeAlias(deltaBlockInput, workspace, outSize, activeBatchSize,
|
|
484
|
+
7 * outSize * batchSize);
|
|
485
|
+
MakeAlias(deltaInputGate, workspace, outSize, activeBatchSize,
|
|
486
|
+
8 * outSize * batchSize);
|
|
487
|
+
MakeAlias(deltaForgetGate, workspace, outSize, activeBatchSize,
|
|
488
|
+
9 * outSize * batchSize);
|
|
489
|
+
MakeAlias(deltaOutputGate, workspace, outSize, activeBatchSize,
|
|
490
|
+
10 * outSize * batchSize);
|
|
491
|
+
MakeAlias(deltaCell, workspace, outSize, activeBatchSize,
|
|
492
|
+
11 * outSize * batchSize);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
if (!this->AtFinalStep())
|
|
496
|
+
{
|
|
497
|
+
// To update the cell state, we actually need to use the forget gate
|
|
498
|
+
// values from the next time step.
|
|
499
|
+
MakeAlias(nextForgetGate, this->RecurrentState(this->CurrentStep() + 1),
|
|
500
|
+
outSize, activeBatchSize, 4 * outSize * batchSize);
|
|
501
|
+
}
|
|
498
502
|
}
|
|
499
503
|
}
|
|
500
504
|
|
|
501
505
|
template<typename MatType>
|
|
502
506
|
template<typename Archive>
|
|
503
|
-
void
|
|
507
|
+
void LSTM<MatType>::serialize(Archive& ar, const uint32_t /* version */)
|
|
504
508
|
{
|
|
505
509
|
ar(cereal::base_class<RecurrentLayer<MatType>>(this));
|
|
506
510
|
|
|
@@ -56,12 +56,15 @@ class MaxPoolingRule
|
|
|
56
56
|
* computation.
|
|
57
57
|
*/
|
|
58
58
|
template<typename MatType = arma::mat>
|
|
59
|
-
class
|
|
59
|
+
class MaxPooling : public Layer<MatType>
|
|
60
60
|
{
|
|
61
61
|
public:
|
|
62
|
+
// Convenience typedefs.
|
|
63
|
+
using ElemType = typename MatType::elem_type;
|
|
62
64
|
using CubeType = typename GetCubeType<MatType>::type;
|
|
63
|
-
|
|
64
|
-
|
|
65
|
+
|
|
66
|
+
// Create the MaxPooling object.
|
|
67
|
+
MaxPooling();
|
|
65
68
|
|
|
66
69
|
/**
|
|
67
70
|
* Create the MaxPooling object using the specified number of units.
|
|
@@ -73,26 +76,26 @@ class MaxPoolingType : public Layer<MatType>
|
|
|
73
76
|
* @param floor If true, then a pooling operation that would oly part of the
|
|
74
77
|
* input will be skipped.
|
|
75
78
|
*/
|
|
76
|
-
|
|
79
|
+
MaxPooling(const size_t kernelWidth,
|
|
77
80
|
const size_t kernelHeight,
|
|
78
81
|
const size_t strideWidth = 1,
|
|
79
82
|
const size_t strideHeight = 1,
|
|
80
83
|
const bool floor = true);
|
|
81
84
|
|
|
82
85
|
// Virtual destructor.
|
|
83
|
-
virtual ~
|
|
86
|
+
virtual ~MaxPooling() { }
|
|
84
87
|
|
|
85
|
-
//! Copy the given
|
|
86
|
-
|
|
87
|
-
//! Take ownership of the given
|
|
88
|
-
|
|
89
|
-
//! Copy the given
|
|
90
|
-
|
|
91
|
-
//! Take ownership of the given
|
|
92
|
-
|
|
88
|
+
//! Copy the given MaxPooling.
|
|
89
|
+
MaxPooling(const MaxPooling& other);
|
|
90
|
+
//! Take ownership of the given MaxPooling.
|
|
91
|
+
MaxPooling(MaxPooling&& other);
|
|
92
|
+
//! Copy the given MaxPooling.
|
|
93
|
+
MaxPooling& operator=(const MaxPooling& other);
|
|
94
|
+
//! Take ownership of the given MaxPooling.
|
|
95
|
+
MaxPooling& operator=(MaxPooling&& other);
|
|
93
96
|
|
|
94
|
-
//! Clone the
|
|
95
|
-
|
|
97
|
+
//! Clone the MaxPooling object. This handles polymorphism correctly.
|
|
98
|
+
MaxPooling* Clone() const { return new MaxPooling(*this); }
|
|
96
99
|
|
|
97
100
|
/**
|
|
98
101
|
* Ordinary feed forward pass of a neural network, evaluating the function
|
|
@@ -306,10 +309,7 @@ class MaxPoolingType : public Layer<MatType>
|
|
|
306
309
|
|
|
307
310
|
//! Locally-stored pooling indices.
|
|
308
311
|
arma::Cube<size_t> poolingIndices;
|
|
309
|
-
}; // class
|
|
310
|
-
|
|
311
|
-
// Standard MaxPooling layer.
|
|
312
|
-
using MaxPooling = MaxPoolingType<arma::mat>;
|
|
312
|
+
}; // class MaxPooling
|
|
313
313
|
|
|
314
314
|
} // namespace mlpack
|
|
315
315
|
|
|
@@ -19,14 +19,14 @@
|
|
|
19
19
|
namespace mlpack {
|
|
20
20
|
|
|
21
21
|
template<typename MatType>
|
|
22
|
-
|
|
22
|
+
MaxPooling<MatType>::MaxPooling() :
|
|
23
23
|
Layer<MatType>()
|
|
24
24
|
{
|
|
25
25
|
// Nothing to do here.
|
|
26
26
|
}
|
|
27
27
|
|
|
28
28
|
template<typename MatType>
|
|
29
|
-
|
|
29
|
+
MaxPooling<MatType>::MaxPooling(
|
|
30
30
|
const size_t kernelWidth,
|
|
31
31
|
const size_t kernelHeight,
|
|
32
32
|
const size_t strideWidth,
|
|
@@ -44,8 +44,8 @@ MaxPoolingType<MatType>::MaxPoolingType(
|
|
|
44
44
|
}
|
|
45
45
|
|
|
46
46
|
template<typename MatType>
|
|
47
|
-
|
|
48
|
-
const
|
|
47
|
+
MaxPooling<MatType>::MaxPooling(
|
|
48
|
+
const MaxPooling& other) :
|
|
49
49
|
Layer<MatType>(other),
|
|
50
50
|
kernelWidth(other.kernelWidth),
|
|
51
51
|
kernelHeight(other.kernelHeight),
|
|
@@ -59,8 +59,8 @@ MaxPoolingType<MatType>::MaxPoolingType(
|
|
|
59
59
|
}
|
|
60
60
|
|
|
61
61
|
template<typename MatType>
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
MaxPooling<MatType>::MaxPooling(
|
|
63
|
+
MaxPooling&& other) :
|
|
64
64
|
Layer<MatType>(std::move(other)),
|
|
65
65
|
kernelWidth(std::move(other.kernelWidth)),
|
|
66
66
|
kernelHeight(std::move(other.kernelHeight)),
|
|
@@ -74,8 +74,8 @@ MaxPoolingType<MatType>::MaxPoolingType(
|
|
|
74
74
|
}
|
|
75
75
|
|
|
76
76
|
template<typename MatType>
|
|
77
|
-
|
|
78
|
-
|
|
77
|
+
MaxPooling<MatType>&
|
|
78
|
+
MaxPooling<MatType>::operator=(const MaxPooling& other)
|
|
79
79
|
{
|
|
80
80
|
if (&other != this)
|
|
81
81
|
{
|
|
@@ -93,8 +93,8 @@ MaxPoolingType<MatType>::operator=(const MaxPoolingType& other)
|
|
|
93
93
|
}
|
|
94
94
|
|
|
95
95
|
template<typename MatType>
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
MaxPooling<MatType>&
|
|
97
|
+
MaxPooling<MatType>::operator=(MaxPooling&& other)
|
|
98
98
|
{
|
|
99
99
|
if (&other != this)
|
|
100
100
|
{
|
|
@@ -112,7 +112,7 @@ MaxPoolingType<MatType>::operator=(MaxPoolingType&& other)
|
|
|
112
112
|
}
|
|
113
113
|
|
|
114
114
|
template<typename MatType>
|
|
115
|
-
void
|
|
115
|
+
void MaxPooling<MatType>::Forward(const MatType& input, MatType& output)
|
|
116
116
|
{
|
|
117
117
|
using CubeType = typename GetCubeType<MatType>::type;
|
|
118
118
|
CubeType inputTemp;
|
|
@@ -139,7 +139,7 @@ void MaxPoolingType<MatType>::Forward(const MatType& input, MatType& output)
|
|
|
139
139
|
}
|
|
140
140
|
|
|
141
141
|
template<typename MatType>
|
|
142
|
-
void
|
|
142
|
+
void MaxPooling<MatType>::Backward(
|
|
143
143
|
const MatType& input,
|
|
144
144
|
const MatType& /* output */,
|
|
145
145
|
const MatType& gy,
|
|
@@ -167,7 +167,7 @@ void MaxPoolingType<MatType>::Backward(
|
|
|
167
167
|
}
|
|
168
168
|
|
|
169
169
|
template<typename MatType>
|
|
170
|
-
void
|
|
170
|
+
void MaxPooling<MatType>::ComputeOutputDimensions()
|
|
171
171
|
{
|
|
172
172
|
this->outputDimensions = this->inputDimensions;
|
|
173
173
|
|
|
@@ -197,7 +197,7 @@ void MaxPoolingType<MatType>::ComputeOutputDimensions()
|
|
|
197
197
|
|
|
198
198
|
template<typename MatType>
|
|
199
199
|
template<typename Archive>
|
|
200
|
-
void
|
|
200
|
+
void MaxPooling<MatType>::serialize(
|
|
201
201
|
Archive& ar,
|
|
202
202
|
const uint32_t /* version */)
|
|
203
203
|
|
|
@@ -26,12 +26,15 @@ namespace mlpack {
|
|
|
26
26
|
* computation.
|
|
27
27
|
*/
|
|
28
28
|
template <typename MatType = arma::mat>
|
|
29
|
-
class
|
|
29
|
+
class MeanPooling : public Layer<MatType>
|
|
30
30
|
{
|
|
31
31
|
public:
|
|
32
|
+
// Convenience typedefs.
|
|
33
|
+
using ElemType = typename MatType::elem_type;
|
|
32
34
|
using CubeType = typename GetCubeType<MatType>::type;
|
|
33
|
-
|
|
34
|
-
|
|
35
|
+
|
|
36
|
+
// Create the MeanPooling object.
|
|
37
|
+
MeanPooling();
|
|
35
38
|
|
|
36
39
|
/**
|
|
37
40
|
* Create the MeanPooling object using the specified number of units.
|
|
@@ -43,26 +46,26 @@ class MeanPoolingType : public Layer<MatType>
|
|
|
43
46
|
* @param floor If true, then a pooling operation that would oly part of the
|
|
44
47
|
* input will be skipped.
|
|
45
48
|
*/
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
MeanPooling(const size_t kernelWidth,
|
|
50
|
+
const size_t kernelHeight,
|
|
51
|
+
const size_t strideWidth = 1,
|
|
52
|
+
const size_t strideHeight = 1,
|
|
53
|
+
const bool floor = true);
|
|
51
54
|
|
|
52
55
|
// Virtual destructor.
|
|
53
|
-
virtual ~
|
|
56
|
+
virtual ~MeanPooling() { }
|
|
54
57
|
|
|
55
|
-
//! Copy the given
|
|
56
|
-
|
|
57
|
-
//! Take ownership of the given
|
|
58
|
-
|
|
59
|
-
//! Copy the given
|
|
60
|
-
|
|
61
|
-
//! Take ownership of the given
|
|
62
|
-
|
|
58
|
+
//! Copy the given MeanPooling.
|
|
59
|
+
MeanPooling(const MeanPooling& other);
|
|
60
|
+
//! Take ownership of the given MeanPooling.
|
|
61
|
+
MeanPooling(MeanPooling&& other);
|
|
62
|
+
//! Copy the given MeanPooling.
|
|
63
|
+
MeanPooling& operator=(const MeanPooling& other);
|
|
64
|
+
//! Take ownership of the given MeanPooling.
|
|
65
|
+
MeanPooling& operator=(MeanPooling&& other);
|
|
63
66
|
|
|
64
|
-
//! Clone the
|
|
65
|
-
|
|
67
|
+
//! Clone the MeanPooling object. This handles polymorphism correctly.
|
|
68
|
+
MeanPooling* Clone() const { return new MeanPooling(*this); }
|
|
66
69
|
|
|
67
70
|
/**
|
|
68
71
|
* Ordinary feed forward pass of a neural network, evaluating the function
|
|
@@ -149,7 +152,7 @@ class MeanPoolingType : public Layer<MatType>
|
|
|
149
152
|
*/
|
|
150
153
|
typename MatType::elem_type Pooling(const MatType& input)
|
|
151
154
|
{
|
|
152
|
-
return
|
|
155
|
+
return mean(vectorise(input));
|
|
153
156
|
}
|
|
154
157
|
|
|
155
158
|
//! Locally-stored width of the pooling window.
|
|
@@ -169,10 +172,7 @@ class MeanPoolingType : public Layer<MatType>
|
|
|
169
172
|
|
|
170
173
|
//! Locally-stored number channels.
|
|
171
174
|
size_t channels;
|
|
172
|
-
}; // class
|
|
173
|
-
|
|
174
|
-
// Standard MeanPooling layer.
|
|
175
|
-
using MeanPooling = MeanPoolingType<arma::mat>;
|
|
175
|
+
}; // class MeanPooling
|
|
176
176
|
|
|
177
177
|
} // namespace mlpack
|
|
178
178
|
|