mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -13,14 +13,13 @@
|
|
|
13
13
|
#ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_NAIVE_CONVOLUTION_HPP
|
|
14
14
|
#define MLPACK_METHODS_ANN_CONVOLUTION_RULES_NAIVE_CONVOLUTION_HPP
|
|
15
15
|
|
|
16
|
-
#include
|
|
17
|
-
#include "border_modes.hpp"
|
|
16
|
+
#include "base_convolution.hpp"
|
|
18
17
|
|
|
19
18
|
namespace mlpack {
|
|
20
19
|
|
|
21
20
|
/**
|
|
22
21
|
* Computes the two-dimensional convolution. This class allows specification of
|
|
23
|
-
* the type of the border type. The convolution can be
|
|
22
|
+
* the type of the border type. The convolution can be computed with the valid
|
|
24
23
|
* border type of the full border type (default).
|
|
25
24
|
*
|
|
26
25
|
* FullConvolution: returns the full two-dimensional convolution.
|
|
@@ -31,115 +30,16 @@ namespace mlpack {
|
|
|
31
30
|
* ValidConvolution).
|
|
32
31
|
*/
|
|
33
32
|
template<typename BorderMode = FullConvolution>
|
|
34
|
-
class NaiveConvolution
|
|
33
|
+
class NaiveConvolution : public BaseConvolution<BorderMode>
|
|
35
34
|
{
|
|
36
35
|
public:
|
|
37
36
|
/**
|
|
38
|
-
* Perform a convolution
|
|
39
|
-
*
|
|
40
|
-
*
|
|
41
|
-
*
|
|
42
|
-
*
|
|
43
|
-
*
|
|
44
|
-
* @param dH Stride of filter application in the y direction.
|
|
45
|
-
* @param dilationW The dilation factor in x direction.
|
|
46
|
-
* @param dilationH The dilation factor in y direction.
|
|
47
|
-
* @param appending If true, it will not initialize the output. Instead,
|
|
48
|
-
* it will append the results to the output.
|
|
49
|
-
*/
|
|
50
|
-
template<typename InMatType, typename FilMatType, typename OutMatType,
|
|
51
|
-
typename Border = BorderMode>
|
|
52
|
-
static std::enable_if_t<std::is_same_v<Border, ValidConvolution>, void>
|
|
53
|
-
Convolution(const InMatType& input,
|
|
54
|
-
const FilMatType& filter,
|
|
55
|
-
OutMatType& output,
|
|
56
|
-
const size_t dW = 1,
|
|
57
|
-
const size_t dH = 1,
|
|
58
|
-
const size_t dilationW = 1,
|
|
59
|
-
const size_t dilationH = 1,
|
|
60
|
-
const bool appending = false,
|
|
61
|
-
const typename std::enable_if_t<IsMatrix<InMatType>::value>* = 0)
|
|
62
|
-
{
|
|
63
|
-
using eT = typename InMatType::elem_type;
|
|
64
|
-
// Compute the output size. The filterRows and filterCols computation must
|
|
65
|
-
// take into account the fact that dilation only adds rows or columns
|
|
66
|
-
// *between* filter elements. So, e.g., a dilation of 2 on a kernel size of
|
|
67
|
-
// 3x3 means an effective kernel size of 5x5, *not* 6x6.
|
|
68
|
-
if (!appending)
|
|
69
|
-
{
|
|
70
|
-
const size_t filterRows = filter.n_rows * dilationH - (dilationH - 1);
|
|
71
|
-
const size_t filterCols = filter.n_cols * dilationW - (dilationW - 1);
|
|
72
|
-
const size_t outputRows = (input.n_rows - filterRows + dH) / dH;
|
|
73
|
-
const size_t outputCols = (input.n_cols - filterCols + dW) / dW;
|
|
74
|
-
output.zeros(outputRows, outputCols);
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
// It seems to be about 3.5 times faster to use pointers instead of
|
|
78
|
-
// filter(ki, kj) * input(leftInput + ki, topInput + kj) and output(i, j).
|
|
79
|
-
eT* outputPtr = output.memptr();
|
|
80
|
-
|
|
81
|
-
for (size_t j = 0; j < output.n_cols; ++j)
|
|
82
|
-
{
|
|
83
|
-
for (size_t i = 0; i < output.n_rows; ++i, outputPtr++)
|
|
84
|
-
{
|
|
85
|
-
const eT* kernelPtr = filter.memptr();
|
|
86
|
-
for (size_t kj = 0; kj < filter.n_cols; ++kj)
|
|
87
|
-
{
|
|
88
|
-
const eT* inputPtr = input.colptr(kj * dilationW + j * dW) + i * dH;
|
|
89
|
-
for (size_t ki = 0; ki < filter.n_rows; ++ki, ++kernelPtr,
|
|
90
|
-
inputPtr += dilationH)
|
|
91
|
-
*outputPtr += *kernelPtr * (*inputPtr);
|
|
92
|
-
}
|
|
93
|
-
}
|
|
94
|
-
}
|
|
95
|
-
}
|
|
96
|
-
|
|
97
|
-
/**
|
|
98
|
-
* Perform a convolution (full mode).
|
|
99
|
-
*
|
|
100
|
-
* @param input Input used to perform the convolution.
|
|
101
|
-
* @param filter Filter used to perform the convolution.
|
|
102
|
-
* @param output Output data that contains the results of the convolution.
|
|
103
|
-
* @param dW Stride of filter application in the x direction.
|
|
104
|
-
* @param dH Stride of filter application in the y direction.
|
|
105
|
-
* @param dilationW The dilation factor in x direction.
|
|
106
|
-
* @param dilationH The dilation factor in y direction.
|
|
107
|
-
* @param appending If true, it will not initialize the output. Instead,
|
|
108
|
-
* it will append the results to the output.
|
|
109
|
-
*/
|
|
110
|
-
template<typename InMatType, typename FilMatType, typename OutMatType,
|
|
111
|
-
typename Border = BorderMode>
|
|
112
|
-
static std::enable_if_t<std::is_same_v<Border, FullConvolution>, void>
|
|
113
|
-
Convolution(const InMatType& input,
|
|
114
|
-
const FilMatType& filter,
|
|
115
|
-
OutMatType& output,
|
|
116
|
-
const size_t dW = 1,
|
|
117
|
-
const size_t dH = 1,
|
|
118
|
-
const size_t dilationW = 1,
|
|
119
|
-
const size_t dilationH = 1,
|
|
120
|
-
const bool appending = false,
|
|
121
|
-
const typename std::enable_if_t<IsMatrix<InMatType>::value>* = 0)
|
|
122
|
-
{
|
|
123
|
-
// First, compute the necessary padding for the full convolution. It is
|
|
124
|
-
// possible that this might be an overestimate. Note that these variables
|
|
125
|
-
// only hold the padding on one side of the input.
|
|
126
|
-
const size_t filterRows = filter.n_rows * dilationH - (dilationH - 1);
|
|
127
|
-
const size_t filterCols = filter.n_cols * dilationW - (dilationW - 1);
|
|
128
|
-
const size_t paddingRows = filterRows - 1;
|
|
129
|
-
const size_t paddingCols = filterCols - 1;
|
|
130
|
-
|
|
131
|
-
// Pad filter and input to the working output shape.
|
|
132
|
-
InMatType inputPadded(input.n_rows + 2 * paddingRows,
|
|
133
|
-
input.n_cols + 2 * paddingCols);
|
|
134
|
-
inputPadded.submat(paddingRows, paddingCols, paddingRows + input.n_rows - 1,
|
|
135
|
-
paddingCols + input.n_cols - 1) = input;
|
|
136
|
-
|
|
137
|
-
NaiveConvolution<ValidConvolution>::Convolution(inputPadded, filter,
|
|
138
|
-
output, dW, dH, dilationW, dilationH, appending);
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
/**
|
|
142
|
-
* Perform a convolution using 3rd order tensors.
|
|
37
|
+
* Perform a convolution using 3rd order tensors. Expects that `filter` has
|
|
38
|
+
* `input.n_slices * output.n_slices` slices. The Nth `input.n_slices` filters
|
|
39
|
+
* are applied to all input slices and output to the Nth output slice.
|
|
40
|
+
* eg. 2 input slices: filter 0 applies to input 0, output 0,
|
|
41
|
+
* fil 1 * in 1 = out 0, fil 2 * in 0 = out 1, fil 3 * in 1 = out 1,
|
|
42
|
+
* fil 4 * in 0 = out 2, fil 5 * in 1 = out 2, etc.
|
|
143
43
|
*
|
|
144
44
|
* @param input Input used to perform the convolution.
|
|
145
45
|
* @param filter Filter used to perform the convolution.
|
|
@@ -164,19 +64,26 @@ class NaiveConvolution
|
|
|
164
64
|
const typename std::enable_if_t<IsCube<CubeType>::value>* = 0)
|
|
165
65
|
{
|
|
166
66
|
using MatType = typename GetDenseMatType<CubeType>::type;
|
|
167
|
-
MatType convOutput;
|
|
168
|
-
NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
|
|
169
|
-
convOutput, dW, dH, dilationW, dilationH, appending);
|
|
170
67
|
|
|
171
|
-
|
|
172
|
-
|
|
68
|
+
CubeType inputPadded;
|
|
69
|
+
NaiveConvolution::PadInput(input, filter, inputPadded, dilationW,
|
|
70
|
+
dilationH);
|
|
71
|
+
|
|
72
|
+
const size_t inMaps = input.n_slices;
|
|
73
|
+
const size_t outMaps = filter.n_slices / inMaps;
|
|
173
74
|
|
|
174
|
-
|
|
75
|
+
if (!appending)
|
|
76
|
+
NaiveConvolution::InitalizeOutput(inputPadded, filter, output, dW, dH,
|
|
77
|
+
dilationW, dilationH, outMaps);
|
|
175
78
|
|
|
176
|
-
for (size_t i =
|
|
79
|
+
for (size_t i = 0; i < inMaps; i++)
|
|
177
80
|
{
|
|
178
|
-
|
|
179
|
-
|
|
81
|
+
MatType& inputSlice = inputPadded.slice(i);
|
|
82
|
+
for (size_t j = 0; j < outMaps; j++)
|
|
83
|
+
{
|
|
84
|
+
Conv2(inputSlice, filter.slice(j * inMaps + i), output.slice(j),
|
|
85
|
+
dW, dH, dilationW, dilationH);
|
|
86
|
+
}
|
|
180
87
|
}
|
|
181
88
|
}
|
|
182
89
|
|
|
@@ -207,25 +114,23 @@ class NaiveConvolution
|
|
|
207
114
|
const typename std::enable_if_t<IsMatrix<MatType>::value>* = 0,
|
|
208
115
|
const typename std::enable_if_t<IsCube<CubeType>::value>* = 0)
|
|
209
116
|
{
|
|
210
|
-
MatType
|
|
211
|
-
NaiveConvolution
|
|
212
|
-
|
|
117
|
+
MatType inputPadded;
|
|
118
|
+
NaiveConvolution::PadInput(input, filter, inputPadded, dilationW,
|
|
119
|
+
dilationH);
|
|
213
120
|
|
|
214
121
|
if (!appending)
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
output.slice(0) = convOutput;
|
|
122
|
+
NaiveConvolution::InitalizeOutput(inputPadded, filter, output, dW, dH,
|
|
123
|
+
dilationW, dilationH, filter.n_slices);
|
|
218
124
|
|
|
219
|
-
for (size_t
|
|
125
|
+
for (size_t s = 0; s < filter.n_slices; s++)
|
|
220
126
|
{
|
|
221
|
-
|
|
222
|
-
|
|
127
|
+
Conv2(inputPadded, filter.slice(s), output.slice(s), dW, dH,
|
|
128
|
+
dilationW, dilationH);
|
|
223
129
|
}
|
|
224
130
|
}
|
|
225
|
-
|
|
131
|
+
private:
|
|
226
132
|
/**
|
|
227
|
-
* Perform a convolution
|
|
228
|
-
* dense matrix as filter.
|
|
133
|
+
* Perform a valid convolution.
|
|
229
134
|
*
|
|
230
135
|
* @param input Input used to perform the convolution.
|
|
231
136
|
* @param filter Filter used to perform the convolution.
|
|
@@ -233,38 +138,88 @@ class NaiveConvolution
|
|
|
233
138
|
* @param dW Stride of filter application in the x direction.
|
|
234
139
|
* @param dH Stride of filter application in the y direction.
|
|
235
140
|
* @param dilationW The dilation factor in x direction.
|
|
236
|
-
* @param dilationH The dilation factor in y direction.
|
|
237
|
-
* @param appending If true, it will not initialize the output. Instead,
|
|
238
|
-
* it will append the results to the output.
|
|
141
|
+
* @param dilationH The dilation factor in y direction.
|
|
239
142
|
*/
|
|
240
|
-
template<typename MatType
|
|
241
|
-
static void
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
const bool appending = false,
|
|
250
|
-
const typename std::enable_if_t<IsMatrix<MatType>::value>* = 0,
|
|
251
|
-
const typename std::enable_if_t<IsCube<CubeType>::value>* = 0)
|
|
143
|
+
template<typename MatType>
|
|
144
|
+
static void Conv2(const MatType& input,
|
|
145
|
+
const MatType& filter,
|
|
146
|
+
MatType& output,
|
|
147
|
+
const size_t dW,
|
|
148
|
+
const size_t dH,
|
|
149
|
+
const size_t dilationW,
|
|
150
|
+
const size_t dilationH,
|
|
151
|
+
const std::enable_if_t<IsArma<MatType>::value>* = 0)
|
|
252
152
|
{
|
|
253
|
-
MatType
|
|
254
|
-
NaiveConvolution<BorderMode>::Convolution(input.slice(0), filter,
|
|
255
|
-
convOutput, dW, dH, dilationW, dilationH, appending);
|
|
153
|
+
using eT = typename MatType::elem_type;
|
|
256
154
|
|
|
257
|
-
|
|
258
|
-
|
|
155
|
+
// It seems to be about 3.5 times faster to use pointers instead of
|
|
156
|
+
// filter(ki, kj) * input(leftInput + ki, topInput + kj) and output(i, j).
|
|
157
|
+
eT* outputPtr = output.memptr();
|
|
158
|
+
|
|
159
|
+
for (size_t j = 0; j < output.n_cols; ++j)
|
|
160
|
+
{
|
|
161
|
+
for (size_t i = 0; i < output.n_rows; ++i, outputPtr++)
|
|
162
|
+
{
|
|
163
|
+
const eT* kernelPtr = filter.memptr();
|
|
164
|
+
for (size_t kj = 0; kj < filter.n_cols; ++kj)
|
|
165
|
+
{
|
|
166
|
+
const eT* inputPtr = input.colptr(kj * dilationH + j * dH) + i * dW;
|
|
167
|
+
for (size_t ki = 0; ki < filter.n_rows; ++ki, ++kernelPtr,
|
|
168
|
+
inputPtr += dilationW)
|
|
169
|
+
*outputPtr += *kernelPtr * (*inputPtr);
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
}
|
|
259
174
|
|
|
260
|
-
|
|
175
|
+
#if defined(MLPACK_HAS_COOT)
|
|
261
176
|
|
|
262
|
-
|
|
177
|
+
/**
|
|
178
|
+
* Perform a valid convolution on Bandicoot matrices.
|
|
179
|
+
*
|
|
180
|
+
* @param input Input used to perform the convolution.
|
|
181
|
+
* @param filter Filter used to perform the convolution.
|
|
182
|
+
* @param output Output data that contains the results of the convolution.
|
|
183
|
+
* @param dW Stride of filter application in the x direction.
|
|
184
|
+
* @param dH Stride of filter application in the y direction.
|
|
185
|
+
* @param dilationW The dilation factor in x direction.
|
|
186
|
+
* @param dilationH The dilation factor in y direction.
|
|
187
|
+
*/
|
|
188
|
+
template<typename MatType>
|
|
189
|
+
static void Conv2(const MatType& input,
|
|
190
|
+
const MatType& filter,
|
|
191
|
+
MatType& output,
|
|
192
|
+
const size_t dW,
|
|
193
|
+
const size_t dH,
|
|
194
|
+
const size_t dilationW,
|
|
195
|
+
const size_t dilationH,
|
|
196
|
+
const std::enable_if_t<IsCoot<MatType>::value>* = 0)
|
|
197
|
+
{
|
|
198
|
+
bool useDilation = (dilationW != 1) || (dilationH != 1);
|
|
199
|
+
MatType dilatedFilter;
|
|
200
|
+
const size_t filterRows = filter.n_rows * dilationW - (dilationW - 1);
|
|
201
|
+
const size_t filterCols = filter.n_cols * dilationH - (dilationH - 1);
|
|
202
|
+
if (useDilation)
|
|
263
203
|
{
|
|
264
|
-
|
|
265
|
-
|
|
204
|
+
using UVecType = typename GetURowType<MatType>::type;
|
|
205
|
+
// Dilate the kernel by setting the non-zero rows and columns.
|
|
206
|
+
dilatedFilter.zeros(filterRows, filterCols);
|
|
207
|
+
dilatedFilter.submat(linspace<UVecType>(0, filterRows - 1, filter.n_rows),
|
|
208
|
+
linspace<UVecType>(0, filterCols - 1, filter.n_cols)) = filter;
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
// Apply convolution.
|
|
212
|
+
for (size_t j = 0; j < output.n_cols; ++j)
|
|
213
|
+
{
|
|
214
|
+
for (size_t i = 0; i < output.n_rows; ++i)
|
|
215
|
+
{
|
|
216
|
+
output.at(i, j) = accu((useDilation ? dilatedFilter : filter) %
|
|
217
|
+
input.submat(i * dW, j * dH, i * dW + filterRows - 1,
|
|
218
|
+
j * dH + filterCols - 1));
|
|
219
|
+
}
|
|
266
220
|
}
|
|
267
221
|
}
|
|
222
|
+
#endif // defined(MLPACK_HAS_COOT)
|
|
268
223
|
}; // class NaiveConvolution
|
|
269
224
|
|
|
270
225
|
} // namespace mlpack
|