mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -34,9 +34,10 @@ class MultiQuadFunction
|
|
|
34
34
|
* @param x Input data.
|
|
35
35
|
* @return f(x).
|
|
36
36
|
*/
|
|
37
|
-
|
|
37
|
+
template<typename ElemType>
|
|
38
|
+
static ElemType Fn(const ElemType x)
|
|
38
39
|
{
|
|
39
|
-
return std::
|
|
40
|
+
return std::sqrt(1 + x * x);
|
|
40
41
|
}
|
|
41
42
|
|
|
42
43
|
/**
|
|
@@ -48,7 +49,7 @@ class MultiQuadFunction
|
|
|
48
49
|
template<typename InputVecType, typename OutputVecType>
|
|
49
50
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
50
51
|
{
|
|
51
|
-
y =
|
|
52
|
+
y = sqrt((1 + square(x)));
|
|
52
53
|
}
|
|
53
54
|
|
|
54
55
|
/**
|
|
@@ -61,7 +62,8 @@ class MultiQuadFunction
|
|
|
61
62
|
* @param y Result of Fn(x).
|
|
62
63
|
* @return f'(x)
|
|
63
64
|
*/
|
|
64
|
-
|
|
65
|
+
template<typename ElemType>
|
|
66
|
+
static ElemType Deriv(const ElemType x, const ElemType y)
|
|
65
67
|
{
|
|
66
68
|
return x / y;
|
|
67
69
|
}
|
|
@@ -33,7 +33,8 @@ class Poisson1Function
|
|
|
33
33
|
* @param x Input data.
|
|
34
34
|
* @return f(x).
|
|
35
35
|
*/
|
|
36
|
-
|
|
36
|
+
template<typename ElemType>
|
|
37
|
+
static ElemType Fn(const ElemType x)
|
|
37
38
|
{
|
|
38
39
|
return (x - 1) * std::exp(-x);
|
|
39
40
|
}
|
|
@@ -57,7 +58,8 @@ class Poisson1Function
|
|
|
57
58
|
* @param y Result of Fn(x).
|
|
58
59
|
* @return f'(x)
|
|
59
60
|
*/
|
|
60
|
-
|
|
61
|
+
template<typename ElemType>
|
|
62
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
61
63
|
{
|
|
62
64
|
return -std::exp(-x) * (x - 2);
|
|
63
65
|
}
|
|
@@ -33,9 +33,10 @@ class QuadraticFunction
|
|
|
33
33
|
* @param x Input data.
|
|
34
34
|
* @return f(x).
|
|
35
35
|
*/
|
|
36
|
-
|
|
36
|
+
template<typename ElemType>
|
|
37
|
+
static ElemType Fn(const ElemType x)
|
|
37
38
|
{
|
|
38
|
-
return std::pow(x, 2);
|
|
39
|
+
return std::pow(x, ElemType(2));
|
|
39
40
|
}
|
|
40
41
|
|
|
41
42
|
/**
|
|
@@ -47,7 +48,7 @@ class QuadraticFunction
|
|
|
47
48
|
template<typename InputVecType, typename OutputVecType>
|
|
48
49
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
49
50
|
{
|
|
50
|
-
y =
|
|
51
|
+
y = square(x);
|
|
51
52
|
}
|
|
52
53
|
|
|
53
54
|
/**
|
|
@@ -57,7 +58,8 @@ class QuadraticFunction
|
|
|
57
58
|
* @param y Result of Fn(x).
|
|
58
59
|
* @return f'(x)
|
|
59
60
|
*/
|
|
60
|
-
|
|
61
|
+
template<typename ElemType>
|
|
62
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
61
63
|
{
|
|
62
64
|
return 2 * x;
|
|
63
65
|
}
|
|
@@ -50,9 +50,10 @@ class RectifierFunction
|
|
|
50
50
|
* @param x Input data.
|
|
51
51
|
* @return f(x).
|
|
52
52
|
*/
|
|
53
|
-
|
|
53
|
+
template<typename ElemType>
|
|
54
|
+
static ElemType Fn(const ElemType x)
|
|
54
55
|
{
|
|
55
|
-
return std::max(0
|
|
56
|
+
return std::max(ElemType(0), x);
|
|
56
57
|
}
|
|
57
58
|
|
|
58
59
|
/**
|
|
@@ -64,9 +65,7 @@ class RectifierFunction
|
|
|
64
65
|
template<typename MatType>
|
|
65
66
|
static void Fn(const MatType& x, MatType& y)
|
|
66
67
|
{
|
|
67
|
-
y
|
|
68
|
-
y.zeros();
|
|
69
|
-
y = max(y, x);
|
|
68
|
+
y = clamp(x, 0, std::numeric_limits<typename MatType::elem_type>::max());
|
|
70
69
|
}
|
|
71
70
|
|
|
72
71
|
/**
|
|
@@ -76,9 +75,10 @@ class RectifierFunction
|
|
|
76
75
|
* @param y Result of Fn(x).
|
|
77
76
|
* @return f'(x)
|
|
78
77
|
*/
|
|
79
|
-
|
|
78
|
+
template<typename ElemType>
|
|
79
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
80
80
|
{
|
|
81
|
-
return (
|
|
81
|
+
return ElemType(x > 0);
|
|
82
82
|
}
|
|
83
83
|
|
|
84
84
|
/**
|
|
@@ -89,11 +89,11 @@ class RectifierFunction
|
|
|
89
89
|
* @param dy The resulting derivatives.
|
|
90
90
|
*/
|
|
91
91
|
template<typename InputType, typename OutputType, typename DerivType>
|
|
92
|
-
static void Deriv(const InputType& x
|
|
93
|
-
const OutputType&
|
|
92
|
+
static void Deriv(const InputType& /* x */,
|
|
93
|
+
const OutputType& y,
|
|
94
94
|
DerivType& dy)
|
|
95
95
|
{
|
|
96
|
-
dy =
|
|
96
|
+
dy = sign(y);
|
|
97
97
|
}
|
|
98
98
|
}; // class RectifierFunction
|
|
99
99
|
|
|
@@ -49,9 +49,10 @@ class SILUFunction
|
|
|
49
49
|
* @param x Input data.
|
|
50
50
|
* @return f(x).
|
|
51
51
|
*/
|
|
52
|
-
|
|
52
|
+
template<typename ElemType>
|
|
53
|
+
static ElemType Fn(const ElemType x)
|
|
53
54
|
{
|
|
54
|
-
return x / (1
|
|
55
|
+
return x / (1 + std::exp(-x));
|
|
55
56
|
}
|
|
56
57
|
|
|
57
58
|
/**
|
|
@@ -63,7 +64,7 @@ class SILUFunction
|
|
|
63
64
|
template<typename InputVecType, typename OutputVecType>
|
|
64
65
|
static void Fn(const InputVecType &x, OutputVecType &y)
|
|
65
66
|
{
|
|
66
|
-
y = x / (1
|
|
67
|
+
y = x / (1 + exp(-x));
|
|
67
68
|
}
|
|
68
69
|
|
|
69
70
|
/**
|
|
@@ -73,11 +74,12 @@ class SILUFunction
|
|
|
73
74
|
* @param y Result of Fn(x).
|
|
74
75
|
* @return f'(x)
|
|
75
76
|
*/
|
|
76
|
-
|
|
77
|
+
template<typename ElemType>
|
|
78
|
+
static double Deriv(const ElemType x, const ElemType y)
|
|
77
79
|
{
|
|
78
80
|
// since y = x * sigmoid(x)
|
|
79
|
-
|
|
80
|
-
return x == 0 ? 0.5 : sigmoid * (1
|
|
81
|
+
const ElemType sigmoid = y / x; // save an exp
|
|
82
|
+
return x == 0 ? ElemType(0.5) : sigmoid * (1 + x * (1 - sigmoid));
|
|
81
83
|
// the expression above is indeterminate at 0, even though
|
|
82
84
|
// the expression solely in terms of x is defined (= 0.5)
|
|
83
85
|
}
|
|
@@ -97,10 +99,10 @@ class SILUFunction
|
|
|
97
99
|
// since y = x * sigmoid(x)
|
|
98
100
|
// DerivVecType sigmoid = y / x;
|
|
99
101
|
// dy = sigmoid % (1.0 + x % (1.0 - sigmoid));
|
|
100
|
-
dy = (y / x) % (1
|
|
102
|
+
dy = (y / x) % (1 + x - y);
|
|
101
103
|
// the expression above is indeterminate at 0, even though
|
|
102
104
|
// the expression solely in terms of x is defined (= 0.5)
|
|
103
|
-
dy(arma::find(x == 0)).fill(0.5);
|
|
105
|
+
dy(arma::find(x == 0)).fill(typename InputVecType::elem_type(0.5));
|
|
104
106
|
}
|
|
105
107
|
}; // class SILUFunction
|
|
106
108
|
|
|
@@ -48,9 +48,10 @@ class SoftplusFunction
|
|
|
48
48
|
* @param x Input data.
|
|
49
49
|
* @return f(x).
|
|
50
50
|
*/
|
|
51
|
-
|
|
51
|
+
template<typename ElemType>
|
|
52
|
+
static ElemType Fn(const ElemType x)
|
|
52
53
|
{
|
|
53
|
-
const
|
|
54
|
+
const ElemType val = std::log(1 + std::exp(x));
|
|
54
55
|
if (std::isfinite(val))
|
|
55
56
|
return val;
|
|
56
57
|
return x;
|
|
@@ -65,7 +66,7 @@ class SoftplusFunction
|
|
|
65
66
|
template<typename InputType, typename OutputType>
|
|
66
67
|
static void Fn(const InputType& x, OutputType& y)
|
|
67
68
|
{
|
|
68
|
-
y.set_size(
|
|
69
|
+
y.set_size(size(x));
|
|
69
70
|
|
|
70
71
|
for (size_t i = 0; i < x.n_elem; ++i)
|
|
71
72
|
y(i) = Fn(x(i));
|
|
@@ -78,9 +79,10 @@ class SoftplusFunction
|
|
|
78
79
|
* @param y Result of Fn(x).
|
|
79
80
|
* @return f'(x)
|
|
80
81
|
*/
|
|
81
|
-
|
|
82
|
+
template<typename ElemType>
|
|
83
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
82
84
|
{
|
|
83
|
-
return 1
|
|
85
|
+
return 1 / (1 + std::exp(-x));
|
|
84
86
|
}
|
|
85
87
|
|
|
86
88
|
/**
|
|
@@ -95,7 +97,7 @@ class SoftplusFunction
|
|
|
95
97
|
const OutputType& /* y */,
|
|
96
98
|
DerivType& dy)
|
|
97
99
|
{
|
|
98
|
-
dy = 1
|
|
100
|
+
dy = 1 / (1 + exp(-x));
|
|
99
101
|
}
|
|
100
102
|
|
|
101
103
|
/**
|
|
@@ -104,9 +106,10 @@ class SoftplusFunction
|
|
|
104
106
|
* @param y Input data.
|
|
105
107
|
* @return f^{-1}(y)
|
|
106
108
|
*/
|
|
107
|
-
|
|
109
|
+
template<typename ElemType>
|
|
110
|
+
static ElemType Inv(const ElemType y)
|
|
108
111
|
{
|
|
109
|
-
const
|
|
112
|
+
const ElemType val = std::log(std::exp(y) - 1);
|
|
110
113
|
if (std::isfinite(val))
|
|
111
114
|
return val;
|
|
112
115
|
return y;
|
|
@@ -121,7 +124,7 @@ class SoftplusFunction
|
|
|
121
124
|
template<typename InputType, typename OutputType>
|
|
122
125
|
static void Inv(const InputType& y, OutputType& x)
|
|
123
126
|
{
|
|
124
|
-
x.set_size(
|
|
127
|
+
x.set_size(size(y));
|
|
125
128
|
|
|
126
129
|
for (size_t i = 0; i < y.n_elem; ++i)
|
|
127
130
|
x(i) = Inv(y(i));
|
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|
*
|
|
10
10
|
* @code
|
|
11
11
|
* @inproceedings{GlorotAISTATS2010,
|
|
12
|
-
* title={
|
|
12
|
+
* title={Understanding the difficulty of training deep feedforward
|
|
13
13
|
* neural networks},
|
|
14
14
|
* author={Glorot, Xavier and Bengio, Yoshua},
|
|
15
15
|
* booktitle={Proceedings of AISTATS 2010},
|
|
@@ -34,13 +34,7 @@ namespace mlpack {
|
|
|
34
34
|
*
|
|
35
35
|
* @f{eqnarray*}{
|
|
36
36
|
* f(x) &=& \frac{x}{1 + |x|} \\
|
|
37
|
-
* f'(x) &=& (1
|
|
38
|
-
* f(x) &=& \left\{
|
|
39
|
-
* \begin{array}{lr}
|
|
40
|
-
* -\frac{x}{1 - x} & : x \le 0 \\
|
|
41
|
-
* \frac{x}{1 + x} & : x > 0
|
|
42
|
-
* \end{array}
|
|
43
|
-
* \right.
|
|
37
|
+
* f'(x) &=& (1 + |f(x)|)^2 \\
|
|
44
38
|
* @f}
|
|
45
39
|
*/
|
|
46
40
|
class SoftsignFunction
|
|
@@ -52,11 +46,10 @@ class SoftsignFunction
|
|
|
52
46
|
* @param x Input data.
|
|
53
47
|
* @return f(x).
|
|
54
48
|
*/
|
|
55
|
-
|
|
49
|
+
template<typename ElemType>
|
|
50
|
+
static ElemType Fn(const ElemType x)
|
|
56
51
|
{
|
|
57
|
-
|
|
58
|
-
return x > -DBL_MAX ? x / (1.0 + std::abs(x)) : -1.0;
|
|
59
|
-
return 1.0;
|
|
52
|
+
return x / (1 + std::abs(x));
|
|
60
53
|
}
|
|
61
54
|
|
|
62
55
|
/**
|
|
@@ -68,10 +61,7 @@ class SoftsignFunction
|
|
|
68
61
|
template<typename InputVecType, typename OutputVecType>
|
|
69
62
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
70
63
|
{
|
|
71
|
-
y
|
|
72
|
-
|
|
73
|
-
for (size_t i = 0; i < x.n_elem; ++i)
|
|
74
|
-
y(i) = Fn(x(i));
|
|
64
|
+
y = x / (1 + abs(x));
|
|
75
65
|
}
|
|
76
66
|
|
|
77
67
|
/**
|
|
@@ -81,9 +71,10 @@ class SoftsignFunction
|
|
|
81
71
|
* @param y Result of Fn(x).
|
|
82
72
|
* @return f'(x)
|
|
83
73
|
*/
|
|
84
|
-
|
|
74
|
+
template<typename ElemType>
|
|
75
|
+
static ElemType Deriv(const ElemType x, const ElemType /* y */)
|
|
85
76
|
{
|
|
86
|
-
return 1
|
|
77
|
+
return 1 / std::pow(1 + std::abs(x), ElemType(2));
|
|
87
78
|
}
|
|
88
79
|
|
|
89
80
|
/**
|
|
@@ -98,7 +89,7 @@ class SoftsignFunction
|
|
|
98
89
|
const OutputVecType& /* y */,
|
|
99
90
|
DerivVecType& dy)
|
|
100
91
|
{
|
|
101
|
-
dy = 1
|
|
92
|
+
dy = 1 / square(1 + abs(x));
|
|
102
93
|
}
|
|
103
94
|
|
|
104
95
|
/**
|
|
@@ -107,12 +98,13 @@ class SoftsignFunction
|
|
|
107
98
|
* @param y Input data.
|
|
108
99
|
* @return f^{-1}(y)
|
|
109
100
|
*/
|
|
110
|
-
|
|
101
|
+
template<typename ElemType>
|
|
102
|
+
static ElemType Inv(const ElemType y)
|
|
111
103
|
{
|
|
112
104
|
if (y > 0)
|
|
113
|
-
return
|
|
105
|
+
return -y / (y - 1);
|
|
114
106
|
else
|
|
115
|
-
return y
|
|
107
|
+
return y / (1 + y);
|
|
116
108
|
}
|
|
117
109
|
|
|
118
110
|
/**
|
|
@@ -124,7 +116,7 @@ class SoftsignFunction
|
|
|
124
116
|
template<typename InputVecType, typename OutputVecType>
|
|
125
117
|
static void Inv(const InputVecType& y, OutputVecType& x)
|
|
126
118
|
{
|
|
127
|
-
x.set_size(
|
|
119
|
+
x.set_size(size(y));
|
|
128
120
|
|
|
129
121
|
for (size_t i = 0; i < y.n_elem; ++i)
|
|
130
122
|
x(i) = Inv(y(i));
|
|
@@ -34,9 +34,10 @@ class SplineFunction
|
|
|
34
34
|
* @param x Input data.
|
|
35
35
|
* @return f(x).
|
|
36
36
|
*/
|
|
37
|
-
|
|
37
|
+
template<typename ElemType>
|
|
38
|
+
static ElemType Fn(const ElemType x)
|
|
38
39
|
{
|
|
39
|
-
return std::pow(x, 2) * std::log(1 + x);
|
|
40
|
+
return std::pow(x, ElemType(2)) * std::log(1 + x);
|
|
40
41
|
}
|
|
41
42
|
|
|
42
43
|
/**
|
|
@@ -48,7 +49,7 @@ class SplineFunction
|
|
|
48
49
|
template<typename InputVecType, typename OutputVecType>
|
|
49
50
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
50
51
|
{
|
|
51
|
-
y =
|
|
52
|
+
y = square(x) % log(1 + x);
|
|
52
53
|
}
|
|
53
54
|
|
|
54
55
|
/**
|
|
@@ -58,9 +59,10 @@ class SplineFunction
|
|
|
58
59
|
* @param y Result of Fn(x).
|
|
59
60
|
* @return f'(x)
|
|
60
61
|
*/
|
|
61
|
-
|
|
62
|
+
template<typename ElemType>
|
|
63
|
+
static ElemType Deriv(const ElemType x, const ElemType y)
|
|
62
64
|
{
|
|
63
|
-
return
|
|
65
|
+
return (x != 0) ? (2 * y / x + std::pow(x, ElemType(2)) / (1 + x)) : 0;
|
|
64
66
|
}
|
|
65
67
|
|
|
66
68
|
/**
|
|
@@ -75,10 +77,10 @@ class SplineFunction
|
|
|
75
77
|
const OutputVecType& y,
|
|
76
78
|
DerivVecType& dy)
|
|
77
79
|
{
|
|
78
|
-
dy = 2 * y / x +
|
|
80
|
+
dy = 2 * y / x + square(x) / (1 + x);
|
|
79
81
|
// the expression above is indeterminate at 0, even though
|
|
80
82
|
// the expression solely in terms of x is defined (= 0)
|
|
81
|
-
dy(
|
|
83
|
+
dy(find(x == 0)).zeros();
|
|
82
84
|
}
|
|
83
85
|
}; // class SplineFunction
|
|
84
86
|
|
|
@@ -36,9 +36,10 @@ class SwishFunction
|
|
|
36
36
|
* @param x Input data.
|
|
37
37
|
* @return f(x).
|
|
38
38
|
*/
|
|
39
|
-
|
|
39
|
+
template<typename ElemType>
|
|
40
|
+
static ElemType Fn(const ElemType x)
|
|
40
41
|
{
|
|
41
|
-
return x / (1
|
|
42
|
+
return x / (1 + std::exp(-x));
|
|
42
43
|
}
|
|
43
44
|
|
|
44
45
|
/**
|
|
@@ -51,7 +52,7 @@ class SwishFunction
|
|
|
51
52
|
static void Fn(const MatType& x, MatType& y,
|
|
52
53
|
const typename std::enable_if_t<IsMatrix<MatType>::value>* = 0)
|
|
53
54
|
{
|
|
54
|
-
y = x / (1
|
|
55
|
+
y = x / (1 + exp(-x));
|
|
55
56
|
}
|
|
56
57
|
|
|
57
58
|
/**
|
|
@@ -64,7 +65,7 @@ class SwishFunction
|
|
|
64
65
|
static void Fn(const VecType& x, VecType& y,
|
|
65
66
|
const typename std::enable_if_t<IsVector<VecType>::value>* = 0)
|
|
66
67
|
{
|
|
67
|
-
y.set_size(
|
|
68
|
+
y.set_size(size(x));
|
|
68
69
|
|
|
69
70
|
for (size_t i = 0; i < x.n_elem; ++i)
|
|
70
71
|
y(i) = Fn(x(i));
|
|
@@ -77,10 +78,11 @@ class SwishFunction
|
|
|
77
78
|
* @param y Result of Fn(x).
|
|
78
79
|
* @return f'(x)
|
|
79
80
|
*/
|
|
80
|
-
|
|
81
|
+
template<typename ElemType>
|
|
82
|
+
static ElemType Deriv(const ElemType x, const ElemType y)
|
|
81
83
|
{
|
|
82
|
-
|
|
83
|
-
return x == 0 ? 0.5 : sigmoid * (1
|
|
84
|
+
const ElemType sigmoid = y / x; // save an exp
|
|
85
|
+
return (x == 0) ? ElemType(0.5) : sigmoid * (1 + x * (1 - sigmoid));
|
|
84
86
|
// the expression above is indeterminate at 0, even though
|
|
85
87
|
// the expression solely in terms of x is defined (= 0.5)
|
|
86
88
|
}
|
|
@@ -97,10 +99,10 @@ class SwishFunction
|
|
|
97
99
|
const OutputVecType& y,
|
|
98
100
|
DerivVecType& dy)
|
|
99
101
|
{
|
|
100
|
-
dy = (y / x) % (1
|
|
102
|
+
dy = (y / x) % (1 + x - y);
|
|
101
103
|
// the expression above is indeterminate at 0, even though
|
|
102
104
|
// the expression solely in terms of x is defined (= 0.5)
|
|
103
|
-
dy(
|
|
105
|
+
dy(find(x == 0)).fill(typename InputVecType::elem_type(0.5));
|
|
104
106
|
}
|
|
105
107
|
}; // class SwishFunction
|
|
106
108
|
|
|
@@ -48,7 +48,8 @@ class TanhExpFunction
|
|
|
48
48
|
* @param x Input data.
|
|
49
49
|
* @return f(x).
|
|
50
50
|
*/
|
|
51
|
-
|
|
51
|
+
template<typename ElemType>
|
|
52
|
+
static ElemType Fn(const ElemType x)
|
|
52
53
|
{
|
|
53
54
|
return x * std::tanh(std::exp(x));
|
|
54
55
|
}
|
|
@@ -62,7 +63,7 @@ class TanhExpFunction
|
|
|
62
63
|
template<typename InputVecType, typename OutputVecType>
|
|
63
64
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
64
65
|
{
|
|
65
|
-
y = x %
|
|
66
|
+
y = x % tanh(exp(x));
|
|
66
67
|
}
|
|
67
68
|
|
|
68
69
|
/**
|
|
@@ -72,11 +73,12 @@ class TanhExpFunction
|
|
|
72
73
|
* @param y Result of Fn(x).
|
|
73
74
|
* @return f'(x)
|
|
74
75
|
*/
|
|
75
|
-
|
|
76
|
+
template<typename ElemType>
|
|
77
|
+
static ElemType Deriv(const ElemType x, const ElemType y)
|
|
76
78
|
{
|
|
77
79
|
// leverage both y and x
|
|
78
|
-
return x == 0 ? std::tanh(1) :
|
|
79
|
-
y / x + x * std::exp(x) * (1 - std::pow(y / x, 2));
|
|
80
|
+
return (x == 0) ? std::tanh(1) :
|
|
81
|
+
y / x + x * std::exp(x) * (1 - std::pow(y / x, ElemType(2)));
|
|
80
82
|
}
|
|
81
83
|
|
|
82
84
|
/**
|
|
@@ -92,10 +94,10 @@ class TanhExpFunction
|
|
|
92
94
|
DerivVecType& dy)
|
|
93
95
|
{
|
|
94
96
|
// leverage both y and x
|
|
95
|
-
dy = y / x + x % exp(x) % (1 -
|
|
97
|
+
dy = y / x + x % exp(x) % (1 - square(y / x));
|
|
96
98
|
// the expression above is indeterminate at 0, even though
|
|
97
99
|
// the expression solely in terms of x is defined (= tanh(1))
|
|
98
|
-
dy(
|
|
100
|
+
dy(find(x == 0)).fill(std::tanh(typename InputVecType::elem_type(1)));
|
|
99
101
|
}
|
|
100
102
|
}; // class TanhExpFunction
|
|
101
103
|
|
|
@@ -34,7 +34,8 @@ class TanhFunction
|
|
|
34
34
|
* @param x Input data.
|
|
35
35
|
* @return f(x).
|
|
36
36
|
*/
|
|
37
|
-
|
|
37
|
+
template<typename ElemType>
|
|
38
|
+
static ElemType Fn(const ElemType x)
|
|
38
39
|
{
|
|
39
40
|
return std::tanh(x);
|
|
40
41
|
}
|
|
@@ -48,7 +49,7 @@ class TanhFunction
|
|
|
48
49
|
template<typename InputVecType, typename OutputVecType>
|
|
49
50
|
static void Fn(const InputVecType& x, OutputVecType& y)
|
|
50
51
|
{
|
|
51
|
-
y =
|
|
52
|
+
y = tanh(x);
|
|
52
53
|
}
|
|
53
54
|
|
|
54
55
|
/**
|
|
@@ -58,9 +59,10 @@ class TanhFunction
|
|
|
58
59
|
* @param y Result of Fn(x).
|
|
59
60
|
* @return f'(x)
|
|
60
61
|
*/
|
|
61
|
-
|
|
62
|
+
template<typename ElemType>
|
|
63
|
+
static ElemType Deriv(const ElemType /* x */, const ElemType y)
|
|
62
64
|
{
|
|
63
|
-
return 1 - std::pow(y, 2);
|
|
65
|
+
return 1 - std::pow(y, ElemType(2));
|
|
64
66
|
}
|
|
65
67
|
|
|
66
68
|
/**
|
|
@@ -75,7 +77,7 @@ class TanhFunction
|
|
|
75
77
|
const OutputVecType& y,
|
|
76
78
|
DerivVecType& dy)
|
|
77
79
|
{
|
|
78
|
-
dy = 1 -
|
|
80
|
+
dy = 1 - square(y);
|
|
79
81
|
}
|
|
80
82
|
|
|
81
83
|
/**
|
|
@@ -84,7 +86,8 @@ class TanhFunction
|
|
|
84
86
|
* @param y Input data.
|
|
85
87
|
* @return f^{-1}(x)
|
|
86
88
|
*/
|
|
87
|
-
|
|
89
|
+
template<typename ElemType>
|
|
90
|
+
static ElemType Inv(const ElemType y)
|
|
88
91
|
{
|
|
89
92
|
return std::atanh(y);
|
|
90
93
|
}
|
|
@@ -98,7 +101,7 @@ class TanhFunction
|
|
|
98
101
|
template<typename InputVecType, typename OutputVecType>
|
|
99
102
|
static void Inv(const InputVecType& y, OutputVecType& x)
|
|
100
103
|
{
|
|
101
|
-
x =
|
|
104
|
+
x = atanh(y);
|
|
102
105
|
}
|
|
103
106
|
}; // class TanhFunction
|
|
104
107
|
|