mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -18,7 +18,7 @@ namespace mlpack {
|
|
|
18
18
|
|
|
19
19
|
// Create the LinearRecurrent layer.
|
|
20
20
|
template<typename MatType, typename RegularizerType>
|
|
21
|
-
|
|
21
|
+
LinearRecurrent<MatType, RegularizerType>::LinearRecurrent() :
|
|
22
22
|
RecurrentLayer<MatType>(),
|
|
23
23
|
inSize(0),
|
|
24
24
|
outSize(0)
|
|
@@ -27,7 +27,7 @@ LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType() :
|
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
template<typename MatType, typename RegularizerType>
|
|
30
|
-
|
|
30
|
+
LinearRecurrent<MatType, RegularizerType>::LinearRecurrent(
|
|
31
31
|
const size_t outSize,
|
|
32
32
|
RegularizerType regularizer) :
|
|
33
33
|
RecurrentLayer<MatType>(),
|
|
@@ -40,8 +40,8 @@ LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
|
|
|
40
40
|
|
|
41
41
|
// Copy constructor.
|
|
42
42
|
template<typename MatType, typename RegularizerType>
|
|
43
|
-
|
|
44
|
-
const
|
|
43
|
+
LinearRecurrent<MatType, RegularizerType>::LinearRecurrent(
|
|
44
|
+
const LinearRecurrent& layer) :
|
|
45
45
|
RecurrentLayer<MatType>(layer),
|
|
46
46
|
inSize(layer.inSize),
|
|
47
47
|
outSize(layer.outSize),
|
|
@@ -52,8 +52,8 @@ LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
|
|
|
52
52
|
|
|
53
53
|
// Move constructor.
|
|
54
54
|
template<typename MatType, typename RegularizerType>
|
|
55
|
-
|
|
56
|
-
|
|
55
|
+
LinearRecurrent<MatType, RegularizerType>::LinearRecurrent(
|
|
56
|
+
LinearRecurrent&& layer) :
|
|
57
57
|
RecurrentLayer<MatType>(std::move(layer)),
|
|
58
58
|
inSize(std::move(layer.inSize)),
|
|
59
59
|
outSize(std::move(layer.outSize)),
|
|
@@ -66,9 +66,9 @@ LinearRecurrentType<MatType, RegularizerType>::LinearRecurrentType(
|
|
|
66
66
|
|
|
67
67
|
// Copy operator.
|
|
68
68
|
template<typename MatType, typename RegularizerType>
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
const
|
|
69
|
+
LinearRecurrent<MatType, RegularizerType>&
|
|
70
|
+
LinearRecurrent<MatType, RegularizerType>::operator=(
|
|
71
|
+
const LinearRecurrent& layer)
|
|
72
72
|
{
|
|
73
73
|
if (&layer != this)
|
|
74
74
|
{
|
|
@@ -83,9 +83,9 @@ LinearRecurrentType<MatType, RegularizerType>::operator=(
|
|
|
83
83
|
|
|
84
84
|
// Move operator.
|
|
85
85
|
template<typename MatType, typename RegularizerType>
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
86
|
+
LinearRecurrent<MatType, RegularizerType>&
|
|
87
|
+
LinearRecurrent<MatType, RegularizerType>::operator=(
|
|
88
|
+
LinearRecurrent&& layer)
|
|
89
89
|
{
|
|
90
90
|
if (&layer != this)
|
|
91
91
|
{
|
|
@@ -104,7 +104,7 @@ LinearRecurrentType<MatType, RegularizerType>::operator=(
|
|
|
104
104
|
|
|
105
105
|
// Set the parameters of the layer.
|
|
106
106
|
template<typename MatType, typename RegularizerType>
|
|
107
|
-
void
|
|
107
|
+
void LinearRecurrent<MatType, RegularizerType>::SetWeights(
|
|
108
108
|
const MatType& weightsIn)
|
|
109
109
|
{
|
|
110
110
|
MakeAlias(parameters, weightsIn, WeightSize(), 1);
|
|
@@ -116,7 +116,7 @@ void LinearRecurrentType<MatType, RegularizerType>::SetWeights(
|
|
|
116
116
|
|
|
117
117
|
// Forward pass of linear recurrent layer.
|
|
118
118
|
template<typename MatType, typename RegularizerType>
|
|
119
|
-
void
|
|
119
|
+
void LinearRecurrent<MatType, RegularizerType>::Forward(
|
|
120
120
|
const MatType& input, MatType& output)
|
|
121
121
|
{
|
|
122
122
|
// Take the forward step: f(x) = Wx + Uh + b.
|
|
@@ -127,7 +127,7 @@ void LinearRecurrentType<MatType, RegularizerType>::Forward(
|
|
|
127
127
|
else
|
|
128
128
|
{
|
|
129
129
|
output = weights * input +
|
|
130
|
-
recurrentWeights *
|
|
130
|
+
recurrentWeights * previousOutput;
|
|
131
131
|
}
|
|
132
132
|
|
|
133
133
|
#pragma omp for
|
|
@@ -136,12 +136,12 @@ void LinearRecurrentType<MatType, RegularizerType>::Forward(
|
|
|
136
136
|
|
|
137
137
|
// Update the recurrent state if needed.
|
|
138
138
|
if (!this->AtFinalStep())
|
|
139
|
-
|
|
139
|
+
currentOutput = output;
|
|
140
140
|
}
|
|
141
141
|
|
|
142
142
|
// Backward pass of linear recurrent layer.
|
|
143
143
|
template<typename MatType, typename RegularizerType>
|
|
144
|
-
void
|
|
144
|
+
void LinearRecurrent<MatType, RegularizerType>::Backward(
|
|
145
145
|
const MatType& /* input */,
|
|
146
146
|
const MatType& /* output */,
|
|
147
147
|
const MatType& gy,
|
|
@@ -159,7 +159,7 @@ void LinearRecurrentType<MatType, RegularizerType>::Backward(
|
|
|
159
159
|
{
|
|
160
160
|
// Via the recurrence, the result is equivalent, just with the recurrent
|
|
161
161
|
// gradient as the gy parameter.
|
|
162
|
-
g += weights.t() *
|
|
162
|
+
g += weights.t() * currentGradient;
|
|
163
163
|
}
|
|
164
164
|
|
|
165
165
|
if (this->HasPreviousStep())
|
|
@@ -169,20 +169,19 @@ void LinearRecurrentType<MatType, RegularizerType>::Backward(
|
|
|
169
169
|
//
|
|
170
170
|
// With respect to the output, we can just propagate back through the
|
|
171
171
|
// recurrent weights.
|
|
172
|
-
|
|
172
|
+
previousGradient = recurrentWeights.t() * gy;
|
|
173
173
|
|
|
174
174
|
if (!this->AtFinalStep())
|
|
175
175
|
{
|
|
176
176
|
// If we also have a path from dz/dh^t, this can be added.
|
|
177
|
-
|
|
178
|
-
recurrentWeights.t() * this->RecurrentGradient(this->CurrentStep());
|
|
177
|
+
previousGradient += recurrentWeights.t() * currentGradient;
|
|
179
178
|
}
|
|
180
179
|
}
|
|
181
180
|
}
|
|
182
181
|
|
|
183
182
|
// Compute the gradient with respect to the input.
|
|
184
183
|
template<typename MatType, typename RegularizerType>
|
|
185
|
-
void
|
|
184
|
+
void LinearRecurrent<MatType, RegularizerType>::Gradient(
|
|
186
185
|
const MatType& input,
|
|
187
186
|
const MatType& error,
|
|
188
187
|
MatType& gradient)
|
|
@@ -204,7 +203,7 @@ void LinearRecurrentType<MatType, RegularizerType>::Gradient(
|
|
|
204
203
|
if (this->HasPreviousStep())
|
|
205
204
|
{
|
|
206
205
|
gradient.submat(whOffset, 0, bOffset - 1, 0) =
|
|
207
|
-
vectorise(error *
|
|
206
|
+
vectorise(error * previousOutput.t());
|
|
208
207
|
}
|
|
209
208
|
gradient.submat(bOffset, 0, gradient.n_rows - 1, 0) = sum(error, 1);
|
|
210
209
|
|
|
@@ -215,15 +214,14 @@ void LinearRecurrentType<MatType, RegularizerType>::Gradient(
|
|
|
215
214
|
if (!this->AtFinalStep())
|
|
216
215
|
{
|
|
217
216
|
gradient.submat(0, 0, whOffset - 1, 0) +=
|
|
218
|
-
vectorise(
|
|
217
|
+
vectorise(currentGradient * input.t());
|
|
219
218
|
if (this->HasPreviousStep())
|
|
220
219
|
{
|
|
221
220
|
gradient.submat(whOffset, 0, bOffset - 1, 0) +=
|
|
222
|
-
vectorise(
|
|
223
|
-
this->RecurrentState(this->PreviousStep()).t());
|
|
221
|
+
vectorise(currentGradient * previousOutput.t());
|
|
224
222
|
}
|
|
225
223
|
gradient.submat(bOffset, 0, gradient.n_rows - 1, 0) += sum(
|
|
226
|
-
|
|
224
|
+
currentGradient, 1);
|
|
227
225
|
|
|
228
226
|
// this->HiddenDeriv(this->PreviousStep()) was already computed in
|
|
229
227
|
// Backward(), so no need to do it here.
|
|
@@ -232,7 +230,7 @@ void LinearRecurrentType<MatType, RegularizerType>::Gradient(
|
|
|
232
230
|
|
|
233
231
|
// Get the total number of trainable parameters.
|
|
234
232
|
template<typename MatType, typename RegularizerType>
|
|
235
|
-
size_t
|
|
233
|
+
size_t LinearRecurrent<MatType, RegularizerType>::WeightSize() const
|
|
236
234
|
{
|
|
237
235
|
return (inSize * outSize) /* weight matrix */ +
|
|
238
236
|
(outSize * outSize) /* recurrent state matrix */ +
|
|
@@ -240,7 +238,7 @@ size_t LinearRecurrentType<MatType, RegularizerType>::WeightSize() const
|
|
|
240
238
|
}
|
|
241
239
|
|
|
242
240
|
template<typename MatType, typename RegularizerType>
|
|
243
|
-
size_t
|
|
241
|
+
size_t LinearRecurrent<MatType, RegularizerType>::RecurrentSize() const
|
|
244
242
|
{
|
|
245
243
|
return outSize;
|
|
246
244
|
}
|
|
@@ -248,7 +246,7 @@ size_t LinearRecurrentType<MatType, RegularizerType>::RecurrentSize() const
|
|
|
248
246
|
// Compute the output dimensions of the layer, assuming that inputDimension has
|
|
249
247
|
// been set.
|
|
250
248
|
template<typename MatType, typename RegularizerType>
|
|
251
|
-
void
|
|
249
|
+
void LinearRecurrent<MatType, RegularizerType>::ComputeOutputDimensions()
|
|
252
250
|
{
|
|
253
251
|
// Compute the total number of input dimensions.
|
|
254
252
|
inSize = this->inputDimensions[0];
|
|
@@ -261,10 +259,41 @@ void LinearRecurrentType<MatType, RegularizerType>::ComputeOutputDimensions()
|
|
|
261
259
|
this->outputDimensions[0] = outSize;
|
|
262
260
|
}
|
|
263
261
|
|
|
262
|
+
template<typename MatType, typename RegularizerType>
|
|
263
|
+
void LinearRecurrent<MatType, RegularizerType>::OnStepChanged(
|
|
264
|
+
const size_t step,
|
|
265
|
+
const size_t /* batchSize */,
|
|
266
|
+
const size_t activeBatchSize,
|
|
267
|
+
const bool backwards)
|
|
268
|
+
{
|
|
269
|
+
// Make aliases for the output from the recurrent state.
|
|
270
|
+
MakeAlias(currentOutput, this->RecurrentState(step),
|
|
271
|
+
outSize, activeBatchSize);
|
|
272
|
+
|
|
273
|
+
if (this->HasPreviousStep())
|
|
274
|
+
{
|
|
275
|
+
MakeAlias(previousOutput, this->RecurrentState(this->PreviousStep()),
|
|
276
|
+
outSize, activeBatchSize);
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
// Make aliases for the gradient from the recurrent gradient.
|
|
280
|
+
if (backwards)
|
|
281
|
+
{
|
|
282
|
+
MakeAlias(currentGradient, this->RecurrentGradient(step),
|
|
283
|
+
outSize, activeBatchSize);
|
|
284
|
+
|
|
285
|
+
if (this->HasPreviousStep())
|
|
286
|
+
{
|
|
287
|
+
MakeAlias(previousGradient, this->RecurrentGradient(this->PreviousStep()),
|
|
288
|
+
outSize, activeBatchSize);
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
264
293
|
// Serialize the layer.
|
|
265
294
|
template<typename MatType, typename RegularizerType>
|
|
266
295
|
template<typename Archive>
|
|
267
|
-
void
|
|
296
|
+
void LinearRecurrent<MatType, RegularizerType>::serialize(
|
|
268
297
|
Archive& ar, const uint32_t /* version */)
|
|
269
298
|
{
|
|
270
299
|
ar(cereal::base_class<RecurrentLayer<MatType>>(this));
|
|
@@ -29,37 +29,31 @@ namespace mlpack {
|
|
|
29
29
|
* computation.
|
|
30
30
|
*/
|
|
31
31
|
template <typename MatType = arma::mat>
|
|
32
|
-
class
|
|
32
|
+
class LogSoftMax : public Layer<MatType>
|
|
33
33
|
{
|
|
34
34
|
public:
|
|
35
|
+
// Convenience typedef to access the element type of the weights and data.
|
|
36
|
+
using ElemType = typename MatType::elem_type;
|
|
37
|
+
|
|
35
38
|
/**
|
|
36
39
|
* Create the LogSoftmax layer.
|
|
37
40
|
*/
|
|
38
|
-
|
|
41
|
+
LogSoftMax();
|
|
39
42
|
|
|
40
|
-
//! Clone the
|
|
41
|
-
|
|
43
|
+
//! Clone the LogSoftMax object. This handles polymorphism correctly.
|
|
44
|
+
LogSoftMax* Clone() const { return new LogSoftMax(*this); }
|
|
42
45
|
|
|
43
46
|
// Virtual destructor.
|
|
44
|
-
virtual ~
|
|
47
|
+
virtual ~LogSoftMax() { }
|
|
45
48
|
|
|
46
|
-
//! Copy the given
|
|
47
|
-
|
|
48
|
-
//! Take ownership of the given
|
|
49
|
-
|
|
50
|
-
//! Copy the given
|
|
51
|
-
|
|
52
|
-
//! Take ownership of the given
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
/**
|
|
56
|
-
* A wrapper function to call the correct implementation according to the
|
|
57
|
-
* specific matrix type (e.g., arma, coot).
|
|
58
|
-
*
|
|
59
|
-
* @param input Input data used for evaluating the specified function.
|
|
60
|
-
* @param output Resulting output activation.
|
|
61
|
-
*/
|
|
62
|
-
void Forward(const MatType& input, MatType& output);
|
|
49
|
+
//! Copy the given LogSoftMax.
|
|
50
|
+
LogSoftMax(const LogSoftMax& other);
|
|
51
|
+
//! Take ownership of the given LogSoftMax.
|
|
52
|
+
LogSoftMax(LogSoftMax&& other);
|
|
53
|
+
//! Copy the given LogSoftMax.
|
|
54
|
+
LogSoftMax& operator=(const LogSoftMax& other);
|
|
55
|
+
//! Take ownership of the given LogSoftMax.
|
|
56
|
+
LogSoftMax& operator=(LogSoftMax&& other);
|
|
63
57
|
|
|
64
58
|
/**
|
|
65
59
|
* Ordinary feed forward pass of a neural network, evaluating the function
|
|
@@ -68,15 +62,7 @@ class LogSoftMaxType : public Layer<MatType>
|
|
|
68
62
|
* @param input Input data used for evaluating the specified function.
|
|
69
63
|
* @param output Resulting output activation.
|
|
70
64
|
*/
|
|
71
|
-
void
|
|
72
|
-
const typename std::enable_if_t<
|
|
73
|
-
arma::is_arma_type<MatType>::value>* = 0);
|
|
74
|
-
|
|
75
|
-
#ifdef MLPACK_HAS_COOT
|
|
76
|
-
void ForwardImpl(const MatType& input, MatType& output,
|
|
77
|
-
const typename std::enable_if_t<
|
|
78
|
-
coot::is_coot_type<MatType>::value>* = 0);
|
|
79
|
-
#endif
|
|
65
|
+
void Forward(const MatType& input, MatType& output);
|
|
80
66
|
|
|
81
67
|
/**
|
|
82
68
|
* Ordinary feed backward pass of a neural network, calculating the function
|
|
@@ -101,11 +87,6 @@ class LogSoftMaxType : public Layer<MatType>
|
|
|
101
87
|
}
|
|
102
88
|
}; // class LogSoftmaxType
|
|
103
89
|
|
|
104
|
-
// Convenience typedefs.
|
|
105
|
-
|
|
106
|
-
// Standard Linear layer using no regularization.
|
|
107
|
-
using LogSoftMax = LogSoftMaxType<arma::mat>;
|
|
108
|
-
|
|
109
90
|
} // namespace mlpack
|
|
110
91
|
|
|
111
92
|
// Include implementation.
|
|
@@ -18,29 +18,29 @@
|
|
|
18
18
|
namespace mlpack {
|
|
19
19
|
|
|
20
20
|
template<typename MatType>
|
|
21
|
-
|
|
21
|
+
LogSoftMax<MatType>::LogSoftMax() :
|
|
22
22
|
Layer<MatType>()
|
|
23
23
|
{
|
|
24
24
|
// Nothing to do here.
|
|
25
25
|
}
|
|
26
26
|
|
|
27
27
|
template<typename MatType>
|
|
28
|
-
|
|
28
|
+
LogSoftMax<MatType>::LogSoftMax(const LogSoftMax& other) :
|
|
29
29
|
Layer<MatType>(other)
|
|
30
30
|
{
|
|
31
31
|
// Nothing to do here.
|
|
32
32
|
}
|
|
33
33
|
|
|
34
34
|
template<typename MatType>
|
|
35
|
-
|
|
35
|
+
LogSoftMax<MatType>::LogSoftMax(LogSoftMax&& other) :
|
|
36
36
|
Layer<MatType>(std::move(other))
|
|
37
37
|
{
|
|
38
38
|
// Nothing to do here.
|
|
39
39
|
}
|
|
40
40
|
|
|
41
41
|
template<typename MatType>
|
|
42
|
-
|
|
43
|
-
|
|
42
|
+
LogSoftMax<MatType>&
|
|
43
|
+
LogSoftMax<MatType>::operator=(const LogSoftMax& other)
|
|
44
44
|
{
|
|
45
45
|
if (&other != this)
|
|
46
46
|
{
|
|
@@ -51,8 +51,8 @@ LogSoftMaxType<MatType>::operator=(const LogSoftMaxType& other)
|
|
|
51
51
|
}
|
|
52
52
|
|
|
53
53
|
template<typename MatType>
|
|
54
|
-
|
|
55
|
-
|
|
54
|
+
LogSoftMax<MatType>&
|
|
55
|
+
LogSoftMax<MatType>::operator=(LogSoftMax&& other)
|
|
56
56
|
{
|
|
57
57
|
if (&other != this)
|
|
58
58
|
{
|
|
@@ -63,85 +63,69 @@ LogSoftMaxType<MatType>::operator=(LogSoftMaxType&& other)
|
|
|
63
63
|
}
|
|
64
64
|
|
|
65
65
|
template<typename MatType>
|
|
66
|
-
void
|
|
66
|
+
void LogSoftMax<MatType>::Forward(const MatType& input, MatType& output)
|
|
67
67
|
{
|
|
68
|
-
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
template<typename MatType>
|
|
72
|
-
void LogSoftMaxType<MatType>::ForwardImpl(
|
|
73
|
-
const MatType& input,
|
|
74
|
-
MatType& output,
|
|
75
|
-
const typename std::enable_if_t<arma::is_arma_type<MatType>::value>*)
|
|
76
|
-
{
|
|
77
|
-
MatType maxInput = repmat(max(input, 0), input.n_rows, 1);
|
|
78
|
-
output = (maxInput - input);
|
|
79
|
-
|
|
80
|
-
// Approximation of the base-e exponential function. The accuracy, however, is
|
|
81
|
-
// about 0.00001 lower than using exp. Credits go to Leon Bottou.
|
|
82
|
-
#pragma omp parallel for
|
|
83
|
-
for (size_t i = 0; i < output.n_elem; ++i)
|
|
68
|
+
if constexpr (IsArma<MatType>::value)
|
|
84
69
|
{
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if (x < 13.0)
|
|
70
|
+
MatType maxInput = repmat(max(input, 0), input.n_rows, 1);
|
|
71
|
+
output = (maxInput - input);
|
|
72
|
+
|
|
73
|
+
// Approximation of the base-e exponential function. The accuracy, however,
|
|
74
|
+
// is about 0.00001 lower than using exp. Credits go to Leon Bottou.
|
|
75
|
+
#pragma omp parallel for
|
|
76
|
+
for (size_t i = 0; i < output.n_elem; ++i)
|
|
94
77
|
{
|
|
95
|
-
double
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
78
|
+
double x = output(i);
|
|
79
|
+
//! Fast approximation of exp(-x) for x positive.
|
|
80
|
+
static constexpr double A0 = 1.0;
|
|
81
|
+
static constexpr double A1 = 0.125;
|
|
82
|
+
static constexpr double A2 = 0.0078125;
|
|
83
|
+
static constexpr double A3 = 0.00032552083;
|
|
84
|
+
static constexpr double A4 = 1.0172526e-5;
|
|
85
|
+
|
|
86
|
+
if (x < 13.0)
|
|
87
|
+
{
|
|
88
|
+
double y = A0 + x * (A1 + x * (A2 + x * (A3 + x * A4)));
|
|
89
|
+
y *= y;
|
|
90
|
+
y *= y;
|
|
91
|
+
y *= y;
|
|
92
|
+
y = 1 / y;
|
|
93
|
+
output(i) = ElemType(y);
|
|
94
|
+
}
|
|
95
|
+
else
|
|
96
|
+
{
|
|
97
|
+
output(i) = 0;
|
|
98
|
+
}
|
|
101
99
|
}
|
|
102
|
-
|
|
100
|
+
|
|
101
|
+
#pragma omp parallel for
|
|
102
|
+
for (size_t col = 0; col < maxInput.n_cols; ++col)
|
|
103
103
|
{
|
|
104
|
-
|
|
104
|
+
ElemType colSum = 0;
|
|
105
|
+
for (size_t row = 0; row < output.n_rows; ++row)
|
|
106
|
+
{
|
|
107
|
+
colSum += output(row, col);
|
|
108
|
+
}
|
|
109
|
+
ElemType logSum = std::log(colSum);
|
|
110
|
+
for (size_t row = 0; row < maxInput.n_rows; ++row)
|
|
111
|
+
{
|
|
112
|
+
maxInput(row, col) += logSum;
|
|
113
|
+
}
|
|
105
114
|
}
|
|
115
|
+
output = input - maxInput;
|
|
106
116
|
}
|
|
107
|
-
|
|
108
|
-
#pragma omp parallel for
|
|
109
|
-
for (size_t col = 0; col < maxInput.n_cols; ++col)
|
|
117
|
+
else if constexpr (IsCoot<MatType>::value)
|
|
110
118
|
{
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
double logSum = std::log(colSum);
|
|
117
|
-
for (size_t row = 0; row < maxInput.n_rows; ++row)
|
|
118
|
-
{
|
|
119
|
-
maxInput(row, col) += logSum;
|
|
120
|
-
}
|
|
119
|
+
MatType maxInput = repmat(max(input), input.n_rows, 1);
|
|
120
|
+
output = (maxInput - input);
|
|
121
|
+
output = exp(-output);
|
|
122
|
+
maxInput.each_row() += log(sum(output));
|
|
123
|
+
output = input - maxInput;
|
|
121
124
|
}
|
|
122
|
-
|
|
123
|
-
output = input - maxInput;
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
#ifdef MLPACK_HAS_COOT
|
|
127
|
-
|
|
128
|
-
template<typename MatType>
|
|
129
|
-
void LogSoftMaxType<MatType>::ForwardImpl(
|
|
130
|
-
const MatType& input,
|
|
131
|
-
MatType& output,
|
|
132
|
-
const typename std::enable_if_t<coot::is_coot_type<MatType>::value>*)
|
|
133
|
-
{
|
|
134
|
-
MatType maxInput = repmat(max(input), input.n_rows, 1);
|
|
135
|
-
output = (maxInput - input);
|
|
136
|
-
output = exp(output * -1);
|
|
137
|
-
maxInput.each_row() += log(sum(output));
|
|
138
|
-
output = input - maxInput;
|
|
139
125
|
}
|
|
140
126
|
|
|
141
|
-
#endif
|
|
142
|
-
|
|
143
127
|
template<typename MatType>
|
|
144
|
-
void
|
|
128
|
+
void LogSoftMax<MatType>::Backward(
|
|
145
129
|
const MatType& /* input */,
|
|
146
130
|
const MatType& output,
|
|
147
131
|
const MatType& gy,
|
|
@@ -53,11 +53,14 @@ namespace mlpack {
|
|
|
53
53
|
* computation.
|
|
54
54
|
*/
|
|
55
55
|
template<typename MatType = arma::mat>
|
|
56
|
-
class
|
|
56
|
+
class LSTM : public RecurrentLayer<MatType>
|
|
57
57
|
{
|
|
58
58
|
public:
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
// Convenience typedef to access the element type of the weights and data.
|
|
60
|
+
using ElemType = typename MatType::elem_type;
|
|
61
|
+
|
|
62
|
+
// Create the LSTM object.
|
|
63
|
+
LSTM();
|
|
61
64
|
|
|
62
65
|
/**
|
|
63
66
|
* Create the LSTM layer object using the specified parameters.
|
|
@@ -65,21 +68,21 @@ class LSTMType : public RecurrentLayer<MatType>
|
|
|
65
68
|
* @param outSize The number of output units.
|
|
66
69
|
* @param rho Maximum number of steps to backpropagate through time (BPTT).
|
|
67
70
|
*/
|
|
68
|
-
|
|
71
|
+
LSTM(const size_t outSize);
|
|
69
72
|
|
|
70
|
-
|
|
71
|
-
|
|
73
|
+
// Clone the LSTM object. This handles polymorphism correctly.
|
|
74
|
+
LSTM* Clone() const { return new LSTM(*this); }
|
|
72
75
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
76
|
+
// Copy the given LSTM object.
|
|
77
|
+
LSTM(const LSTM& other);
|
|
78
|
+
// Take ownership of the given LSTM object's data.
|
|
79
|
+
LSTM(LSTM&& other);
|
|
80
|
+
// Copy the given LSTM object.
|
|
81
|
+
LSTM& operator=(const LSTM& other);
|
|
82
|
+
// Take ownership of the given LSTM object's data.
|
|
83
|
+
LSTM& operator=(LSTM&& other);
|
|
81
84
|
|
|
82
|
-
virtual ~
|
|
85
|
+
virtual ~LSTM() { }
|
|
83
86
|
|
|
84
87
|
/**
|
|
85
88
|
* Reset the layer parameter. The method is called to
|
|
@@ -217,6 +220,12 @@ class LSTMType : public RecurrentLayer<MatType>
|
|
|
217
220
|
this->outputDimensions[0] = outSize;
|
|
218
221
|
}
|
|
219
222
|
|
|
223
|
+
// Update the internal aliases of the layer when the step changes.
|
|
224
|
+
void OnStepChanged(const size_t step,
|
|
225
|
+
const size_t batchSize,
|
|
226
|
+
const size_t activeBatchSize,
|
|
227
|
+
const bool backwards);
|
|
228
|
+
|
|
220
229
|
/**
|
|
221
230
|
* Serialize the layer.
|
|
222
231
|
*/
|
|
@@ -287,20 +296,8 @@ class LSTMType : public RecurrentLayer<MatType>
|
|
|
287
296
|
MatType nextDeltaForgetGate;
|
|
288
297
|
MatType nextDeltaOutputGate;
|
|
289
298
|
MatType nextDeltaCell;
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
// the correct places in the current recurrent state methods.
|
|
293
|
-
void SetInternalAliases(const size_t batchSize);
|
|
294
|
-
|
|
295
|
-
// Calling this function will set up workspace memory for the backward pass,
|
|
296
|
-
// if necessary.
|
|
297
|
-
void SetBackwardWorkspace(const size_t batchSize);
|
|
298
|
-
}; // class LSTMType
|
|
299
|
-
|
|
300
|
-
// Convenience typedefs.
|
|
301
|
-
|
|
302
|
-
// Standard LSTM layer.
|
|
303
|
-
using LSTM = LSTMType<arma::mat>;
|
|
299
|
+
MatType nextForgetGate;
|
|
300
|
+
}; // class LSTM
|
|
304
301
|
|
|
305
302
|
} // namespace mlpack
|
|
306
303
|
|