mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -37,7 +37,11 @@ template<
|
|
|
37
37
|
class RNN
|
|
38
38
|
{
|
|
39
39
|
public:
|
|
40
|
+
// Convenience typedefs.
|
|
41
|
+
using ElemType = typename MatType::elem_type;
|
|
40
42
|
using CubeType = typename GetCubeType<MatType>::type;
|
|
43
|
+
using URowType = typename GetURowType<MatType>::type;
|
|
44
|
+
|
|
41
45
|
/**
|
|
42
46
|
* Create the RNN object.
|
|
43
47
|
*
|
|
@@ -76,18 +80,52 @@ class RNN
|
|
|
76
80
|
/**
|
|
77
81
|
* Add a new module to the model.
|
|
78
82
|
*
|
|
79
|
-
* @param args The layer
|
|
83
|
+
* @param args The parameters to pass to the constructor of the layer.
|
|
84
|
+
*/
|
|
85
|
+
template<typename LayerType, typename... Args>
|
|
86
|
+
void Add(Args&&... args)
|
|
87
|
+
{
|
|
88
|
+
network.template Add<LayerType>(std::forward<Args>(args)...);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* Add a new layer to the model, without specifying the matrix type of the
|
|
93
|
+
* layer as a template parameter.
|
|
94
|
+
*
|
|
95
|
+
* @param args The parameters to pass to the constructor of the layer.
|
|
80
96
|
*/
|
|
81
|
-
template
|
|
82
|
-
|
|
97
|
+
template<template<typename...> typename LayerType,
|
|
98
|
+
typename... Args>
|
|
99
|
+
void Add(Args&&... args)
|
|
100
|
+
{
|
|
101
|
+
network.template Add<LayerType<MatType>>(std::forward<Args>(args)...);
|
|
102
|
+
}
|
|
83
103
|
|
|
84
104
|
/**
|
|
85
105
|
* Add a new module to the model.
|
|
86
106
|
*
|
|
87
107
|
* @param layer The Layer to be added to the model.
|
|
88
108
|
*/
|
|
109
|
+
[[deprecated("Will be removed in mlpack 5.0.0. Use Add(std::move(layer)).")]]
|
|
89
110
|
void Add(Layer<MatType>* layer) { network.Add(layer); }
|
|
90
111
|
|
|
112
|
+
/**
|
|
113
|
+
* Add a new layer to the model by copying/moving the parameters of the given
|
|
114
|
+
* layer. Note that any trainable weights of this layer will be reset!
|
|
115
|
+
* (Constant parameters are kept.) Preferably, pass the layer with
|
|
116
|
+
* std::move().
|
|
117
|
+
*
|
|
118
|
+
* @param layer The layer to be added to the model.
|
|
119
|
+
*/
|
|
120
|
+
template<typename LayerType>
|
|
121
|
+
void Add(LayerType&& layer,
|
|
122
|
+
// This SFINAE can be removed in mlpack 5.0.0.
|
|
123
|
+
typename std::enable_if<!std::is_pointer_v<
|
|
124
|
+
std::remove_reference_t<LayerType>>>::type* = 0)
|
|
125
|
+
{
|
|
126
|
+
network.Add(std::forward<LayerType>(layer));
|
|
127
|
+
}
|
|
128
|
+
|
|
91
129
|
//! Get the network model.
|
|
92
130
|
const std::vector<Layer<MatType>*>& Network() const
|
|
93
131
|
{
|
|
@@ -128,11 +166,10 @@ class RNN
|
|
|
128
166
|
* @return The final objective of the trained model (NaN or Inf on error).
|
|
129
167
|
*/
|
|
130
168
|
template<typename OptimizerType, typename... CallbackTypes>
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
CallbackTypes&&... callbacks);
|
|
169
|
+
ElemType Train(CubeType predictors,
|
|
170
|
+
CubeType responses,
|
|
171
|
+
OptimizerType& optimizer,
|
|
172
|
+
CallbackTypes&&... callbacks);
|
|
136
173
|
|
|
137
174
|
/**
|
|
138
175
|
* Train the recurrent network on the given input data. By default, the
|
|
@@ -156,10 +193,9 @@ class RNN
|
|
|
156
193
|
* @return The final objective of the trained model (NaN or Inf on error).
|
|
157
194
|
*/
|
|
158
195
|
template<typename OptimizerType = ens::RMSProp, typename... CallbackTypes>
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
CallbackTypes&&... callbacks);
|
|
196
|
+
ElemType Train(CubeType predictors,
|
|
197
|
+
CubeType responses,
|
|
198
|
+
CallbackTypes&&... callbacks);
|
|
163
199
|
|
|
164
200
|
/**
|
|
165
201
|
* Train the recurrent network on the given input data using the given
|
|
@@ -186,12 +222,11 @@ class RNN
|
|
|
186
222
|
* @return The final objective of the trained model (NaN or Inf on error).
|
|
187
223
|
*/
|
|
188
224
|
template<typename OptimizerType, typename... CallbackTypes>
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
CallbackTypes&&... callbacks);
|
|
225
|
+
ElemType Train(CubeType predictors,
|
|
226
|
+
CubeType responses,
|
|
227
|
+
URowType sequenceLengths,
|
|
228
|
+
OptimizerType& optimizer,
|
|
229
|
+
CallbackTypes&&... callbacks);
|
|
195
230
|
|
|
196
231
|
/**
|
|
197
232
|
* Train the recurrent network on the given input data, given that each input
|
|
@@ -222,11 +257,10 @@ class RNN
|
|
|
222
257
|
* @return The final objective of the trained model (NaN or Inf on error).
|
|
223
258
|
*/
|
|
224
259
|
template<typename OptimizerType = ens::RMSProp, typename... CallbackTypes>
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
CallbackTypes&&... callbacks);
|
|
260
|
+
ElemType Train(CubeType predictors,
|
|
261
|
+
CubeType responses,
|
|
262
|
+
URowType sequenceLengths,
|
|
263
|
+
CallbackTypes&&... callbacks);
|
|
230
264
|
|
|
231
265
|
/**
|
|
232
266
|
* Predict the responses to a given set of predictors. The responses will
|
|
@@ -257,7 +291,7 @@ class RNN
|
|
|
257
291
|
*/
|
|
258
292
|
void Predict(const CubeType& predictors,
|
|
259
293
|
CubeType& results,
|
|
260
|
-
const
|
|
294
|
+
const URowType& sequenceLengths);
|
|
261
295
|
|
|
262
296
|
// Return the nujmber of weights in the model.
|
|
263
297
|
size_t WeightSize() { return network.WeightSize(); }
|
|
@@ -318,11 +352,26 @@ class RNN
|
|
|
318
352
|
* @param predictors Input variables.
|
|
319
353
|
* @param responses Target outputs for input variables.
|
|
320
354
|
*/
|
|
321
|
-
|
|
322
|
-
const CubeType& predictors,
|
|
323
|
-
const CubeType& responses);
|
|
355
|
+
ElemType Evaluate(const CubeType& predictors, const CubeType& responses);
|
|
324
356
|
|
|
325
|
-
|
|
357
|
+
/**
|
|
358
|
+
* Evaluate the recurrent network with the given predictors and responses.
|
|
359
|
+
* This functions is usually used to monitor progress while training.
|
|
360
|
+
*
|
|
361
|
+
* @param predictors Input variables.
|
|
362
|
+
* @param responses Target outputs for input variables.
|
|
363
|
+
* @param sequenceLengths Length of each input sequences. Should have size
|
|
364
|
+
* `predictors.n_cols`, and all values should be less than or equal to
|
|
365
|
+
* `predictors.n_slices`.
|
|
366
|
+
* @param batchSize Number of points to be passed at a time to use for
|
|
367
|
+
* objective function evaluation.
|
|
368
|
+
*/
|
|
369
|
+
ElemType Evaluate(const CubeType& predictors,
|
|
370
|
+
const CubeType& responses,
|
|
371
|
+
const URowType& sequenceLengths,
|
|
372
|
+
const size_t batchSize);
|
|
373
|
+
|
|
374
|
+
// Serialize the model.
|
|
326
375
|
template<typename Archive>
|
|
327
376
|
void serialize(Archive& ar, const uint32_t /* version */);
|
|
328
377
|
|
|
@@ -337,7 +386,7 @@ class RNN
|
|
|
337
386
|
*
|
|
338
387
|
* @param parameters Matrix model parameters.
|
|
339
388
|
*/
|
|
340
|
-
|
|
389
|
+
ElemType Evaluate(const MatType& parameters);
|
|
341
390
|
|
|
342
391
|
/**
|
|
343
392
|
* Evaluate the recurrent network with the given parameters, but using only
|
|
@@ -353,9 +402,9 @@ class RNN
|
|
|
353
402
|
* @param batchSize Number of points to be passed at a time to use for
|
|
354
403
|
* objective function evaluation.
|
|
355
404
|
*/
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
405
|
+
ElemType Evaluate(const MatType& parameters,
|
|
406
|
+
const size_t begin,
|
|
407
|
+
const size_t batchSize);
|
|
359
408
|
|
|
360
409
|
/**
|
|
361
410
|
* Evaluate the recurrent network with the given parameters.
|
|
@@ -366,8 +415,8 @@ class RNN
|
|
|
366
415
|
* @param gradient Matrix to output gradient into.
|
|
367
416
|
*/
|
|
368
417
|
template<typename GradType>
|
|
369
|
-
|
|
370
|
-
|
|
418
|
+
ElemType EvaluateWithGradient(const MatType& parameters,
|
|
419
|
+
GradType& gradient);
|
|
371
420
|
|
|
372
421
|
/**
|
|
373
422
|
* Evaluate the recurrent network with the given parameters, but using only
|
|
@@ -382,10 +431,10 @@ class RNN
|
|
|
382
431
|
* objective function evaluation.
|
|
383
432
|
*/
|
|
384
433
|
template<typename GradType>
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
434
|
+
ElemType EvaluateWithGradient(const MatType& parameters,
|
|
435
|
+
const size_t begin,
|
|
436
|
+
GradType& gradient,
|
|
437
|
+
const size_t batchSize);
|
|
389
438
|
|
|
390
439
|
/**
|
|
391
440
|
* Evaluate the gradient of the recurrent network with the given parameters,
|
|
@@ -428,7 +477,7 @@ class RNN
|
|
|
428
477
|
*/
|
|
429
478
|
void ResetData(CubeType predictors,
|
|
430
479
|
CubeType responses,
|
|
431
|
-
|
|
480
|
+
URowType sequenceLengths = URowType());
|
|
432
481
|
|
|
433
482
|
private:
|
|
434
483
|
// Helper functions.
|
|
@@ -441,7 +490,52 @@ class RNN
|
|
|
441
490
|
void ResetMemoryState(const size_t memorySize, const size_t batchSize);
|
|
442
491
|
|
|
443
492
|
//! Set the current step index of all recurrent layers to `step`.
|
|
444
|
-
void SetCurrentStep(const size_t step,
|
|
493
|
+
void SetCurrentStep(const size_t step,
|
|
494
|
+
const bool end,
|
|
495
|
+
size_t batchSize,
|
|
496
|
+
size_t activeBatchSize,
|
|
497
|
+
bool backwards = false);
|
|
498
|
+
|
|
499
|
+
// Reorders the data in a batch to have sequence lengths in descending order.
|
|
500
|
+
void ReorderBatch(const size_t begin,
|
|
501
|
+
const size_t batchSize,
|
|
502
|
+
CubeType& predictors,
|
|
503
|
+
CubeType& responses,
|
|
504
|
+
URowType& sequenceLengths);
|
|
505
|
+
|
|
506
|
+
// Calculates the number of active points in the batch.
|
|
507
|
+
void CalculateActivePoints(size_t& activeBatchSize,
|
|
508
|
+
const size_t begin,
|
|
509
|
+
URowType& sequenceLengths,
|
|
510
|
+
const size_t step,
|
|
511
|
+
const std::enable_if_t<
|
|
512
|
+
IsArma<URowType>::value>* = 0)
|
|
513
|
+
{
|
|
514
|
+
// Since we know that `sequenceLengths` is sorted in order of descending
|
|
515
|
+
// lengths and `activeBatchSize` only decreases as `step` increases, we
|
|
516
|
+
// can just decrease `activeBatchSize` until we find the sequence length
|
|
517
|
+
// that is greater than the current step.
|
|
518
|
+
while (activeBatchSize > 0 &&
|
|
519
|
+
sequenceLengths[begin + activeBatchSize - 1] <= step)
|
|
520
|
+
activeBatchSize--;
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
#if defined(MLPACK_HAS_COOT)
|
|
524
|
+
|
|
525
|
+
void CalculateActivePoints(size_t& activeBatchSize,
|
|
526
|
+
const size_t begin,
|
|
527
|
+
const URowType& sequenceLengths,
|
|
528
|
+
const size_t step,
|
|
529
|
+
const std::enable_if_t<
|
|
530
|
+
IsCoot<URowType>::value>* = 0)
|
|
531
|
+
{
|
|
532
|
+
// Individual element access is probably slower if `URowType` is a
|
|
533
|
+
// Bandicoot type so we don't use the optimized version.
|
|
534
|
+
activeBatchSize = accu(sequenceLengths
|
|
535
|
+
.subvec(begin, begin + activeBatchSize - 1) > step);
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
#endif // defined(MLPACK_HAS_COOT)
|
|
445
539
|
|
|
446
540
|
//! Number of timesteps to consider for backpropagation through time (BPTT).
|
|
447
541
|
size_t bpttSteps;
|
|
@@ -464,8 +558,8 @@ class RNN
|
|
|
464
558
|
CubeType responses;
|
|
465
559
|
|
|
466
560
|
// The length of each input sequence. If this is empty, then every sequence
|
|
467
|
-
// is
|
|
468
|
-
|
|
561
|
+
// is assumed to have the same length (`predictors.n_slices`).
|
|
562
|
+
URowType sequenceLengths;
|
|
469
563
|
}; // class RNNType
|
|
470
564
|
|
|
471
565
|
} // namespace mlpack
|