mlpack 4.6.2__cp39-cp39-win_amd64.whl → 4.7.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +6 -6
- mlpack/adaboost_classify.cp39-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp39-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp39-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp39-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp39-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp39-win_amd64.pyd +0 -0
- mlpack/cf.cp39-win_amd64.pyd +0 -0
- mlpack/dbscan.cp39-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp39-win_amd64.pyd +0 -0
- mlpack/det.cp39-win_amd64.pyd +0 -0
- mlpack/emst.cp39-win_amd64.pyd +0 -0
- mlpack/fastmks.cp39-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp39-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp39-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp39-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp39-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp39-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp39-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp39-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp39-win_amd64.pyd +0 -0
- mlpack/image_converter.cp39-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp39-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp39-win_amd64.pyd +0 -0
- mlpack/kfn.cp39-win_amd64.pyd +0 -0
- mlpack/kmeans.cp39-win_amd64.pyd +0 -0
- mlpack/knn.cp39-win_amd64.pyd +0 -0
- mlpack/krann.cp39-win_amd64.pyd +0 -0
- mlpack/lars.cp39-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp39-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp39-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp39-win_amd64.pyd +0 -0
- mlpack/lmnn.cp39-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp39-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp39-win_amd64.pyd +0 -0
- mlpack/lsh.cp39-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp39-win_amd64.pyd +0 -0
- mlpack/nbc.cp39-win_amd64.pyd +0 -0
- mlpack/nca.cp39-win_amd64.pyd +0 -0
- mlpack/nmf.cp39-win_amd64.pyd +0 -0
- mlpack/pca.cp39-win_amd64.pyd +0 -0
- mlpack/perceptron.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp39-win_amd64.pyd +0 -0
- mlpack/radical.cp39-win_amd64.pyd +0 -0
- mlpack/random_forest.cp39-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp39-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp39-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +397 -378
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack.libs/.load-order-mlpack-4.7.0 +2 -0
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- mlpack.libs/.load-order-mlpack-4.6.2 +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -22,8 +22,8 @@
|
|
|
22
22
|
namespace mlpack {
|
|
23
23
|
|
|
24
24
|
template <typename MatType, typename RegularizerType>
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
MultiheadAttention<MatType, RegularizerType>::
|
|
26
|
+
MultiheadAttention() :
|
|
27
27
|
tgtSeqLen(0),
|
|
28
28
|
srcSeqLen(0),
|
|
29
29
|
embedDim(0),
|
|
@@ -35,11 +35,11 @@ MultiheadAttentionType() :
|
|
|
35
35
|
}
|
|
36
36
|
|
|
37
37
|
template <typename MatType, typename RegularizerType>
|
|
38
|
-
|
|
39
|
-
|
|
38
|
+
MultiheadAttention<MatType, RegularizerType>::
|
|
39
|
+
MultiheadAttention(
|
|
40
40
|
const size_t tgtSeqLen,
|
|
41
41
|
const size_t numHeads,
|
|
42
|
-
const
|
|
42
|
+
const CubeType& attnmask,
|
|
43
43
|
const MatType& keypaddingmask,
|
|
44
44
|
const bool selfAttention) :
|
|
45
45
|
tgtSeqLen(tgtSeqLen),
|
|
@@ -53,7 +53,7 @@ MultiheadAttentionType(
|
|
|
53
53
|
}
|
|
54
54
|
|
|
55
55
|
template <typename MatType, typename RegularizerType>
|
|
56
|
-
void
|
|
56
|
+
void MultiheadAttention<MatType, RegularizerType>::SetWeights(
|
|
57
57
|
const MatType& weightsIn)
|
|
58
58
|
{
|
|
59
59
|
MakeAlias(weights, weightsIn, (4 * embedDim + 4) * embedDim, 1);
|
|
@@ -70,7 +70,7 @@ void MultiheadAttentionType<MatType, RegularizerType>::SetWeights(
|
|
|
70
70
|
}
|
|
71
71
|
|
|
72
72
|
template <typename MatType, typename RegularizerType>
|
|
73
|
-
void
|
|
73
|
+
void MultiheadAttention<MatType, RegularizerType>::
|
|
74
74
|
Forward(const MatType& input, MatType& output)
|
|
75
75
|
{
|
|
76
76
|
if (input.n_rows != embedDim *
|
|
@@ -122,7 +122,7 @@ Forward(const MatType& input, MatType& output)
|
|
|
122
122
|
|
|
123
123
|
// The scaling factor sqrt(headDim) is used to prevent exploding values
|
|
124
124
|
// after dot product i.e. when qProj is multiplied with kProj.
|
|
125
|
-
qProj /= std::sqrt(headDim);
|
|
125
|
+
qProj /= ElemType(std::sqrt(headDim));
|
|
126
126
|
|
|
127
127
|
// Split the qProj, kProj and vProj into n heads. That's what Multihead
|
|
128
128
|
// Attention is.
|
|
@@ -131,40 +131,16 @@ Forward(const MatType& input, MatType& output)
|
|
|
131
131
|
vProj.reshape(srcSeqLen, headDim, numHeads * batchSize);
|
|
132
132
|
|
|
133
133
|
// Calculate the scores i.e. perform the matrix multiplication operation
|
|
134
|
-
// on qProj and kProj. Here score =
|
|
135
|
-
scores = MultiplyCube2Cube(
|
|
136
|
-
|
|
137
|
-
// Apply the attention mask if provided. The attention mask is used to black-
|
|
138
|
-
// out future sequences and generally used in Encoder-Decoder attention.
|
|
139
|
-
// The attention mask has elements -inf or 0.
|
|
140
|
-
// The shape of the attention mask : (tgtSeqLen, srcSeqLen).
|
|
141
|
-
if (!attnMask.is_empty())
|
|
142
|
-
{
|
|
143
|
-
if (attnMask.n_rows != tgtSeqLen || attnMask.n_cols != srcSeqLen)
|
|
144
|
-
Log::Fatal << "The size of the 'attn_mask' is not correct.\n";
|
|
145
|
-
scores.each_slice() += attnMask;
|
|
146
|
-
}
|
|
147
|
-
|
|
148
|
-
// Apply the key padding mask when provided. It blacks-out any particular
|
|
149
|
-
// word in the sequence.
|
|
150
|
-
// The key padding mask has elements -inf or 0
|
|
151
|
-
// The shape of keyPaddingMask : (1, srcSeqLen).
|
|
152
|
-
if (!keyPaddingMask.is_empty())
|
|
153
|
-
{
|
|
154
|
-
if (keyPaddingMask.n_rows != 1 || keyPaddingMask.n_cols != srcSeqLen)
|
|
155
|
-
Log::Fatal << "The size of the 'keyPaddingMask' is not correct.\n";
|
|
156
|
-
scores.each_slice() += repmat(keyPaddingMask, tgtSeqLen, 1);
|
|
157
|
-
}
|
|
134
|
+
// on qProj and kProj. Here score = kProj . qProj'
|
|
135
|
+
scores = MultiplyCube2Cube(kProj, qProj, false, true);
|
|
158
136
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
softmax.Forward(scores.slice(i), scores.slice(i));
|
|
162
|
-
}
|
|
137
|
+
// Apply softmax to non-masked elements.
|
|
138
|
+
MaskedForwardSoftmax(scores, numHeads, batchSize, attnMask, keyPaddingMask);
|
|
163
139
|
|
|
164
140
|
// Calculate the attention output i.e. matrix multiplication of softmax
|
|
165
141
|
// output and vProj.
|
|
166
142
|
// The shape of attnOutput : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
167
|
-
attnOut = MultiplyCube2Cube(scores, vProj,
|
|
143
|
+
attnOut = MultiplyCube2Cube(scores, vProj, true, false);
|
|
168
144
|
|
|
169
145
|
// Now we will concatenate output of all the heads i.e. we will reshape
|
|
170
146
|
// attnOut to (tgtSeqLen, embedDim, batchSize).
|
|
@@ -173,13 +149,13 @@ Forward(const MatType& input, MatType& output)
|
|
|
173
149
|
// The final output is the linear projection of attention output.
|
|
174
150
|
for (size_t i = 0; i < batchSize; ++i)
|
|
175
151
|
{
|
|
176
|
-
output.col(i) = vectorise(trans(attnOut.slice(i) * outWt
|
|
152
|
+
output.col(i) = vectorise(trans(attnOut.slice(i) * outWt.t()
|
|
177
153
|
+ repmat(outBias, tgtSeqLen, 1)));
|
|
178
154
|
}
|
|
179
155
|
}
|
|
180
156
|
|
|
181
157
|
template <typename MatType, typename RegularizerType>
|
|
182
|
-
void
|
|
158
|
+
void MultiheadAttention<MatType, RegularizerType>::
|
|
183
159
|
Backward(const MatType& /* input */,
|
|
184
160
|
const MatType& /* output */,
|
|
185
161
|
const MatType& gy,
|
|
@@ -207,7 +183,7 @@ Backward(const MatType& /* input */,
|
|
|
207
183
|
// The shape of gyTemp : (embedDim, tgtSeqLen, batchSize).
|
|
208
184
|
// The shape of outWt : (embedDim, embedDim).
|
|
209
185
|
// The shape of the result : (tgtSeqLen, embedDim, batchSize).
|
|
210
|
-
gyTemp = MultiplyCube2Mat(gyTemp, outWt, true,
|
|
186
|
+
gyTemp = MultiplyCube2Mat(gyTemp, outWt, true, false);
|
|
211
187
|
|
|
212
188
|
// Now since the shape of gyTemp is (tgtSeqLen, embedDim, batchSize). We will
|
|
213
189
|
// split it into n heads.
|
|
@@ -216,9 +192,9 @@ Backward(const MatType& /* input */,
|
|
|
216
192
|
|
|
217
193
|
// Obtain backpropagted error of value.
|
|
218
194
|
// Shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
219
|
-
// Shape of scores : (
|
|
195
|
+
// Shape of scores : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
220
196
|
// The shape of tmp : (srcSeqLen, headDim, numHeads * batchSize).
|
|
221
|
-
CubeType tmp = MultiplyCube2Cube(scores, gyTemp,
|
|
197
|
+
CubeType tmp = MultiplyCube2Cube(scores, gyTemp, false, false);
|
|
222
198
|
|
|
223
199
|
// Concatenate results of all the attention heads.
|
|
224
200
|
tmp.reshape(srcSeqLen, embedDim, batchSize);
|
|
@@ -239,8 +215,8 @@ Backward(const MatType& /* input */,
|
|
|
239
215
|
|
|
240
216
|
// The shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
241
217
|
// The shape of vProj : (srcSeqLen, headDim, numHeads * batchSize).
|
|
242
|
-
// So the new shape of gyTemp : (
|
|
243
|
-
gyTemp = MultiplyCube2Cube(
|
|
218
|
+
// So the new shape of gyTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
219
|
+
gyTemp = MultiplyCube2Cube(vProj, gyTemp, false, true);
|
|
244
220
|
|
|
245
221
|
for (size_t i = 0; i < numHeads * batchSize; ++i)
|
|
246
222
|
{
|
|
@@ -251,9 +227,9 @@ Backward(const MatType& /* input */,
|
|
|
251
227
|
|
|
252
228
|
// Obtain backpropagated error of key.
|
|
253
229
|
// The shape of qProj : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
254
|
-
// The shape of gyTemp : (
|
|
230
|
+
// The shape of gyTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
255
231
|
// The new shape of tmp : (srcSeqLen, headDim, numHeads * batchSize).
|
|
256
|
-
tmp = MultiplyCube2Cube(gyTemp, qProj,
|
|
232
|
+
tmp = MultiplyCube2Cube(gyTemp, qProj, false, false);
|
|
257
233
|
|
|
258
234
|
// Concatenate results of all the attention heads.
|
|
259
235
|
tmp.reshape(srcSeqLen, embedDim, batchSize);
|
|
@@ -276,9 +252,10 @@ Backward(const MatType& /* input */,
|
|
|
276
252
|
|
|
277
253
|
// Obtain backpropagated error of the query.
|
|
278
254
|
// The shape of kProj : (srcSeqLen, headDim, numHeads * batchSize).
|
|
279
|
-
// The shape of gyTemp : (
|
|
255
|
+
// The shape of gyTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
280
256
|
// The new shape of tmp : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
281
|
-
tmp = MultiplyCube2Cube(gyTemp, kProj) /
|
|
257
|
+
tmp = MultiplyCube2Cube(gyTemp, kProj, true, false) /
|
|
258
|
+
ElemType(std::sqrt(headDim));
|
|
282
259
|
|
|
283
260
|
// Concatenate results of all the attention heads.
|
|
284
261
|
tmp.reshape(tgtSeqLen, embedDim, batchSize);
|
|
@@ -300,7 +277,7 @@ Backward(const MatType& /* input */,
|
|
|
300
277
|
}
|
|
301
278
|
|
|
302
279
|
template <typename MatType, typename RegularizerType>
|
|
303
|
-
void
|
|
280
|
+
void MultiheadAttention<MatType, RegularizerType>::
|
|
304
281
|
Gradient(const MatType& input,
|
|
305
282
|
const MatType& error,
|
|
306
283
|
MatType& gradient)
|
|
@@ -327,7 +304,7 @@ Gradient(const MatType& input,
|
|
|
327
304
|
const size_t wtSize = embedDim * embedDim;
|
|
328
305
|
|
|
329
306
|
// The shape of gradient : (4 * embedDim * embedDim + 4 * embedDim, 1).
|
|
330
|
-
gradient.set_size(
|
|
307
|
+
gradient.set_size(size(weights));
|
|
331
308
|
|
|
332
309
|
const CubeType q, k, v;
|
|
333
310
|
MakeAlias(const_cast<CubeType&>(q), input, embedDim, tgtSeqLen, batchSize,
|
|
@@ -356,22 +333,23 @@ Gradient(const MatType& input,
|
|
|
356
333
|
|
|
357
334
|
// Gradient wrt. outWt, i.e. dL/d(outWt). We will take sum of gyTemp along
|
|
358
335
|
// the slices and vectorise the output.
|
|
359
|
-
|
|
336
|
+
CubeType tmpCube = sum(gyTemp, 2);
|
|
337
|
+
gradient.rows(3 * wtSize, 4 * wtSize - 1) = vectorise(tmpCube.slice(0).t());
|
|
360
338
|
|
|
361
339
|
// Partial derivative wrt. attnOut.
|
|
362
340
|
// The shape of outWt : (embedDim, embedDim).
|
|
363
341
|
// The shape of errorTemp : (embedDim, tgtSeqLen, batchSize).
|
|
364
342
|
// The shape of gyTemp : (tgtSeqLen, embedDim, batchSize).
|
|
365
|
-
gyTemp = MultiplyCube2Mat(errorTemp, outWt, true,
|
|
343
|
+
gyTemp = MultiplyCube2Mat(errorTemp, outWt, true, false);
|
|
366
344
|
|
|
367
345
|
// Now we will split it into n heads i.e. reshape it into a cube of shape
|
|
368
346
|
// (tgtSeqLen, headDim, numHeads * batchSize).
|
|
369
347
|
gyTemp.reshape(tgtSeqLen, headDim, numHeads * batchSize);
|
|
370
348
|
|
|
371
349
|
// Shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
372
|
-
// Shape of scores : (
|
|
350
|
+
// Shape of scores : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
373
351
|
// The new shape of errorTemp : (srcSeqLen, headDim, numHeads * batchSize).
|
|
374
|
-
errorTemp = MultiplyCube2Cube(scores, gyTemp,
|
|
352
|
+
errorTemp = MultiplyCube2Cube(scores, gyTemp, false, false);
|
|
375
353
|
|
|
376
354
|
// Now we will concatenate the propagated errors from all heads i.e. we
|
|
377
355
|
// will reshape errorTemp to (srcSeqLen, embedDim, batchSize).
|
|
@@ -393,22 +371,23 @@ Gradient(const MatType& input,
|
|
|
393
371
|
|
|
394
372
|
// Now, the shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
395
373
|
// The shape of vProj : (srcSeqLen, headDim, numHeads * batchSize).
|
|
396
|
-
// The new shape of errorTemp : (
|
|
397
|
-
errorTemp = MultiplyCube2Cube(
|
|
374
|
+
// The new shape of errorTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
375
|
+
errorTemp = MultiplyCube2Cube(vProj, gyTemp, false, true);
|
|
398
376
|
|
|
399
377
|
for (size_t i = 0; i < numHeads * batchSize; ++i)
|
|
400
378
|
{
|
|
401
|
-
// The shape of scores : (
|
|
402
|
-
// The shape of errorTemp : (
|
|
379
|
+
// The shape of scores : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
380
|
+
// The shape of errorTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
403
381
|
// The new shape of errorTemp remain same.
|
|
404
382
|
softmax.Backward({} /* unused */, scores.slice(i), errorTemp.slice(i),
|
|
405
383
|
errorTemp.slice(i));
|
|
406
384
|
}
|
|
407
385
|
|
|
386
|
+
|
|
408
387
|
// The shape of qProj : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
409
|
-
// The shape of errorTemp : (
|
|
388
|
+
// The shape of errorTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
410
389
|
// The shape of gyTemp : (srcSeqLen, headDim, numHeads * batchSize).
|
|
411
|
-
gyTemp = MultiplyCube2Cube(errorTemp, qProj,
|
|
390
|
+
gyTemp = MultiplyCube2Cube(errorTemp, qProj, false, false);
|
|
412
391
|
|
|
413
392
|
// We will now conctenate the propagated errors from all heads.
|
|
414
393
|
// The new shape of gyTemp : (srcSeqLen, embedDim, batchSize).
|
|
@@ -429,13 +408,13 @@ Gradient(const MatType& input,
|
|
|
429
408
|
gradient.rows(wtSize, 2 * wtSize - 1) = vectorise(sum(gyTemp, 2));
|
|
430
409
|
|
|
431
410
|
// The shape of kProj : (srcSeqLen, headDim, numHeads * batchSize).
|
|
432
|
-
// The shape of errorTemp : (
|
|
411
|
+
// The shape of errorTemp : (srcSeqLen, tgtSeqLen, numHeads * batchSize).
|
|
433
412
|
// The shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
|
|
434
|
-
gyTemp = MultiplyCube2Cube(errorTemp, kProj,
|
|
413
|
+
gyTemp = MultiplyCube2Cube(errorTemp, kProj, true, false);
|
|
435
414
|
|
|
436
415
|
// Now, we will concatenate propagated error of all heads.
|
|
437
416
|
gyTemp.reshape(tgtSeqLen, embedDim, batchSize);
|
|
438
|
-
gyTemp /= std::sqrt(headDim);
|
|
417
|
+
gyTemp /= ElemType(std::sqrt(headDim));
|
|
439
418
|
|
|
440
419
|
// Gradient wrt. qBias, i.e. dL/d(qBias). We will take summation over all the
|
|
441
420
|
// batches of gyTemp and over all the sequences.
|
|
@@ -457,7 +436,7 @@ Gradient(const MatType& input,
|
|
|
457
436
|
|
|
458
437
|
template <typename MatType, typename RegularizerType>
|
|
459
438
|
template <typename Archive>
|
|
460
|
-
void
|
|
439
|
+
void MultiheadAttention<MatType, RegularizerType>::
|
|
461
440
|
serialize(Archive& ar, const uint32_t /* version */)
|
|
462
441
|
{
|
|
463
442
|
ar(cereal::base_class<Layer<MatType>>(this));
|
|
@@ -492,6 +471,124 @@ serialize(Archive& ar, const uint32_t /* version */)
|
|
|
492
471
|
}
|
|
493
472
|
}
|
|
494
473
|
|
|
474
|
+
template<typename MatType, typename RegularizerType>
|
|
475
|
+
void MultiheadAttention<MatType, RegularizerType>::MaskedForwardSoftmax(
|
|
476
|
+
CubeType& scores,
|
|
477
|
+
const size_t numHeads,
|
|
478
|
+
const size_t batchSize,
|
|
479
|
+
const CubeType& attnMask,
|
|
480
|
+
const MatType& keyPaddingMask)
|
|
481
|
+
{
|
|
482
|
+
if (attnMask.empty() && keyPaddingMask.empty())
|
|
483
|
+
{
|
|
484
|
+
// No masking required: we can use the simple implementation.
|
|
485
|
+
for (size_t i = 0; i < scores.n_slices; ++i)
|
|
486
|
+
{
|
|
487
|
+
scores.slice(i) = exp(scores.slice(i).each_row() -
|
|
488
|
+
max(scores.slice(i), 0));
|
|
489
|
+
scores.slice(i).each_row() /= sum(scores.slice(i), 0);
|
|
490
|
+
}
|
|
491
|
+
}
|
|
492
|
+
else if (attnMask.empty() && !keyPaddingMask.empty())
|
|
493
|
+
{
|
|
494
|
+
// There is one key padding mask column for each element in the batch.
|
|
495
|
+
for (size_t i = 0; i < batchSize; ++i)
|
|
496
|
+
{
|
|
497
|
+
for (size_t h = 0; h < numHeads; ++h)
|
|
498
|
+
{
|
|
499
|
+
const size_t s = i * numHeads + h;
|
|
500
|
+
|
|
501
|
+
for (size_t c = 0; c < scores.n_cols; ++c)
|
|
502
|
+
{
|
|
503
|
+
ElemType maxVal = std::numeric_limits<ElemType>::lowest();
|
|
504
|
+
for (size_t r = 0; r < scores.n_rows; ++r)
|
|
505
|
+
if (keyPaddingMask(r, i) >= ElemType(0) && scores(r, c, s) > maxVal)
|
|
506
|
+
maxVal = scores(r, c, s);
|
|
507
|
+
|
|
508
|
+
for (size_t r = 0; r < scores.n_rows; ++r)
|
|
509
|
+
{
|
|
510
|
+
if (keyPaddingMask(r, i) < ElemType(0))
|
|
511
|
+
scores(r, c, s) = ElemType(0);
|
|
512
|
+
else
|
|
513
|
+
scores(r, c, s) = std::exp(scores(r, c, s) - maxVal);
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
if (maxVal != std::numeric_limits<ElemType>::lowest())
|
|
517
|
+
scores.slice(s).col(c) /= sum(scores.slice(s).col(c));
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
}
|
|
522
|
+
else if (!attnMask.empty() && keyPaddingMask.empty())
|
|
523
|
+
{
|
|
524
|
+
// There is one attention mask for each element in the batch.
|
|
525
|
+
for (size_t i = 0; i < batchSize; ++i)
|
|
526
|
+
{
|
|
527
|
+
for (size_t h = 0; h < numHeads; ++h)
|
|
528
|
+
{
|
|
529
|
+
const size_t s = i * numHeads + h;
|
|
530
|
+
|
|
531
|
+
for (size_t c = 0; c < scores.n_cols; ++c)
|
|
532
|
+
{
|
|
533
|
+
ElemType maxVal = std::numeric_limits<ElemType>::lowest();
|
|
534
|
+
for (size_t r = 0; r < scores.n_rows; ++r)
|
|
535
|
+
if (attnMask(r, c, i) >= ElemType(0) && scores(r, c, s) > maxVal)
|
|
536
|
+
maxVal = scores(r, c, s);
|
|
537
|
+
|
|
538
|
+
for (size_t r = 0; r < scores.n_rows; ++r)
|
|
539
|
+
{
|
|
540
|
+
if (attnMask(r, c, i) < ElemType(0))
|
|
541
|
+
scores(r, c, s) = ElemType(0);
|
|
542
|
+
else
|
|
543
|
+
scores(r, c, s) = std::exp(scores(r, c, s) - maxVal);
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
if (maxVal != std::numeric_limits<ElemType>::lowest())
|
|
547
|
+
scores.slice(s).col(c) /= sum(scores.slice(s).col(c));
|
|
548
|
+
}
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
else // !attnMask.empty() && !keyPaddingMask.empty()
|
|
553
|
+
{
|
|
554
|
+
// There is one key padding mask column for each element in the batch, and
|
|
555
|
+
// one attention mask for each element in the batch.
|
|
556
|
+
for (size_t i = 0; i < batchSize; ++i)
|
|
557
|
+
{
|
|
558
|
+
for (size_t h = 0; h < numHeads; ++h)
|
|
559
|
+
{
|
|
560
|
+
const size_t s = i * numHeads + h;
|
|
561
|
+
|
|
562
|
+
for (size_t c = 0; c < scores.n_cols; ++c)
|
|
563
|
+
{
|
|
564
|
+
ElemType maxVal = std::numeric_limits<ElemType>::lowest();
|
|
565
|
+
for (size_t r = 0; r < scores.n_rows; ++r)
|
|
566
|
+
{
|
|
567
|
+
if (attnMask(r, c, i) >= ElemType(0) &&
|
|
568
|
+
keyPaddingMask(r, i) >= ElemType(0) &&
|
|
569
|
+
scores(r, c, s) > maxVal)
|
|
570
|
+
{
|
|
571
|
+
maxVal = scores(r, c, s);
|
|
572
|
+
}
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
for (size_t r = 0; r < scores.n_rows; ++r)
|
|
576
|
+
{
|
|
577
|
+
if (attnMask(r, c, i) < ElemType(0) ||
|
|
578
|
+
keyPaddingMask(r, i) < ElemType(0))
|
|
579
|
+
scores(r, c, s) = ElemType(0);
|
|
580
|
+
else
|
|
581
|
+
scores(r, c, s) = std::exp(scores(r, c, s) - maxVal);
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
if (maxVal != std::numeric_limits<ElemType>::lowest())
|
|
585
|
+
scores.slice(s).col(c) /= sum(scores.slice(s).col(c));
|
|
586
|
+
}
|
|
587
|
+
}
|
|
588
|
+
}
|
|
589
|
+
}
|
|
590
|
+
}
|
|
591
|
+
|
|
495
592
|
} // namespace mlpack
|
|
496
593
|
|
|
497
594
|
#endif
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
//
|
|
2
1
|
/**
|
|
3
2
|
* @filer methods/ann/layer/nearest_interpolation.hpp
|
|
4
3
|
* @author Andrew Furey
|
|
@@ -29,14 +28,18 @@ namespace mlpack {
|
|
|
29
28
|
* arma::sp_mat or arma::cube).
|
|
30
29
|
*/
|
|
31
30
|
template<typename MatType = arma::mat>
|
|
32
|
-
class
|
|
31
|
+
class NearestInterpolation : public Layer<MatType>
|
|
33
32
|
{
|
|
34
33
|
public:
|
|
34
|
+
// Convenience typedefs.
|
|
35
|
+
using ElemType = typename MatType::elem_type;
|
|
35
36
|
using CubeType = typename GetCubeType<MatType>::type;
|
|
36
|
-
//! Create the NearestInterpolation object.
|
|
37
|
-
NearestInterpolationType();
|
|
38
37
|
|
|
39
|
-
|
|
38
|
+
// Create the NearestInterpolation object.
|
|
39
|
+
NearestInterpolation();
|
|
40
|
+
|
|
41
|
+
/**
|
|
42
|
+
* Create NearestInterpolation Object with the same scaleFactor along
|
|
40
43
|
* each dimension.
|
|
41
44
|
* NOTE: scaleFactors must be a two element vector, the first element
|
|
42
45
|
* for scaling the first dimension and the second element for scaling
|
|
@@ -44,25 +47,25 @@ class NearestInterpolationType : public Layer<MatType>
|
|
|
44
47
|
*
|
|
45
48
|
* If the input dimensions are n x m x ..., then the output dimensions
|
|
46
49
|
* will be (n x scaleFactors[0]) x (m x scaleFactors[1]) x ...
|
|
47
|
-
*
|
|
50
|
+
*
|
|
48
51
|
* @param scaleFactor Scale factors to scale each dimension by.
|
|
49
52
|
*/
|
|
50
|
-
|
|
53
|
+
NearestInterpolation(const std::vector<double> scaleFactors);
|
|
51
54
|
|
|
52
|
-
|
|
53
|
-
return new
|
|
55
|
+
NearestInterpolation* Clone() const {
|
|
56
|
+
return new NearestInterpolation(*this);
|
|
54
57
|
}
|
|
55
58
|
|
|
56
|
-
virtual ~
|
|
59
|
+
virtual ~NearestInterpolation() { }
|
|
57
60
|
|
|
58
|
-
//! Copy the given
|
|
59
|
-
|
|
60
|
-
//! Take ownership of the given
|
|
61
|
-
|
|
62
|
-
//! Copy the given
|
|
63
|
-
|
|
64
|
-
//! Take ownership of the given
|
|
65
|
-
|
|
61
|
+
//! Copy the given NearestInterpolation layer.
|
|
62
|
+
NearestInterpolation(const NearestInterpolation& other);
|
|
63
|
+
//! Take ownership of the given NearestInterpolation layer.
|
|
64
|
+
NearestInterpolation(NearestInterpolation&& other);
|
|
65
|
+
//! Copy the given NearestInterpolation layer.
|
|
66
|
+
NearestInterpolation& operator=(const NearestInterpolation& other);
|
|
67
|
+
//! Take ownership of the given NearestInterpolation layer.
|
|
68
|
+
NearestInterpolation& operator=(NearestInterpolation&& other);
|
|
66
69
|
|
|
67
70
|
/**
|
|
68
71
|
* Forward pass through the layer. The layer interpolates
|
|
@@ -81,12 +84,14 @@ class NearestInterpolationType : public Layer<MatType>
|
|
|
81
84
|
* the input size.
|
|
82
85
|
*
|
|
83
86
|
* @param * (input) The input matrix.
|
|
84
|
-
* @param
|
|
85
|
-
* @param
|
|
87
|
+
* @param * (output) The output matrix.
|
|
88
|
+
* @param gy The computed backward gradient.
|
|
89
|
+
* @param g The resulting down-sampled output.
|
|
86
90
|
*/
|
|
87
|
-
void Backward(const MatType& /*input*/,
|
|
88
|
-
const MatType&
|
|
89
|
-
MatType&
|
|
91
|
+
void Backward(const MatType& /* input */,
|
|
92
|
+
const MatType& /* output */,
|
|
93
|
+
const MatType& gy,
|
|
94
|
+
MatType& g);
|
|
90
95
|
|
|
91
96
|
//! Compute the output dimensions of the layer, based on the internal values
|
|
92
97
|
//! of `InputDimensions()`.
|
|
@@ -103,8 +108,6 @@ class NearestInterpolationType : public Layer<MatType>
|
|
|
103
108
|
std::vector<double> scaleFactors;
|
|
104
109
|
}; // class NearestInterpolation
|
|
105
110
|
|
|
106
|
-
using NearestInterpolation = NearestInterpolationType<arma::mat>;
|
|
107
|
-
|
|
108
111
|
} // namespace mlpack
|
|
109
112
|
|
|
110
113
|
// Include implementation.
|
|
@@ -19,16 +19,16 @@
|
|
|
19
19
|
namespace mlpack {
|
|
20
20
|
|
|
21
21
|
template<typename MatType>
|
|
22
|
-
|
|
23
|
-
|
|
22
|
+
NearestInterpolation<MatType>::NearestInterpolation():
|
|
23
|
+
Layer<MatType>()
|
|
24
24
|
{
|
|
25
25
|
// Nothing to do here.
|
|
26
26
|
}
|
|
27
27
|
|
|
28
28
|
template<typename MatType>
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
NearestInterpolation<MatType>::
|
|
30
|
+
NearestInterpolation(const std::vector<double> scaleFactors) :
|
|
31
|
+
Layer<MatType>()
|
|
32
32
|
{
|
|
33
33
|
if (scaleFactors.size() != 2) {
|
|
34
34
|
throw std::runtime_error("Scale factors must have 2 dimensions");
|
|
@@ -37,27 +37,27 @@ NearestInterpolationType(const std::vector<double> scaleFactors) :
|
|
|
37
37
|
}
|
|
38
38
|
|
|
39
39
|
template<typename MatType>
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
40
|
+
NearestInterpolation<MatType>::
|
|
41
|
+
NearestInterpolation(const NearestInterpolation& other) :
|
|
42
|
+
Layer<MatType>(),
|
|
43
|
+
scaleFactors(other.scaleFactors)
|
|
44
44
|
{
|
|
45
45
|
// Nothing to do here.
|
|
46
46
|
}
|
|
47
47
|
|
|
48
48
|
template<typename MatType>
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
49
|
+
NearestInterpolation<MatType>::
|
|
50
|
+
NearestInterpolation(NearestInterpolation&& other) :
|
|
51
|
+
Layer<MatType>(std::move(other)),
|
|
52
|
+
scaleFactors(std::move(other.scaleFactors))
|
|
53
53
|
{
|
|
54
54
|
// Nothing to do here.
|
|
55
55
|
}
|
|
56
56
|
|
|
57
57
|
template<typename MatType>
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
operator=(const
|
|
58
|
+
NearestInterpolation<MatType>&
|
|
59
|
+
NearestInterpolation<MatType>::
|
|
60
|
+
operator=(const NearestInterpolation& other)
|
|
61
61
|
{
|
|
62
62
|
if (&other != this)
|
|
63
63
|
{
|
|
@@ -68,9 +68,9 @@ operator=(const NearestInterpolationType& other)
|
|
|
68
68
|
}
|
|
69
69
|
|
|
70
70
|
template<typename MatType>
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
operator=(
|
|
71
|
+
NearestInterpolation<MatType>&
|
|
72
|
+
NearestInterpolation<MatType>::
|
|
73
|
+
operator=(NearestInterpolation&& other)
|
|
74
74
|
{
|
|
75
75
|
if (&other != this)
|
|
76
76
|
{
|
|
@@ -81,8 +81,8 @@ operator=(NearestInterpolationType&& other)
|
|
|
81
81
|
}
|
|
82
82
|
|
|
83
83
|
template<typename MatType>
|
|
84
|
-
void
|
|
85
|
-
|
|
84
|
+
void NearestInterpolation<MatType>::Forward(
|
|
85
|
+
const MatType& input, MatType& output)
|
|
86
86
|
{
|
|
87
87
|
const size_t channels = this->inputDimensions[2];
|
|
88
88
|
|
|
@@ -100,7 +100,7 @@ void NearestInterpolationType<MatType>::Forward(
|
|
|
100
100
|
|
|
101
101
|
for (size_t i = 0; i < outRowSize; ++i)
|
|
102
102
|
{
|
|
103
|
-
size_t rOrigin = std::floor(i
|
|
103
|
+
size_t rOrigin = std::floor(i / scaleFactors[0]);
|
|
104
104
|
for (size_t j = 0; j < outColSize; ++j)
|
|
105
105
|
{
|
|
106
106
|
size_t cOrigin = std::floor(j / scaleFactors[1]);
|
|
@@ -113,10 +113,11 @@ void NearestInterpolationType<MatType>::Forward(
|
|
|
113
113
|
}
|
|
114
114
|
|
|
115
115
|
template<typename MatType>
|
|
116
|
-
void
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
116
|
+
void NearestInterpolation<MatType>::Backward(
|
|
117
|
+
const MatType& /* input */,
|
|
118
|
+
const MatType& /* output */,
|
|
119
|
+
const MatType& gy,
|
|
120
|
+
MatType& g)
|
|
120
121
|
{
|
|
121
122
|
const size_t channels = this->inputDimensions[2];
|
|
122
123
|
|
|
@@ -126,12 +127,11 @@ void NearestInterpolationType<MatType>::Backward(
|
|
|
126
127
|
const size_t inRowSize = this->inputDimensions[0];
|
|
127
128
|
const size_t inColSize = this->inputDimensions[1];
|
|
128
129
|
|
|
129
|
-
CubeType
|
|
130
|
-
CubeType
|
|
130
|
+
CubeType gTemp;
|
|
131
|
+
CubeType gyTemp;
|
|
131
132
|
|
|
132
|
-
MakeAlias(
|
|
133
|
-
MakeAlias(
|
|
134
|
-
false);
|
|
133
|
+
MakeAlias(gTemp, g, inRowSize, inColSize, channels, 0);
|
|
134
|
+
MakeAlias(gyTemp, gy, outRowSize, outColSize, channels, 0);
|
|
135
135
|
|
|
136
136
|
for (size_t i = 0; i < outRowSize; ++i)
|
|
137
137
|
{
|
|
@@ -140,15 +140,13 @@ void NearestInterpolationType<MatType>::Backward(
|
|
|
140
140
|
{
|
|
141
141
|
size_t cOrigin = std::floor(j / scaleFactors[1]);
|
|
142
142
|
for (size_t k = 0; k < channels; ++k)
|
|
143
|
-
|
|
144
|
-
outputAsCube(rOrigin, cOrigin, k) += gradientAsCube(i, j, k);
|
|
145
|
-
}
|
|
143
|
+
gTemp(rOrigin, cOrigin, k) += gyTemp(i, j, k);
|
|
146
144
|
}
|
|
147
145
|
}
|
|
148
146
|
}
|
|
149
147
|
|
|
150
148
|
template<typename MatType>
|
|
151
|
-
void
|
|
149
|
+
void NearestInterpolation<MatType>::ComputeOutputDimensions()
|
|
152
150
|
{
|
|
153
151
|
if (this->inputDimensions.size() < scaleFactors.size())
|
|
154
152
|
{
|
|
@@ -168,9 +166,10 @@ void NearestInterpolationType<MatType>::ComputeOutputDimensions()
|
|
|
168
166
|
|
|
169
167
|
template<typename MatType>
|
|
170
168
|
template<typename Archive>
|
|
171
|
-
void
|
|
172
|
-
|
|
169
|
+
void NearestInterpolation<MatType>::serialize(
|
|
170
|
+
Archive& ar, const uint32_t /* version */)
|
|
173
171
|
{
|
|
172
|
+
ar(cereal::base_class<Layer<MatType>>(this));
|
|
174
173
|
ar(CEREAL_NVP(scaleFactors));
|
|
175
174
|
}
|
|
176
175
|
|