mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file methods/ann/layer/gru.hpp
|
|
3
|
+
* @author Sumedh Ghaisas
|
|
4
|
+
* @author Zachary Ng
|
|
5
|
+
*
|
|
6
|
+
* Definition of the GRU layer.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_METHODS_ANN_LAYER_GRU_HPP
|
|
14
|
+
#define MLPACK_METHODS_ANN_LAYER_GRU_HPP
|
|
15
|
+
|
|
16
|
+
#include <mlpack/prereqs.hpp>
|
|
17
|
+
|
|
18
|
+
namespace mlpack {
|
|
19
|
+
|
|
20
|
+
/**
|
|
21
|
+
* An implementation of a gru network layer using the following algorithm.
|
|
22
|
+
*
|
|
23
|
+
* r_t = sigmoid(W_r x_t + U_r y_{t - 1})
|
|
24
|
+
* z_t = sigmoid(W_z x_t + U_z y_{t - 1})
|
|
25
|
+
* h_t = tanh(W_h x_t + r_t % (U_h y_{t - 1}))
|
|
26
|
+
* y_t = (1 - z_t) % y_{t - 1} + z_t % h_t
|
|
27
|
+
*
|
|
28
|
+
* For more information, read the following paper:
|
|
29
|
+
*
|
|
30
|
+
* @code
|
|
31
|
+
* @inproceedings{chung2015gated,
|
|
32
|
+
* title = {Gated Feedback Recurrent Neural Networks},
|
|
33
|
+
* author = {Chung, Junyoung and G{\"u}l{\c{c}}ehre, Caglar and Cho,
|
|
34
|
+
* Kyunghyun and Bengio, Yoshua},
|
|
35
|
+
* booktitle = {ICML},
|
|
36
|
+
* pages = {2067--2075},
|
|
37
|
+
* year = {2015},
|
|
38
|
+
* url = {https://arxiv.org/abs/1502.02367}
|
|
39
|
+
* }
|
|
40
|
+
* @endcode
|
|
41
|
+
*
|
|
42
|
+
* This cell can be used in RNNs.
|
|
43
|
+
*
|
|
44
|
+
* @tparam MatType Type of the input data (arma::colvec, arma::mat,
|
|
45
|
+
* arma::sp_mat or arma::cube).
|
|
46
|
+
*/
|
|
47
|
+
template <typename MatType = arma::mat>
|
|
48
|
+
class GRU : public RecurrentLayer<MatType>
|
|
49
|
+
{
|
|
50
|
+
public:
|
|
51
|
+
// Create the GRU object.
|
|
52
|
+
GRU();
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* Create the GRU layer object using the specified parameters.
|
|
56
|
+
*
|
|
57
|
+
* @param outSize The number of output units.
|
|
58
|
+
*/
|
|
59
|
+
GRU(const size_t outSize);
|
|
60
|
+
|
|
61
|
+
// Clone the GRU object. This handles polymorphism correctly.
|
|
62
|
+
GRU* Clone() const { return new GRU(*this); }
|
|
63
|
+
|
|
64
|
+
// Copy the given GRU object.
|
|
65
|
+
GRU(const GRU& other);
|
|
66
|
+
// Take ownership of the given GRU object's data.
|
|
67
|
+
GRU(GRU&& other);
|
|
68
|
+
// Copy the given GRU object.
|
|
69
|
+
GRU& operator=(const GRU& other);
|
|
70
|
+
// Take ownership of the given GRU object's data.
|
|
71
|
+
GRU& operator=(GRU&& other);
|
|
72
|
+
|
|
73
|
+
virtual ~GRU() { }
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
* Reset the layer parameter. The method is called to
|
|
77
|
+
* assign the allocated memory to the internal learnable parameters.
|
|
78
|
+
*/
|
|
79
|
+
void SetWeights(const MatType& weightsIn);
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* Ordinary feed forward pass of a neural network, evaluating the function
|
|
83
|
+
* f(x) by propagating the activity forward through f.
|
|
84
|
+
*
|
|
85
|
+
* @param input Input data used for evaluating the specified function.
|
|
86
|
+
* @param output Resulting output activation.
|
|
87
|
+
*/
|
|
88
|
+
void Forward(const MatType& input, MatType& output);
|
|
89
|
+
|
|
90
|
+
/**
|
|
91
|
+
* Ordinary feed backward pass of a neural network, calculating the function
|
|
92
|
+
* f(x) by propagating x backwards trough f. Using the results from the feed
|
|
93
|
+
* forward pass.
|
|
94
|
+
*
|
|
95
|
+
* @param input The input data (x) given to the forward pass.
|
|
96
|
+
* @param output The propagated data (f(x)) resulting from Forward()
|
|
97
|
+
* @param gy Propagated error from next layer.
|
|
98
|
+
* @param g Matrix to store propagated error in for previous layer.
|
|
99
|
+
*/
|
|
100
|
+
void Backward(const MatType& /* input */,
|
|
101
|
+
const MatType& output,
|
|
102
|
+
const MatType& gy,
|
|
103
|
+
MatType& g);
|
|
104
|
+
|
|
105
|
+
/*
|
|
106
|
+
* Calculate the gradient using the output delta and the input activation.
|
|
107
|
+
*
|
|
108
|
+
* @param input Original input data provided to Forward().
|
|
109
|
+
* @param error Error as computed by `Backward()`.
|
|
110
|
+
* @param gradient Matrix to store the gradients in.
|
|
111
|
+
*/
|
|
112
|
+
void Gradient(const MatType& input,
|
|
113
|
+
const MatType& /* error */,
|
|
114
|
+
MatType& gradient);
|
|
115
|
+
|
|
116
|
+
// Get the parameters.
|
|
117
|
+
MatType const& Parameters() const { return weights; }
|
|
118
|
+
// Modify the parameters.
|
|
119
|
+
MatType& Parameters() { return weights; }
|
|
120
|
+
|
|
121
|
+
// Get the total number of trainable parameters.
|
|
122
|
+
size_t WeightSize() const;
|
|
123
|
+
|
|
124
|
+
// Get the total number of recurrent state parameters.
|
|
125
|
+
size_t RecurrentSize() const;
|
|
126
|
+
|
|
127
|
+
// Given a properly set InputDimensions(), compute the output dimensions.
|
|
128
|
+
void ComputeOutputDimensions()
|
|
129
|
+
{
|
|
130
|
+
inSize = this->inputDimensions[0];
|
|
131
|
+
for (size_t i = 1; i < this->inputDimensions.size(); ++i)
|
|
132
|
+
inSize *= this->inputDimensions[i];
|
|
133
|
+
this->outputDimensions = std::vector<size_t>(this->inputDimensions.size(),
|
|
134
|
+
1);
|
|
135
|
+
|
|
136
|
+
// The GRU layer flattens its input.
|
|
137
|
+
this->outputDimensions[0] = outSize;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
// Update the internal aliases of the layer when the step changes.
|
|
141
|
+
void OnStepChanged(const size_t step,
|
|
142
|
+
const size_t batchSize,
|
|
143
|
+
const size_t activeBatchSize,
|
|
144
|
+
const bool backwards);
|
|
145
|
+
|
|
146
|
+
/**
|
|
147
|
+
* Serialize the layer
|
|
148
|
+
*/
|
|
149
|
+
template<typename Archive>
|
|
150
|
+
void serialize(Archive& ar, const uint32_t /* version */);
|
|
151
|
+
|
|
152
|
+
private:
|
|
153
|
+
// Locally-stored number of input units.
|
|
154
|
+
size_t inSize;
|
|
155
|
+
|
|
156
|
+
// Locally-stored number of output units.
|
|
157
|
+
size_t outSize;
|
|
158
|
+
|
|
159
|
+
// Locally-stored weight object.
|
|
160
|
+
MatType weights;
|
|
161
|
+
|
|
162
|
+
// Weight aliases for input connections.
|
|
163
|
+
MatType resetGateWeight;
|
|
164
|
+
MatType updateGateWeight;
|
|
165
|
+
MatType hiddenGateWeight;
|
|
166
|
+
|
|
167
|
+
// Weight aliases for recurrent connections.
|
|
168
|
+
MatType recurrentResetGateWeight;
|
|
169
|
+
MatType recurrentUpdateGateWeight;
|
|
170
|
+
MatType recurrentHiddenGateWeight;
|
|
171
|
+
|
|
172
|
+
// Recurrent state aliases.
|
|
173
|
+
MatType resetGate;
|
|
174
|
+
MatType updateGate;
|
|
175
|
+
MatType hiddenGate;
|
|
176
|
+
MatType currentOutput;
|
|
177
|
+
MatType prevOutput;
|
|
178
|
+
|
|
179
|
+
// Backwards workspace
|
|
180
|
+
MatType workspace;
|
|
181
|
+
MatType deltaReset;
|
|
182
|
+
MatType deltaUpdate;
|
|
183
|
+
MatType deltaHidden;
|
|
184
|
+
// These correspond to, e.g., dy_{t + 1}.
|
|
185
|
+
MatType nextDeltaReset;
|
|
186
|
+
MatType nextDeltaUpdate;
|
|
187
|
+
MatType nextDeltaHidden;
|
|
188
|
+
}; // class GRU
|
|
189
|
+
|
|
190
|
+
} // namespace mlpack
|
|
191
|
+
|
|
192
|
+
// Include implementation.
|
|
193
|
+
#include "gru_impl.hpp"
|
|
194
|
+
|
|
195
|
+
#endif
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file methods/ann/layer/gru_impl.hpp
|
|
3
|
+
* @author Sumedh Ghaisas
|
|
4
|
+
* @author Zachary Ng
|
|
5
|
+
*
|
|
6
|
+
* Implementation of the GRU class, which implements a gru network
|
|
7
|
+
* layer.
|
|
8
|
+
*
|
|
9
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
10
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
11
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
12
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
13
|
+
*/
|
|
14
|
+
#ifndef MLPACK_METHODS_ANN_LAYER_GRU_IMPL_HPP
|
|
15
|
+
#define MLPACK_METHODS_ANN_LAYER_GRU_IMPL_HPP
|
|
16
|
+
|
|
17
|
+
// In case it hasn't yet been included.
|
|
18
|
+
#include "gru.hpp"
|
|
19
|
+
|
|
20
|
+
namespace mlpack {
|
|
21
|
+
|
|
22
|
+
template<typename MatType>
|
|
23
|
+
GRU<MatType>::GRU() :
|
|
24
|
+
RecurrentLayer<MatType>()
|
|
25
|
+
{
|
|
26
|
+
// Nothing to do here.
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
template<typename MatType>
|
|
30
|
+
GRU<MatType>::GRU(const size_t outSize) :
|
|
31
|
+
RecurrentLayer<MatType>(),
|
|
32
|
+
inSize(0),
|
|
33
|
+
outSize(outSize)
|
|
34
|
+
{
|
|
35
|
+
// Nothing to do here.
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
template<typename MatType>
|
|
39
|
+
GRU<MatType>::GRU(const GRU& other) :
|
|
40
|
+
RecurrentLayer<MatType>(other),
|
|
41
|
+
inSize(other.inSize),
|
|
42
|
+
outSize(other.outSize)
|
|
43
|
+
{
|
|
44
|
+
// Nothing to do here.
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
template<typename MatType>
|
|
48
|
+
GRU<MatType>::GRU(GRU&& other) :
|
|
49
|
+
RecurrentLayer<MatType>(std::move(other)),
|
|
50
|
+
inSize(other.inSize),
|
|
51
|
+
outSize(other.outSize)
|
|
52
|
+
{
|
|
53
|
+
// Nothing to do here.
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
template<typename MatType>
|
|
57
|
+
GRU<MatType>& GRU<MatType>::operator=(const GRU& other)
|
|
58
|
+
{
|
|
59
|
+
if (this != &other)
|
|
60
|
+
{
|
|
61
|
+
RecurrentLayer<MatType>::operator=(other);
|
|
62
|
+
inSize = other.inSize;
|
|
63
|
+
outSize = other.outSize;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
return *this;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
template<typename MatType>
|
|
70
|
+
GRU<MatType>& GRU<MatType>::operator=(GRU&& other)
|
|
71
|
+
{
|
|
72
|
+
if (this != &other)
|
|
73
|
+
{
|
|
74
|
+
RecurrentLayer<MatType>::operator=(std::move(other));
|
|
75
|
+
inSize = other.inSize;
|
|
76
|
+
outSize = other.outSize;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return *this;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
template<typename MatType>
|
|
84
|
+
void GRU<MatType>::SetWeights(const MatType& weightsIn)
|
|
85
|
+
{
|
|
86
|
+
MakeAlias(weights, weightsIn, weightsIn.n_rows, weightsIn.n_cols);
|
|
87
|
+
|
|
88
|
+
const size_t inputWeightSize = outSize * inSize;
|
|
89
|
+
MakeAlias(resetGateWeight, weightsIn, outSize, inSize, 0);
|
|
90
|
+
MakeAlias(updateGateWeight, weightsIn, outSize, inSize, inputWeightSize);
|
|
91
|
+
MakeAlias(hiddenGateWeight, weightsIn, outSize, inSize, inputWeightSize * 2);
|
|
92
|
+
|
|
93
|
+
const size_t recurrentWeightOffset = inputWeightSize * 3;
|
|
94
|
+
const size_t recurrentWeightSize = outSize * outSize;
|
|
95
|
+
MakeAlias(recurrentResetGateWeight, weightsIn, outSize, outSize,
|
|
96
|
+
recurrentWeightOffset);
|
|
97
|
+
MakeAlias(recurrentUpdateGateWeight, weightsIn, outSize, outSize,
|
|
98
|
+
recurrentWeightOffset + recurrentWeightSize);
|
|
99
|
+
MakeAlias(recurrentHiddenGateWeight, weightsIn, outSize, outSize,
|
|
100
|
+
recurrentWeightOffset + recurrentWeightSize * 2);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
template<typename MatType>
|
|
104
|
+
void GRU<MatType>::Forward(const MatType& input, MatType& output)
|
|
105
|
+
{
|
|
106
|
+
// Compute internal state using the following algorithm.
|
|
107
|
+
// r_t = sigmoid(W_r x_t + U_r y_{t - 1})
|
|
108
|
+
// z_t = sigmoid(W_z x_t + U_z y_{t - 1})
|
|
109
|
+
// h_t = tanh(W_h x_t + r_t % (U_h y_{t - 1}))
|
|
110
|
+
// y_t = (1 - z_t) % y_{t - 1} + z_t % h_t
|
|
111
|
+
|
|
112
|
+
// Process non recurrent input.
|
|
113
|
+
updateGate = updateGateWeight * input;
|
|
114
|
+
resetGate = resetGateWeight * input;
|
|
115
|
+
|
|
116
|
+
// Add recurrent input.
|
|
117
|
+
if (this->HasPreviousStep())
|
|
118
|
+
{
|
|
119
|
+
resetGate += recurrentResetGateWeight * prevOutput;
|
|
120
|
+
updateGate += recurrentUpdateGateWeight * prevOutput;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// Apply sigmoid activation function.
|
|
124
|
+
resetGate = 1 / (1 + exp(-resetGate));
|
|
125
|
+
updateGate = 1 / (1 + exp(-updateGate));
|
|
126
|
+
|
|
127
|
+
// Calculate candidate activation vector.
|
|
128
|
+
hiddenGate = hiddenGateWeight * input;
|
|
129
|
+
|
|
130
|
+
// Add recurrent portion to activation vector.
|
|
131
|
+
if (this->HasPreviousStep())
|
|
132
|
+
{
|
|
133
|
+
hiddenGate += resetGate % (recurrentHiddenGateWeight * prevOutput);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// Apply tanh activation function.
|
|
137
|
+
hiddenGate = tanh(hiddenGate);
|
|
138
|
+
|
|
139
|
+
// Compute output.
|
|
140
|
+
output = updateGate % hiddenGate;
|
|
141
|
+
|
|
142
|
+
// Add recurrent portion to output.
|
|
143
|
+
if (this->HasPreviousStep())
|
|
144
|
+
{
|
|
145
|
+
output += (1 - updateGate) % prevOutput;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
currentOutput = output;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
template<typename MatType>
|
|
152
|
+
void GRU<MatType>::Backward(
|
|
153
|
+
const MatType& /* input */,
|
|
154
|
+
const MatType& /* output */,
|
|
155
|
+
const MatType& gy,
|
|
156
|
+
MatType& g)
|
|
157
|
+
{
|
|
158
|
+
// Work backwards to get error at each gate.
|
|
159
|
+
// y_t = (1 - z_t) % y_{t - 1} + z_t % h_t
|
|
160
|
+
// dh_t = dy % z_t
|
|
161
|
+
deltaHidden = gy % updateGate;
|
|
162
|
+
// The hidden gate uses a tanh activation function.
|
|
163
|
+
// The derivative of tanh(x) is actually 1 - tanh^2(x) but
|
|
164
|
+
// tanh has already been applied to hiddenGate in Forward().
|
|
165
|
+
deltaHidden = deltaHidden % (1 - square(hiddenGate));
|
|
166
|
+
|
|
167
|
+
// y_t = (1 - z_t) % y_{t - 1} + z_t % h_t
|
|
168
|
+
// dz_t = dy % h_t - dy % y_{t - 1}
|
|
169
|
+
deltaUpdate = gy % hiddenGate;
|
|
170
|
+
if (this->HasPreviousStep())
|
|
171
|
+
deltaUpdate -= gy % prevOutput;
|
|
172
|
+
// The reset and update gate use sigmoid activation.
|
|
173
|
+
// The derivative is sigmoid(x) * (1 - sigmoid(x)). Since sigmoid has
|
|
174
|
+
// already been applied to the gates, it's just `x * (1 - x)`
|
|
175
|
+
deltaUpdate = deltaUpdate % (updateGate % (1 - updateGate));
|
|
176
|
+
|
|
177
|
+
if (this->HasPreviousStep())
|
|
178
|
+
{
|
|
179
|
+
// h_t = tanh(W_h x_t + r_t % (U_h y_{t - 1}))
|
|
180
|
+
// dr_t = dh_t % (U_h y_{t - 1})
|
|
181
|
+
deltaReset = deltaHidden % (recurrentHiddenGateWeight * prevOutput);
|
|
182
|
+
deltaReset = deltaReset % (resetGate % (1 - resetGate));
|
|
183
|
+
}
|
|
184
|
+
else
|
|
185
|
+
{
|
|
186
|
+
deltaReset.zeros(deltaHidden.n_rows, deltaHidden.n_cols);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// Calculate the input error.
|
|
190
|
+
// r_t = sigmoid(W_r x_t + U_r y_{t - 1})
|
|
191
|
+
// z_t = sigmoid(W_z x_t + U_z y_{t - 1})
|
|
192
|
+
// h_t = tanh(W_h x_t + r_t % (U_h y_{t - 1}))
|
|
193
|
+
// dx_t = W_r * dr_t + W_z * dz_t + W_h * dh_t
|
|
194
|
+
g = resetGateWeight.t() * deltaReset +
|
|
195
|
+
updateGateWeight.t() * deltaUpdate +
|
|
196
|
+
hiddenGateWeight.t() * deltaHidden;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
template<typename MatType>
|
|
200
|
+
void GRU<MatType>::Gradient(
|
|
201
|
+
const MatType& input,
|
|
202
|
+
const MatType& /* error */,
|
|
203
|
+
MatType& gradient)
|
|
204
|
+
{
|
|
205
|
+
size_t offset = 0;
|
|
206
|
+
// Non recurrent reset gate weights.
|
|
207
|
+
gradient.submat(offset, 0, offset + resetGateWeight.n_elem - 1, 0) =
|
|
208
|
+
vectorise(deltaReset * input.t());
|
|
209
|
+
offset += resetGateWeight.n_elem;
|
|
210
|
+
// Non recurrent update gate weights.
|
|
211
|
+
gradient.submat(offset, 0, offset + updateGateWeight.n_elem - 1, 0) =
|
|
212
|
+
vectorise(deltaUpdate * input.t());
|
|
213
|
+
offset += updateGateWeight.n_elem;
|
|
214
|
+
// Non recurrent hidden gate weights.
|
|
215
|
+
gradient.submat(offset, 0, offset + hiddenGateWeight.n_elem - 1, 0) =
|
|
216
|
+
vectorise(deltaHidden * input.t());
|
|
217
|
+
offset += hiddenGateWeight.n_elem;
|
|
218
|
+
|
|
219
|
+
// nextDelta is not set until after the first step.
|
|
220
|
+
if (!this->AtFinalStep())
|
|
221
|
+
{
|
|
222
|
+
// Recurrent reset gate weights.
|
|
223
|
+
gradient.submat(offset, 0, offset + recurrentResetGateWeight.n_elem - 1,
|
|
224
|
+
0) = vectorise(nextDeltaReset * currentOutput.t());
|
|
225
|
+
offset += recurrentResetGateWeight.n_elem;
|
|
226
|
+
// Recurrent update gate weights.
|
|
227
|
+
gradient.submat(offset, 0, offset + recurrentUpdateGateWeight.n_elem - 1,
|
|
228
|
+
0) = vectorise(nextDeltaUpdate * currentOutput.t());
|
|
229
|
+
offset += recurrentUpdateGateWeight.n_elem;
|
|
230
|
+
// Recurrent hidden gate weights.
|
|
231
|
+
gradient.submat(offset, 0, offset + recurrentHiddenGateWeight.n_elem - 1,
|
|
232
|
+
0) = vectorise(nextDeltaHidden * currentOutput.t());
|
|
233
|
+
offset += recurrentHiddenGateWeight.n_elem;
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
template<typename MatType>
|
|
238
|
+
size_t GRU<MatType>::WeightSize() const
|
|
239
|
+
{
|
|
240
|
+
return outSize * inSize * 3 + /* Input weight connections */
|
|
241
|
+
outSize * outSize * 3; /* Recurrent weight connections */
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
template<typename MatType>
|
|
245
|
+
size_t GRU<MatType>::RecurrentSize() const
|
|
246
|
+
{
|
|
247
|
+
// The recurrent state has to store the output, reset gate, update gate,
|
|
248
|
+
// and hidden gate.
|
|
249
|
+
return outSize * 4;
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
template<typename MatType>
|
|
253
|
+
void GRU<MatType>::OnStepChanged(const size_t step,
|
|
254
|
+
const size_t batchSize,
|
|
255
|
+
const size_t activeBatchSize,
|
|
256
|
+
const bool backwards)
|
|
257
|
+
{
|
|
258
|
+
// Make aliases for the internal gate states from the recurrent state.
|
|
259
|
+
MatType& state = this->RecurrentState(step);
|
|
260
|
+
|
|
261
|
+
MakeAlias(currentOutput, state, outSize, activeBatchSize);
|
|
262
|
+
MakeAlias(resetGate, state, outSize, activeBatchSize, outSize * batchSize);
|
|
263
|
+
MakeAlias(updateGate, state, outSize, activeBatchSize, 2 * outSize *
|
|
264
|
+
batchSize);
|
|
265
|
+
MakeAlias(hiddenGate, state, outSize, activeBatchSize, 3 * outSize *
|
|
266
|
+
batchSize);
|
|
267
|
+
|
|
268
|
+
if (this->HasPreviousStep())
|
|
269
|
+
{
|
|
270
|
+
MatType& prevState = this->RecurrentState(this->PreviousStep());
|
|
271
|
+
MakeAlias(prevOutput, prevState, outSize, activeBatchSize);
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
// Also set the workspaces for the backwards pass, if requested.
|
|
275
|
+
if (backwards)
|
|
276
|
+
{
|
|
277
|
+
// We need to hold enough space for two time steps.
|
|
278
|
+
workspace.set_size(6 * outSize, batchSize);
|
|
279
|
+
|
|
280
|
+
if (step % 2 == 0)
|
|
281
|
+
{
|
|
282
|
+
MakeAlias(deltaReset, workspace, outSize, activeBatchSize);
|
|
283
|
+
MakeAlias(deltaUpdate, workspace, outSize, activeBatchSize,
|
|
284
|
+
outSize * batchSize);
|
|
285
|
+
MakeAlias(deltaHidden, workspace, outSize, activeBatchSize,
|
|
286
|
+
2 * outSize * batchSize);
|
|
287
|
+
|
|
288
|
+
MakeAlias(nextDeltaReset, workspace, outSize, activeBatchSize,
|
|
289
|
+
3 * outSize * batchSize);
|
|
290
|
+
MakeAlias(nextDeltaUpdate, workspace, outSize, activeBatchSize,
|
|
291
|
+
4 * outSize * batchSize);
|
|
292
|
+
MakeAlias(nextDeltaHidden, workspace, outSize, activeBatchSize,
|
|
293
|
+
5 * outSize * batchSize);
|
|
294
|
+
}
|
|
295
|
+
else
|
|
296
|
+
{
|
|
297
|
+
MakeAlias(nextDeltaReset, workspace, outSize, activeBatchSize);
|
|
298
|
+
MakeAlias(nextDeltaUpdate, workspace, outSize, activeBatchSize,
|
|
299
|
+
outSize * batchSize);
|
|
300
|
+
MakeAlias(nextDeltaHidden, workspace, outSize, activeBatchSize,
|
|
301
|
+
2 * outSize * batchSize);
|
|
302
|
+
|
|
303
|
+
MakeAlias(deltaReset, workspace, outSize, activeBatchSize,
|
|
304
|
+
3 * outSize * batchSize);
|
|
305
|
+
MakeAlias(deltaUpdate, workspace, outSize, activeBatchSize,
|
|
306
|
+
4 * outSize * batchSize);
|
|
307
|
+
MakeAlias(deltaHidden, workspace, outSize, activeBatchSize,
|
|
308
|
+
5 * outSize * batchSize);
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
template<typename MatType>
|
|
314
|
+
template<typename Archive>
|
|
315
|
+
void GRU<MatType>::serialize(Archive& ar, const uint32_t /* version */)
|
|
316
|
+
{
|
|
317
|
+
ar(cereal::base_class<RecurrentLayer<MatType>>(this));
|
|
318
|
+
|
|
319
|
+
ar(CEREAL_NVP(inSize));
|
|
320
|
+
ar(CEREAL_NVP(outSize));
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
} // namespace mlpack
|
|
324
|
+
|
|
325
|
+
#endif
|
|
@@ -46,9 +46,12 @@ namespace mlpack {
|
|
|
46
46
|
* type to differ from the input type (Default: arma::mat).
|
|
47
47
|
*/
|
|
48
48
|
template <typename MatType = arma::mat>
|
|
49
|
-
class
|
|
49
|
+
class HardTanH : public Layer<MatType>
|
|
50
50
|
{
|
|
51
51
|
public:
|
|
52
|
+
// Convenience typedef to access the element type of the weights and data.
|
|
53
|
+
using ElemType = typename MatType::elem_type;
|
|
54
|
+
|
|
52
55
|
/**
|
|
53
56
|
* Create the HardTanH object using the specified parameters. The range
|
|
54
57
|
* of the linear region can be adjusted by specifying the maxValue and
|
|
@@ -57,24 +60,24 @@ class HardTanHType : public Layer<MatType>
|
|
|
57
60
|
* @param maxValue Range of the linear region maximum value.
|
|
58
61
|
* @param minValue Range of the linear region minimum value.
|
|
59
62
|
*/
|
|
60
|
-
|
|
63
|
+
HardTanH(const double maxValue = 1, const double minValue = -1);
|
|
61
64
|
|
|
62
|
-
virtual ~
|
|
65
|
+
virtual ~HardTanH() { }
|
|
63
66
|
|
|
64
67
|
//! Copy the other HardTanH layer
|
|
65
|
-
|
|
68
|
+
HardTanH(const HardTanH& layer);
|
|
66
69
|
|
|
67
70
|
//! Take ownership of the members of the other HardTanH Layer
|
|
68
|
-
|
|
71
|
+
HardTanH(HardTanH&& layer);
|
|
69
72
|
|
|
70
73
|
//! Copy the other HardTanH layer
|
|
71
|
-
|
|
74
|
+
HardTanH& operator=(const HardTanH& layer);
|
|
72
75
|
|
|
73
76
|
//! Take ownership of the members of the other HardTanH Layer
|
|
74
|
-
|
|
77
|
+
HardTanH& operator=(HardTanH&& layer);
|
|
75
78
|
|
|
76
|
-
//! Clone the
|
|
77
|
-
|
|
79
|
+
//! Clone the HardTanH object. This handles polymorphism correctly.
|
|
80
|
+
HardTanH* Clone() const { return new HardTanH(*this); }
|
|
78
81
|
|
|
79
82
|
/**
|
|
80
83
|
* Ordinary feed forward pass of a neural network, evaluating the function
|
|
@@ -122,12 +125,7 @@ class HardTanHType : public Layer<MatType>
|
|
|
122
125
|
|
|
123
126
|
//! Minimum value for the HardTanH function.
|
|
124
127
|
double minValue;
|
|
125
|
-
}; // class
|
|
126
|
-
|
|
127
|
-
// Convenience typedefs.
|
|
128
|
-
|
|
129
|
-
// Standard HardTanH layer.
|
|
130
|
-
using HardTanH = HardTanHType<arma::mat>;
|
|
128
|
+
}; // class HardTanH
|
|
131
129
|
|
|
132
130
|
} // namespace mlpack
|
|
133
131
|
|
|
@@ -19,7 +19,7 @@
|
|
|
19
19
|
namespace mlpack {
|
|
20
20
|
|
|
21
21
|
template<typename MatType>
|
|
22
|
-
|
|
22
|
+
HardTanH<MatType>::HardTanH(
|
|
23
23
|
const double maxValue,
|
|
24
24
|
const double minValue) :
|
|
25
25
|
Layer<MatType>(),
|
|
@@ -30,7 +30,7 @@ HardTanHType<MatType>::HardTanHType(
|
|
|
30
30
|
}
|
|
31
31
|
|
|
32
32
|
template<typename MatType>
|
|
33
|
-
|
|
33
|
+
HardTanH<MatType>::HardTanH(const HardTanH& layer) :
|
|
34
34
|
Layer<MatType>(layer),
|
|
35
35
|
maxValue(layer.maxValue),
|
|
36
36
|
minValue(layer.minValue)
|
|
@@ -39,7 +39,7 @@ HardTanHType<MatType>::HardTanHType(const HardTanHType& layer) :
|
|
|
39
39
|
}
|
|
40
40
|
|
|
41
41
|
template<typename MatType>
|
|
42
|
-
|
|
42
|
+
HardTanH<MatType>::HardTanH(HardTanH&& layer) :
|
|
43
43
|
Layer<MatType>(std::move(layer)),
|
|
44
44
|
maxValue(std::move(layer.maxValue)),
|
|
45
45
|
minValue(std::move(layer.minValue))
|
|
@@ -48,8 +48,8 @@ HardTanHType<MatType>::HardTanHType(HardTanHType&& layer) :
|
|
|
48
48
|
}
|
|
49
49
|
|
|
50
50
|
template<typename MatType>
|
|
51
|
-
|
|
52
|
-
const
|
|
51
|
+
HardTanH<MatType>& HardTanH<MatType>::operator=(
|
|
52
|
+
const HardTanH& layer)
|
|
53
53
|
{
|
|
54
54
|
if (&layer != this)
|
|
55
55
|
{
|
|
@@ -62,7 +62,7 @@ HardTanHType<MatType>& HardTanHType<MatType>::operator=(
|
|
|
62
62
|
}
|
|
63
63
|
|
|
64
64
|
template<typename MatType>
|
|
65
|
-
|
|
65
|
+
HardTanH<MatType>& HardTanH<MatType>::operator=(HardTanH&& layer)
|
|
66
66
|
{
|
|
67
67
|
if (&layer != this)
|
|
68
68
|
{
|
|
@@ -74,19 +74,19 @@ HardTanHType<MatType>& HardTanHType<MatType>::operator=(HardTanHType&& layer)
|
|
|
74
74
|
return *this;
|
|
75
75
|
}
|
|
76
76
|
template<typename MatType>
|
|
77
|
-
void
|
|
77
|
+
void HardTanH<MatType>::Forward(
|
|
78
78
|
const MatType& input, MatType& output)
|
|
79
79
|
{
|
|
80
80
|
#pragma omp parallel for
|
|
81
81
|
for (size_t i = 0; i < input.n_elem; ++i)
|
|
82
82
|
{
|
|
83
|
-
output(i) = (input(i) > maxValue ? maxValue :
|
|
84
|
-
(input(i) < minValue ? minValue : input(i)));
|
|
83
|
+
output(i) = (input(i) > ElemType(maxValue) ? ElemType(maxValue) :
|
|
84
|
+
(input(i) < ElemType(minValue) ? ElemType(minValue) : input(i)));
|
|
85
85
|
}
|
|
86
86
|
}
|
|
87
87
|
|
|
88
88
|
template<typename MatType>
|
|
89
|
-
void
|
|
89
|
+
void HardTanH<MatType>::Backward(
|
|
90
90
|
const MatType& input,
|
|
91
91
|
const MatType& /* output */,
|
|
92
92
|
const MatType& gy,
|
|
@@ -99,7 +99,7 @@ void HardTanHType<MatType>::Backward(
|
|
|
99
99
|
{
|
|
100
100
|
// input should not have any values greater than maxValue
|
|
101
101
|
// and lesser than minValue
|
|
102
|
-
if (input(i) <= minValue || input(i) >= maxValue)
|
|
102
|
+
if (input(i) <= ElemType(minValue) || input(i) >= ElemType(maxValue))
|
|
103
103
|
{
|
|
104
104
|
g(i) = 0;
|
|
105
105
|
}
|
|
@@ -108,7 +108,7 @@ void HardTanHType<MatType>::Backward(
|
|
|
108
108
|
|
|
109
109
|
template<typename MatType>
|
|
110
110
|
template<typename Archive>
|
|
111
|
-
void
|
|
111
|
+
void HardTanH<MatType>::serialize(
|
|
112
112
|
Archive& ar,
|
|
113
113
|
const uint32_t /* version */)
|
|
114
114
|
{
|