mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -20,14 +20,14 @@
|
|
|
20
20
|
namespace mlpack {
|
|
21
21
|
|
|
22
22
|
template<typename MatType>
|
|
23
|
-
|
|
23
|
+
MeanPooling<MatType>::MeanPooling() :
|
|
24
24
|
Layer<MatType>()
|
|
25
25
|
{
|
|
26
26
|
// Nothing to do here.
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
template<typename MatType>
|
|
30
|
-
|
|
30
|
+
MeanPooling<MatType>::MeanPooling(
|
|
31
31
|
const size_t kernelWidth,
|
|
32
32
|
const size_t kernelHeight,
|
|
33
33
|
const size_t strideWidth,
|
|
@@ -45,8 +45,8 @@ MeanPoolingType<MatType>::MeanPoolingType(
|
|
|
45
45
|
}
|
|
46
46
|
|
|
47
47
|
template<typename MatType>
|
|
48
|
-
|
|
49
|
-
const
|
|
48
|
+
MeanPooling<MatType>::MeanPooling(
|
|
49
|
+
const MeanPooling& other) :
|
|
50
50
|
Layer<MatType>(other),
|
|
51
51
|
kernelWidth(other.kernelWidth),
|
|
52
52
|
kernelHeight(other.kernelHeight),
|
|
@@ -59,8 +59,8 @@ MeanPoolingType<MatType>::MeanPoolingType(
|
|
|
59
59
|
}
|
|
60
60
|
|
|
61
61
|
template<typename MatType>
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
MeanPooling<MatType>::MeanPooling(
|
|
63
|
+
MeanPooling&& other) :
|
|
64
64
|
Layer<MatType>(std::move(other)),
|
|
65
65
|
kernelWidth(std::move(other.kernelWidth)),
|
|
66
66
|
kernelHeight(std::move(other.kernelHeight)),
|
|
@@ -73,8 +73,8 @@ MeanPoolingType<MatType>::MeanPoolingType(
|
|
|
73
73
|
}
|
|
74
74
|
|
|
75
75
|
template<typename MatType>
|
|
76
|
-
|
|
77
|
-
|
|
76
|
+
MeanPooling<MatType>&
|
|
77
|
+
MeanPooling<MatType>::operator=(const MeanPooling& other)
|
|
78
78
|
{
|
|
79
79
|
if (&other != this)
|
|
80
80
|
{
|
|
@@ -91,8 +91,8 @@ MeanPoolingType<MatType>::operator=(const MeanPoolingType& other)
|
|
|
91
91
|
}
|
|
92
92
|
|
|
93
93
|
template<typename MatType>
|
|
94
|
-
|
|
95
|
-
|
|
94
|
+
MeanPooling<MatType>&
|
|
95
|
+
MeanPooling<MatType>::operator=(MeanPooling&& other)
|
|
96
96
|
{
|
|
97
97
|
if (&other != this)
|
|
98
98
|
{
|
|
@@ -109,7 +109,7 @@ MeanPoolingType<MatType>::operator=(MeanPoolingType&& other)
|
|
|
109
109
|
}
|
|
110
110
|
|
|
111
111
|
template<typename MatType>
|
|
112
|
-
void
|
|
112
|
+
void MeanPooling<MatType>::Forward(
|
|
113
113
|
const MatType& input, MatType& output)
|
|
114
114
|
{
|
|
115
115
|
// Create Alias of input as 2D image as input is 1D vector.
|
|
@@ -127,7 +127,7 @@ void MeanPoolingType<MatType>::Forward(
|
|
|
127
127
|
}
|
|
128
128
|
|
|
129
129
|
template<typename MatType>
|
|
130
|
-
void
|
|
130
|
+
void MeanPooling<MatType>::Backward(
|
|
131
131
|
const MatType& input,
|
|
132
132
|
const MatType& /* output */,
|
|
133
133
|
const MatType& gy,
|
|
@@ -154,7 +154,7 @@ void MeanPoolingType<MatType>::Backward(
|
|
|
154
154
|
}
|
|
155
155
|
|
|
156
156
|
template<typename MatType>
|
|
157
|
-
void
|
|
157
|
+
void MeanPooling<MatType>::ComputeOutputDimensions()
|
|
158
158
|
{
|
|
159
159
|
this->outputDimensions = this->inputDimensions;
|
|
160
160
|
|
|
@@ -184,7 +184,7 @@ void MeanPoolingType<MatType>::ComputeOutputDimensions()
|
|
|
184
184
|
|
|
185
185
|
template<typename MatType>
|
|
186
186
|
template<typename Archive>
|
|
187
|
-
void
|
|
187
|
+
void MeanPooling<MatType>::serialize(
|
|
188
188
|
Archive& ar,
|
|
189
189
|
const uint32_t /* version */)
|
|
190
190
|
{
|
|
@@ -199,7 +199,7 @@ void MeanPoolingType<MatType>::serialize(
|
|
|
199
199
|
}
|
|
200
200
|
|
|
201
201
|
template<typename MatType>
|
|
202
|
-
void
|
|
202
|
+
void MeanPooling<MatType>::PoolingOperation(
|
|
203
203
|
const CubeType& input,
|
|
204
204
|
CubeType& output)
|
|
205
205
|
{
|
|
@@ -242,7 +242,7 @@ void MeanPoolingType<MatType>::PoolingOperation(
|
|
|
242
242
|
}
|
|
243
243
|
|
|
244
244
|
template<typename MatType>
|
|
245
|
-
void
|
|
245
|
+
void MeanPooling<MatType>::Unpooling(
|
|
246
246
|
const MatType& error,
|
|
247
247
|
MatType& output)
|
|
248
248
|
{
|
|
@@ -78,10 +78,10 @@ class MultiLayer : public Layer<MatType>
|
|
|
78
78
|
* @param start Index of first layer to pass data through.
|
|
79
79
|
* @param end Index of last layer to pass data through.
|
|
80
80
|
*/
|
|
81
|
-
void
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
81
|
+
void PartialForward(const MatType& input,
|
|
82
|
+
MatType& output,
|
|
83
|
+
const size_t start,
|
|
84
|
+
const size_t end);
|
|
85
85
|
|
|
86
86
|
/**
|
|
87
87
|
* Perform a backward pass with the given data. `gy` is expected to be the
|
|
@@ -164,9 +164,24 @@ class MultiLayer : public Layer<MatType>
|
|
|
164
164
|
* @param args The layer parameter.
|
|
165
165
|
*/
|
|
166
166
|
template <typename LayerType, typename... Args>
|
|
167
|
-
void Add(Args
|
|
167
|
+
void Add(Args&&... args)
|
|
168
168
|
{
|
|
169
|
-
network.push_back(new LayerType(args...));
|
|
169
|
+
network.push_back(new LayerType(std::forward<Args>(args)...));
|
|
170
|
+
layerOutputs.push_back(MatType());
|
|
171
|
+
layerDeltas.push_back(MatType());
|
|
172
|
+
layerGradients.push_back(MatType());
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
/**
|
|
176
|
+
* Add a new module to the model, using the MatType of this layer.
|
|
177
|
+
*
|
|
178
|
+
* @param args The layer parameter.
|
|
179
|
+
*/
|
|
180
|
+
template<template<typename...> typename LayerType,
|
|
181
|
+
typename... Args>
|
|
182
|
+
void Add(Args&&... args)
|
|
183
|
+
{
|
|
184
|
+
network.push_back(new LayerType<MatType>(std::forward<Args>(args)...));
|
|
170
185
|
layerOutputs.push_back(MatType());
|
|
171
186
|
layerDeltas.push_back(MatType());
|
|
172
187
|
layerGradients.push_back(MatType());
|
|
@@ -177,6 +192,7 @@ class MultiLayer : public Layer<MatType>
|
|
|
177
192
|
*
|
|
178
193
|
* @param layer The Layer to be added to the model.
|
|
179
194
|
*/
|
|
195
|
+
[[deprecated("Will be removed in mlpack 5.0.0. Pass a reference instead.")]]
|
|
180
196
|
void Add(Layer<MatType>* layer)
|
|
181
197
|
{
|
|
182
198
|
network.push_back(layer);
|
|
@@ -185,6 +201,20 @@ class MultiLayer : public Layer<MatType>
|
|
|
185
201
|
layerGradients.push_back(MatType());
|
|
186
202
|
}
|
|
187
203
|
|
|
204
|
+
template<typename LayerType>
|
|
205
|
+
void Add(LayerType&& layer,
|
|
206
|
+
typename std::enable_if_t<
|
|
207
|
+
!std::is_pointer_v<std::remove_reference_t<LayerType>>>* = 0)
|
|
208
|
+
{
|
|
209
|
+
using NewLayerType =
|
|
210
|
+
typename std::remove_cv_t<std::remove_reference_t<LayerType>>;
|
|
211
|
+
|
|
212
|
+
network.push_back(new NewLayerType(std::forward<LayerType>(layer)));
|
|
213
|
+
layerOutputs.push_back(MatType());
|
|
214
|
+
layerDeltas.push_back(MatType());
|
|
215
|
+
layerGradients.push_back(MatType());
|
|
216
|
+
}
|
|
217
|
+
|
|
188
218
|
//! Get the network (series of layers) held by this MultiLayer.
|
|
189
219
|
const std::vector<Layer<MatType>*>& Network() const
|
|
190
220
|
{
|
|
@@ -81,6 +81,8 @@ MultiLayer<MatType>& MultiLayer<MatType>::operator=(const MultiLayer& other)
|
|
|
81
81
|
{
|
|
82
82
|
Layer<MatType>::operator=(other);
|
|
83
83
|
|
|
84
|
+
for (size_t i = 0; i < network.size(); ++i)
|
|
85
|
+
delete network[i];
|
|
84
86
|
network.clear();
|
|
85
87
|
layerOutputs.clear();
|
|
86
88
|
layerDeltas.clear();
|
|
@@ -120,6 +122,8 @@ MultiLayer<MatType>& MultiLayer<MatType>::operator=(MultiLayer&& other)
|
|
|
120
122
|
totalInputSize = std::move(other.totalInputSize);
|
|
121
123
|
totalOutputSize = std::move(other.totalOutputSize);
|
|
122
124
|
|
|
125
|
+
for (size_t i = 0; i < network.size(); ++i)
|
|
126
|
+
delete network[i];
|
|
123
127
|
network = std::move(other.network);
|
|
124
128
|
|
|
125
129
|
layerOutputs.resize(network.size(), MatType());
|
|
@@ -138,11 +142,11 @@ template<typename MatType>
|
|
|
138
142
|
void MultiLayer<MatType>::Forward(
|
|
139
143
|
const MatType& input, MatType& output)
|
|
140
144
|
{
|
|
141
|
-
|
|
145
|
+
PartialForward(input, output, 0, network.size() - 1);
|
|
142
146
|
}
|
|
143
147
|
|
|
144
148
|
template<typename MatType>
|
|
145
|
-
void MultiLayer<MatType>::
|
|
149
|
+
void MultiLayer<MatType>::PartialForward(
|
|
146
150
|
const MatType& input,
|
|
147
151
|
MatType& output,
|
|
148
152
|
const size_t start,
|
|
@@ -66,14 +66,17 @@ template <
|
|
|
66
66
|
typename MatType = arma::mat,
|
|
67
67
|
typename RegularizerType = NoRegularizer
|
|
68
68
|
>
|
|
69
|
-
class
|
|
69
|
+
class MultiheadAttention : public Layer<MatType>
|
|
70
70
|
{
|
|
71
71
|
public:
|
|
72
|
+
// Convenience typedefs.
|
|
73
|
+
using ElemType = typename MatType::elem_type;
|
|
72
74
|
using CubeType = typename GetCubeType<MatType>::type;
|
|
75
|
+
|
|
73
76
|
/**
|
|
74
77
|
* Default constructor.
|
|
75
78
|
*/
|
|
76
|
-
|
|
79
|
+
MultiheadAttention();
|
|
77
80
|
|
|
78
81
|
/**
|
|
79
82
|
* Create the MultiheadAttention object using the specified modules.
|
|
@@ -87,17 +90,17 @@ class MultiheadAttentionType : public Layer<MatType>
|
|
|
87
90
|
* @param selfAttention Use self-attention; source key, query, and value all
|
|
88
91
|
* come from the same inputs
|
|
89
92
|
*/
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
93
|
+
MultiheadAttention(const size_t tgtSeqLen,
|
|
94
|
+
const size_t numHeads,
|
|
95
|
+
const CubeType& attnMask = CubeType(),
|
|
96
|
+
const MatType& keyPaddingMask = MatType(),
|
|
97
|
+
const bool selfAttention = false);
|
|
95
98
|
|
|
96
|
-
//! Clone the
|
|
99
|
+
//! Clone the MultiheadAttention object. This handles polymorphism
|
|
97
100
|
//! correctly.
|
|
98
|
-
|
|
101
|
+
MultiheadAttention* Clone() const override
|
|
99
102
|
{
|
|
100
|
-
return new
|
|
103
|
+
return new MultiheadAttention(*this);
|
|
101
104
|
}
|
|
102
105
|
|
|
103
106
|
/**
|
|
@@ -175,9 +178,9 @@ class MultiheadAttentionType : public Layer<MatType>
|
|
|
175
178
|
size_t& NumHeads() { return numHeads; }
|
|
176
179
|
|
|
177
180
|
//! Get the two dimensional Attention Mask. Contains values 0 or 1.
|
|
178
|
-
|
|
181
|
+
CubeType const& AttentionMask() const { return attnMask; }
|
|
179
182
|
//! Modify the two dimensional Attention Mask. Should take values 0 or 1.
|
|
180
|
-
|
|
183
|
+
CubeType& AttentionMask() { return attnMask; }
|
|
181
184
|
|
|
182
185
|
//! Get Key Padding Mask. Contains values 0 or 1.
|
|
183
186
|
MatType const& KeyPaddingMask() const { return keyPaddingMask; }
|
|
@@ -265,8 +268,11 @@ class MultiheadAttentionType : public Layer<MatType>
|
|
|
265
268
|
}
|
|
266
269
|
|
|
267
270
|
private:
|
|
268
|
-
|
|
269
|
-
|
|
271
|
+
static void MaskedForwardSoftmax(CubeType& scores,
|
|
272
|
+
const size_t numHeads,
|
|
273
|
+
const size_t batchSize,
|
|
274
|
+
const CubeType& attnMask,
|
|
275
|
+
const MatType& keyPaddingMask);
|
|
270
276
|
|
|
271
277
|
//! Target sequence length.
|
|
272
278
|
size_t tgtSeqLen;
|
|
@@ -283,11 +289,12 @@ class MultiheadAttentionType : public Layer<MatType>
|
|
|
283
289
|
//! Dimensionality of each head.
|
|
284
290
|
size_t headDim;
|
|
285
291
|
|
|
286
|
-
//! Two dimensional Attention Mask of shape (tgtSeqLen,
|
|
287
|
-
//! the values [-Inf, 0]
|
|
288
|
-
|
|
292
|
+
//! Two dimensional Attention Mask of shape (srcSeqLen, tgtSeqLen, batchSize).
|
|
293
|
+
//! Takes the values [-Inf, 0]
|
|
294
|
+
CubeType attnMask;
|
|
289
295
|
|
|
290
|
-
//! Key Padding Mask.
|
|
296
|
+
//! Key Padding Mask. The shape of keyPaddingMask : (srcSeqLen, batchSize)
|
|
297
|
+
//! the values [-Inf, 0]
|
|
291
298
|
MatType keyPaddingMask;
|
|
292
299
|
|
|
293
300
|
//! Whether or not self-attention is used (source key, value, and query all
|
|
@@ -337,15 +344,12 @@ class MultiheadAttentionType : public Layer<MatType>
|
|
|
337
344
|
CubeType attnOut;
|
|
338
345
|
|
|
339
346
|
//! Softmax layer to represent the probabilities of next sequence.
|
|
340
|
-
|
|
347
|
+
Softmax<MatType> softmax;
|
|
341
348
|
|
|
342
349
|
//! Locally-stored regularizer object.
|
|
343
350
|
RegularizerType regularizer;
|
|
344
351
|
}; // class MultiheadAttention
|
|
345
352
|
|
|
346
|
-
// Standard MultiheadAttention layer using no regularization.
|
|
347
|
-
using MultiheadAttention = MultiheadAttentionType<arma::mat, NoRegularizer>;
|
|
348
|
-
|
|
349
353
|
} // namespace mlpack
|
|
350
354
|
|
|
351
355
|
// Include implementation.
|