mlpack 4.6.2__cp38-cp38-win_amd64.whl → 4.7.0__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +3 -3
- mlpack/adaboost_classify.cp38-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp38-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp38-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp38-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp38-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp38-win_amd64.pyd +0 -0
- mlpack/cf.cp38-win_amd64.pyd +0 -0
- mlpack/dbscan.cp38-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp38-win_amd64.pyd +0 -0
- mlpack/det.cp38-win_amd64.pyd +0 -0
- mlpack/emst.cp38-win_amd64.pyd +0 -0
- mlpack/fastmks.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp38-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp38-win_amd64.pyd +0 -0
- mlpack/image_converter.cp38-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp38-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp38-win_amd64.pyd +0 -0
- mlpack/kfn.cp38-win_amd64.pyd +0 -0
- mlpack/kmeans.cp38-win_amd64.pyd +0 -0
- mlpack/knn.cp38-win_amd64.pyd +0 -0
- mlpack/krann.cp38-win_amd64.pyd +0 -0
- mlpack/lars.cp38-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp38-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp38-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp38-win_amd64.pyd +0 -0
- mlpack/lmnn.cp38-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp38-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp38-win_amd64.pyd +0 -0
- mlpack/lsh.cp38-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp38-win_amd64.pyd +0 -0
- mlpack/nbc.cp38-win_amd64.pyd +0 -0
- mlpack/nca.cp38-win_amd64.pyd +0 -0
- mlpack/nmf.cp38-win_amd64.pyd +0 -0
- mlpack/pca.cp38-win_amd64.pyd +0 -0
- mlpack/perceptron.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp38-win_amd64.pyd +0 -0
- mlpack/radical.cp38-win_amd64.pyd +0 -0
- mlpack/random_forest.cp38-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp38-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp38-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +5 -5
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +395 -376
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{.load-order-mlpack-4.6.2 → .load-order-mlpack-4.7.0} +0 -0
|
@@ -68,7 +68,7 @@ class DecisionTreeRegressor :
|
|
|
68
68
|
*/
|
|
69
69
|
template<typename MatType, typename ResponsesType>
|
|
70
70
|
DecisionTreeRegressor(MatType data,
|
|
71
|
-
const
|
|
71
|
+
const DatasetInfo& datasetInfo,
|
|
72
72
|
ResponsesType responses,
|
|
73
73
|
const size_t minimumLeafSize = 10,
|
|
74
74
|
const double minimumGainSplit = 1e-7,
|
|
@@ -121,7 +121,7 @@ class DecisionTreeRegressor :
|
|
|
121
121
|
template<typename MatType, typename ResponsesType, typename WeightsType>
|
|
122
122
|
DecisionTreeRegressor(
|
|
123
123
|
MatType data,
|
|
124
|
-
const
|
|
124
|
+
const DatasetInfo& datasetInfo,
|
|
125
125
|
ResponsesType responses,
|
|
126
126
|
WeightsType weights,
|
|
127
127
|
const size_t minimumLeafSize = 10,
|
|
@@ -182,7 +182,7 @@ class DecisionTreeRegressor :
|
|
|
182
182
|
DecisionTreeRegressor(
|
|
183
183
|
const DecisionTreeRegressor& other,
|
|
184
184
|
MatType data,
|
|
185
|
-
const
|
|
185
|
+
const DatasetInfo& datasetInfo,
|
|
186
186
|
ResponsesType responses,
|
|
187
187
|
WeightsType weights,
|
|
188
188
|
const size_t minimumLeafSize = 10,
|
|
@@ -277,7 +277,7 @@ class DecisionTreeRegressor :
|
|
|
277
277
|
*/
|
|
278
278
|
template<typename MatType, typename ResponsesType>
|
|
279
279
|
double Train(MatType data,
|
|
280
|
-
const
|
|
280
|
+
const DatasetInfo& datasetInfo,
|
|
281
281
|
ResponsesType responses,
|
|
282
282
|
const size_t minimumLeafSize = 10,
|
|
283
283
|
const double minimumGainSplit = 1e-7,
|
|
@@ -338,7 +338,7 @@ class DecisionTreeRegressor :
|
|
|
338
338
|
*/
|
|
339
339
|
template<typename MatType, typename ResponsesType, typename WeightsType>
|
|
340
340
|
double Train(MatType data,
|
|
341
|
-
const
|
|
341
|
+
const DatasetInfo& datasetInfo,
|
|
342
342
|
ResponsesType responses,
|
|
343
343
|
WeightsType weights,
|
|
344
344
|
const size_t minimumLeafSize = 10,
|
|
@@ -481,7 +481,7 @@ class DecisionTreeRegressor :
|
|
|
481
481
|
double Train(MatType& data,
|
|
482
482
|
const size_t begin,
|
|
483
483
|
const size_t count,
|
|
484
|
-
const
|
|
484
|
+
const DatasetInfo& datasetInfo,
|
|
485
485
|
ResponsesType& responses,
|
|
486
486
|
arma::rowvec& weights,
|
|
487
487
|
const size_t minimumLeafSize,
|
|
@@ -45,7 +45,7 @@ DecisionTreeRegressor<FitnessFunction,
|
|
|
45
45
|
DimensionSelectionType,
|
|
46
46
|
NoRecursion>::DecisionTreeRegressor(
|
|
47
47
|
MatType data,
|
|
48
|
-
const
|
|
48
|
+
const DatasetInfo& datasetInfo,
|
|
49
49
|
ResponsesType responses,
|
|
50
50
|
const size_t minimumLeafSize,
|
|
51
51
|
const double minimumGainSplit,
|
|
@@ -117,7 +117,7 @@ DecisionTreeRegressor<FitnessFunction,
|
|
|
117
117
|
DimensionSelectionType,
|
|
118
118
|
NoRecursion>::DecisionTreeRegressor(
|
|
119
119
|
MatType data,
|
|
120
|
-
const
|
|
120
|
+
const DatasetInfo& datasetInfo,
|
|
121
121
|
ResponsesType responses,
|
|
122
122
|
WeightsType weights,
|
|
123
123
|
const size_t minimumLeafSize,
|
|
@@ -199,7 +199,7 @@ DecisionTreeRegressor<FitnessFunction,
|
|
|
199
199
|
NoRecursion>::DecisionTreeRegressor(
|
|
200
200
|
const DecisionTreeRegressor& other,
|
|
201
201
|
MatType data,
|
|
202
|
-
const
|
|
202
|
+
const DatasetInfo& datasetInfo,
|
|
203
203
|
ResponsesType responses,
|
|
204
204
|
WeightsType weights,
|
|
205
205
|
const size_t minimumLeafSize,
|
|
@@ -429,7 +429,7 @@ double DecisionTreeRegressor<FitnessFunction,
|
|
|
429
429
|
DimensionSelectionType,
|
|
430
430
|
NoRecursion>::Train(
|
|
431
431
|
MatType data,
|
|
432
|
-
const
|
|
432
|
+
const DatasetInfo& datasetInfo,
|
|
433
433
|
ResponsesType responses,
|
|
434
434
|
const size_t minimumLeafSize,
|
|
435
435
|
const double minimumGainSplit,
|
|
@@ -510,7 +510,7 @@ double DecisionTreeRegressor<FitnessFunction,
|
|
|
510
510
|
DimensionSelectionType,
|
|
511
511
|
NoRecursion>::Train(
|
|
512
512
|
MatType data,
|
|
513
|
-
const
|
|
513
|
+
const DatasetInfo& datasetInfo,
|
|
514
514
|
ResponsesType responses,
|
|
515
515
|
WeightsType weights,
|
|
516
516
|
const size_t minimumLeafSize,
|
|
@@ -601,7 +601,7 @@ double DecisionTreeRegressor<FitnessFunction,
|
|
|
601
601
|
MatType& data,
|
|
602
602
|
const size_t begin,
|
|
603
603
|
const size_t count,
|
|
604
|
-
const
|
|
604
|
+
const DatasetInfo& datasetInfo,
|
|
605
605
|
ResponsesType& responses,
|
|
606
606
|
arma::rowvec& weights,
|
|
607
607
|
const size_t minimumLeafSize,
|
|
@@ -630,7 +630,7 @@ double DecisionTreeRegressor<FitnessFunction,
|
|
|
630
630
|
i = dimensionSelector.Next())
|
|
631
631
|
{
|
|
632
632
|
double dimGain = -DBL_MAX;
|
|
633
|
-
if (datasetInfo.Type(i) ==
|
|
633
|
+
if (datasetInfo.Type(i) == Datatype::categorical)
|
|
634
634
|
{
|
|
635
635
|
dimGain = CategoricalSplit::template SplitIfBetter<UseWeights>(bestGain,
|
|
636
636
|
data.cols(begin, begin + count - 1).row(i),
|
|
@@ -643,7 +643,7 @@ double DecisionTreeRegressor<FitnessFunction,
|
|
|
643
643
|
*this,
|
|
644
644
|
fitnessFunction);
|
|
645
645
|
}
|
|
646
|
-
else if (datasetInfo.Type(i) ==
|
|
646
|
+
else if (datasetInfo.Type(i) == Datatype::numeric)
|
|
647
647
|
{
|
|
648
648
|
dimGain = NumericSplit::template SplitIfBetter<UseWeights>(bestGain,
|
|
649
649
|
data.cols(begin, begin + count - 1).row(i),
|
|
@@ -679,14 +679,14 @@ double DecisionTreeRegressor<FitnessFunction,
|
|
|
679
679
|
|
|
680
680
|
// Get the number of children we will have.
|
|
681
681
|
size_t numChildren = 0;
|
|
682
|
-
if (datasetInfo.Type(bestDim) ==
|
|
682
|
+
if (datasetInfo.Type(bestDim) == Datatype::categorical)
|
|
683
683
|
numChildren = CategoricalSplit::NumChildren(splitInfo, *this);
|
|
684
684
|
else
|
|
685
685
|
numChildren = NumericSplit::NumChildren(splitInfo, *this);
|
|
686
686
|
|
|
687
687
|
// Calculate all child assignments.
|
|
688
688
|
arma::Row<size_t> childAssignments(count);
|
|
689
|
-
if (datasetInfo.Type(bestDim) ==
|
|
689
|
+
if (datasetInfo.Type(bestDim) == Datatype::categorical)
|
|
690
690
|
{
|
|
691
691
|
for (size_t j = begin; j < begin + count; ++j)
|
|
692
692
|
childAssignments[j - begin] = CategoricalSplit::CalculateDirection(
|
|
@@ -844,7 +844,7 @@ double DecisionTreeRegressor<FitnessFunction,
|
|
|
844
844
|
// We know that the split is numeric.
|
|
845
845
|
size_t numChildren = NumericSplit::NumChildren(splitInfo, *this);
|
|
846
846
|
splitDimension = bestDim;
|
|
847
|
-
dimensionType = (size_t)
|
|
847
|
+
dimensionType = (size_t) Datatype::numeric;
|
|
848
848
|
|
|
849
849
|
// Calculate all child assignments.
|
|
850
850
|
arma::Row<size_t> childAssignments(count);
|
|
@@ -986,7 +986,7 @@ size_t DecisionTreeRegressor<FitnessFunction,
|
|
|
986
986
|
NoRecursion
|
|
987
987
|
>::CalculateDirection(const VecType& point) const
|
|
988
988
|
{
|
|
989
|
-
if ((
|
|
989
|
+
if ((Datatype) dimensionType == Datatype::categorical)
|
|
990
990
|
return CategoricalSplit::CalculateDirection(point[splitDimension],
|
|
991
991
|
splitInfo, *this);
|
|
992
992
|
else
|
|
@@ -297,7 +297,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
297
297
|
}
|
|
298
298
|
|
|
299
299
|
if (params.Has("tag_counters_file"))
|
|
300
|
-
|
|
300
|
+
Save(params.Get<string>("tag_counters_file"), counters);
|
|
301
301
|
}
|
|
302
302
|
|
|
303
303
|
timers.Stop("det_test_set_tagging");
|
|
@@ -356,7 +356,7 @@ struct Train
|
|
|
356
356
|
|
|
357
357
|
// Now read the matrix.
|
|
358
358
|
Mat<size_t> label;
|
|
359
|
-
|
|
359
|
+
Load(lineBuf, label, Fatal);
|
|
360
360
|
|
|
361
361
|
// Ensure that matrix only has one row.
|
|
362
362
|
if (label.n_cols == 1)
|
|
@@ -387,7 +387,7 @@ struct Train
|
|
|
387
387
|
else
|
|
388
388
|
{
|
|
389
389
|
Mat<size_t> label;
|
|
390
|
-
|
|
390
|
+
Load(labelsFile, label, Fatal);
|
|
391
391
|
|
|
392
392
|
// Ensure that matrix only has one row.
|
|
393
393
|
if (label.n_cols == 1)
|
|
@@ -498,7 +498,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& /* timers */)
|
|
|
498
498
|
|
|
499
499
|
// Now read the matrix.
|
|
500
500
|
trainSeq.push_back(mat());
|
|
501
|
-
|
|
501
|
+
Load(lineBuf, trainSeq.back(), Fatal);
|
|
502
502
|
|
|
503
503
|
// See if we need to transpose the data.
|
|
504
504
|
if (type == "discrete")
|
|
@@ -516,7 +516,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& /* timers */)
|
|
|
516
516
|
{
|
|
517
517
|
// Only one input file.
|
|
518
518
|
trainSeq.resize(1);
|
|
519
|
-
|
|
519
|
+
Load(inputFile, trainSeq[0], Fatal);
|
|
520
520
|
}
|
|
521
521
|
|
|
522
522
|
// Get the type.
|
|
@@ -38,7 +38,7 @@ template<typename ActionType, typename ExtraInfoType>
|
|
|
38
38
|
void LoadHMMAndPerformAction(const std::string& modelFile,
|
|
39
39
|
ExtraInfoType* x)
|
|
40
40
|
{
|
|
41
|
-
const std::string extension =
|
|
41
|
+
const std::string extension = Extension(modelFile);
|
|
42
42
|
if (extension == "xml")
|
|
43
43
|
{
|
|
44
44
|
LoadHMMAndPerformActionHelper<ActionType, cereal::XMLInputArchive>(
|
|
@@ -126,7 +126,7 @@ char GetHMMType();
|
|
|
126
126
|
template<typename HMMType>
|
|
127
127
|
void SaveHMM(HMMType& hmm, const std::string& modelFile)
|
|
128
128
|
{
|
|
129
|
-
const std::string extension =
|
|
129
|
+
const std::string extension = Extension(modelFile);
|
|
130
130
|
if (extension == "xml")
|
|
131
131
|
SaveHMMHelper<cereal::XMLOutputArchive>(hmm, modelFile);
|
|
132
132
|
else if (extension == "bin")
|
|
@@ -133,7 +133,7 @@ class HoeffdingTree
|
|
|
133
133
|
* @param copyDatasetInfo If true, then a copy of the datasetInfo will be
|
|
134
134
|
* made.
|
|
135
135
|
*/
|
|
136
|
-
HoeffdingTree(const
|
|
136
|
+
HoeffdingTree(const DatasetInfo& datasetInfo,
|
|
137
137
|
const size_t numClasses,
|
|
138
138
|
const double successProbability = 0.95,
|
|
139
139
|
const size_t maxSamples = 0,
|
|
@@ -211,7 +211,7 @@ class HoeffdingTree
|
|
|
211
211
|
*/
|
|
212
212
|
template<typename MatType>
|
|
213
213
|
HoeffdingTree(const MatType& data,
|
|
214
|
-
const
|
|
214
|
+
const DatasetInfo& datasetInfo,
|
|
215
215
|
const arma::Row<size_t>& labels,
|
|
216
216
|
const size_t numClasses,
|
|
217
217
|
const bool batchTraining = true,
|
|
@@ -330,7 +330,7 @@ class HoeffdingTree
|
|
|
330
330
|
*/
|
|
331
331
|
template<typename MatType>
|
|
332
332
|
void Train(const MatType& data,
|
|
333
|
-
const
|
|
333
|
+
const DatasetInfo& info,
|
|
334
334
|
const arma::Row<size_t>& labels,
|
|
335
335
|
const size_t numClasses = 0,
|
|
336
336
|
const bool batchTraining = true,
|
|
@@ -340,7 +340,7 @@ class HoeffdingTree
|
|
|
340
340
|
|
|
341
341
|
template<typename MatType>
|
|
342
342
|
void Train(const MatType& data,
|
|
343
|
-
const
|
|
343
|
+
const DatasetInfo& info,
|
|
344
344
|
const arma::Row<size_t>& labels,
|
|
345
345
|
const size_t numClasses,
|
|
346
346
|
const bool batchTraining,
|
|
@@ -497,7 +497,7 @@ class HoeffdingTree
|
|
|
497
497
|
/**
|
|
498
498
|
* Reset the tree, setting a new number of classes and a new datasetInfo.
|
|
499
499
|
*/
|
|
500
|
-
void Reset(const
|
|
500
|
+
void Reset(const DatasetInfo& datasetInfo, const size_t numClasses);
|
|
501
501
|
|
|
502
502
|
//! Serialize the split.
|
|
503
503
|
template<typename Archive>
|
|
@@ -527,7 +527,7 @@ class HoeffdingTree
|
|
|
527
527
|
//! The minimum number of samples for splitting.
|
|
528
528
|
size_t minSamples;
|
|
529
529
|
//! The dataset information.
|
|
530
|
-
const
|
|
530
|
+
const DatasetInfo* datasetInfo;
|
|
531
531
|
//! Whether or not we own the dataset information.
|
|
532
532
|
bool ownsInfo;
|
|
533
533
|
//! The required probability of success for a split to be performed.
|
|
@@ -33,7 +33,7 @@ HoeffdingTree<
|
|
|
33
33
|
maxSamples(size_t(-1)),
|
|
34
34
|
checkInterval(100),
|
|
35
35
|
minSamples(100),
|
|
36
|
-
datasetInfo(new
|
|
36
|
+
datasetInfo(new DatasetInfo()),
|
|
37
37
|
ownsInfo(true),
|
|
38
38
|
successProbability(0.95),
|
|
39
39
|
splitDimension(size_t(-1)),
|
|
@@ -71,7 +71,7 @@ HoeffdingTree<
|
|
|
71
71
|
maxSamples((maxSamples == 0) ? size_t(-1) : maxSamples),
|
|
72
72
|
checkInterval(checkInterval),
|
|
73
73
|
minSamples(minSamples),
|
|
74
|
-
datasetInfo(new
|
|
74
|
+
datasetInfo(new DatasetInfo(dimensionality)),
|
|
75
75
|
ownsInfo(true),
|
|
76
76
|
successProbability(successProbability),
|
|
77
77
|
splitDimension(size_t(-1)),
|
|
@@ -103,7 +103,7 @@ HoeffdingTree<
|
|
|
103
103
|
FitnessFunction,
|
|
104
104
|
NumericSplitType,
|
|
105
105
|
CategoricalSplitType
|
|
106
|
-
>::HoeffdingTree(const
|
|
106
|
+
>::HoeffdingTree(const DatasetInfo& datasetInfo,
|
|
107
107
|
const size_t numClasses,
|
|
108
108
|
const double successProbability,
|
|
109
109
|
const size_t maxSamples,
|
|
@@ -123,7 +123,7 @@ HoeffdingTree<
|
|
|
123
123
|
maxSamples((maxSamples == 0) ? size_t(-1) : maxSamples),
|
|
124
124
|
checkInterval(checkInterval),
|
|
125
125
|
minSamples(minSamples),
|
|
126
|
-
datasetInfo(copyDatasetInfo ? new
|
|
126
|
+
datasetInfo(copyDatasetInfo ? new DatasetInfo(datasetInfo) :
|
|
127
127
|
&datasetInfo),
|
|
128
128
|
ownsInfo(copyDatasetInfo),
|
|
129
129
|
successProbability(successProbability),
|
|
@@ -142,7 +142,7 @@ HoeffdingTree<
|
|
|
142
142
|
{
|
|
143
143
|
for (size_t i = 0; i < datasetInfo.Dimensionality(); ++i)
|
|
144
144
|
{
|
|
145
|
-
if (datasetInfo.Type(i) ==
|
|
145
|
+
if (datasetInfo.Type(i) == Datatype::categorical)
|
|
146
146
|
{
|
|
147
147
|
categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
|
|
148
148
|
datasetInfo.NumMappings(i), numClasses, categoricalSplitIn));
|
|
@@ -182,7 +182,7 @@ HoeffdingTree<
|
|
|
182
182
|
maxSamples((maxSamples == 0) ? size_t(-1) : maxSamples),
|
|
183
183
|
checkInterval(checkInterval),
|
|
184
184
|
minSamples(minSamples),
|
|
185
|
-
datasetInfo(new
|
|
185
|
+
datasetInfo(new DatasetInfo(data.n_rows)),
|
|
186
186
|
ownsInfo(true),
|
|
187
187
|
successProbability(successProbability),
|
|
188
188
|
splitDimension(size_t(-1)),
|
|
@@ -207,7 +207,7 @@ HoeffdingTree<
|
|
|
207
207
|
NumericSplitType,
|
|
208
208
|
CategoricalSplitType
|
|
209
209
|
>::HoeffdingTree(const MatType& data,
|
|
210
|
-
const
|
|
210
|
+
const DatasetInfo& datasetInfoIn,
|
|
211
211
|
const arma::Row<size_t>& labels,
|
|
212
212
|
const size_t numClasses,
|
|
213
213
|
const bool batchTraining,
|
|
@@ -225,7 +225,7 @@ HoeffdingTree<
|
|
|
225
225
|
maxSamples((maxSamples == 0) ? size_t(-1) : maxSamples),
|
|
226
226
|
checkInterval(checkInterval),
|
|
227
227
|
minSamples(minSamples),
|
|
228
|
-
datasetInfo(new
|
|
228
|
+
datasetInfo(new DatasetInfo(datasetInfoIn)),
|
|
229
229
|
ownsInfo(true),
|
|
230
230
|
successProbability(successProbability),
|
|
231
231
|
splitDimension(size_t(-1)),
|
|
@@ -257,7 +257,7 @@ HoeffdingTree<FitnessFunction, NumericSplitType, CategoricalSplitType>::
|
|
|
257
257
|
maxSamples(other.maxSamples),
|
|
258
258
|
checkInterval(other.checkInterval),
|
|
259
259
|
minSamples(other.minSamples),
|
|
260
|
-
datasetInfo(new
|
|
260
|
+
datasetInfo(new DatasetInfo(*other.datasetInfo)),
|
|
261
261
|
ownsInfo(true),
|
|
262
262
|
successProbability(other.successProbability),
|
|
263
263
|
splitDimension(other.splitDimension),
|
|
@@ -341,7 +341,7 @@ HoeffdingTree<FitnessFunction, NumericSplitType, CategoricalSplitType>&
|
|
|
341
341
|
maxSamples = other.maxSamples;
|
|
342
342
|
checkInterval = other.checkInterval;
|
|
343
343
|
minSamples = other.minSamples;
|
|
344
|
-
datasetInfo = new
|
|
344
|
+
datasetInfo = new DatasetInfo(*other.datasetInfo);
|
|
345
345
|
ownsInfo = true;
|
|
346
346
|
successProbability = other.successProbability;
|
|
347
347
|
splitDimension = other.splitDimension;
|
|
@@ -482,7 +482,7 @@ void HoeffdingTree<
|
|
|
482
482
|
// Create a new datasetInfo, which assumes that all features are numeric.
|
|
483
483
|
if (ownsInfo)
|
|
484
484
|
delete datasetInfo;
|
|
485
|
-
datasetInfo = new
|
|
485
|
+
datasetInfo = new DatasetInfo(data.n_rows);
|
|
486
486
|
ownsInfo = true;
|
|
487
487
|
|
|
488
488
|
// Set the number of classes correctly.
|
|
@@ -510,7 +510,7 @@ void HoeffdingTree<
|
|
|
510
510
|
NumericSplitType,
|
|
511
511
|
CategoricalSplitType
|
|
512
512
|
>::Train(const MatType& data,
|
|
513
|
-
const
|
|
513
|
+
const DatasetInfo& info,
|
|
514
514
|
const arma::Row<size_t>& labels,
|
|
515
515
|
const size_t numClasses,
|
|
516
516
|
const bool batchTraining,
|
|
@@ -535,7 +535,7 @@ void HoeffdingTree<
|
|
|
535
535
|
NumericSplitType,
|
|
536
536
|
CategoricalSplitType
|
|
537
537
|
>::Train(const MatType& data,
|
|
538
|
-
const
|
|
538
|
+
const DatasetInfo& info,
|
|
539
539
|
const arma::Row<size_t>& labels,
|
|
540
540
|
const size_t numClasses,
|
|
541
541
|
const bool batchTraining,
|
|
@@ -596,9 +596,9 @@ void HoeffdingTree<
|
|
|
596
596
|
size_t categoricalIndex = 0;
|
|
597
597
|
for (size_t i = 0; i < point.n_rows; ++i)
|
|
598
598
|
{
|
|
599
|
-
if (datasetInfo->Type(i) ==
|
|
599
|
+
if (datasetInfo->Type(i) == Datatype::categorical)
|
|
600
600
|
categoricalSplits[categoricalIndex++].Train(point[i], label);
|
|
601
|
-
else if (datasetInfo->Type(i) ==
|
|
601
|
+
else if (datasetInfo->Type(i) == Datatype::numeric)
|
|
602
602
|
numericSplits[numericIndex++].Train(point[i], label);
|
|
603
603
|
}
|
|
604
604
|
|
|
@@ -673,10 +673,10 @@ size_t HoeffdingTree<
|
|
|
673
673
|
// best two splits that can be done in every network.
|
|
674
674
|
double bestGain = 0.0;
|
|
675
675
|
double secondBestGain = 0.0;
|
|
676
|
-
if (type ==
|
|
676
|
+
if (type == Datatype::categorical)
|
|
677
677
|
categoricalSplits[index].EvaluateFitnessFunction(bestGain,
|
|
678
678
|
secondBestGain);
|
|
679
|
-
else if (type ==
|
|
679
|
+
else if (type == Datatype::numeric)
|
|
680
680
|
numericSplits[index].EvaluateFitnessFunction(bestGain, secondBestGain);
|
|
681
681
|
|
|
682
682
|
// See if these gains are better than the previous.
|
|
@@ -706,7 +706,7 @@ size_t HoeffdingTree<
|
|
|
706
706
|
splitDimension = largestIndex;
|
|
707
707
|
const size_t type = dimensionMappings->at(largestIndex).first;
|
|
708
708
|
const size_t index = dimensionMappings->at(largestIndex).second;
|
|
709
|
-
if (type ==
|
|
709
|
+
if (type == Datatype::categorical)
|
|
710
710
|
{
|
|
711
711
|
// I don't know if this should be here.
|
|
712
712
|
majorityClass = categoricalSplits[index].MajorityClass();
|
|
@@ -801,9 +801,9 @@ size_t HoeffdingTree<
|
|
|
801
801
|
>::CalculateDirection(const VecType& point) const
|
|
802
802
|
{
|
|
803
803
|
// Don't call this before the node is split...
|
|
804
|
-
if (datasetInfo->Type(splitDimension) ==
|
|
804
|
+
if (datasetInfo->Type(splitDimension) == Datatype::numeric)
|
|
805
805
|
return numericSplit.CalculateDirection(point[splitDimension]);
|
|
806
|
-
else if (datasetInfo->Type(splitDimension) ==
|
|
806
|
+
else if (datasetInfo->Type(splitDimension) == Datatype::categorical)
|
|
807
807
|
return categoricalSplit.CalculateDirection(point[splitDimension]);
|
|
808
808
|
else
|
|
809
809
|
return 0; // Not sure what to do here...
|
|
@@ -938,13 +938,13 @@ void HoeffdingTree<
|
|
|
938
938
|
// Create the children.
|
|
939
939
|
arma::Col<size_t> childMajorities;
|
|
940
940
|
if (dimensionMappings->at(splitDimension).first ==
|
|
941
|
-
|
|
941
|
+
Datatype::categorical)
|
|
942
942
|
{
|
|
943
943
|
categoricalSplits[dimensionMappings->at(splitDimension).second].Split(
|
|
944
944
|
childMajorities, categoricalSplit);
|
|
945
945
|
}
|
|
946
946
|
else if (dimensionMappings->at(splitDimension).first ==
|
|
947
|
-
|
|
947
|
+
Datatype::numeric)
|
|
948
948
|
{
|
|
949
949
|
numericSplits[dimensionMappings->at(splitDimension).second].Split(
|
|
950
950
|
childMajorities, numericSplit);
|
|
@@ -1016,7 +1016,7 @@ void HoeffdingTree<
|
|
|
1016
1016
|
{
|
|
1017
1017
|
if (ownsInfo)
|
|
1018
1018
|
delete datasetInfo;
|
|
1019
|
-
datasetInfo = new
|
|
1019
|
+
datasetInfo = new DatasetInfo(dimensionality); // All features numeric.
|
|
1020
1020
|
ownsInfo = true;
|
|
1021
1021
|
|
|
1022
1022
|
this->numClasses = numClasses;
|
|
@@ -1033,7 +1033,7 @@ void HoeffdingTree<
|
|
|
1033
1033
|
FitnessFunction,
|
|
1034
1034
|
NumericSplitType,
|
|
1035
1035
|
CategoricalSplitType
|
|
1036
|
-
>::Reset(const
|
|
1036
|
+
>::Reset(const DatasetInfo& info, const size_t numClasses)
|
|
1037
1037
|
{
|
|
1038
1038
|
if (ownsInfo)
|
|
1039
1039
|
delete datasetInfo;
|
|
@@ -1066,9 +1066,9 @@ void HoeffdingTree<
|
|
|
1066
1066
|
ar(CEREAL_POINTER(dimensionMappings));
|
|
1067
1067
|
|
|
1068
1068
|
// Special handling for const object.
|
|
1069
|
-
|
|
1069
|
+
DatasetInfo* d = NULL;
|
|
1070
1070
|
if (cereal::is_saving<Archive>())
|
|
1071
|
-
d = const_cast<
|
|
1071
|
+
d = const_cast<DatasetInfo*>(datasetInfo);
|
|
1072
1072
|
ar(CEREAL_POINTER(d));
|
|
1073
1073
|
|
|
1074
1074
|
if (cereal::is_loading<Archive>())
|
|
@@ -1108,7 +1108,7 @@ void HoeffdingTree<
|
|
|
1108
1108
|
categoricalSplits.clear();
|
|
1109
1109
|
for (size_t i = 0; i < datasetInfo->Dimensionality(); ++i)
|
|
1110
1110
|
{
|
|
1111
|
-
if (datasetInfo->Type(i) ==
|
|
1111
|
+
if (datasetInfo->Type(i) == Datatype::categorical)
|
|
1112
1112
|
categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
|
|
1113
1113
|
datasetInfo->NumMappings(i), numClasses));
|
|
1114
1114
|
else
|
|
@@ -1136,7 +1136,7 @@ void HoeffdingTree<
|
|
|
1136
1136
|
else
|
|
1137
1137
|
{
|
|
1138
1138
|
// We have split, so we only need to save the split and the children.
|
|
1139
|
-
if (datasetInfo->Type(splitDimension) ==
|
|
1139
|
+
if (datasetInfo->Type(splitDimension) == Datatype::categorical)
|
|
1140
1140
|
ar(CEREAL_NVP(categoricalSplit));
|
|
1141
1141
|
else
|
|
1142
1142
|
ar(CEREAL_NVP(numericSplit));
|
|
@@ -1280,18 +1280,18 @@ void HoeffdingTree<
|
|
|
1280
1280
|
ownsMappings = true;
|
|
1281
1281
|
for (size_t i = 0; i < datasetInfo->Dimensionality(); ++i)
|
|
1282
1282
|
{
|
|
1283
|
-
if (datasetInfo->Type(i) ==
|
|
1283
|
+
if (datasetInfo->Type(i) == Datatype::categorical)
|
|
1284
1284
|
{
|
|
1285
1285
|
categoricalSplits.push_back(CategoricalSplitType<FitnessFunction>(
|
|
1286
1286
|
datasetInfo->NumMappings(i), numClasses, categoricalSplitIn));
|
|
1287
|
-
(*dimensionMappings)[i] = std::make_pair(
|
|
1287
|
+
(*dimensionMappings)[i] = std::make_pair(Datatype::categorical,
|
|
1288
1288
|
categoricalSplits.size() - 1);
|
|
1289
1289
|
}
|
|
1290
1290
|
else
|
|
1291
1291
|
{
|
|
1292
1292
|
numericSplits.push_back(NumericSplitType<FitnessFunction>(numClasses,
|
|
1293
1293
|
numericSplitIn));
|
|
1294
|
-
(*dimensionMappings)[i] = std::make_pair(
|
|
1294
|
+
(*dimensionMappings)[i] = std::make_pair(Datatype::numeric,
|
|
1295
1295
|
numericSplits.size() - 1);
|
|
1296
1296
|
}
|
|
1297
1297
|
}
|
|
@@ -23,7 +23,6 @@
|
|
|
23
23
|
|
|
24
24
|
using namespace std;
|
|
25
25
|
using namespace mlpack;
|
|
26
|
-
using namespace mlpack::data;
|
|
27
26
|
using namespace mlpack::util;
|
|
28
27
|
|
|
29
28
|
// Program Name.
|
|
@@ -276,7 +275,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
|
|
|
276
275
|
if (params.Has("test"))
|
|
277
276
|
{
|
|
278
277
|
// Before loading, pre-set the dataset info by getting the raw parameter
|
|
279
|
-
// (that doesn't call
|
|
278
|
+
// (that doesn't call Load()).
|
|
280
279
|
std::get<0>(params.GetRaw<TupleType>("test")) = datasetInfo;
|
|
281
280
|
arma::mat testSet = std::get<1>(params.Get<TupleType>("test"));
|
|
282
281
|
|
|
@@ -112,7 +112,7 @@ class HoeffdingTreeModel
|
|
|
112
112
|
* Hoeffding numeric split.
|
|
113
113
|
*/
|
|
114
114
|
void BuildModel(const arma::mat& dataset,
|
|
115
|
-
const
|
|
115
|
+
const DatasetInfo& datasetInfo,
|
|
116
116
|
const arma::Row<size_t>& labels,
|
|
117
117
|
const size_t numClasses,
|
|
118
118
|
const bool batchTraining,
|
|
@@ -185,7 +185,7 @@ class HoeffdingTreeModel
|
|
|
185
185
|
ar(CEREAL_NVP(type));
|
|
186
186
|
|
|
187
187
|
// Fake dataset info may be needed to create fake trees.
|
|
188
|
-
|
|
188
|
+
DatasetInfo info;
|
|
189
189
|
if (type == GINI_HOEFFDING)
|
|
190
190
|
ar(CEREAL_POINTER(giniHoeffdingTree));
|
|
191
191
|
else if (type == GINI_BINARY)
|
|
@@ -129,7 +129,7 @@ inline HoeffdingTreeModel::~HoeffdingTreeModel()
|
|
|
129
129
|
// Create the model.
|
|
130
130
|
inline void HoeffdingTreeModel::BuildModel(
|
|
131
131
|
const arma::mat& dataset,
|
|
132
|
-
const
|
|
132
|
+
const DatasetInfo& datasetInfo,
|
|
133
133
|
const arma::Row<size_t>& labels,
|
|
134
134
|
const size_t numClasses,
|
|
135
135
|
const bool batchTraining,
|
|
@@ -226,10 +226,10 @@ Score(const size_t queryIndex, TreeType& referenceNode)
|
|
|
226
226
|
sample(oldSize + i) =
|
|
227
227
|
EvaluateKernel(queryIndex, referenceNode.Descendant(randomPoint));
|
|
228
228
|
}
|
|
229
|
-
meanSample =
|
|
230
|
-
const double
|
|
229
|
+
meanSample = mean(sample);
|
|
230
|
+
const double sampleStddev = stddev(sample);
|
|
231
231
|
const double mThreshBase =
|
|
232
|
-
z *
|
|
232
|
+
z * sampleStddev * (1 + relError) / (relError * meanSample);
|
|
233
233
|
const size_t mThresh = std::ceil(mThreshBase * mThreshBase);
|
|
234
234
|
|
|
235
235
|
if (sample.size() < mThresh)
|
|
@@ -441,10 +441,10 @@ Score(TreeType& queryNode, TreeType& referenceNode)
|
|
|
441
441
|
sample(oldSize + i) =
|
|
442
442
|
EvaluateKernel(queryIndex, referenceNode.Descendant(randomPoint));
|
|
443
443
|
}
|
|
444
|
-
meanSample =
|
|
445
|
-
const double
|
|
444
|
+
meanSample = mean(sample);
|
|
445
|
+
const double sampleStddev = stddev(sample);
|
|
446
446
|
const double mThreshBase =
|
|
447
|
-
z *
|
|
447
|
+
z * sampleStddev * (1 + relError) / (relError * meanSample);
|
|
448
448
|
const size_t mThresh = std::ceil(mThreshBase * mThreshBase);
|
|
449
449
|
|
|
450
450
|
if (sample.size() < mThresh)
|
|
@@ -510,7 +510,7 @@ LARS<ModelMatType>::Train(const MatType& matX,
|
|
|
510
510
|
{
|
|
511
511
|
if (fitIntercept)
|
|
512
512
|
{
|
|
513
|
-
offsetX =
|
|
513
|
+
offsetX = mean(matX, 1);
|
|
514
514
|
dataTrans = (matX.each_col() - offsetX).t();
|
|
515
515
|
}
|
|
516
516
|
|
|
@@ -536,7 +536,7 @@ LARS<ModelMatType>::Train(const MatType& matX,
|
|
|
536
536
|
// We don't need to transpose the data---it's already in row-major form.
|
|
537
537
|
if (fitIntercept)
|
|
538
538
|
{
|
|
539
|
-
offsetX =
|
|
539
|
+
offsetX = mean(matX, 0).t();
|
|
540
540
|
dataTrans = (matX.each_row() - offsetX.t());
|
|
541
541
|
}
|
|
542
542
|
|
|
@@ -558,7 +558,7 @@ LARS<ModelMatType>::Train(const MatType& matX,
|
|
|
558
558
|
|
|
559
559
|
if (fitIntercept)
|
|
560
560
|
{
|
|
561
|
-
this->offsetY =
|
|
561
|
+
this->offsetY = mean(y);
|
|
562
562
|
yCentered = y - this->offsetY;
|
|
563
563
|
}
|
|
564
564
|
|
|
@@ -288,7 +288,7 @@ void LinearSVMFunction<MatType, ParametersType>::Gradient(
|
|
|
288
288
|
}
|
|
289
289
|
else
|
|
290
290
|
{
|
|
291
|
-
gradient.set_size(
|
|
291
|
+
gradient.set_size(size(parameters));
|
|
292
292
|
gradient.submat(0, 0, parameters.n_rows - 2, parameters.n_cols - 1) =
|
|
293
293
|
dataset * difference.t();
|
|
294
294
|
gradient.row(parameters.n_rows - 1) =
|
|
@@ -345,7 +345,7 @@ void LinearSVMFunction<MatType, ParametersType>::Gradient(
|
|
|
345
345
|
}
|
|
346
346
|
else
|
|
347
347
|
{
|
|
348
|
-
gradient.set_size(
|
|
348
|
+
gradient.set_size(size(parameters));
|
|
349
349
|
gradient.submat(0, 0, parameters.n_rows - 2, parameters.n_cols - 1) =
|
|
350
350
|
dataset.cols(firstId, lastId) * difference.t();
|
|
351
351
|
gradient.row(parameters.n_rows - 1) =
|
|
@@ -400,7 +400,7 @@ LinearSVMFunction<MatType, ParametersType>::EvaluateWithGradient(
|
|
|
400
400
|
}
|
|
401
401
|
else
|
|
402
402
|
{
|
|
403
|
-
gradient.set_size(
|
|
403
|
+
gradient.set_size(size(parameters));
|
|
404
404
|
gradient.submat(0, 0, parameters.n_rows - 2, parameters.n_cols - 1) =
|
|
405
405
|
dataset * difference.t();
|
|
406
406
|
gradient.row(parameters.n_rows - 1) =
|
|
@@ -472,7 +472,7 @@ LinearSVMFunction<MatType, ParametersType>::EvaluateWithGradient(
|
|
|
472
472
|
}
|
|
473
473
|
else
|
|
474
474
|
{
|
|
475
|
-
gradient.set_size(
|
|
475
|
+
gradient.set_size(size(parameters));
|
|
476
476
|
gradient.submat(0, 0, parameters.n_rows - 2, parameters.n_cols - 1) =
|
|
477
477
|
dataset.cols(firstId, lastId) * difference.t();
|
|
478
478
|
gradient.row(parameters.n_rows - 1) =
|