mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -161,7 +161,7 @@ typename MatType::elem_type RNN<
|
|
|
161
161
|
OptimizerType& optimizer,
|
|
162
162
|
CallbackTypes&&... callbacks)
|
|
163
163
|
{
|
|
164
|
-
ResetData(std::move(predictors), std::move(responses),
|
|
164
|
+
ResetData(std::move(predictors), std::move(responses), URowType());
|
|
165
165
|
|
|
166
166
|
network.WarnMessageMaxIterations(optimizer, this->predictors.n_cols);
|
|
167
167
|
|
|
@@ -170,8 +170,8 @@ typename MatType::elem_type RNN<
|
|
|
170
170
|
|
|
171
171
|
// Train the model.
|
|
172
172
|
Timer::Start("rnn_optimization");
|
|
173
|
-
const
|
|
174
|
-
|
|
173
|
+
const ElemType out = optimizer.Optimize(*this, network.Parameters(),
|
|
174
|
+
callbacks...);
|
|
175
175
|
Timer::Stop("rnn_optimization");
|
|
176
176
|
|
|
177
177
|
Log::Info << "RNN::Train(): final objective of trained model is " << out
|
|
@@ -212,7 +212,7 @@ typename MatType::elem_type RNN<
|
|
|
212
212
|
>::Train(
|
|
213
213
|
CubeType predictors,
|
|
214
214
|
CubeType responses,
|
|
215
|
-
|
|
215
|
+
URowType sequenceLengths,
|
|
216
216
|
OptimizerType& optimizer,
|
|
217
217
|
CallbackTypes&&... callbacks)
|
|
218
218
|
{
|
|
@@ -226,8 +226,8 @@ typename MatType::elem_type RNN<
|
|
|
226
226
|
|
|
227
227
|
// Train the model.
|
|
228
228
|
Timer::Start("rnn_optimization");
|
|
229
|
-
const
|
|
230
|
-
|
|
229
|
+
const ElemType out = optimizer.Optimize(*this, network.Parameters(),
|
|
230
|
+
callbacks...);
|
|
231
231
|
Timer::Stop("rnn_optimization");
|
|
232
232
|
|
|
233
233
|
Log::Info << "RNN::Train(): final objective of trained model is " << out
|
|
@@ -248,7 +248,7 @@ typename MatType::elem_type RNN<
|
|
|
248
248
|
>::Train(
|
|
249
249
|
CubeType predictors,
|
|
250
250
|
CubeType responses,
|
|
251
|
-
|
|
251
|
+
URowType sequenceLengths,
|
|
252
252
|
CallbackTypes&&... callbacks)
|
|
253
253
|
{
|
|
254
254
|
OptimizerType optimizer;
|
|
@@ -289,7 +289,8 @@ void RNN<
|
|
|
289
289
|
// Iterate over all time steps.
|
|
290
290
|
for (size_t t = 0; t < predictors.n_slices; ++t)
|
|
291
291
|
{
|
|
292
|
-
SetCurrentStep(t, (t == predictors.n_slices - 1)
|
|
292
|
+
SetCurrentStep(t, (t == predictors.n_slices - 1), effectiveBatchSize,
|
|
293
|
+
effectiveBatchSize);
|
|
293
294
|
|
|
294
295
|
// Create aliases for the input and output. If we are in single mode, we
|
|
295
296
|
// always output into the same slice.
|
|
@@ -315,7 +316,7 @@ void RNN<
|
|
|
315
316
|
>::Predict(
|
|
316
317
|
const CubeType& predictors,
|
|
317
318
|
CubeType& results,
|
|
318
|
-
const
|
|
319
|
+
const URowType& sequenceLengths)
|
|
319
320
|
{
|
|
320
321
|
// Ensure that the network is configured correctly.
|
|
321
322
|
network.CheckNetwork("RNN::Predict()", predictors.n_rows, true, false);
|
|
@@ -334,7 +335,7 @@ void RNN<
|
|
|
334
335
|
const size_t steps = sequenceLengths[i];
|
|
335
336
|
for (size_t t = 0; t < steps; ++t)
|
|
336
337
|
{
|
|
337
|
-
SetCurrentStep(t, (t == steps - 1));
|
|
338
|
+
SetCurrentStep(t, (t == steps - 1), 1, 1);
|
|
338
339
|
|
|
339
340
|
// Create aliases for the input and output. If we are in single mode, we
|
|
340
341
|
// always output into the same slice.
|
|
@@ -375,6 +376,132 @@ void RNN<
|
|
|
375
376
|
}
|
|
376
377
|
}
|
|
377
378
|
|
|
379
|
+
template<
|
|
380
|
+
typename OutputLayerType,
|
|
381
|
+
typename InitializationRuleType,
|
|
382
|
+
typename MatType
|
|
383
|
+
>
|
|
384
|
+
typename MatType::elem_type RNN<
|
|
385
|
+
OutputLayerType,
|
|
386
|
+
InitializationRuleType,
|
|
387
|
+
MatType
|
|
388
|
+
>::Evaluate(
|
|
389
|
+
const CubeType& predictors,
|
|
390
|
+
const CubeType& responses)
|
|
391
|
+
{
|
|
392
|
+
// Ensure that the network is configured correctly.
|
|
393
|
+
network.CheckNetwork("RNN::Evaluate()", predictors.n_rows);
|
|
394
|
+
|
|
395
|
+
// Add the loss of the network unrelated to output.
|
|
396
|
+
ElemType lossSum = ElemType(network.network.Loss());
|
|
397
|
+
|
|
398
|
+
// Reset recurrent memory state.
|
|
399
|
+
ResetMemoryState(0, predictors.n_cols);
|
|
400
|
+
|
|
401
|
+
// Iterate over all time slices.
|
|
402
|
+
MatType forwardOutput;
|
|
403
|
+
for (size_t t = 0; t < predictors.n_slices; t++)
|
|
404
|
+
{
|
|
405
|
+
SetCurrentStep(t, (t == predictors.n_slices - 1), predictors.n_cols,
|
|
406
|
+
predictors.n_cols);
|
|
407
|
+
// Do a forward pass and calculate the loss.
|
|
408
|
+
network.Forward(predictors.slice(t), forwardOutput);
|
|
409
|
+
if (!single || t == predictors.n_slices - 1)
|
|
410
|
+
lossSum += network.outputLayer.Forward(forwardOutput,
|
|
411
|
+
responses.slice(single ? 0 : t));
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
return lossSum;
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
template<
|
|
418
|
+
typename OutputLayerType,
|
|
419
|
+
typename InitializationRuleType,
|
|
420
|
+
typename MatType
|
|
421
|
+
>
|
|
422
|
+
typename MatType::elem_type RNN<
|
|
423
|
+
OutputLayerType,
|
|
424
|
+
InitializationRuleType,
|
|
425
|
+
MatType
|
|
426
|
+
>::Evaluate(
|
|
427
|
+
const CubeType& predictors,
|
|
428
|
+
const CubeType& responses,
|
|
429
|
+
const URowType& sequenceLengths,
|
|
430
|
+
const size_t batchSize)
|
|
431
|
+
{
|
|
432
|
+
// Ensure that the network is configured correctly.
|
|
433
|
+
network.CheckNetwork("RNN::Evaluate()", predictors.n_rows);
|
|
434
|
+
|
|
435
|
+
if (sequenceLengths.n_elem > 0 && batchSize != 1 && single)
|
|
436
|
+
throw std::invalid_argument(
|
|
437
|
+
"Batch size must be 1 for ragged sequences in single mode!");
|
|
438
|
+
|
|
439
|
+
// Add the loss of the network unrelated to output.
|
|
440
|
+
ElemType lossSum = ElemType(network.network.Loss());
|
|
441
|
+
|
|
442
|
+
CubeType reordPredictors, reordResponses;
|
|
443
|
+
URowType reordLengths;
|
|
444
|
+
if (batchSize > 1)
|
|
445
|
+
{
|
|
446
|
+
// Make copies of the arguments so they can be reordered while the orignal
|
|
447
|
+
// arguments stay constant.
|
|
448
|
+
reordPredictors = predictors;
|
|
449
|
+
reordResponses = responses;
|
|
450
|
+
reordLengths = sequenceLengths;
|
|
451
|
+
}
|
|
452
|
+
else
|
|
453
|
+
{
|
|
454
|
+
// Reordering isn't actually needed so we just make aliases.
|
|
455
|
+
MakeAlias(reordPredictors, predictors, predictors.n_rows,
|
|
456
|
+
predictors.n_cols, predictors.n_slices);
|
|
457
|
+
MakeAlias(reordResponses, responses, responses.n_rows,
|
|
458
|
+
responses.n_cols, responses.n_slices);
|
|
459
|
+
MakeAlias(reordLengths, sequenceLengths, sequenceLengths.n_elem);
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
for (size_t i = 0; i < predictors.n_cols; i += batchSize)
|
|
463
|
+
{
|
|
464
|
+
size_t effectiveBatchSize = std::min(batchSize,
|
|
465
|
+
size_t(predictors.n_cols) - i);
|
|
466
|
+
|
|
467
|
+
// Reset recurrent memory state.
|
|
468
|
+
ResetMemoryState(0, effectiveBatchSize);
|
|
469
|
+
|
|
470
|
+
// Reorder the data so the sequence lengths are in descending order.
|
|
471
|
+
if (batchSize > 1)
|
|
472
|
+
ReorderBatch(i, effectiveBatchSize, reordPredictors, reordResponses,
|
|
473
|
+
reordLengths);
|
|
474
|
+
|
|
475
|
+
MatType forwardOutput, inputAlias, responseAlias;
|
|
476
|
+
// Iterate over all time slices.
|
|
477
|
+
size_t slices = reordLengths
|
|
478
|
+
.subvec(i, i + effectiveBatchSize - 1).max();
|
|
479
|
+
size_t activeBatchSize = effectiveBatchSize;
|
|
480
|
+
for (size_t t = 0; t < slices; t++)
|
|
481
|
+
{
|
|
482
|
+
// Calculate the number of active points.
|
|
483
|
+
CalculateActivePoints(activeBatchSize, i, reordLengths, t);
|
|
484
|
+
|
|
485
|
+
SetCurrentStep(t, (t == slices - 1), effectiveBatchSize, activeBatchSize);
|
|
486
|
+
|
|
487
|
+
// Get the input and response data.
|
|
488
|
+
MakeAlias(inputAlias, reordPredictors.slice(t), predictors.n_rows,
|
|
489
|
+
activeBatchSize, i * predictors.n_rows);
|
|
490
|
+
|
|
491
|
+
// Do a forward pass and calculate the loss.
|
|
492
|
+
network.Forward(inputAlias, forwardOutput);
|
|
493
|
+
if (!single || t == slices - 1)
|
|
494
|
+
{
|
|
495
|
+
MakeAlias(responseAlias, reordResponses.slice((single ? 0 : t)),
|
|
496
|
+
responses.n_rows, activeBatchSize, i * responses.n_rows);
|
|
497
|
+
lossSum += network.outputLayer.Forward(forwardOutput, responseAlias);
|
|
498
|
+
}
|
|
499
|
+
}
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
return lossSum;
|
|
503
|
+
}
|
|
504
|
+
|
|
378
505
|
template<
|
|
379
506
|
typename OutputLayerType,
|
|
380
507
|
typename InitializationRuleType,
|
|
@@ -431,32 +558,43 @@ typename MatType::elem_type RNN<
|
|
|
431
558
|
// Ensure the network is valid.
|
|
432
559
|
network.CheckNetwork("RNN::Evaluate()", predictors.n_rows);
|
|
433
560
|
|
|
561
|
+
if (sequenceLengths.n_elem > 0 && batchSize != 1 && single)
|
|
562
|
+
throw std::invalid_argument(
|
|
563
|
+
"Batch size must be 1 for ragged sequences in single mode!");
|
|
564
|
+
|
|
434
565
|
// The core of the computation here is to pass through each step. Since we
|
|
435
566
|
// are not computing the gradient, we can be "clever" and use only one memory
|
|
436
567
|
// cell---we don't need to know about the past.
|
|
437
|
-
ResetMemoryState(
|
|
568
|
+
ResetMemoryState(0, batchSize);
|
|
438
569
|
MatType output(network.network.OutputSize(), batchSize);
|
|
439
570
|
|
|
440
|
-
|
|
441
|
-
|
|
571
|
+
// Reorder the data so the sequence lengths are in descending order.
|
|
572
|
+
if (sequenceLengths.n_elem > 0 && batchSize > 1)
|
|
573
|
+
ReorderBatch(begin, batchSize, predictors, responses, sequenceLengths);
|
|
442
574
|
|
|
443
|
-
|
|
575
|
+
ElemType loss = 0;
|
|
444
576
|
MatType stepData, responseData;
|
|
445
577
|
const size_t steps = (sequenceLengths.n_elem == 0) ? predictors.n_slices :
|
|
446
|
-
sequenceLengths
|
|
578
|
+
sequenceLengths.subvec(begin, begin + batchSize - 1).max();
|
|
579
|
+
size_t activeBatchSize = batchSize;
|
|
447
580
|
for (size_t t = 0; t < steps; ++t)
|
|
448
581
|
{
|
|
582
|
+
// Calculate the number of active points.
|
|
583
|
+
if (sequenceLengths.n_elem > 0)
|
|
584
|
+
CalculateActivePoints(activeBatchSize, begin, sequenceLengths, t);
|
|
585
|
+
|
|
449
586
|
// Manually reset the data of the network to be an alias of the current time
|
|
450
587
|
// step.
|
|
451
|
-
SetCurrentStep(t, (t == steps));
|
|
588
|
+
SetCurrentStep(t, (t == steps), batchSize, activeBatchSize);
|
|
589
|
+
|
|
452
590
|
MakeAlias(network.predictors, predictors.slice(t), predictors.n_rows,
|
|
453
|
-
|
|
591
|
+
activeBatchSize, begin * predictors.slice(t).n_rows);
|
|
454
592
|
const size_t responseStep = (single) ? 0 : t;
|
|
455
593
|
MakeAlias(network.responses, responses.slice(responseStep),
|
|
456
|
-
responses.n_rows,
|
|
594
|
+
responses.n_rows, activeBatchSize,
|
|
457
595
|
begin * responses.slice(responseStep).n_rows);
|
|
458
596
|
|
|
459
|
-
loss += network.Evaluate(output, begin,
|
|
597
|
+
loss += network.Evaluate(output, begin, activeBatchSize);
|
|
460
598
|
}
|
|
461
599
|
|
|
462
600
|
return loss;
|
|
@@ -497,10 +635,11 @@ typename MatType::elem_type RNN<
|
|
|
497
635
|
{
|
|
498
636
|
network.CheckNetwork("RNN::EvaluateWithGradient()", predictors.n_rows);
|
|
499
637
|
|
|
500
|
-
if (sequenceLengths.n_elem > 0 && batchSize != 1)
|
|
501
|
-
throw std::invalid_argument(
|
|
638
|
+
if (sequenceLengths.n_elem > 0 && batchSize != 1 && single)
|
|
639
|
+
throw std::invalid_argument(
|
|
640
|
+
"Batch size must be 1 for ragged sequences in single mode!");
|
|
502
641
|
|
|
503
|
-
|
|
642
|
+
ElemType loss = 0;
|
|
504
643
|
|
|
505
644
|
// We must save anywhere between 1 and `bpttSteps` states, but we are limited
|
|
506
645
|
// by `predictors.n_slices`.
|
|
@@ -523,22 +662,31 @@ typename MatType::elem_type RNN<
|
|
|
523
662
|
// Add loss (this is not dependent on time steps, and should only be added
|
|
524
663
|
// once). This is, e.g., regularizer loss, and other additive losses not
|
|
525
664
|
// having to do with the output layer.
|
|
526
|
-
loss += network.network.Loss();
|
|
665
|
+
loss += ElemType(network.network.Loss());
|
|
666
|
+
|
|
667
|
+
// Reorder the data so the sequence lengths are in descending order.
|
|
668
|
+
if (sequenceLengths.n_elem > 0 && batchSize > 1)
|
|
669
|
+
ReorderBatch(begin, batchSize, predictors, responses, sequenceLengths);
|
|
527
670
|
|
|
528
671
|
// For backpropagation through time, we must backpropagate for every
|
|
529
672
|
// subsequence of length `bpttSteps`. Before we've taken `bpttSteps` though,
|
|
530
673
|
// we will be backpropagating shorter sequences.
|
|
531
674
|
const size_t steps = (sequenceLengths.n_elem == 0) ? predictors.n_slices :
|
|
532
|
-
sequenceLengths
|
|
675
|
+
sequenceLengths.subvec(begin, begin + batchSize - 1).max();
|
|
676
|
+
size_t activeBatchSize = batchSize;
|
|
533
677
|
for (size_t t = 0; t < steps; ++t)
|
|
534
678
|
{
|
|
535
|
-
|
|
679
|
+
// Calculate the number of active points.
|
|
680
|
+
if (sequenceLengths.n_elem > 0)
|
|
681
|
+
CalculateActivePoints(activeBatchSize, begin, sequenceLengths, t);
|
|
682
|
+
|
|
683
|
+
SetCurrentStep(t, (t == (steps - 1)), batchSize, activeBatchSize);
|
|
536
684
|
|
|
537
685
|
// Make an alias of the step's data for the forward pass.
|
|
538
|
-
MakeAlias(stepData, predictors.slice(t), predictors.n_rows,
|
|
539
|
-
begin * predictors.
|
|
686
|
+
MakeAlias(stepData, predictors.slice(t), predictors.n_rows, activeBatchSize,
|
|
687
|
+
begin * predictors.n_rows);
|
|
540
688
|
MakeAlias(outputData, outputs.slice(t % effectiveBPTTSteps), outputs.n_rows,
|
|
541
|
-
|
|
689
|
+
activeBatchSize);
|
|
542
690
|
network.network.Forward(stepData, outputData);
|
|
543
691
|
|
|
544
692
|
// Determine what the response should be. If we are in single mode but not
|
|
@@ -553,7 +701,7 @@ typename MatType::elem_type RNN<
|
|
|
553
701
|
MatType error;
|
|
554
702
|
for (size_t step = 0; step < std::min(t + 1, effectiveBPTTSteps); ++step)
|
|
555
703
|
{
|
|
556
|
-
SetCurrentStep(t - step, (step == 0));
|
|
704
|
+
SetCurrentStep(t - step, (step == 0), batchSize, activeBatchSize, true);
|
|
557
705
|
|
|
558
706
|
if (step > 0)
|
|
559
707
|
{
|
|
@@ -561,9 +709,9 @@ typename MatType::elem_type RNN<
|
|
|
561
709
|
error.zeros();
|
|
562
710
|
|
|
563
711
|
MakeAlias(stepData, predictors.slice(t - step), predictors.n_rows,
|
|
564
|
-
|
|
712
|
+
activeBatchSize, begin * predictors.slice(t - step).n_rows);
|
|
565
713
|
MakeAlias(outputData, outputs.slice((t - step) % effectiveBPTTSteps),
|
|
566
|
-
outputs.n_rows,
|
|
714
|
+
outputs.n_rows, activeBatchSize);
|
|
567
715
|
}
|
|
568
716
|
else
|
|
569
717
|
{
|
|
@@ -571,11 +719,11 @@ typename MatType::elem_type RNN<
|
|
|
571
719
|
// error.
|
|
572
720
|
const size_t responseStep = (single) ? 0 : t - step;
|
|
573
721
|
MakeAlias(stepData, predictors.slice(t - step), predictors.n_rows,
|
|
574
|
-
|
|
722
|
+
activeBatchSize, begin * predictors.n_rows);
|
|
575
723
|
MakeAlias(responseData, responses.slice(responseStep), responses.n_rows,
|
|
576
|
-
|
|
724
|
+
activeBatchSize, begin * responses.n_rows);
|
|
577
725
|
MakeAlias(outputData, outputs.slice((t - step) % effectiveBPTTSteps),
|
|
578
|
-
outputs.n_rows,
|
|
726
|
+
outputs.n_rows, activeBatchSize);
|
|
579
727
|
|
|
580
728
|
// We only need to do this on the first time step of BPTT.
|
|
581
729
|
loss += network.outputLayer.Forward(outputData, responseData);
|
|
@@ -590,7 +738,8 @@ typename MatType::elem_type RNN<
|
|
|
590
738
|
// TODO: note that we could avoid the copy of currentGradient by having
|
|
591
739
|
// each layer *add* its gradient to `gradient`. However that would
|
|
592
740
|
// require some amount of refactoring.
|
|
593
|
-
MatType networkDelta
|
|
741
|
+
MatType networkDelta(predictors.n_rows, activeBatchSize,
|
|
742
|
+
GetFillType<MatType>::none);
|
|
594
743
|
GradType currentGradient(gradient.n_rows, gradient.n_cols,
|
|
595
744
|
GetFillType<MatType>::zeros);
|
|
596
745
|
network.network.Backward(stepData, outputData, error, networkDelta);
|
|
@@ -656,7 +805,7 @@ void RNN<
|
|
|
656
805
|
>::ResetData(
|
|
657
806
|
CubeType predictors,
|
|
658
807
|
CubeType responses,
|
|
659
|
-
|
|
808
|
+
URowType sequenceLengths)
|
|
660
809
|
{
|
|
661
810
|
this->predictors = std::move(predictors);
|
|
662
811
|
this->responses = std::move(responses);
|
|
@@ -694,17 +843,60 @@ void RNN<
|
|
|
694
843
|
OutputLayerType,
|
|
695
844
|
InitializationRuleType,
|
|
696
845
|
MatType
|
|
697
|
-
>::SetCurrentStep(const size_t step,
|
|
846
|
+
>::SetCurrentStep(const size_t step,
|
|
847
|
+
const bool end,
|
|
848
|
+
size_t batchSize,
|
|
849
|
+
size_t activeBatchSize,
|
|
850
|
+
bool backwards)
|
|
698
851
|
{
|
|
699
|
-
// Iterate over all layers and set the
|
|
852
|
+
// Iterate over all layers and set the current step.
|
|
700
853
|
for (Layer<MatType>* l : network.Network())
|
|
701
854
|
{
|
|
702
855
|
// We can only call CurrentStep() on RecurrentLayers.
|
|
703
856
|
RecurrentLayer<MatType>* r =
|
|
704
857
|
dynamic_cast<RecurrentLayer<MatType>*>(l);
|
|
705
858
|
if (r != nullptr)
|
|
859
|
+
{
|
|
706
860
|
r->CurrentStep(step, end);
|
|
861
|
+
r->OnStepChanged(step, batchSize, activeBatchSize, backwards);
|
|
862
|
+
}
|
|
863
|
+
}
|
|
864
|
+
}
|
|
865
|
+
|
|
866
|
+
template<
|
|
867
|
+
typename OutputLayerType,
|
|
868
|
+
typename InitializationRuleType,
|
|
869
|
+
typename MatType
|
|
870
|
+
>
|
|
871
|
+
void RNN<
|
|
872
|
+
OutputLayerType,
|
|
873
|
+
InitializationRuleType,
|
|
874
|
+
MatType
|
|
875
|
+
>::ReorderBatch(const size_t begin,
|
|
876
|
+
const size_t batchSize,
|
|
877
|
+
CubeType& predictors,
|
|
878
|
+
CubeType& responses,
|
|
879
|
+
URowType& sequenceLengths)
|
|
880
|
+
{
|
|
881
|
+
using UColType = typename GetUColType<MatType>::type;
|
|
882
|
+
URowType batchLengths;
|
|
883
|
+
MakeAlias(batchLengths, sequenceLengths, batchSize, begin);
|
|
884
|
+
|
|
885
|
+
// Get the new ordering of this batch.
|
|
886
|
+
UColType ordering = sort_index(batchLengths, "descending");
|
|
887
|
+
|
|
888
|
+
// Reorder all slices to use the new ordering.
|
|
889
|
+
MatType batchPredictors, batchResponses;
|
|
890
|
+
for (size_t i = 0; i < predictors.n_slices; i++)
|
|
891
|
+
{
|
|
892
|
+
MakeAlias(batchPredictors, predictors.slice(i), predictors.n_rows,
|
|
893
|
+
batchSize, begin * predictors.n_rows);
|
|
894
|
+
MakeAlias(batchResponses, responses.slice(i), responses.n_rows,
|
|
895
|
+
batchSize, begin * responses.n_rows);
|
|
896
|
+
batchPredictors = batchPredictors.cols(ordering);
|
|
897
|
+
batchResponses = batchResponses.cols(ordering);
|
|
707
898
|
}
|
|
899
|
+
batchLengths = batchLengths.cols(ordering);
|
|
708
900
|
}
|
|
709
901
|
|
|
710
902
|
} // namespace mlpack
|
|
@@ -80,7 +80,7 @@ void DrusillaSelect<MatType>::Train(
|
|
|
80
80
|
candidateSet.set_size(referenceSet.n_rows, l * m);
|
|
81
81
|
candidateIndices.set_size(l * m);
|
|
82
82
|
|
|
83
|
-
arma::vec dataMean(
|
|
83
|
+
arma::vec dataMean(mean(referenceSet, 1));
|
|
84
84
|
arma::vec norms(referenceSet.n_cols);
|
|
85
85
|
|
|
86
86
|
MatType refCopy(referenceSet.n_rows, referenceSet.n_cols);
|
mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp
CHANGED
|
@@ -100,7 +100,7 @@ BINDING_SEE_ALSO("Bayesian Interpolation",
|
|
|
100
100
|
"https://cs.uwaterloo.ca/~mannr/cs886-w10/mackay-bayesian.pdf");
|
|
101
101
|
BINDING_SEE_ALSO("Bayesian Linear Regression, Section 3.3",
|
|
102
102
|
// I wonder how long this full text PDF will remain available...
|
|
103
|
-
"https://www.microsoft.com/en-us/research/uploads/
|
|
103
|
+
"https://www.microsoft.com/en-us/research/wp-content/uploads/2006/01/"
|
|
104
104
|
"Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf");
|
|
105
105
|
BINDING_SEE_ALSO("BayesianLinearRegression C++ class documentation",
|
|
106
106
|
"@doc/user/methods/bayesian_linear_regression.md");
|
|
@@ -313,7 +313,7 @@ inline double ParallelSGD<ExponentialBackoff>::Optimize(
|
|
|
313
313
|
// Get the stepsize for this iteration
|
|
314
314
|
double stepSize = decayPolicy.StepSize(i);
|
|
315
315
|
|
|
316
|
-
if (
|
|
316
|
+
if (Shuffle()) // Determine order of visitation.
|
|
317
317
|
std::shuffle(visitationOrder.begin(), visitationOrder.end(),
|
|
318
318
|
mlpack::RandGen());
|
|
319
319
|
|
|
@@ -64,7 +64,7 @@ class DecisionTree :
|
|
|
64
64
|
*/
|
|
65
65
|
template<typename MatType, typename LabelsType>
|
|
66
66
|
DecisionTree(MatType data,
|
|
67
|
-
const
|
|
67
|
+
const DatasetInfo& datasetInfo,
|
|
68
68
|
LabelsType labels,
|
|
69
69
|
const size_t numClasses,
|
|
70
70
|
const size_t minimumLeafSize = 10,
|
|
@@ -121,7 +121,7 @@ class DecisionTree :
|
|
|
121
121
|
template<typename MatType, typename LabelsType, typename WeightsType>
|
|
122
122
|
DecisionTree(
|
|
123
123
|
MatType data,
|
|
124
|
-
const
|
|
124
|
+
const DatasetInfo& datasetInfo,
|
|
125
125
|
LabelsType labels,
|
|
126
126
|
const size_t numClasses,
|
|
127
127
|
WeightsType weights,
|
|
@@ -186,7 +186,7 @@ class DecisionTree :
|
|
|
186
186
|
DecisionTree(
|
|
187
187
|
const DecisionTree& other,
|
|
188
188
|
MatType data,
|
|
189
|
-
const
|
|
189
|
+
const DatasetInfo& datasetInfo,
|
|
190
190
|
LabelsType labels,
|
|
191
191
|
const size_t numClasses,
|
|
192
192
|
WeightsType weights,
|
|
@@ -291,7 +291,7 @@ class DecisionTree :
|
|
|
291
291
|
*/
|
|
292
292
|
template<typename MatType, typename LabelsType>
|
|
293
293
|
double Train(MatType data,
|
|
294
|
-
const
|
|
294
|
+
const DatasetInfo& datasetInfo,
|
|
295
295
|
LabelsType labels,
|
|
296
296
|
const size_t numClasses,
|
|
297
297
|
const size_t minimumLeafSize = 10,
|
|
@@ -350,7 +350,7 @@ class DecisionTree :
|
|
|
350
350
|
*/
|
|
351
351
|
template<typename MatType, typename LabelsType, typename WeightsType>
|
|
352
352
|
double Train(MatType data,
|
|
353
|
-
const
|
|
353
|
+
const DatasetInfo& datasetInfo,
|
|
354
354
|
LabelsType labels,
|
|
355
355
|
const size_t numClasses,
|
|
356
356
|
WeightsType weights,
|
|
@@ -540,7 +540,7 @@ class DecisionTree :
|
|
|
540
540
|
double Train(MatType& data,
|
|
541
541
|
const size_t begin,
|
|
542
542
|
const size_t count,
|
|
543
|
-
const
|
|
543
|
+
const DatasetInfo& datasetInfo,
|
|
544
544
|
arma::Row<size_t>& labels,
|
|
545
545
|
const size_t numClasses,
|
|
546
546
|
WeightsType& weights,
|
|
@@ -29,7 +29,7 @@ DecisionTree<FitnessFunction,
|
|
|
29
29
|
DimensionSelectionType,
|
|
30
30
|
NoRecursion>::DecisionTree(
|
|
31
31
|
MatType data,
|
|
32
|
-
const
|
|
32
|
+
const DatasetInfo& datasetInfo,
|
|
33
33
|
LabelsType labels,
|
|
34
34
|
const size_t numClasses,
|
|
35
35
|
const size_t minimumLeafSize,
|
|
@@ -103,7 +103,7 @@ DecisionTree<FitnessFunction,
|
|
|
103
103
|
DimensionSelectionType,
|
|
104
104
|
NoRecursion>::DecisionTree(
|
|
105
105
|
MatType data,
|
|
106
|
-
const
|
|
106
|
+
const DatasetInfo& datasetInfo,
|
|
107
107
|
LabelsType labels,
|
|
108
108
|
const size_t numClasses,
|
|
109
109
|
WeightsType weights,
|
|
@@ -186,7 +186,7 @@ DecisionTree<FitnessFunction,
|
|
|
186
186
|
NoRecursion>::DecisionTree(
|
|
187
187
|
const DecisionTree& other,
|
|
188
188
|
MatType data,
|
|
189
|
-
const
|
|
189
|
+
const DatasetInfo& datasetInfo,
|
|
190
190
|
LabelsType labels,
|
|
191
191
|
const size_t numClasses,
|
|
192
192
|
WeightsType weights,
|
|
@@ -446,7 +446,7 @@ double DecisionTree<FitnessFunction,
|
|
|
446
446
|
DimensionSelectionType,
|
|
447
447
|
NoRecursion>::Train(
|
|
448
448
|
MatType data,
|
|
449
|
-
const
|
|
449
|
+
const DatasetInfo& datasetInfo,
|
|
450
450
|
LabelsType labels,
|
|
451
451
|
const size_t numClasses,
|
|
452
452
|
const size_t minimumLeafSize,
|
|
@@ -527,7 +527,7 @@ double DecisionTree<FitnessFunction,
|
|
|
527
527
|
DimensionSelectionType,
|
|
528
528
|
NoRecursion>::Train(
|
|
529
529
|
MatType data,
|
|
530
|
-
const
|
|
530
|
+
const DatasetInfo& datasetInfo,
|
|
531
531
|
LabelsType labels,
|
|
532
532
|
const size_t numClasses,
|
|
533
533
|
WeightsType weights,
|
|
@@ -618,7 +618,7 @@ double DecisionTree<FitnessFunction,
|
|
|
618
618
|
MatType& data,
|
|
619
619
|
const size_t begin,
|
|
620
620
|
const size_t count,
|
|
621
|
-
const
|
|
621
|
+
const DatasetInfo& datasetInfo,
|
|
622
622
|
arma::Row<size_t>& labels,
|
|
623
623
|
const size_t numClasses,
|
|
624
624
|
WeightsType& weights,
|
|
@@ -650,7 +650,7 @@ double DecisionTree<FitnessFunction,
|
|
|
650
650
|
i = dimensionSelector.Next())
|
|
651
651
|
{
|
|
652
652
|
double dimGain = -DBL_MAX;
|
|
653
|
-
if (datasetInfo.Type(i) ==
|
|
653
|
+
if (datasetInfo.Type(i) == Datatype::categorical)
|
|
654
654
|
{
|
|
655
655
|
dimGain = CategoricalSplit::template SplitIfBetter<UseWeights>(bestGain,
|
|
656
656
|
data.cols(begin, begin + count - 1).row(i),
|
|
@@ -663,7 +663,7 @@ double DecisionTree<FitnessFunction,
|
|
|
663
663
|
classProbabilities,
|
|
664
664
|
*this);
|
|
665
665
|
}
|
|
666
|
-
else if (datasetInfo.Type(i) ==
|
|
666
|
+
else if (datasetInfo.Type(i) == Datatype::numeric)
|
|
667
667
|
{
|
|
668
668
|
dimGain = NumericSplit::template SplitIfBetter<UseWeights>(bestGain,
|
|
669
669
|
data.cols(begin, begin + count - 1).row(i),
|
|
@@ -699,14 +699,14 @@ double DecisionTree<FitnessFunction,
|
|
|
699
699
|
|
|
700
700
|
// Get the number of children we will have.
|
|
701
701
|
size_t numChildren = 0;
|
|
702
|
-
if (datasetInfo.Type(bestDim) ==
|
|
702
|
+
if (datasetInfo.Type(bestDim) == Datatype::categorical)
|
|
703
703
|
numChildren = CategoricalSplit::NumChildren(classProbabilities, *this);
|
|
704
704
|
else
|
|
705
705
|
numChildren = NumericSplit::NumChildren(classProbabilities, *this);
|
|
706
706
|
|
|
707
707
|
// Calculate all child assignments.
|
|
708
708
|
arma::Row<size_t> childAssignments(count);
|
|
709
|
-
if (datasetInfo.Type(bestDim) ==
|
|
709
|
+
if (datasetInfo.Type(bestDim) == Datatype::categorical)
|
|
710
710
|
{
|
|
711
711
|
for (size_t j = begin; j < begin + count; ++j)
|
|
712
712
|
childAssignments[j - begin] = CategoricalSplit::CalculateDirection(
|
|
@@ -868,7 +868,7 @@ double DecisionTree<FitnessFunction,
|
|
|
868
868
|
size_t numChildren =
|
|
869
869
|
NumericSplit::NumChildren(classProbabilities, *this);
|
|
870
870
|
splitDimension = bestDim;
|
|
871
|
-
dimensionType = (size_t)
|
|
871
|
+
dimensionType = (size_t) Datatype::numeric;
|
|
872
872
|
|
|
873
873
|
// Calculate all child assignments.
|
|
874
874
|
arma::Row<size_t> childAssignments(count);
|
|
@@ -1099,7 +1099,7 @@ size_t DecisionTree<FitnessFunction,
|
|
|
1099
1099
|
DimensionSelectionType,
|
|
1100
1100
|
NoRecursion>::CalculateDirection(const VecType& point) const
|
|
1101
1101
|
{
|
|
1102
|
-
if ((
|
|
1102
|
+
if ((Datatype) dimensionType == Datatype::categorical)
|
|
1103
1103
|
return CategoricalSplit::CalculateDirection(point[splitDimension],
|
|
1104
1104
|
classProbabilities, *this);
|
|
1105
1105
|
else
|