mlpack 4.6.2__cp38-cp38-win_amd64.whl → 4.7.0__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +3 -3
- mlpack/adaboost_classify.cp38-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp38-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp38-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp38-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp38-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp38-win_amd64.pyd +0 -0
- mlpack/cf.cp38-win_amd64.pyd +0 -0
- mlpack/dbscan.cp38-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp38-win_amd64.pyd +0 -0
- mlpack/det.cp38-win_amd64.pyd +0 -0
- mlpack/emst.cp38-win_amd64.pyd +0 -0
- mlpack/fastmks.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp38-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp38-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp38-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp38-win_amd64.pyd +0 -0
- mlpack/image_converter.cp38-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp38-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp38-win_amd64.pyd +0 -0
- mlpack/kfn.cp38-win_amd64.pyd +0 -0
- mlpack/kmeans.cp38-win_amd64.pyd +0 -0
- mlpack/knn.cp38-win_amd64.pyd +0 -0
- mlpack/krann.cp38-win_amd64.pyd +0 -0
- mlpack/lars.cp38-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp38-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp38-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp38-win_amd64.pyd +0 -0
- mlpack/lmnn.cp38-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp38-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp38-win_amd64.pyd +0 -0
- mlpack/lsh.cp38-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp38-win_amd64.pyd +0 -0
- mlpack/nbc.cp38-win_amd64.pyd +0 -0
- mlpack/nca.cp38-win_amd64.pyd +0 -0
- mlpack/nmf.cp38-win_amd64.pyd +0 -0
- mlpack/pca.cp38-win_amd64.pyd +0 -0
- mlpack/perceptron.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp38-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp38-win_amd64.pyd +0 -0
- mlpack/radical.cp38-win_amd64.pyd +0 -0
- mlpack/random_forest.cp38-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp38-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp38-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +5 -5
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +395 -376
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{.load-order-mlpack-4.6.2 → .load-order-mlpack-4.7.0} +0 -0
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
#include "one_hot_encoding.hpp"
|
|
18
18
|
|
|
19
19
|
namespace mlpack {
|
|
20
|
-
namespace data {
|
|
21
20
|
|
|
22
21
|
/**
|
|
23
22
|
* Given a set of labels of a particular datatype, convert them to binary
|
|
@@ -147,7 +146,7 @@ void OneHotEncoding(const arma::Mat<eT>& input,
|
|
|
147
146
|
* Overloaded function for the above function, which takes a matrix as input
|
|
148
147
|
* and also a DatasetInfo object and outputs a matrix.
|
|
149
148
|
* This function encodes all the dimensions marked `Datatype::categorical`
|
|
150
|
-
* in the
|
|
149
|
+
* in the DatasetInfo.
|
|
151
150
|
*
|
|
152
151
|
* @param input Input dataset to be encoded.
|
|
153
152
|
* @param output Encoded matrix.
|
|
@@ -156,12 +155,12 @@ void OneHotEncoding(const arma::Mat<eT>& input,
|
|
|
156
155
|
template<typename eT>
|
|
157
156
|
void OneHotEncoding(const arma::Mat<eT>& input,
|
|
158
157
|
arma::Mat<eT>& output,
|
|
159
|
-
const
|
|
158
|
+
const DatasetInfo& datasetInfo)
|
|
160
159
|
{
|
|
161
160
|
std::vector<size_t> indices;
|
|
162
161
|
for (size_t i = 0; i < datasetInfo.Dimensionality(); ++i)
|
|
163
162
|
{
|
|
164
|
-
if (datasetInfo.Type(i) ==
|
|
163
|
+
if (datasetInfo.Type(i) == Datatype::categorical)
|
|
165
164
|
{
|
|
166
165
|
indices.push_back(i);
|
|
167
166
|
}
|
|
@@ -169,7 +168,6 @@ void OneHotEncoding(const arma::Mat<eT>& input,
|
|
|
169
168
|
OneHotEncoding(input, arma::Col<size_t>(indices), output);
|
|
170
169
|
}
|
|
171
170
|
|
|
172
|
-
} // namespace data
|
|
173
171
|
} // namespace mlpack
|
|
174
172
|
|
|
175
173
|
#endif
|
|
@@ -15,123 +15,16 @@
|
|
|
15
15
|
#define MLPACK_CORE_DATA_SAVE_HPP
|
|
16
16
|
|
|
17
17
|
#include <mlpack/prereqs.hpp>
|
|
18
|
-
#include <mlpack/core/util/log.hpp>
|
|
19
18
|
|
|
19
|
+
#include "image_options.hpp"
|
|
20
20
|
#include "text_options.hpp"
|
|
21
|
-
#include "format.hpp"
|
|
22
|
-
#include "image_info.hpp"
|
|
23
21
|
#include "detect_file_type.hpp"
|
|
22
|
+
#include "save_deprecated.hpp"
|
|
23
|
+
#include "save_numeric.hpp"
|
|
24
|
+
#include "save_model.hpp"
|
|
24
25
|
#include "save_image.hpp"
|
|
25
|
-
#include "utilities.hpp"
|
|
26
26
|
|
|
27
27
|
namespace mlpack {
|
|
28
|
-
namespace data /** Functions to load and save matrices. */ {
|
|
29
|
-
|
|
30
|
-
/**
|
|
31
|
-
* Saves a matrix to file, guessing the filetype from the extension. This
|
|
32
|
-
* will transpose the matrix at save time. If the filetype cannot be
|
|
33
|
-
* determined, an error will be given.
|
|
34
|
-
*
|
|
35
|
-
* The supported types of files are the same as found in Armadillo:
|
|
36
|
-
*
|
|
37
|
-
* - CSV (arma::csv_ascii), denoted by .csv, or optionally .txt
|
|
38
|
-
* - ASCII (arma::raw_ascii), denoted by .txt
|
|
39
|
-
* - Armadillo ASCII (arma::arma_ascii), also denoted by .txt
|
|
40
|
-
* - PGM (arma::pgm_binary), denoted by .pgm
|
|
41
|
-
* - PPM (arma::ppm_binary), denoted by .ppm
|
|
42
|
-
* - Raw binary (arma::raw_binary), denoted by .bin
|
|
43
|
-
* - Armadillo binary (arma::arma_binary), denoted by .bin
|
|
44
|
-
* - HDF5 (arma::hdf5_binary), denoted by .hdf5, .hdf, .h5, or .he5
|
|
45
|
-
*
|
|
46
|
-
* By default, this function will try to automatically determine the format to
|
|
47
|
-
* save with based only on the filename's extension. If you would prefer to
|
|
48
|
-
* specify a file type manually, override the default
|
|
49
|
-
* `inputSaveType` parameter with the correct type above (e.g.
|
|
50
|
-
* `arma::csv_ascii`.)
|
|
51
|
-
*
|
|
52
|
-
* If the 'fatal' parameter is set to true, a std::runtime_error exception will
|
|
53
|
-
* be thrown upon failure. If the 'transpose' parameter is set to true, the
|
|
54
|
-
* matrix will be transposed before saving. Generally, because mlpack stores
|
|
55
|
-
* matrices in a column-major format and most datasets are stored on disk as
|
|
56
|
-
* row-major, this parameter should be left at its default value of 'true'.
|
|
57
|
-
*
|
|
58
|
-
* @param filename Name of file to save to.
|
|
59
|
-
* @param matrix Matrix to save into file.
|
|
60
|
-
* @param fatal If an error should be reported as fatal (default false).
|
|
61
|
-
* @param transpose If true, transpose the matrix before saving (default true).
|
|
62
|
-
* @param inputSaveType File type to save to (defaults to arma::auto_detect).
|
|
63
|
-
* @return Boolean value indicating success or failure of save.
|
|
64
|
-
*/
|
|
65
|
-
template<typename eT>
|
|
66
|
-
bool Save(const std::string& filename,
|
|
67
|
-
const arma::Mat<eT>& matrix,
|
|
68
|
-
const bool fatal = false,
|
|
69
|
-
bool transpose = true,
|
|
70
|
-
FileType inputSaveType = FileType::AutoDetect);
|
|
71
|
-
|
|
72
|
-
/**
|
|
73
|
-
* Saves a sparse matrix to file, guessing the filetype from the
|
|
74
|
-
* extension. This will transpose the matrix at save time. If the
|
|
75
|
-
* filetype cannot be determined, an error will be given.
|
|
76
|
-
*
|
|
77
|
-
* The supported types of files are the same as found in Armadillo:
|
|
78
|
-
*
|
|
79
|
-
* - TSV (coord_ascii), denoted by .tsv or .txt
|
|
80
|
-
* - TXT (coord_ascii), denoted by .txt
|
|
81
|
-
* - Raw binary (raw_binary), denoted by .bin
|
|
82
|
-
* - Armadillo binary (arma_binary), denoted by .bin
|
|
83
|
-
*
|
|
84
|
-
* If the file extension is not one of those types, an error will be given. If
|
|
85
|
-
* the 'fatal' parameter is set to true, a std::runtime_error exception will be
|
|
86
|
-
* thrown upon failure. If the 'transpose' parameter is set to true, the matrix
|
|
87
|
-
* will be transposed before saving. Generally, because mlpack stores matrices
|
|
88
|
-
* in a column-major format and most datasets are stored on disk as row-major,
|
|
89
|
-
* this parameter should be left at its default value of 'true'.
|
|
90
|
-
*
|
|
91
|
-
* @param filename Name of file to save to.
|
|
92
|
-
* @param matrix Sparse matrix to save into file.
|
|
93
|
-
* @param fatal If an error should be reported as fatal (default false).
|
|
94
|
-
* @param transpose If true, transpose the matrix before saving (default true).
|
|
95
|
-
* @return Boolean value indicating success or failure of save.
|
|
96
|
-
*/
|
|
97
|
-
template<typename eT>
|
|
98
|
-
bool Save(const std::string& filename,
|
|
99
|
-
const arma::SpMat<eT>& matrix,
|
|
100
|
-
const bool fatal = false,
|
|
101
|
-
bool transpose = true);
|
|
102
|
-
|
|
103
|
-
/**
|
|
104
|
-
* Saves a model to file, guessing the filetype from the extension, or,
|
|
105
|
-
* optionally, saving the specified format. If automatic extension detection is
|
|
106
|
-
* used and the filetype cannot be determined, and error will be given.
|
|
107
|
-
*
|
|
108
|
-
* The supported types of files are the same as what is supported by the
|
|
109
|
-
* cereal library:
|
|
110
|
-
*
|
|
111
|
-
* - json, denoted by .json
|
|
112
|
-
* - xml, denoted by .xml
|
|
113
|
-
* - binary, denoted by .bin
|
|
114
|
-
*
|
|
115
|
-
* The format parameter can take any of the values in the 'format' enum:
|
|
116
|
-
* 'format::autodetect', 'format::json', 'format::xml', and 'format::binary'.
|
|
117
|
-
* The autodetect functionality operates on the file extension (so, "file.txt"
|
|
118
|
-
* would be autodetected as text).
|
|
119
|
-
*
|
|
120
|
-
* The name parameter should be specified to indicate the name of the structure
|
|
121
|
-
* to be saved. If Load() is later called on the generated file, the name used
|
|
122
|
-
* to load should be the same as the name used for this call to Save().
|
|
123
|
-
*
|
|
124
|
-
* If the parameter 'fatal' is set to true, then an exception will be thrown in
|
|
125
|
-
* the event of a save failure. Otherwise, the method will return false and the
|
|
126
|
-
* relevant error information will be printed to Log::Warn.
|
|
127
|
-
*/
|
|
128
|
-
template<typename T>
|
|
129
|
-
bool Save(const std::string& filename,
|
|
130
|
-
const std::string& name,
|
|
131
|
-
T& t,
|
|
132
|
-
const bool fatal = false,
|
|
133
|
-
format f = format::autodetect,
|
|
134
|
-
std::enable_if_t<HasSerialize<T>::value>* = 0);
|
|
135
28
|
|
|
136
29
|
/**
|
|
137
30
|
* This function defines a unified data saving interface for the library.
|
|
@@ -145,23 +38,36 @@ bool Save(const std::string& filename,
|
|
|
145
38
|
* @param opts DataOptions to be passed to the function
|
|
146
39
|
* @return Boolean value indicating success or failure of Save.
|
|
147
40
|
*/
|
|
41
|
+
|
|
148
42
|
template<typename MatType, typename DataOptionsType>
|
|
149
43
|
bool Save(const std::string& filename,
|
|
150
44
|
const MatType& matrix,
|
|
151
45
|
DataOptionsType& opts,
|
|
152
|
-
std::enable_if_t<
|
|
153
|
-
|
|
154
|
-
std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>* = 0);
|
|
46
|
+
const typename std::enable_if_t<
|
|
47
|
+
IsDataOptions<DataOptionsType>::value>* = 0);
|
|
155
48
|
|
|
156
|
-
template<typename MatType, typename DataOptionsType>
|
|
49
|
+
template<typename MatType, typename DataOptionsType = PlainDataOptions>
|
|
157
50
|
bool Save(const std::string& filename,
|
|
158
51
|
const MatType& matrix,
|
|
159
|
-
const DataOptionsType& opts,
|
|
160
|
-
std::enable_if_t<
|
|
161
|
-
|
|
162
|
-
|
|
52
|
+
const DataOptionsType& opts = DataOptionsType(),
|
|
53
|
+
const typename std::enable_if_t<
|
|
54
|
+
IsDataOptions<DataOptionsType>::value>* = 0);
|
|
55
|
+
|
|
56
|
+
template<typename eT>
|
|
57
|
+
bool Save(const std::vector<std::string>& files,
|
|
58
|
+
const arma::Mat<eT>& matrix,
|
|
59
|
+
ImageOptions& opts);
|
|
60
|
+
|
|
61
|
+
// Image saving API for multiple files.
|
|
62
|
+
template<typename eT>
|
|
63
|
+
bool Save(const std::vector<std::string>& files,
|
|
64
|
+
const arma::Mat<eT>& matrix,
|
|
65
|
+
ImageOptions& opts)
|
|
66
|
+
|
|
67
|
+
{
|
|
68
|
+
return SaveImage(files, matrix, opts);
|
|
69
|
+
}
|
|
163
70
|
|
|
164
|
-
} // namespace data
|
|
165
71
|
} // namespace mlpack
|
|
166
72
|
|
|
167
73
|
// Include implementation.
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/data/save_dense.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
5
|
+
*
|
|
6
|
+
* Internal implementation of dense matrix save function.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_DATA_SAVE_DENSE_HPP
|
|
14
|
+
#define MLPACK_CORE_DATA_SAVE_DENSE_HPP
|
|
15
|
+
|
|
16
|
+
#include "save_matrix.hpp"
|
|
17
|
+
|
|
18
|
+
namespace mlpack {
|
|
19
|
+
|
|
20
|
+
template<typename eT>
|
|
21
|
+
bool SaveDense(const arma::Mat<eT>& matrix,
|
|
22
|
+
TextOptions& opts,
|
|
23
|
+
const std::string& filename,
|
|
24
|
+
std::fstream& stream)
|
|
25
|
+
{
|
|
26
|
+
bool success = false;
|
|
27
|
+
arma::Mat<eT> tmp;
|
|
28
|
+
// Transpose the matrix.
|
|
29
|
+
if (!opts.NoTranspose())
|
|
30
|
+
{
|
|
31
|
+
tmp = trans(matrix);
|
|
32
|
+
success = SaveMatrix(tmp, opts, filename, stream);
|
|
33
|
+
}
|
|
34
|
+
else
|
|
35
|
+
success = SaveMatrix(matrix, opts, filename, stream);
|
|
36
|
+
|
|
37
|
+
return success;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
} // namespace mlpack
|
|
41
|
+
|
|
42
|
+
#endif
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/data/save_deprecated.hpp
|
|
3
|
+
* @author Omar Shrit
|
|
4
|
+
*
|
|
5
|
+
* Contains declaration and implementation of old deprecated save function.
|
|
6
|
+
* This should be removed when releasing mlpack 5.0.0.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_DATA_SAVE_DEPRECATED_HPP
|
|
14
|
+
#define MLPACK_CORE_DATA_SAVE_DEPRECATED_HPP
|
|
15
|
+
|
|
16
|
+
// In case it hasn't already been included.
|
|
17
|
+
#include "save.hpp"
|
|
18
|
+
#include "extension.hpp"
|
|
19
|
+
|
|
20
|
+
namespace mlpack {
|
|
21
|
+
|
|
22
|
+
/**
|
|
23
|
+
* Saves a matrix to file, guessing the filetype from the extension. This
|
|
24
|
+
* will transpose the matrix at save time. If the filetype cannot be
|
|
25
|
+
* determined, an error will be given.
|
|
26
|
+
*
|
|
27
|
+
* The supported types of files are the same as found in Armadillo:
|
|
28
|
+
*
|
|
29
|
+
* - CSV (arma::csv_ascii), denoted by .csv, or optionally .txt
|
|
30
|
+
* - ASCII (arma::raw_ascii), denoted by .txt
|
|
31
|
+
* - Armadillo ASCII (arma::arma_ascii), also denoted by .txt
|
|
32
|
+
* - PGM (arma::pgm_binary), denoted by .pgm
|
|
33
|
+
* - PPM (arma::ppm_binary), denoted by .ppm
|
|
34
|
+
* - Raw binary (arma::raw_binary), denoted by .bin
|
|
35
|
+
* - Armadillo binary (arma::arma_binary), denoted by .bin
|
|
36
|
+
* - HDF5 (arma::hdf5_binary), denoted by .hdf5, .hdf, .h5, or .he5
|
|
37
|
+
*
|
|
38
|
+
* By default, this function will try to automatically determine the format to
|
|
39
|
+
* save with based only on the filename's extension. If you would prefer to
|
|
40
|
+
* specify a file type manually, override the default
|
|
41
|
+
* `inputSaveType` parameter with the correct type above (e.g.
|
|
42
|
+
* `arma::csv_ascii`.)
|
|
43
|
+
*
|
|
44
|
+
* If the 'fatal' parameter is set to true, a std::runtime_error exception will
|
|
45
|
+
* be thrown upon failure. If the 'transpose' parameter is set to true, the
|
|
46
|
+
* matrix will be transposed before saving. Generally, because mlpack stores
|
|
47
|
+
* matrices in a column-major format and most datasets are stored on disk as
|
|
48
|
+
* row-major, this parameter should be left at its default value of 'true'.
|
|
49
|
+
*
|
|
50
|
+
* @param filename Name of file to save to.
|
|
51
|
+
* @param matrix Matrix to save into file.
|
|
52
|
+
* @param fatal If an error should be reported as fatal (default false).
|
|
53
|
+
* @param transpose If true, transpose the matrix before saving (default true).
|
|
54
|
+
* @param inputSaveType File type to save to (defaults to arma::auto_detect).
|
|
55
|
+
* @return Boolean value indicating success or failure of save.
|
|
56
|
+
*/
|
|
57
|
+
template<typename eT>
|
|
58
|
+
bool Save(const std::string& filename,
|
|
59
|
+
const arma::Mat<eT>& matrix,
|
|
60
|
+
const bool fatal = false,
|
|
61
|
+
bool transpose = true,
|
|
62
|
+
FileType inputSaveType = FileType::AutoDetect);
|
|
63
|
+
|
|
64
|
+
/**
|
|
65
|
+
* Saves a sparse matrix to file, guessing the filetype from the
|
|
66
|
+
* extension. This will transpose the matrix at save time. If the
|
|
67
|
+
* filetype cannot be determined, an error will be given.
|
|
68
|
+
*
|
|
69
|
+
* The supported types of files are the same as found in Armadillo:
|
|
70
|
+
*
|
|
71
|
+
* - TSV (coord_ascii), denoted by .tsv or .txt
|
|
72
|
+
* - TXT (coord_ascii), denoted by .txt
|
|
73
|
+
* - Raw binary (raw_binary), denoted by .bin
|
|
74
|
+
* - Armadillo binary (arma_binary), denoted by .bin
|
|
75
|
+
*
|
|
76
|
+
* If the file extension is not one of those types, an error will be given. If
|
|
77
|
+
* the 'fatal' parameter is set to true, a std::runtime_error exception will be
|
|
78
|
+
* thrown upon failure. If the 'transpose' parameter is set to true, the matrix
|
|
79
|
+
* will be transposed before saving. Generally, because mlpack stores matrices
|
|
80
|
+
* in a column-major format and most datasets are stored on disk as row-major,
|
|
81
|
+
* this parameter should be left at its default value of 'true'.
|
|
82
|
+
*
|
|
83
|
+
* @param filename Name of file to save to.
|
|
84
|
+
* @param matrix Sparse matrix to save into file.
|
|
85
|
+
* @param fatal If an error should be reported as fatal (default false).
|
|
86
|
+
* @param transpose If true, transpose the matrix before saving (default true).
|
|
87
|
+
* @return Boolean value indicating success or failure of save.
|
|
88
|
+
*/
|
|
89
|
+
template<typename eT>
|
|
90
|
+
bool Save(const std::string& filename,
|
|
91
|
+
const arma::SpMat<eT>& matrix,
|
|
92
|
+
const bool fatal = false,
|
|
93
|
+
bool transpose = true);
|
|
94
|
+
|
|
95
|
+
/**
|
|
96
|
+
* Saves a model to file, guessing the filetype from the extension, or,
|
|
97
|
+
* optionally, saving the specified format. If automatic extension detection is
|
|
98
|
+
* used and the filetype cannot be determined, and error will be given.
|
|
99
|
+
*
|
|
100
|
+
* The supported types of files are the same as what is supported by the
|
|
101
|
+
* cereal library:
|
|
102
|
+
*
|
|
103
|
+
* - JSON, denoted by .json
|
|
104
|
+
* - XML, denoted by .xml
|
|
105
|
+
* - BIN, denoted by .bin
|
|
106
|
+
*
|
|
107
|
+
* The FileType parameter can take any of the model-specific values in the
|
|
108
|
+
* 'FileType' enum: 'FileType::Autodetect', 'FileType::JSON', 'FileType::XML',
|
|
109
|
+
* and 'FileType::BIN'. The autodetect functionality operates on the file
|
|
110
|
+
* extension (so, "file.txt" would be autodetected as text).
|
|
111
|
+
*
|
|
112
|
+
* The name parameter should be specified to indicate the name of the structure
|
|
113
|
+
* to be saved. If Load() is later called on the generated file, the name used
|
|
114
|
+
* to load should be the same as the name used for this call to Save().
|
|
115
|
+
*
|
|
116
|
+
* If the parameter 'fatal' is set to true, then an exception will be thrown in
|
|
117
|
+
* the event of a save failure. Otherwise, the method will return false and the
|
|
118
|
+
* relevant error information will be printed to Log::Warn.
|
|
119
|
+
*/
|
|
120
|
+
template<typename T>
|
|
121
|
+
bool Save(const std::string& filename,
|
|
122
|
+
const std::string& name,
|
|
123
|
+
T& t,
|
|
124
|
+
const bool fatal = false,
|
|
125
|
+
format f = format::autodetect,
|
|
126
|
+
std::enable_if_t<HasSerialize<T>::value>* = 0);
|
|
127
|
+
|
|
128
|
+
template<typename eT>
|
|
129
|
+
bool Save(const std::string& filename,
|
|
130
|
+
const arma::Col<eT>& vec,
|
|
131
|
+
const bool fatal,
|
|
132
|
+
FileType inputSaveType)
|
|
133
|
+
{
|
|
134
|
+
// Don't transpose: one observation per line (for CSVs at least).
|
|
135
|
+
return Save(filename, vec, fatal, false, inputSaveType);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
template<typename eT>
|
|
139
|
+
bool Save(const std::string& filename,
|
|
140
|
+
const arma::Row<eT>& rowvec,
|
|
141
|
+
const bool fatal,
|
|
142
|
+
FileType inputSaveType)
|
|
143
|
+
{
|
|
144
|
+
return Save(filename, rowvec, fatal, true, inputSaveType);
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// Save a Sparse Matrix
|
|
148
|
+
template<typename eT>
|
|
149
|
+
bool Save(const std::string& filename,
|
|
150
|
+
const arma::SpMat<eT>& matrix,
|
|
151
|
+
const bool fatal,
|
|
152
|
+
bool transpose)
|
|
153
|
+
{
|
|
154
|
+
MatrixOptions opts;
|
|
155
|
+
opts.Fatal() = fatal;
|
|
156
|
+
opts.NoTranspose() = !transpose;
|
|
157
|
+
|
|
158
|
+
return Save(filename, matrix, opts);
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
template<typename eT>
|
|
162
|
+
bool Save(const std::string& filename,
|
|
163
|
+
const arma::Mat<eT>& matrix,
|
|
164
|
+
const bool fatal,
|
|
165
|
+
bool transpose,
|
|
166
|
+
FileType inputSaveType)
|
|
167
|
+
{
|
|
168
|
+
MatrixOptions opts;
|
|
169
|
+
opts.Fatal() = fatal;
|
|
170
|
+
opts.NoTranspose() = !transpose;
|
|
171
|
+
opts.Format() = inputSaveType;
|
|
172
|
+
|
|
173
|
+
return Save(filename, matrix, opts);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
/**
|
|
177
|
+
* Save the image file from the given matrix.
|
|
178
|
+
*
|
|
179
|
+
* @param filename Name of the image file.
|
|
180
|
+
* @param matrix Matrix to save the image from.
|
|
181
|
+
* @param info An object of ImageInfo class.
|
|
182
|
+
* @param fatal If an error should be reported as fatal (default false).
|
|
183
|
+
* @return Boolean value indicating success or failure of load.
|
|
184
|
+
*/
|
|
185
|
+
template<typename eT>
|
|
186
|
+
bool Save(const std::string& filename,
|
|
187
|
+
const arma::Mat<eT>& matrix,
|
|
188
|
+
ImageInfo& opts,
|
|
189
|
+
const bool fatal)
|
|
190
|
+
{
|
|
191
|
+
opts.Fatal() = fatal;
|
|
192
|
+
opts.Format() = FileType::ImageType;
|
|
193
|
+
std::vector<std::string> files;
|
|
194
|
+
files.push_back(filename);
|
|
195
|
+
return SaveImage(files, matrix, opts);
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
/**
|
|
199
|
+
* Save the image file from the given matrix.
|
|
200
|
+
*
|
|
201
|
+
* @param files A vector consisting of filenames.
|
|
202
|
+
* @param matrix Matrix to save the image from.
|
|
203
|
+
* @param info An object of ImageInfo class.
|
|
204
|
+
* @param fatal If an error should be reported as fatal (default false).
|
|
205
|
+
* @return Boolean value indicating success or failure of load.
|
|
206
|
+
*/
|
|
207
|
+
template<typename eT>
|
|
208
|
+
bool Save(const std::vector<std::string>& files,
|
|
209
|
+
const arma::Mat<eT>& matrix,
|
|
210
|
+
ImageInfo& opts,
|
|
211
|
+
const bool fatal)
|
|
212
|
+
{
|
|
213
|
+
opts.Fatal() = fatal;
|
|
214
|
+
opts.Format() = FileType::ImageType;
|
|
215
|
+
return SaveImage(files, matrix, opts);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
// Save a model to file.
|
|
219
|
+
// Keep this implementation until mlpack 5.0.0 Then we can remove it.
|
|
220
|
+
template<typename T>
|
|
221
|
+
bool Save(const std::string& filename,
|
|
222
|
+
const std::string& name,
|
|
223
|
+
T& t,
|
|
224
|
+
const bool fatal,
|
|
225
|
+
format f,
|
|
226
|
+
std::enable_if_t<HasSerialize<T>::value>*)
|
|
227
|
+
{
|
|
228
|
+
if (f == format::autodetect)
|
|
229
|
+
{
|
|
230
|
+
std::string extension = Extension(filename);
|
|
231
|
+
|
|
232
|
+
if (extension == "xml")
|
|
233
|
+
f = format::xml;
|
|
234
|
+
else if (extension == "bin")
|
|
235
|
+
f = format::binary;
|
|
236
|
+
else if (extension == "json")
|
|
237
|
+
f = format::json;
|
|
238
|
+
else
|
|
239
|
+
{
|
|
240
|
+
if (fatal)
|
|
241
|
+
Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
|
|
242
|
+
<< " extension? (allowed: xml/bin/json)" << std::endl;
|
|
243
|
+
else
|
|
244
|
+
Log::Warn << "Unable to detect type of '" << filename << "'; save "
|
|
245
|
+
<< "failed. Incorrect extension? (allowed: xml/bin/json)"
|
|
246
|
+
<< std::endl;
|
|
247
|
+
|
|
248
|
+
return false;
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
// Open the file to save to.
|
|
253
|
+
std::ofstream ofs;
|
|
254
|
+
#ifdef _WIN32
|
|
255
|
+
if (f == format::binary) // Open non-text types in binary mode on Windows.
|
|
256
|
+
ofs.open(filename, std::ofstream::out | std::ofstream::binary);
|
|
257
|
+
else
|
|
258
|
+
ofs.open(filename, std::ofstream::out);
|
|
259
|
+
#else
|
|
260
|
+
ofs.open(filename, std::ofstream::out);
|
|
261
|
+
#endif
|
|
262
|
+
|
|
263
|
+
if (!ofs.is_open())
|
|
264
|
+
{
|
|
265
|
+
if (fatal)
|
|
266
|
+
Log::Fatal << "Unable to open file '" << filename << "' to save object '"
|
|
267
|
+
<< name << "'." << std::endl;
|
|
268
|
+
else
|
|
269
|
+
Log::Warn << "Unable to open file '" << filename << "' to save object '"
|
|
270
|
+
<< name << "'." << std::endl;
|
|
271
|
+
|
|
272
|
+
return false;
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
try
|
|
276
|
+
{
|
|
277
|
+
if (f == format::xml)
|
|
278
|
+
{
|
|
279
|
+
cereal::XMLOutputArchive ar(ofs);
|
|
280
|
+
ar(cereal::make_nvp(name.c_str(), t));
|
|
281
|
+
}
|
|
282
|
+
else if (f == format::json)
|
|
283
|
+
{
|
|
284
|
+
cereal::JSONOutputArchive ar(ofs);
|
|
285
|
+
ar(cereal::make_nvp(name.c_str(), t));
|
|
286
|
+
}
|
|
287
|
+
else if (f == format::binary)
|
|
288
|
+
{
|
|
289
|
+
cereal::BinaryOutputArchive ar(ofs);
|
|
290
|
+
ar(cereal::make_nvp(name.c_str(), t));
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
return true;
|
|
294
|
+
}
|
|
295
|
+
catch (cereal::Exception& e)
|
|
296
|
+
{
|
|
297
|
+
if (fatal)
|
|
298
|
+
Log::Fatal << e.what() << std::endl;
|
|
299
|
+
else
|
|
300
|
+
Log::Warn << e.what() << std::endl;
|
|
301
|
+
|
|
302
|
+
return false;
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
} // namespace mlpack
|
|
307
|
+
|
|
308
|
+
#endif
|