mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -54,9 +54,10 @@ class ElishFunction
|
|
|
54
54
|
* @param x Input data.
|
|
55
55
|
* @return f(x).
|
|
56
56
|
*/
|
|
57
|
-
|
|
57
|
+
template<typename ElemType>
|
|
58
|
+
static ElemType Fn(const ElemType x)
|
|
58
59
|
{
|
|
59
|
-
if (x < 0
|
|
60
|
+
if (x < 0)
|
|
60
61
|
return (std::exp(x) - 1) / (1 + std::exp(-x));
|
|
61
62
|
|
|
62
63
|
return x / (1 + std::exp(-x));
|
|
@@ -71,8 +72,8 @@ class ElishFunction
|
|
|
71
72
|
template<typename InputVecType, typename OutputVecType>
|
|
72
73
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
73
74
|
{
|
|
74
|
-
y = ((x < 0
|
|
75
|
-
+ ((x >= 0
|
|
75
|
+
y = (conv_to<InputVecType>::from(x < 0) % ((exp(x) - 1) / (1 + exp(-x))))
|
|
76
|
+
+ (conv_to<InputVecType>::from(x >= 0) % (x / (1 + exp(-x))));
|
|
76
77
|
}
|
|
77
78
|
|
|
78
79
|
/**
|
|
@@ -82,17 +83,19 @@ class ElishFunction
|
|
|
82
83
|
* @param y Result of Fn(x).
|
|
83
84
|
* @return f'(x).
|
|
84
85
|
*/
|
|
85
|
-
|
|
86
|
+
template<typename ElemType>
|
|
87
|
+
static ElemType Deriv(const ElemType x, const ElemType y)
|
|
86
88
|
{
|
|
87
|
-
if (x < 0
|
|
89
|
+
if (x < 0)
|
|
88
90
|
{
|
|
89
91
|
return std::exp(x) - 2 / (1 + std::exp(x)) +
|
|
90
92
|
2 / std::pow(1 + std::exp(x) , 2);
|
|
91
93
|
}
|
|
92
94
|
else if (x == 0)
|
|
93
95
|
{
|
|
94
|
-
|
|
95
|
-
|
|
96
|
+
// The expression below is indeterminate at 0, even though the expression
|
|
97
|
+
// solely in terms of x is defined (= 0.5).
|
|
98
|
+
return ElemType(0.5);
|
|
96
99
|
}
|
|
97
100
|
else
|
|
98
101
|
{
|
|
@@ -118,12 +121,14 @@ class ElishFunction
|
|
|
118
121
|
// the expression solely in terms of x is defined (= 0.5)
|
|
119
122
|
// only calculate exp(x) once for each element where x < 0
|
|
120
123
|
// this gives approx 3x speedup, despite allocating the temp vector
|
|
121
|
-
DerivVecType ex = (x < 0) % exp(x);
|
|
122
|
-
dy = ((x < 0) %
|
|
123
|
-
|
|
124
|
+
DerivVecType ex = conv_to<DerivVecType>::from(x < 0) % exp(x);
|
|
125
|
+
dy = (conv_to<InputVecType>::from(x < 0) %
|
|
126
|
+
((ex - 2 / (1 + ex) + 2 / square(1 + ex)))) +
|
|
127
|
+
(conv_to<InputVecType>::from(x > 0) %
|
|
128
|
+
((y / x) % (1 + x - y)));
|
|
124
129
|
// need to do this here, because the /x above gives nans even when the
|
|
125
130
|
// condition is not met (e.g. when x > 0 is false)
|
|
126
|
-
dy(arma::find(x == 0)).fill(0.5);
|
|
131
|
+
dy(arma::find(x == 0)).fill(typename InputVecType::elem_type(0.5));
|
|
127
132
|
}
|
|
128
133
|
}; // class ElishFunction
|
|
129
134
|
|
|
@@ -45,9 +45,10 @@ class ElliotFunction
|
|
|
45
45
|
* @param x Input data.
|
|
46
46
|
* @return f(x).
|
|
47
47
|
*/
|
|
48
|
-
|
|
48
|
+
template<typename ElemType>
|
|
49
|
+
static ElemType Fn(const ElemType x)
|
|
49
50
|
{
|
|
50
|
-
return x / (1
|
|
51
|
+
return x / (1 + std::abs(x));
|
|
51
52
|
}
|
|
52
53
|
|
|
53
54
|
/**
|
|
@@ -56,10 +57,10 @@ class ElliotFunction
|
|
|
56
57
|
* @param x Input data.
|
|
57
58
|
* @param y The resulting output activation.
|
|
58
59
|
*/
|
|
59
|
-
template
|
|
60
|
+
template<typename InputVecType, typename OutputVecType>
|
|
60
61
|
static void Fn(const InputVecType &x, OutputVecType &y)
|
|
61
62
|
{
|
|
62
|
-
y = x / (1
|
|
63
|
+
y = x / (1 + arma::abs(x));
|
|
63
64
|
}
|
|
64
65
|
|
|
65
66
|
/**
|
|
@@ -69,9 +70,10 @@ class ElliotFunction
|
|
|
69
70
|
* @param y Result of Fn(x).
|
|
70
71
|
* @return f'(x).
|
|
71
72
|
*/
|
|
72
|
-
|
|
73
|
+
template<typename ElemType>
|
|
74
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
73
75
|
{
|
|
74
|
-
return 1
|
|
76
|
+
return 1 / std::pow(1 + std::abs(x), ElemType(2));
|
|
75
77
|
}
|
|
76
78
|
|
|
77
79
|
/**
|
|
@@ -86,7 +88,7 @@ class ElliotFunction
|
|
|
86
88
|
const OutputVecType& /* y */,
|
|
87
89
|
DerivVecType &dy)
|
|
88
90
|
{
|
|
89
|
-
dy = 1
|
|
91
|
+
dy = 1 / square(1 + abs(x));
|
|
90
92
|
}
|
|
91
93
|
}; // class ElliotFunction
|
|
92
94
|
|
|
@@ -22,7 +22,7 @@ namespace mlpack {
|
|
|
22
22
|
*
|
|
23
23
|
* @f{eqnarray*}{
|
|
24
24
|
* f(x) &=& e^{-1 * x^2} \\
|
|
25
|
-
* f'(x) &=& 2 * -x * f(x)
|
|
25
|
+
* f'(x) &=& 2 * -x * f(x)
|
|
26
26
|
* @f}
|
|
27
27
|
*/
|
|
28
28
|
class GaussianFunction
|
|
@@ -34,10 +34,10 @@ class GaussianFunction
|
|
|
34
34
|
* @param x Input data.
|
|
35
35
|
* @return f(x).
|
|
36
36
|
*/
|
|
37
|
-
template<typename
|
|
38
|
-
static
|
|
37
|
+
template<typename ElemType>
|
|
38
|
+
static ElemType Fn(const ElemType x)
|
|
39
39
|
{
|
|
40
|
-
return std::exp(-
|
|
40
|
+
return std::exp(-std::pow(x, ElemType(2)));
|
|
41
41
|
}
|
|
42
42
|
|
|
43
43
|
/**
|
|
@@ -49,7 +49,7 @@ class GaussianFunction
|
|
|
49
49
|
template<typename InputVecType, typename OutputVecType>
|
|
50
50
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
51
51
|
{
|
|
52
|
-
y = exp(-
|
|
52
|
+
y = exp(-square(x));
|
|
53
53
|
}
|
|
54
54
|
|
|
55
55
|
/**
|
|
@@ -59,7 +59,8 @@ class GaussianFunction
|
|
|
59
59
|
* @param y Result of Fn(x).
|
|
60
60
|
* @return f'(x)
|
|
61
61
|
*/
|
|
62
|
-
|
|
62
|
+
template<typename ElemType>
|
|
63
|
+
static ElemType Deriv(const ElemType x, const ElemType y)
|
|
63
64
|
{
|
|
64
65
|
return -2 * x * y;
|
|
65
66
|
}
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file methods/ann/activation_functions/gelu_exact_function.hpp
|
|
3
|
+
* @author Kumar Utkarsh
|
|
4
|
+
*
|
|
5
|
+
* Definition and implementation of the exact Gaussian Error Linear Unit (GELU)
|
|
6
|
+
* function.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_EXACT_FUNCTION_HPP
|
|
14
|
+
#define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_EXACT_FUNCTION_HPP
|
|
15
|
+
|
|
16
|
+
#include <mlpack/prereqs.hpp>
|
|
17
|
+
|
|
18
|
+
namespace mlpack {
|
|
19
|
+
|
|
20
|
+
/**
|
|
21
|
+
* The exact GELU function, defined by
|
|
22
|
+
*
|
|
23
|
+
* @f{eqnarray*}{
|
|
24
|
+
* f(x) = x * Phi(x) \\
|
|
25
|
+
* Phi(x) = 0.5 * (1 + erf(x / sqrt(2))) \\
|
|
26
|
+
* f'(x) = Phi(x) + x * phi(x) \\
|
|
27
|
+
* phi(x) = (1 / sqrt(2\pi)) * exp(-x^2 / 2)
|
|
28
|
+
* @f}
|
|
29
|
+
*/
|
|
30
|
+
class GELUExactFunction
|
|
31
|
+
{
|
|
32
|
+
public:
|
|
33
|
+
//! Compute the exact GELU function for a single value.
|
|
34
|
+
static double Fn(const double x)
|
|
35
|
+
{
|
|
36
|
+
return 0.5 * x * (1.0 + std::erf(x / std::sqrt(2.0)));
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
//! Compute the exact GELU function for matrices/vectors.
|
|
40
|
+
template<typename InputVecType, typename OutputVecType>
|
|
41
|
+
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
42
|
+
{
|
|
43
|
+
y = 0.5 * x % (1.0 + erf(x / std::sqrt(2.0)));
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
// Compute the first derivative of the exact GELU function for a single value
|
|
47
|
+
static double Deriv(const double x, const double y )
|
|
48
|
+
{
|
|
49
|
+
const double phi = std::exp(-0.5 * x * x) / std::sqrt(2.0 * M_PI);
|
|
50
|
+
// Reuse y to avoid costly Phi(x) computation.
|
|
51
|
+
return (x == 0.0) ? 0.5 : (y / x + x * phi);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
//! Compute the first derivative for matrices/vectors.
|
|
55
|
+
template<typename InputVecType, typename OutputVecType, typename DerivVecType>
|
|
56
|
+
static void Deriv(const InputVecType& x,
|
|
57
|
+
const OutputVecType& y,
|
|
58
|
+
DerivVecType& dy)
|
|
59
|
+
{
|
|
60
|
+
dy.set_size(x.n_elem);
|
|
61
|
+
// Reuse y to avoid costly Phi(x) computation.
|
|
62
|
+
for (size_t i = 0; i < x.n_elem; ++i)
|
|
63
|
+
{
|
|
64
|
+
if (x[i] == 0.0) dy[i] = 0.5;
|
|
65
|
+
else dy[i] = y[i] / x[i] +
|
|
66
|
+
x[i] * std::exp(-0.5 * x[i] * x[i]) / std::sqrt(2.0 * M_PI);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
}; // class GELUExactFunction
|
|
70
|
+
|
|
71
|
+
} // namespace mlpack
|
|
72
|
+
|
|
73
|
+
#endif
|
|
@@ -37,10 +37,12 @@ class GELUFunction
|
|
|
37
37
|
* @param x Input data.
|
|
38
38
|
* @return f(x).
|
|
39
39
|
*/
|
|
40
|
-
|
|
40
|
+
template<typename ElemType>
|
|
41
|
+
static ElemType Fn(const ElemType x)
|
|
41
42
|
{
|
|
42
|
-
return
|
|
43
|
-
|
|
43
|
+
return (x / 2) *
|
|
44
|
+
(1 + std::tanh(std::sqrt(2 / arma::Datum<ElemType>::pi) *
|
|
45
|
+
(x + ElemType(0.044715) * std::pow(x, ElemType(3)))));
|
|
44
46
|
}
|
|
45
47
|
|
|
46
48
|
/**
|
|
@@ -52,8 +54,11 @@ class GELUFunction
|
|
|
52
54
|
template<typename InputVecType, typename OutputVecType>
|
|
53
55
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
54
56
|
{
|
|
55
|
-
|
|
56
|
-
|
|
57
|
+
typedef typename InputVecType::elem_type ElemType;
|
|
58
|
+
|
|
59
|
+
y = (x / 2) %
|
|
60
|
+
(1 + tanh(std::sqrt(2 / arma::Datum<ElemType>::pi) *
|
|
61
|
+
(x + ElemType(0.044715) * pow(x, ElemType(3)))));
|
|
57
62
|
}
|
|
58
63
|
|
|
59
64
|
/**
|
|
@@ -63,13 +68,16 @@ class GELUFunction
|
|
|
63
68
|
* @param y Result of Fn(x).
|
|
64
69
|
* @return f'(x)
|
|
65
70
|
*/
|
|
66
|
-
|
|
71
|
+
template<typename ElemType>
|
|
72
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
67
73
|
{
|
|
68
|
-
if (x < -10) return 0
|
|
69
|
-
return 0.5 * std::tanh(0.0356774
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
74
|
+
if (x < -10) return 0; // catch overflows
|
|
75
|
+
return ElemType(0.5) * std::tanh(ElemType(0.0356774) *
|
|
76
|
+
std::pow(x, ElemType(3)) + ElemType(0.797885) * x) +
|
|
77
|
+
(ElemType(0.0535161) * std::pow(x, ElemType(3)) +
|
|
78
|
+
ElemType(0.398942) * x) *
|
|
79
|
+
std::pow(1 / std::cosh(ElemType(0.0356774) * std::pow(x, 3) +
|
|
80
|
+
ElemType(0.797885) * x), 2) + ElemType(0.5);
|
|
73
81
|
}
|
|
74
82
|
|
|
75
83
|
/**
|
|
@@ -84,11 +92,14 @@ class GELUFunction
|
|
|
84
92
|
const OutputVecType& /* y */,
|
|
85
93
|
DerivVecType& dy)
|
|
86
94
|
{
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
0.797885 * x),
|
|
91
|
-
|
|
95
|
+
typedef typename InputVecType::elem_type ElemType;
|
|
96
|
+
|
|
97
|
+
dy = ElemType(0.5) * tanh(ElemType(0.0356774) * pow(x, ElemType(3)) +
|
|
98
|
+
ElemType(0.797885) * x) + (ElemType(0.0535161) * pow(x, ElemType(3)) +
|
|
99
|
+
ElemType(0.398942) * x) %
|
|
100
|
+
pow(1 / cosh(ElemType(0.0356774) * pow(x, ElemType(3)) +
|
|
101
|
+
ElemType(0.797885) * x), 2) + ElemType(0.5);
|
|
102
|
+
dy(find(x < -10)).fill(0); // catch overflows
|
|
92
103
|
}
|
|
93
104
|
}; // class GELUFunction
|
|
94
105
|
|
|
@@ -39,9 +39,10 @@ class HardSigmoidFunction
|
|
|
39
39
|
* @param x Input data.
|
|
40
40
|
* @return f(x).
|
|
41
41
|
*/
|
|
42
|
-
|
|
42
|
+
template<typename ElemType>
|
|
43
|
+
static ElemType Fn(const ElemType x)
|
|
43
44
|
{
|
|
44
|
-
return std::min(1
|
|
45
|
+
return std::min(ElemType(1), std::max(ElemType(0), x / 5 + ElemType(0.5)));
|
|
45
46
|
}
|
|
46
47
|
|
|
47
48
|
/**
|
|
@@ -67,13 +68,14 @@ class HardSigmoidFunction
|
|
|
67
68
|
* @param y Result of Fn(x).
|
|
68
69
|
* @return f'(x)
|
|
69
70
|
*/
|
|
70
|
-
|
|
71
|
+
template<typename ElemType>
|
|
72
|
+
static ElemType Deriv(const ElemType /* x */, const ElemType y)
|
|
71
73
|
{
|
|
72
|
-
if (y == 0
|
|
74
|
+
if (y == 0 || y == 1)
|
|
73
75
|
{
|
|
74
|
-
return 0
|
|
76
|
+
return 0;
|
|
75
77
|
}
|
|
76
|
-
return 0.2;
|
|
78
|
+
return ElemType(0.2);
|
|
77
79
|
}
|
|
78
80
|
|
|
79
81
|
/**
|
|
@@ -52,7 +52,8 @@ class HardSwishFunction
|
|
|
52
52
|
* @param x Input data.
|
|
53
53
|
* @return f(x).
|
|
54
54
|
*/
|
|
55
|
-
|
|
55
|
+
template<typename ElemType>
|
|
56
|
+
static ElemType Fn(const ElemType x)
|
|
56
57
|
{
|
|
57
58
|
if (x <= -3)
|
|
58
59
|
return 0;
|
|
@@ -68,7 +69,7 @@ class HardSwishFunction
|
|
|
68
69
|
* @param x Input data.
|
|
69
70
|
* @param y The resulting output activation.
|
|
70
71
|
*/
|
|
71
|
-
template
|
|
72
|
+
template<typename InputVecType, typename OutputVecType>
|
|
72
73
|
static void Fn(const InputVecType &x, OutputVecType &y)
|
|
73
74
|
{
|
|
74
75
|
y.set_size(size(x));
|
|
@@ -85,14 +86,15 @@ class HardSwishFunction
|
|
|
85
86
|
* @param * (y) Result of Fn(x).
|
|
86
87
|
* @return f'(x).
|
|
87
88
|
*/
|
|
88
|
-
|
|
89
|
+
template<typename ElemType>
|
|
90
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
89
91
|
{
|
|
90
92
|
if (x <= -3)
|
|
91
93
|
return 0;
|
|
92
94
|
else if (x >= 3)
|
|
93
95
|
return 1;
|
|
94
96
|
|
|
95
|
-
return (2 * x + 3
|
|
97
|
+
return (2 * x + 3) / 6;
|
|
96
98
|
}
|
|
97
99
|
|
|
98
100
|
/**
|
|
@@ -56,12 +56,13 @@ class HyperSinhFunction
|
|
|
56
56
|
* @param x Input data.
|
|
57
57
|
* @return f(x).
|
|
58
58
|
*/
|
|
59
|
-
|
|
59
|
+
template<typename ElemType>
|
|
60
|
+
static ElemType Fn(const ElemType x)
|
|
60
61
|
{
|
|
61
62
|
if (x > 0)
|
|
62
|
-
return (std::sinh(x) / 3
|
|
63
|
+
return (std::sinh(x) / 3);
|
|
63
64
|
else
|
|
64
|
-
return (std::pow(x, 3
|
|
65
|
+
return (std::pow(x, ElemType(3)) / 4);
|
|
65
66
|
}
|
|
66
67
|
|
|
67
68
|
/**
|
|
@@ -94,12 +95,13 @@ class HyperSinhFunction
|
|
|
94
95
|
* @param y Input activation.
|
|
95
96
|
* @return f'(x)
|
|
96
97
|
*/
|
|
97
|
-
|
|
98
|
+
template<typename ElemType>
|
|
99
|
+
static ElemType Deriv(const ElemType /* x */, const ElemType y)
|
|
98
100
|
{
|
|
99
101
|
if (y > 0)
|
|
100
|
-
return (std::pow((1.0 / 9.0) + (y * y), 0.5));
|
|
102
|
+
return (std::pow(ElemType(1.0 / 9.0) + (y * y), ElemType(0.5)));
|
|
101
103
|
else
|
|
102
|
-
return (3
|
|
104
|
+
return (3 * std::pow(std::pow(y, ElemType(2)) / 4, ElemType(1.0 / 3.0)));
|
|
103
105
|
}
|
|
104
106
|
|
|
105
107
|
/**
|
|
@@ -113,17 +115,20 @@ class HyperSinhFunction
|
|
|
113
115
|
const OutputVecType& y,
|
|
114
116
|
DerivVecType& dy)
|
|
115
117
|
{
|
|
118
|
+
typedef typename InputVecType::elem_type ElemType;
|
|
119
|
+
|
|
116
120
|
dy.set_size(size(y));
|
|
117
121
|
#pragma omp for
|
|
118
122
|
for (size_t i = 0; i < y.n_elem; ++i)
|
|
119
123
|
{
|
|
120
124
|
if (y(i) > 0)
|
|
121
125
|
{
|
|
122
|
-
dy(i) = (std::pow((1.0 / 9.0) + (y(i) * y(i)), 0.5));
|
|
126
|
+
dy(i) = (std::pow(ElemType(1.0 / 9.0) + (y(i) * y(i)), ElemType(0.5)));
|
|
123
127
|
}
|
|
124
128
|
else
|
|
125
129
|
{
|
|
126
|
-
dy(i) = (3
|
|
130
|
+
dy(i) = (3 * std::pow(std::pow(y(i), ElemType(2)) / 4,
|
|
131
|
+
ElemType(1.0 / 3.0)));
|
|
127
132
|
}
|
|
128
133
|
}
|
|
129
134
|
}
|
|
@@ -33,7 +33,8 @@ class IdentityFunction
|
|
|
33
33
|
* @param x Input data.
|
|
34
34
|
* @return f(x).
|
|
35
35
|
*/
|
|
36
|
-
|
|
36
|
+
template<typename ElemType>
|
|
37
|
+
static ElemType Fn(const ElemType x)
|
|
37
38
|
{
|
|
38
39
|
return x;
|
|
39
40
|
}
|
|
@@ -59,9 +60,10 @@ class IdentityFunction
|
|
|
59
60
|
* @param * (y) Result of Fn(x).
|
|
60
61
|
* @return f'(x)
|
|
61
62
|
*/
|
|
62
|
-
|
|
63
|
+
template<typename ElemType>
|
|
64
|
+
static ElemType Deriv(const ElemType /* x */, const ElemType /* y */)
|
|
63
65
|
{
|
|
64
|
-
return 1
|
|
66
|
+
return 1;
|
|
65
67
|
}
|
|
66
68
|
|
|
67
69
|
/**
|
|
@@ -76,7 +78,7 @@ class IdentityFunction
|
|
|
76
78
|
const OutputVecType& /* y */,
|
|
77
79
|
DerivVecType& dy)
|
|
78
80
|
{
|
|
79
|
-
dy.ones(
|
|
81
|
+
dy.ones(size(x));
|
|
80
82
|
}
|
|
81
83
|
|
|
82
84
|
/**
|
|
@@ -33,9 +33,10 @@ class InvQuadFunction
|
|
|
33
33
|
* @param x Input data.
|
|
34
34
|
* @return f(x).
|
|
35
35
|
*/
|
|
36
|
-
|
|
36
|
+
template<typename ElemType>
|
|
37
|
+
static ElemType Fn(const ElemType x)
|
|
37
38
|
{
|
|
38
|
-
return 1 / (
|
|
39
|
+
return 1 / (1 + x * x);
|
|
39
40
|
}
|
|
40
41
|
|
|
41
42
|
/**
|
|
@@ -47,7 +48,7 @@ class InvQuadFunction
|
|
|
47
48
|
template<typename InputVecType, typename OutputVecType>
|
|
48
49
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
49
50
|
{
|
|
50
|
-
y = 1 / (1 +
|
|
51
|
+
y = 1 / (1 + square(x));
|
|
51
52
|
}
|
|
52
53
|
|
|
53
54
|
/**
|
|
@@ -57,9 +58,10 @@ class InvQuadFunction
|
|
|
57
58
|
* @param y Result of Fn(x).
|
|
58
59
|
* @return f'(x)
|
|
59
60
|
*/
|
|
60
|
-
|
|
61
|
+
template<typename ElemType>
|
|
62
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
61
63
|
{
|
|
62
|
-
return -2 * x / std::pow(1 + std::pow(x, 2), 2);
|
|
64
|
+
return -2 * x / std::pow(1 + std::pow(x, ElemType(2)), ElemType(2));
|
|
63
65
|
}
|
|
64
66
|
|
|
65
67
|
/**
|
|
@@ -74,7 +76,7 @@ class InvQuadFunction
|
|
|
74
76
|
const OutputVecType& /* y */,
|
|
75
77
|
DerivVecType &dy)
|
|
76
78
|
{
|
|
77
|
-
dy = -
|
|
79
|
+
dy = -2 * x / square(1 + square(x));
|
|
78
80
|
}
|
|
79
81
|
}; // class InvQuadFunction
|
|
80
82
|
|
|
@@ -47,7 +47,8 @@ class LiSHTFunction
|
|
|
47
47
|
* @param x Input data.
|
|
48
48
|
* @return f(x).
|
|
49
49
|
*/
|
|
50
|
-
|
|
50
|
+
template<typename ElemType>
|
|
51
|
+
static ElemType Fn(const ElemType x)
|
|
51
52
|
{
|
|
52
53
|
return x * std::tanh(x);
|
|
53
54
|
}
|
|
@@ -61,7 +62,7 @@ class LiSHTFunction
|
|
|
61
62
|
template <typename InputVecType, typename OutputVecType>
|
|
62
63
|
static void Fn(const InputVecType &x, OutputVecType &y)
|
|
63
64
|
{
|
|
64
|
-
y = x %
|
|
65
|
+
y = x % tanh(x);
|
|
65
66
|
}
|
|
66
67
|
|
|
67
68
|
/**
|
|
@@ -71,9 +72,10 @@ class LiSHTFunction
|
|
|
71
72
|
* @param y Result of Fn(x).
|
|
72
73
|
* @return f'(x)
|
|
73
74
|
*/
|
|
74
|
-
|
|
75
|
+
template<typename ElemType>
|
|
76
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
75
77
|
{
|
|
76
|
-
return std::tanh(x) + x * (1 - std::pow(std::tanh(x), 2));
|
|
78
|
+
return std::tanh(x) + x * (1 - std::pow(std::tanh(x), ElemType(2)));
|
|
77
79
|
}
|
|
78
80
|
|
|
79
81
|
/**
|
|
@@ -88,7 +90,7 @@ class LiSHTFunction
|
|
|
88
90
|
const OutputVecType& /* y */,
|
|
89
91
|
DerivVecType& dy)
|
|
90
92
|
{
|
|
91
|
-
dy =
|
|
93
|
+
dy = tanh(x) + x % (1 - square(tanh(x)));
|
|
92
94
|
}
|
|
93
95
|
}; // class LishtFunction
|
|
94
96
|
|
|
@@ -34,18 +34,18 @@ class LogisticFunction
|
|
|
34
34
|
* @param x Input data.
|
|
35
35
|
* @return f(x).
|
|
36
36
|
*/
|
|
37
|
-
template<typename
|
|
38
|
-
static
|
|
37
|
+
template<typename ElemType>
|
|
38
|
+
static ElemType Fn(const ElemType x)
|
|
39
39
|
{
|
|
40
|
-
if (x < arma::Datum<
|
|
40
|
+
if (x < arma::Datum<ElemType>::log_max)
|
|
41
41
|
{
|
|
42
|
-
if (x > -arma::Datum<
|
|
43
|
-
return 1
|
|
42
|
+
if (x > -arma::Datum<ElemType>::log_max)
|
|
43
|
+
return 1 / (1 + std::exp(-x));
|
|
44
44
|
|
|
45
|
-
return 0
|
|
45
|
+
return 0;
|
|
46
46
|
}
|
|
47
47
|
|
|
48
|
-
return 1
|
|
48
|
+
return 1;
|
|
49
49
|
}
|
|
50
50
|
|
|
51
51
|
/**
|
|
@@ -57,7 +57,7 @@ class LogisticFunction
|
|
|
57
57
|
template<typename InputVecType, typename OutputVecType>
|
|
58
58
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
59
59
|
{
|
|
60
|
-
y = (1
|
|
60
|
+
y = (1 / (1 + exp(-x)));
|
|
61
61
|
}
|
|
62
62
|
|
|
63
63
|
/**
|
|
@@ -67,9 +67,10 @@ class LogisticFunction
|
|
|
67
67
|
* @param y Result of Fn(x).
|
|
68
68
|
* @return f'(x)
|
|
69
69
|
*/
|
|
70
|
-
|
|
70
|
+
template<typename ElemType>
|
|
71
|
+
static ElemType Deriv(const ElemType /* x */, const ElemType y)
|
|
71
72
|
{
|
|
72
|
-
return y * (1
|
|
73
|
+
return y * (1 - y);
|
|
73
74
|
}
|
|
74
75
|
|
|
75
76
|
/**
|
|
@@ -84,7 +85,7 @@ class LogisticFunction
|
|
|
84
85
|
const OutputVecType& y,
|
|
85
86
|
DerivVecType& dy)
|
|
86
87
|
{
|
|
87
|
-
dy = y % (1
|
|
88
|
+
dy = y % (1 - y);
|
|
88
89
|
}
|
|
89
90
|
|
|
90
91
|
/**
|
|
@@ -93,7 +94,8 @@ class LogisticFunction
|
|
|
93
94
|
* @param y Input data.
|
|
94
95
|
* @return f^{-1}(y)
|
|
95
96
|
*/
|
|
96
|
-
|
|
97
|
+
template<typename ElemType>
|
|
98
|
+
static ElemType Inv(const ElemType y)
|
|
97
99
|
{
|
|
98
100
|
return arma::trunc_log(y / (1 - y));
|
|
99
101
|
}
|
|
@@ -45,7 +45,8 @@ class MishFunction
|
|
|
45
45
|
* @param x Input data.
|
|
46
46
|
* @return f(x).
|
|
47
47
|
*/
|
|
48
|
-
|
|
48
|
+
template<typename ElemType>
|
|
49
|
+
static ElemType Fn(const ElemType x)
|
|
49
50
|
{
|
|
50
51
|
return x * (std::exp(2 * x) + 2 * std::exp(x)) /
|
|
51
52
|
(2 + 2 * std::exp(x) + std::exp(2 * x));
|
|
@@ -57,7 +58,7 @@ class MishFunction
|
|
|
57
58
|
* @param x Input data.
|
|
58
59
|
* @param y The resulting output activation.
|
|
59
60
|
*/
|
|
60
|
-
template
|
|
61
|
+
template<typename InputVecType, typename OutputVecType>
|
|
61
62
|
static void Fn(const InputVecType &x, OutputVecType &y)
|
|
62
63
|
{
|
|
63
64
|
y = x % (exp(2 * x) + 2 * exp(x)) / (2 + 2 * exp(x) + exp(2 * x));
|
|
@@ -70,11 +71,12 @@ class MishFunction
|
|
|
70
71
|
* @param y Result of Fn(x).
|
|
71
72
|
* @return f'(x)
|
|
72
73
|
*/
|
|
73
|
-
|
|
74
|
+
template<typename ElemType>
|
|
75
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
74
76
|
{
|
|
75
77
|
return std::exp(x) * (4 * (x + 1) + std::exp(x) * (4 * x + 6) +
|
|
76
78
|
4 * std::exp(2 * x) + std::exp(3 * x)) /
|
|
77
|
-
std::pow(std::exp(2 * x) + 2 * std::exp(x) + 2, 2);
|
|
79
|
+
std::pow(std::exp(2 * x) + 2 * std::exp(x) + 2, ElemType(2));
|
|
78
80
|
}
|
|
79
81
|
|
|
80
82
|
/**
|
|
@@ -91,7 +93,7 @@ class MishFunction
|
|
|
91
93
|
{
|
|
92
94
|
dy = exp(x) % (4 * (x + 1) + exp(x) % (4 * x + 6) +
|
|
93
95
|
4 * exp(2 * x) + exp(3 * x)) /
|
|
94
|
-
|
|
96
|
+
square(exp(2 * x) + 2 * exp(x) + 2);
|
|
95
97
|
}
|
|
96
98
|
}; // class MishFunction
|
|
97
99
|
|