mlpack 4.6.2__cp39-cp39-win_amd64.whl → 4.7.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +6 -6
- mlpack/adaboost_classify.cp39-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp39-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp39-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp39-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp39-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp39-win_amd64.pyd +0 -0
- mlpack/cf.cp39-win_amd64.pyd +0 -0
- mlpack/dbscan.cp39-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp39-win_amd64.pyd +0 -0
- mlpack/det.cp39-win_amd64.pyd +0 -0
- mlpack/emst.cp39-win_amd64.pyd +0 -0
- mlpack/fastmks.cp39-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp39-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp39-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp39-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp39-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp39-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp39-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp39-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp39-win_amd64.pyd +0 -0
- mlpack/image_converter.cp39-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp39-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp39-win_amd64.pyd +0 -0
- mlpack/kfn.cp39-win_amd64.pyd +0 -0
- mlpack/kmeans.cp39-win_amd64.pyd +0 -0
- mlpack/knn.cp39-win_amd64.pyd +0 -0
- mlpack/krann.cp39-win_amd64.pyd +0 -0
- mlpack/lars.cp39-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp39-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp39-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp39-win_amd64.pyd +0 -0
- mlpack/lmnn.cp39-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp39-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp39-win_amd64.pyd +0 -0
- mlpack/lsh.cp39-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp39-win_amd64.pyd +0 -0
- mlpack/nbc.cp39-win_amd64.pyd +0 -0
- mlpack/nca.cp39-win_amd64.pyd +0 -0
- mlpack/nmf.cp39-win_amd64.pyd +0 -0
- mlpack/pca.cp39-win_amd64.pyd +0 -0
- mlpack/perceptron.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp39-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp39-win_amd64.pyd +0 -0
- mlpack/radical.cp39-win_amd64.pyd +0 -0
- mlpack/random_forest.cp39-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp39-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp39-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +397 -378
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack.libs/.load-order-mlpack-4.7.0 +2 -0
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- mlpack.libs/.load-order-mlpack-4.6.2 +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file methods/ann/convolution_rules/base_convolution.hpp
|
|
3
|
+
* @author Zachary Ng
|
|
4
|
+
*
|
|
5
|
+
* Base class for convolution rules.
|
|
6
|
+
*
|
|
7
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
8
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
9
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
10
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
11
|
+
*/
|
|
12
|
+
#ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_BASE_CONVOLUTION_HPP
|
|
13
|
+
#define MLPACK_METHODS_ANN_CONVOLUTION_RULES_BASE_CONVOLUTION_HPP
|
|
14
|
+
|
|
15
|
+
#include <mlpack/prereqs.hpp>
|
|
16
|
+
#include <mlpack/core/util/using.hpp>
|
|
17
|
+
#include "border_modes.hpp"
|
|
18
|
+
|
|
19
|
+
namespace mlpack {
|
|
20
|
+
|
|
21
|
+
/**
|
|
22
|
+
* This is an abstract class that contains common functions for convolution.
|
|
23
|
+
* This class allows specification of the type of the border type. The
|
|
24
|
+
* convolution can be computed with the valid border type of the full border
|
|
25
|
+
* type (default).
|
|
26
|
+
*
|
|
27
|
+
* FullConvolution: returns the full two-dimensional convolution.
|
|
28
|
+
* ValidConvolution: returns only those parts of the convolution that are
|
|
29
|
+
* computed without the zero-padded edges.
|
|
30
|
+
*
|
|
31
|
+
* @tparam BorderMode Type of the border mode (FullConvolution or
|
|
32
|
+
* ValidConvolution).
|
|
33
|
+
*/
|
|
34
|
+
template<typename BorderMode = FullConvolution>
|
|
35
|
+
class BaseConvolution
|
|
36
|
+
{
|
|
37
|
+
protected:
|
|
38
|
+
/**
|
|
39
|
+
* Apply padding to an input matrix.
|
|
40
|
+
*
|
|
41
|
+
* @param input Input used to perform the convolution.
|
|
42
|
+
* @param filter Filter used to perform the convolution.
|
|
43
|
+
* @param inputPadded Input with padding applied.
|
|
44
|
+
*/
|
|
45
|
+
template<typename InMatType, typename FilType, typename Border = BorderMode>
|
|
46
|
+
static void
|
|
47
|
+
PadInput(const InMatType& input,
|
|
48
|
+
const FilType& filter,
|
|
49
|
+
InMatType& inputPadded,
|
|
50
|
+
const size_t dilationW,
|
|
51
|
+
const size_t dilationH,
|
|
52
|
+
const typename std::enable_if_t<IsMatrix<InMatType>::value>* = 0)
|
|
53
|
+
{
|
|
54
|
+
if constexpr (std::is_same_v<Border, ValidConvolution>)
|
|
55
|
+
{
|
|
56
|
+
// Use valid padding (none).
|
|
57
|
+
MakeAlias(inputPadded, input, input.n_rows, input.n_cols);
|
|
58
|
+
}
|
|
59
|
+
else
|
|
60
|
+
{
|
|
61
|
+
// Use full padding
|
|
62
|
+
|
|
63
|
+
// First, compute the necessary padding for the full convolution. It is
|
|
64
|
+
// possible that this might be an overestimate. Note that these variables
|
|
65
|
+
// only hold the padding on one side of the input.
|
|
66
|
+
const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
|
|
67
|
+
const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
|
|
68
|
+
const size_t paddingRows = filterRows - 1;
|
|
69
|
+
const size_t paddingCols = filterCols - 1;
|
|
70
|
+
|
|
71
|
+
// Pad filter and input to the working output shape.
|
|
72
|
+
inputPadded = InMatType(input.n_rows + 2 * paddingRows,
|
|
73
|
+
input.n_cols + 2 * paddingCols);
|
|
74
|
+
inputPadded.submat(paddingRows, paddingCols,
|
|
75
|
+
paddingRows + input.n_rows - 1,
|
|
76
|
+
paddingCols + input.n_cols - 1) = input;
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
/**
|
|
81
|
+
* Apply padding to an input cube.
|
|
82
|
+
*
|
|
83
|
+
* @param input Input used to perform the convolution.
|
|
84
|
+
* @param filter Filter used to perform the convolution.
|
|
85
|
+
* @param inputPadded Input with padding applied.
|
|
86
|
+
*/
|
|
87
|
+
template<typename InCubeType, typename FilType, typename Border = BorderMode>
|
|
88
|
+
static void
|
|
89
|
+
PadInput(const InCubeType& input,
|
|
90
|
+
const FilType& filter,
|
|
91
|
+
InCubeType& inputPadded,
|
|
92
|
+
const size_t dilationW,
|
|
93
|
+
const size_t dilationH,
|
|
94
|
+
const typename std::enable_if_t<IsCube<InCubeType>::value>* = 0)
|
|
95
|
+
{
|
|
96
|
+
if constexpr (std::is_same_v<Border, ValidConvolution>)
|
|
97
|
+
{
|
|
98
|
+
// Use valid padding (none).
|
|
99
|
+
MakeAlias(inputPadded, input, input.n_rows, input.n_cols, input.n_slices);
|
|
100
|
+
}
|
|
101
|
+
else
|
|
102
|
+
{
|
|
103
|
+
// Use full padding
|
|
104
|
+
|
|
105
|
+
// First, compute the necessary padding for the full convolution. It is
|
|
106
|
+
// possible that this might be an overestimate. Note that these variables
|
|
107
|
+
// only hold the padding on one side of the input.
|
|
108
|
+
const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
|
|
109
|
+
const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
|
|
110
|
+
const size_t paddingRows = filterRows - 1;
|
|
111
|
+
const size_t paddingCols = filterCols - 1;
|
|
112
|
+
|
|
113
|
+
// Pad filter and input to the working output shape.
|
|
114
|
+
inputPadded = InCubeType(input.n_rows + 2 * paddingRows,
|
|
115
|
+
input.n_cols + 2 * paddingCols, input.n_slices);
|
|
116
|
+
inputPadded.subcube(paddingRows, paddingCols, 0,
|
|
117
|
+
paddingRows + input.n_rows - 1, paddingCols + input.n_cols - 1,
|
|
118
|
+
input.n_slices - 1) = input;
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
/**
|
|
123
|
+
* Initalize the output to the required size.
|
|
124
|
+
*
|
|
125
|
+
* @param inputPadded Input with padding applied.
|
|
126
|
+
* @param filter Filter used to perform the convolution.
|
|
127
|
+
* @param output Output data that contains the results of the convolution.
|
|
128
|
+
* @param dW Stride of filter application in the x direction.
|
|
129
|
+
* @param dH Stride of filter application in the y direction.
|
|
130
|
+
* @param dilationW The dilation factor in x direction.
|
|
131
|
+
* @param dilationH The dilation factor in y direction.
|
|
132
|
+
* @param outSlices The number of slices in the output cube.
|
|
133
|
+
*/
|
|
134
|
+
template<typename InMatType, typename FilType, typename OutMatType>
|
|
135
|
+
static void
|
|
136
|
+
InitalizeOutput(const InMatType& inputPadded,
|
|
137
|
+
const FilType& filter,
|
|
138
|
+
OutMatType& output,
|
|
139
|
+
const size_t dW = 1,
|
|
140
|
+
const size_t dH = 1,
|
|
141
|
+
const size_t dilationW = 1,
|
|
142
|
+
const size_t dilationH = 1,
|
|
143
|
+
const size_t /* outSlices */ = 1,
|
|
144
|
+
const typename std::enable_if_t<
|
|
145
|
+
IsMatrix<OutMatType>::value>* = 0)
|
|
146
|
+
{
|
|
147
|
+
// Compute the output size. The filterRows and filterCols computation must
|
|
148
|
+
// take into account the fact that dilation only adds rows or columns
|
|
149
|
+
// *between* filter elements. So, e.g., a dilation of 2 on a kernel size of
|
|
150
|
+
// 3x3 means an effective kernel size of 5x5, *not* 6x6.
|
|
151
|
+
const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
|
|
152
|
+
const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
|
|
153
|
+
const size_t outputRows = (inputPadded.n_rows - filterRows + dW) / dW;
|
|
154
|
+
const size_t outputCols = (inputPadded.n_cols - filterCols + dH) / dH;
|
|
155
|
+
output.zeros(outputRows, outputCols);
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
/**
|
|
159
|
+
* Initalize the output to the required size.
|
|
160
|
+
*
|
|
161
|
+
* @param inputPadded Input with padding applied.
|
|
162
|
+
* @param filter Filter used to perform the convolution.
|
|
163
|
+
* @param output Output data that contains the results of the convolution.
|
|
164
|
+
* @param dW Stride of filter application in the x direction.
|
|
165
|
+
* @param dH Stride of filter application in the y direction.
|
|
166
|
+
* @param dilationW The dilation factor in x direction.
|
|
167
|
+
* @param dilationH The dilation factor in y direction.
|
|
168
|
+
* @param outSlices The number of slices in the output cube.
|
|
169
|
+
*/
|
|
170
|
+
template<typename InMatType, typename FilType, typename OutCubeType>
|
|
171
|
+
static void
|
|
172
|
+
InitalizeOutput(const InMatType& inputPadded,
|
|
173
|
+
const FilType& filter,
|
|
174
|
+
OutCubeType& output,
|
|
175
|
+
const size_t dW = 1,
|
|
176
|
+
const size_t dH = 1,
|
|
177
|
+
const size_t dilationW = 1,
|
|
178
|
+
const size_t dilationH = 1,
|
|
179
|
+
const size_t outSlices = 1,
|
|
180
|
+
const typename std::enable_if_t<
|
|
181
|
+
IsCube<OutCubeType>::value>* = 0)
|
|
182
|
+
{
|
|
183
|
+
// Compute the output size. The filterRows and filterCols computation must
|
|
184
|
+
// take into account the fact that dilation only adds rows or columns
|
|
185
|
+
// *between* filter elements. So, e.g., a dilation of 2 on a kernel size of
|
|
186
|
+
// 3x3 means an effective kernel size of 5x5, *not* 6x6.
|
|
187
|
+
const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
|
|
188
|
+
const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
|
|
189
|
+
const size_t outputRows = (inputPadded.n_rows - filterRows + dW) / dW;
|
|
190
|
+
const size_t outputCols = (inputPadded.n_cols - filterCols + dH) / dH;
|
|
191
|
+
output.zeros(outputRows, outputCols, outSlices);
|
|
192
|
+
}
|
|
193
|
+
}; // class BaseConvolution
|
|
194
|
+
|
|
195
|
+
} // namespace mlpack
|
|
196
|
+
|
|
197
|
+
#endif
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file methods/ann/convolution_rules/im2col_convolution.hpp
|
|
3
|
+
* @author Zachary Ng
|
|
4
|
+
*
|
|
5
|
+
* Implementation of the im2col convolution. This is actually im2row because we
|
|
6
|
+
* use column major order.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_IM2COL_CONVOLUTION_HPP
|
|
14
|
+
#define MLPACK_METHODS_ANN_CONVOLUTION_RULES_IM2COL_CONVOLUTION_HPP
|
|
15
|
+
|
|
16
|
+
#include "base_convolution.hpp"
|
|
17
|
+
|
|
18
|
+
namespace mlpack {
|
|
19
|
+
|
|
20
|
+
/**
|
|
21
|
+
* Computes the two-dimensional convolution. This class allows specification of
|
|
22
|
+
* the type of the border type. The convolution can be computed with the valid
|
|
23
|
+
* border type of the full border type (default).
|
|
24
|
+
*
|
|
25
|
+
* FullConvolution: returns the full two-dimensional convolution.
|
|
26
|
+
* ValidConvolution: returns only those parts of the convolution that are
|
|
27
|
+
* computed without the zero-padded edges.
|
|
28
|
+
*
|
|
29
|
+
* @tparam BorderMode Type of the border mode (FullConvolution or
|
|
30
|
+
* ValidConvolution).
|
|
31
|
+
*/
|
|
32
|
+
template<typename BorderMode = FullConvolution>
|
|
33
|
+
class Im2ColConvolution : public BaseConvolution<BorderMode>
|
|
34
|
+
{
|
|
35
|
+
public:
|
|
36
|
+
/**
|
|
37
|
+
* Perform a convolution using 3rd order tensors. Expects that `filter` has
|
|
38
|
+
* `input.n_slices * output.n_slices` slices. The Nth `input.n_slices` filters
|
|
39
|
+
* are applied to all input slices and output to the Nth output slice.
|
|
40
|
+
* eg. 2 input slices: filter 0 applies to input 0, output 0,
|
|
41
|
+
* fil 1 * in 1 = out 0, fil 2 * in 0 = out 1, fil 3 * in 1 = out 1,
|
|
42
|
+
* fil 4 * in 0 = out 2, fil 5 * in 1 = out 2, etc.
|
|
43
|
+
*
|
|
44
|
+
* @param input Input used to perform the convolution.
|
|
45
|
+
* @param filter Filter used to perform the convolution.
|
|
46
|
+
* @param output Output data that contains the results of the convolution.
|
|
47
|
+
* @param dW Stride of filter application in the x direction.
|
|
48
|
+
* @param dH Stride of filter application in the y direction.
|
|
49
|
+
* @param dilationW The dilation factor in x direction.
|
|
50
|
+
* @param dilationH The dilation factor in y direction.
|
|
51
|
+
* @param appending If true, it will not initialize the output. Instead,
|
|
52
|
+
* it will append the results to the output.
|
|
53
|
+
*/
|
|
54
|
+
template<typename CubeType>
|
|
55
|
+
static void Convolution(
|
|
56
|
+
const CubeType& input,
|
|
57
|
+
const CubeType& filter,
|
|
58
|
+
CubeType& output,
|
|
59
|
+
const size_t dW = 1,
|
|
60
|
+
const size_t dH = 1,
|
|
61
|
+
const size_t dilationW = 1,
|
|
62
|
+
const size_t dilationH = 1,
|
|
63
|
+
const bool appending = false,
|
|
64
|
+
const typename std::enable_if_t<IsCube<CubeType>::value>* = 0)
|
|
65
|
+
{
|
|
66
|
+
using MatType = typename GetDenseMatType<CubeType>::type;
|
|
67
|
+
|
|
68
|
+
CubeType inputPadded;
|
|
69
|
+
Im2ColConvolution::PadInput(input, filter, inputPadded, dilationW,
|
|
70
|
+
dilationH);
|
|
71
|
+
|
|
72
|
+
const size_t inMaps = input.n_slices;
|
|
73
|
+
const size_t outMaps = filter.n_slices / inMaps;
|
|
74
|
+
|
|
75
|
+
if (!appending)
|
|
76
|
+
Im2ColConvolution::InitalizeOutput(inputPadded, filter, output, dW, dH,
|
|
77
|
+
dilationW, dilationH, outMaps);
|
|
78
|
+
|
|
79
|
+
// `im2row` is held transposed.
|
|
80
|
+
MatType im2row(filter.n_rows * filter.n_cols * input.n_slices,
|
|
81
|
+
output.n_rows * output.n_cols, GetFillType<MatType>::none);
|
|
82
|
+
// Arrange im2row so that each row has patches from each input map.
|
|
83
|
+
for (size_t i = 0; i < input.n_slices; ++i)
|
|
84
|
+
{
|
|
85
|
+
Im2Row(inputPadded.slice(i), im2row, filter.n_rows, filter.n_cols,
|
|
86
|
+
filter.n_rows * filter.n_cols * i, dW, dH, dilationW, dilationH);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
// The filters already have the correct order in memory, just reshape it.
|
|
90
|
+
MatType fil2col;
|
|
91
|
+
MakeAlias(fil2col, filter, filter.n_rows * filter.n_cols * inMaps,
|
|
92
|
+
outMaps);
|
|
93
|
+
|
|
94
|
+
// The output is also already in the correct order.
|
|
95
|
+
MatType tempOutput;
|
|
96
|
+
MakeAlias(tempOutput, output, output.n_rows * output.n_cols, outMaps);
|
|
97
|
+
|
|
98
|
+
tempOutput += trans(im2row) * fil2col;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
/**
|
|
102
|
+
* Perform a convolution using dense matrix as input and a 3rd order tensors
|
|
103
|
+
* as filter and output.
|
|
104
|
+
*
|
|
105
|
+
* @param input Input used to perform the convolution.
|
|
106
|
+
* @param filter Filter used to perform the convolution.
|
|
107
|
+
* @param output Output data that contains the results of the convolution.
|
|
108
|
+
* @param dW Stride of filter application in the x direction.
|
|
109
|
+
* @param dH Stride of filter application in the y direction.
|
|
110
|
+
* @param dilationW The dilation factor in x direction.
|
|
111
|
+
* @param dilationH The dilation factor in y direction.
|
|
112
|
+
* @param appending If true, it will not initialize the output. Instead,
|
|
113
|
+
* it will append the results to the output.
|
|
114
|
+
*/
|
|
115
|
+
template<typename MatType, typename CubeType>
|
|
116
|
+
static void Convolution(
|
|
117
|
+
const MatType& input,
|
|
118
|
+
const CubeType& filter,
|
|
119
|
+
CubeType& output,
|
|
120
|
+
const size_t dW = 1,
|
|
121
|
+
const size_t dH = 1,
|
|
122
|
+
const size_t dilationW = 1,
|
|
123
|
+
const size_t dilationH = 1,
|
|
124
|
+
const bool appending = false,
|
|
125
|
+
const typename std::enable_if_t<IsMatrix<MatType>::value>* = 0,
|
|
126
|
+
const typename std::enable_if_t<IsCube<CubeType>::value>* = 0)
|
|
127
|
+
{
|
|
128
|
+
MatType inputPadded;
|
|
129
|
+
Im2ColConvolution::PadInput(input, filter, inputPadded, dilationW,
|
|
130
|
+
dilationH);
|
|
131
|
+
|
|
132
|
+
if (!appending)
|
|
133
|
+
Im2ColConvolution::InitalizeOutput(inputPadded, filter, output, dW, dH,
|
|
134
|
+
dilationW, dilationH, filter.n_slices);
|
|
135
|
+
|
|
136
|
+
// `im2row` is held transposed.
|
|
137
|
+
MatType im2row(filter.n_rows * filter.n_cols, output.n_rows * output.n_cols,
|
|
138
|
+
GetFillType<MatType>::none);
|
|
139
|
+
Im2Row(inputPadded, im2row, filter.n_rows, filter.n_cols, 0, dW, dH,
|
|
140
|
+
dilationW, dilationH);
|
|
141
|
+
|
|
142
|
+
// The filters already have the correct order in memory, just reshape it.
|
|
143
|
+
MatType fil2col;
|
|
144
|
+
MakeAlias(fil2col, filter, filter.n_rows * filter.n_cols, filter.n_slices);
|
|
145
|
+
|
|
146
|
+
// The output is also already in the correct order.
|
|
147
|
+
MatType tempOutput;
|
|
148
|
+
MakeAlias(tempOutput, output, output.n_rows * output.n_cols,
|
|
149
|
+
filter.n_slices);
|
|
150
|
+
|
|
151
|
+
tempOutput += trans(im2row) * fil2col;
|
|
152
|
+
}
|
|
153
|
+
private:
|
|
154
|
+
/**
|
|
155
|
+
* Take an input and convert each patch into columns (held transposed).
|
|
156
|
+
* This function expects that `im2row` has the expected dimensions.
|
|
157
|
+
*
|
|
158
|
+
* @param input Input used to perform the convolution.
|
|
159
|
+
* @param im2row Patches of the input as rows.
|
|
160
|
+
* @param filterRows Number of rows in a filter.
|
|
161
|
+
* @param filterCols Number of columns in a filter.
|
|
162
|
+
* @param startRow The starting row for the input image.
|
|
163
|
+
* @param dW Stride of filter application in the x direction.
|
|
164
|
+
* @param dH Stride of filter application in the y direction.
|
|
165
|
+
* @param dilationW The dilation factor in x direction.
|
|
166
|
+
* @param dilationH The dilation factor in y direction.
|
|
167
|
+
*/
|
|
168
|
+
template<typename MatType>
|
|
169
|
+
static void Im2Row(const MatType& input,
|
|
170
|
+
MatType& im2row,
|
|
171
|
+
const size_t filterRows,
|
|
172
|
+
const size_t filterCols,
|
|
173
|
+
const size_t startRow = 0,
|
|
174
|
+
const size_t dW = 1,
|
|
175
|
+
const size_t dH = 1,
|
|
176
|
+
const size_t dilationW = 1,
|
|
177
|
+
const size_t dilationH = 1)
|
|
178
|
+
{
|
|
179
|
+
using UVecType = typename GetURowType<MatType>::type;
|
|
180
|
+
|
|
181
|
+
const size_t dFilterRows = filterRows * dilationW - (dilationW - 1);
|
|
182
|
+
const size_t dFilterCols = filterCols * dilationH - (dilationH - 1);
|
|
183
|
+
const size_t outputRows = (input.n_rows - dFilterRows + dW) / dW;
|
|
184
|
+
const size_t outputCols = (input.n_cols - dFilterCols + dH) / dH;
|
|
185
|
+
const bool useDilation = (dilationW != 1) || (dilationH != 1);
|
|
186
|
+
|
|
187
|
+
size_t outCol = 0;
|
|
188
|
+
const size_t filterElems = filterRows * filterCols;
|
|
189
|
+
MatType colAlias;
|
|
190
|
+
for (size_t j = 0; j < outputCols; j++)
|
|
191
|
+
{
|
|
192
|
+
size_t inCol = j * dH;
|
|
193
|
+
for (size_t i = 0; i < outputRows; i++)
|
|
194
|
+
{
|
|
195
|
+
size_t inRow = i * dW;
|
|
196
|
+
// Use an alias instead of `.col()` to avoid the creation of a
|
|
197
|
+
// temporary subview object.
|
|
198
|
+
MakeAlias(colAlias, im2row, filterElems, 1, outCol * im2row.n_rows +
|
|
199
|
+
startRow);
|
|
200
|
+
if (useDilation)
|
|
201
|
+
colAlias = vectorise(input.submat(linspace<UVecType>(inRow,
|
|
202
|
+
inRow + dFilterRows - 1, filterRows),
|
|
203
|
+
linspace<UVecType>(inCol, inCol + dFilterCols - 1, filterCols)));
|
|
204
|
+
else
|
|
205
|
+
colAlias = vectorise(input.submat(inRow, inCol,
|
|
206
|
+
inRow + filterRows - 1, inCol + filterCols - 1));
|
|
207
|
+
outCol++;
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
}; // class Im2ColConvolution
|
|
212
|
+
|
|
213
|
+
} // namespace mlpack
|
|
214
|
+
|
|
215
|
+
#endif
|