mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* @file core/data/save_image.hpp
|
|
3
3
|
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
4
5
|
*
|
|
5
|
-
* Implementation of save functionality.
|
|
6
|
+
* Implementation of save image functionality.
|
|
6
7
|
*
|
|
7
8
|
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
8
9
|
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
@@ -13,54 +14,93 @@
|
|
|
13
14
|
#define MLPACK_CORE_DATA_SAVE_IMAGE_HPP
|
|
14
15
|
|
|
15
16
|
#include <mlpack/core/stb/stb.hpp>
|
|
16
|
-
|
|
17
|
-
#include "image_info.hpp"
|
|
17
|
+
#include <mlpack/core/math/make_alias.hpp>
|
|
18
18
|
|
|
19
19
|
namespace mlpack {
|
|
20
|
-
namespace data {
|
|
21
20
|
|
|
22
|
-
/**
|
|
23
|
-
* Save the image file from the given matrix.
|
|
24
|
-
*
|
|
25
|
-
* @param filename Name of the image file.
|
|
26
|
-
* @param matrix Matrix to save the image from.
|
|
27
|
-
* @param info An object of ImageInfo class.
|
|
28
|
-
* @param fatal If an error should be reported as fatal (default false).
|
|
29
|
-
* @return Boolean value indicating success or failure of load.
|
|
30
|
-
*/
|
|
31
21
|
template<typename eT>
|
|
32
|
-
bool
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
22
|
+
bool SaveImage(const std::vector<std::string>& files,
|
|
23
|
+
const arma::Mat<eT>& matrix,
|
|
24
|
+
ImageOptions& opts)
|
|
25
|
+
{
|
|
26
|
+
if (files.empty())
|
|
27
|
+
{
|
|
28
|
+
std::stringstream oss;
|
|
29
|
+
oss << "Save(): vector of image files is empty; nothing to save.";
|
|
30
|
+
return HandleError(oss, opts);
|
|
31
|
+
}
|
|
36
32
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
33
|
+
// Check if we do have any type that is not supported.
|
|
34
|
+
if (opts.Format() == FileType::ImageType ||
|
|
35
|
+
opts.Format() == FileType::AutoDetect)
|
|
36
|
+
{
|
|
37
|
+
for (size_t i = 0; i < files.size() ; ++i)
|
|
38
|
+
{
|
|
39
|
+
if (!opts.saveType.count(Extension(files.at(i))))
|
|
40
|
+
{
|
|
41
|
+
std::stringstream oss;
|
|
42
|
+
oss << "Save(): file type " << opts.FileTypeToString()
|
|
43
|
+
<< " isn't supported. Currently image saving supports: ";
|
|
44
|
+
for (const auto& x : opts.saveType)
|
|
45
|
+
oss << " " << x;
|
|
46
|
+
oss << "." << std::endl;
|
|
47
|
+
return HandleError(oss, opts);
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
52
|
+
size_t dimension = opts.Width() * opts.Height() * opts.Channels() *
|
|
53
|
+
files.size();
|
|
54
|
+
// We only need to check the rows since it is a matrix.
|
|
55
|
+
if (dimension != matrix.n_rows * matrix.n_cols)
|
|
56
|
+
{
|
|
57
|
+
std::stringstream oss;
|
|
58
|
+
oss << "Save(): The given image dimensions, Width: " << opts.Width()
|
|
59
|
+
<< ", Height: " << opts.Height() << ", Channels: "<< opts.Channels()
|
|
60
|
+
<< " do not match the dimensions of the matrix to be saved!";
|
|
61
|
+
return HandleError(oss, opts);
|
|
62
|
+
}
|
|
63
|
+
// Unfortunately we cannot move because matrix is const.
|
|
64
|
+
arma::Mat<unsigned char> tempMatrix =
|
|
65
|
+
arma::conv_to<arma::Mat<unsigned char>>::from(matrix);
|
|
66
|
+
bool success = false;
|
|
67
|
+
for (size_t i = 0; i < files.size() ; ++i)
|
|
68
|
+
{
|
|
69
|
+
// Update opts.Format() at each iteration.
|
|
70
|
+
DetectFromExtension<arma::Mat<eT>, ImageOptions>(files.at(i), opts);
|
|
71
|
+
if (opts.Format() == FileType::PNG)
|
|
72
|
+
{
|
|
73
|
+
success = stbi_write_png(files.at(i).c_str(), opts.Width(), opts.Height(),
|
|
74
|
+
opts.Channels(), tempMatrix.colptr(i),
|
|
75
|
+
opts.Width() * opts.Channels());
|
|
76
|
+
}
|
|
77
|
+
else if (opts.Format() == FileType::BMP)
|
|
78
|
+
{
|
|
79
|
+
success = stbi_write_bmp(files.at(i).c_str(), opts.Width(), opts.Height(),
|
|
80
|
+
opts.Channels(), tempMatrix.colptr(i));
|
|
81
|
+
}
|
|
82
|
+
else if (opts.Format() == FileType::TGA)
|
|
83
|
+
{
|
|
84
|
+
success = stbi_write_tga(files.at(i).c_str(), opts.Width(), opts.Height(),
|
|
85
|
+
opts.Channels(), tempMatrix.colptr(i));
|
|
86
|
+
}
|
|
87
|
+
else if (opts.Format() == FileType::JPG)
|
|
88
|
+
{
|
|
89
|
+
success = stbi_write_jpg(files.at(i).c_str(), opts.Width(), opts.Height(),
|
|
90
|
+
opts.Channels(), tempMatrix.colptr(i), opts.Quality());
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
if (!success)
|
|
94
|
+
{
|
|
95
|
+
std::stringstream oss;
|
|
96
|
+
oss << "Save(): error saving image to '" << files.at(i) << "'.";
|
|
97
|
+
return HandleError(oss, opts);
|
|
98
|
+
}
|
|
99
|
+
}
|
|
59
100
|
|
|
60
|
-
|
|
61
|
-
}
|
|
101
|
+
return success;
|
|
102
|
+
}
|
|
62
103
|
|
|
63
|
-
//
|
|
64
|
-
#include "save_image_impl.hpp"
|
|
104
|
+
} // namespace mlpack
|
|
65
105
|
|
|
66
106
|
#endif
|
|
@@ -15,164 +15,109 @@
|
|
|
15
15
|
|
|
16
16
|
// In case it hasn't already been included.
|
|
17
17
|
#include "save.hpp"
|
|
18
|
-
#include "extension.hpp"
|
|
19
18
|
|
|
20
19
|
namespace mlpack {
|
|
21
|
-
namespace data {
|
|
22
|
-
|
|
23
|
-
template<typename eT>
|
|
24
|
-
bool Save(const std::string& filename,
|
|
25
|
-
const arma::Col<eT>& vec,
|
|
26
|
-
const bool fatal,
|
|
27
|
-
FileType inputSaveType)
|
|
28
|
-
{
|
|
29
|
-
// Don't transpose: one observation per line (for CSVs at least).
|
|
30
|
-
return Save(filename, vec, fatal, false, inputSaveType);
|
|
31
|
-
}
|
|
32
|
-
|
|
33
|
-
template<typename eT>
|
|
34
|
-
bool Save(const std::string& filename,
|
|
35
|
-
const arma::Row<eT>& rowvec,
|
|
36
|
-
const bool fatal,
|
|
37
|
-
FileType inputSaveType)
|
|
38
|
-
{
|
|
39
|
-
return Save(filename, rowvec, fatal, true, inputSaveType);
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
// Save a Sparse Matrix
|
|
43
|
-
template<typename eT>
|
|
44
|
-
bool Save(const std::string& filename,
|
|
45
|
-
const arma::SpMat<eT>& matrix,
|
|
46
|
-
const bool fatal,
|
|
47
|
-
bool transpose)
|
|
48
|
-
{
|
|
49
|
-
MatrixOptions opts;
|
|
50
|
-
opts.Fatal() = fatal;
|
|
51
|
-
opts.NoTranspose() = !transpose;
|
|
52
|
-
|
|
53
|
-
return Save(filename, matrix, opts);
|
|
54
|
-
}
|
|
55
|
-
|
|
56
|
-
template<typename eT>
|
|
57
|
-
bool Save(const std::string& filename,
|
|
58
|
-
const arma::Mat<eT>& matrix,
|
|
59
|
-
const bool fatal,
|
|
60
|
-
bool transpose,
|
|
61
|
-
FileType inputSaveType)
|
|
62
|
-
{
|
|
63
|
-
MatrixOptions opts;
|
|
64
|
-
opts.Fatal() = fatal;
|
|
65
|
-
opts.NoTranspose() = !transpose;
|
|
66
|
-
opts.Format() = inputSaveType;
|
|
67
|
-
|
|
68
|
-
return Save(filename, matrix, opts);
|
|
69
|
-
}
|
|
70
20
|
|
|
71
21
|
template<typename MatType, typename DataOptionsType>
|
|
72
22
|
bool Save(const std::string& filename,
|
|
73
23
|
const MatType& matrix,
|
|
74
24
|
const DataOptionsType& opts,
|
|
75
|
-
std::enable_if_t<
|
|
76
|
-
|
|
25
|
+
const typename std::enable_if_t<
|
|
26
|
+
IsDataOptions<DataOptionsType>::value>*)
|
|
77
27
|
{
|
|
78
28
|
//! just use default copy ctor with = operator and make a copy.
|
|
79
29
|
DataOptionsType copyOpts(opts);
|
|
80
30
|
return Save(filename, matrix, copyOpts);
|
|
81
31
|
}
|
|
82
32
|
|
|
83
|
-
|
|
84
|
-
* Add this SFINAE in here because the compiler is so stupid that it is not
|
|
85
|
-
* able to distinguish between these two:
|
|
86
|
-
*
|
|
87
|
-
* data::Save(filename, "model", *output);
|
|
88
|
-
*
|
|
89
|
-
* and
|
|
90
|
-
*
|
|
91
|
-
* data::Save(filename, matrix, opts);
|
|
92
|
-
*
|
|
93
|
-
* The second SFINAE is added because the compiler is bot able to see the
|
|
94
|
-
* difference between:
|
|
95
|
-
*
|
|
96
|
-
* data::Save(filename, Row/Col, fatal);
|
|
97
|
-
*
|
|
98
|
-
* and
|
|
99
|
-
*
|
|
100
|
-
* data::Save(filename, Row/Col, Opts);
|
|
101
|
-
*
|
|
102
|
-
* This SFINAE is temporary and must be removed after the integration of stage 3 or
|
|
103
|
-
* when the compiler becomes more intelligent.
|
|
104
|
-
*/
|
|
105
|
-
template<typename MatType, typename DataOptionsType>
|
|
33
|
+
template<typename ObjectType, typename DataOptionsType>
|
|
106
34
|
bool Save(const std::string& filename,
|
|
107
|
-
const
|
|
35
|
+
const ObjectType& matrix,
|
|
108
36
|
DataOptionsType& opts,
|
|
109
|
-
std::enable_if_t<
|
|
110
|
-
|
|
111
|
-
std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
|
|
37
|
+
const typename std::enable_if_t<
|
|
38
|
+
IsDataOptions<DataOptionsType>::value>*)
|
|
112
39
|
{
|
|
113
40
|
Timer::Start("saving_data");
|
|
114
|
-
|
|
115
|
-
|
|
41
|
+
static_assert(!IsArma<ObjectType>::value || !IsSparseMat<ObjectType>::value
|
|
42
|
+
|| !HasSerialize<ObjectType>::value, "mlpack can save Armadillo"
|
|
43
|
+
" matrices or a serialized mlpack model only; please use a known type.");
|
|
44
|
+
const bool isMatrixType = IsArma<ObjectType>::value ||
|
|
45
|
+
IsSparseMat<ObjectType>::value;
|
|
46
|
+
const bool isSerializable = HasSerialize<ObjectType>::value;
|
|
47
|
+
const bool isSparseMatrixType = IsSparseMat<ObjectType>::value;
|
|
48
|
+
|
|
49
|
+
bool success = DetectFileType<ObjectType>(filename, opts, false);
|
|
116
50
|
if (!success)
|
|
117
51
|
{
|
|
118
52
|
Timer::Stop("saving_data");
|
|
119
53
|
return false;
|
|
120
54
|
}
|
|
121
55
|
|
|
56
|
+
const bool isImageFormat = (opts.Format() == FileType::PNG ||
|
|
57
|
+
opts.Format() == FileType::JPG || opts.Format() == FileType::PNM ||
|
|
58
|
+
opts.Format() == FileType::BMP || opts.Format() == FileType::GIF ||
|
|
59
|
+
opts.Format() == FileType::PSD || opts.Format() == FileType::TGA ||
|
|
60
|
+
opts.Format() == FileType::PIC || opts.Format() == FileType::ImageType);
|
|
61
|
+
|
|
122
62
|
std::fstream stream;
|
|
123
|
-
|
|
124
|
-
if (!success)
|
|
63
|
+
if (!isImageFormat)
|
|
125
64
|
{
|
|
126
|
-
|
|
127
|
-
|
|
65
|
+
success = OpenFile(filename, opts, false, stream);
|
|
66
|
+
if (!success)
|
|
67
|
+
{
|
|
68
|
+
Timer::Stop("saving_data");
|
|
69
|
+
return false;
|
|
70
|
+
}
|
|
128
71
|
}
|
|
129
72
|
|
|
130
73
|
// Try to save the file.
|
|
131
74
|
Log::Info << "Saving " << opts.FileTypeToString() << " to '" << filename
|
|
132
75
|
<< "'." << std::endl;
|
|
133
|
-
if constexpr (
|
|
76
|
+
if constexpr (isMatrixType)
|
|
134
77
|
{
|
|
135
|
-
|
|
136
|
-
if constexpr (IsSparseMat<MatType>::value)
|
|
137
|
-
{
|
|
138
|
-
success = SaveSparse(matrix, txtOpts, filename, stream);
|
|
139
|
-
}
|
|
140
|
-
else if constexpr (IsCol<MatType>::value)
|
|
141
|
-
{
|
|
142
|
-
opts.NoTranspose() = true;
|
|
143
|
-
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
144
|
-
}
|
|
145
|
-
else if constexpr (IsRow<MatType>::value)
|
|
78
|
+
if (isImageFormat)
|
|
146
79
|
{
|
|
147
|
-
|
|
148
|
-
|
|
80
|
+
if constexpr (isSparseMatrixType)
|
|
81
|
+
{
|
|
82
|
+
arma::Mat<typename ObjectType::elem_type> tmp =
|
|
83
|
+
arma::conv_to<arma::Mat<
|
|
84
|
+
typename ObjectType::elem_type>>::from(matrix);
|
|
85
|
+
ImageOptions imgOpts(std::move(opts));
|
|
86
|
+
std::vector<std::string> files;
|
|
87
|
+
files.push_back(filename);
|
|
88
|
+
success = SaveImage(files, tmp, imgOpts);
|
|
89
|
+
opts = std::move(imgOpts);
|
|
90
|
+
}
|
|
91
|
+
else
|
|
92
|
+
{
|
|
93
|
+
ImageOptions imgOpts(std::move(opts));
|
|
94
|
+
std::vector<std::string> files;
|
|
95
|
+
files.push_back(filename);
|
|
96
|
+
success = SaveImage(files, matrix, imgOpts);
|
|
97
|
+
opts = std::move(imgOpts);
|
|
98
|
+
}
|
|
149
99
|
}
|
|
150
|
-
else
|
|
100
|
+
else
|
|
151
101
|
{
|
|
152
|
-
success =
|
|
102
|
+
success = SaveNumeric(filename, matrix, stream, opts);
|
|
153
103
|
}
|
|
154
|
-
|
|
104
|
+
}
|
|
105
|
+
else if constexpr (isSerializable)
|
|
106
|
+
{
|
|
107
|
+
success = SaveModel(matrix, opts, stream);
|
|
155
108
|
}
|
|
156
109
|
else
|
|
157
110
|
{
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
<< "or provide specific overloads." << std::endl;
|
|
161
|
-
else
|
|
162
|
-
Log::Warn << "DataOptionsType is unknown! Please use a known type or "
|
|
163
|
-
<< "or provide specific overloads." << std::endl;
|
|
164
|
-
|
|
165
|
-
return false;
|
|
111
|
+
return HandleError("DataOptionsType is unknown! Please use a known type "
|
|
112
|
+
"or provide specific overloads.", opts);
|
|
166
113
|
}
|
|
167
114
|
|
|
168
115
|
if (!success)
|
|
169
116
|
{
|
|
170
117
|
Timer::Stop("saving_data");
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
Log::Warn << "Save to '" << filename << "' failed." << std::endl;
|
|
175
|
-
return false;
|
|
118
|
+
std::stringstream oss;
|
|
119
|
+
oss << "Save to '" << filename << "' failed.";
|
|
120
|
+
return HandleError(oss, opts);
|
|
176
121
|
}
|
|
177
122
|
|
|
178
123
|
Timer::Stop("saving_data");
|
|
@@ -180,136 +125,6 @@ bool Save(const std::string& filename,
|
|
|
180
125
|
return success;
|
|
181
126
|
}
|
|
182
127
|
|
|
183
|
-
template<typename eT>
|
|
184
|
-
bool SaveDense(const arma::Mat<eT>& matrix,
|
|
185
|
-
TextOptions& opts,
|
|
186
|
-
const std::string& filename,
|
|
187
|
-
std::fstream& stream)
|
|
188
|
-
{
|
|
189
|
-
bool success = false;
|
|
190
|
-
arma::Mat<eT> tmp;
|
|
191
|
-
// Transpose the matrix.
|
|
192
|
-
if (!opts.NoTranspose())
|
|
193
|
-
{
|
|
194
|
-
tmp = trans(matrix);
|
|
195
|
-
success = SaveMatrix(tmp, opts, filename, stream);
|
|
196
|
-
}
|
|
197
|
-
else
|
|
198
|
-
success = SaveMatrix(matrix, opts, filename, stream);
|
|
199
|
-
|
|
200
|
-
return success;
|
|
201
|
-
}
|
|
202
|
-
|
|
203
|
-
// Save a Sparse Matrix
|
|
204
|
-
template<typename eT>
|
|
205
|
-
bool SaveSparse(const arma::SpMat<eT>& matrix,
|
|
206
|
-
TextOptions& opts,
|
|
207
|
-
const std::string& filename,
|
|
208
|
-
std::fstream& stream)
|
|
209
|
-
{
|
|
210
|
-
bool success = false;
|
|
211
|
-
arma::SpMat<eT> tmp;
|
|
212
|
-
|
|
213
|
-
// Transpose the matrix.
|
|
214
|
-
if (!opts.NoTranspose())
|
|
215
|
-
{
|
|
216
|
-
arma::SpMat<eT> tmp = trans(matrix);
|
|
217
|
-
success = SaveMatrix(tmp, opts, filename, stream);
|
|
218
|
-
}
|
|
219
|
-
else
|
|
220
|
-
success = SaveMatrix(matrix, opts, filename, stream);
|
|
221
|
-
|
|
222
|
-
return success;
|
|
223
|
-
}
|
|
224
|
-
|
|
225
|
-
//! Save a model to file.
|
|
226
|
-
template<typename T>
|
|
227
|
-
bool Save(const std::string& filename,
|
|
228
|
-
const std::string& name,
|
|
229
|
-
T& t,
|
|
230
|
-
const bool fatal,
|
|
231
|
-
format f,
|
|
232
|
-
std::enable_if_t<HasSerialize<T>::value>*)
|
|
233
|
-
{
|
|
234
|
-
if (f == format::autodetect)
|
|
235
|
-
{
|
|
236
|
-
std::string extension = Extension(filename);
|
|
237
|
-
|
|
238
|
-
if (extension == "xml")
|
|
239
|
-
f = format::xml;
|
|
240
|
-
else if (extension == "bin")
|
|
241
|
-
f = format::binary;
|
|
242
|
-
else if (extension == "json")
|
|
243
|
-
f = format::json;
|
|
244
|
-
else
|
|
245
|
-
{
|
|
246
|
-
if (fatal)
|
|
247
|
-
Log::Fatal << "Unable to detect type of '" << filename << "'; incorrect"
|
|
248
|
-
<< " extension? (allowed: xml/bin/json)" << std::endl;
|
|
249
|
-
else
|
|
250
|
-
Log::Warn << "Unable to detect type of '" << filename << "'; save "
|
|
251
|
-
<< "failed. Incorrect extension? (allowed: xml/bin/json)"
|
|
252
|
-
<< std::endl;
|
|
253
|
-
|
|
254
|
-
return false;
|
|
255
|
-
}
|
|
256
|
-
}
|
|
257
|
-
|
|
258
|
-
// Open the file to save to.
|
|
259
|
-
std::ofstream ofs;
|
|
260
|
-
#ifdef _WIN32
|
|
261
|
-
if (f == format::binary) // Open non-text types in binary mode on Windows.
|
|
262
|
-
ofs.open(filename, std::ofstream::out | std::ofstream::binary);
|
|
263
|
-
else
|
|
264
|
-
ofs.open(filename, std::ofstream::out);
|
|
265
|
-
#else
|
|
266
|
-
ofs.open(filename, std::ofstream::out);
|
|
267
|
-
#endif
|
|
268
|
-
|
|
269
|
-
if (!ofs.is_open())
|
|
270
|
-
{
|
|
271
|
-
if (fatal)
|
|
272
|
-
Log::Fatal << "Unable to open file '" << filename << "' to save object '"
|
|
273
|
-
<< name << "'." << std::endl;
|
|
274
|
-
else
|
|
275
|
-
Log::Warn << "Unable to open file '" << filename << "' to save object '"
|
|
276
|
-
<< name << "'." << std::endl;
|
|
277
|
-
|
|
278
|
-
return false;
|
|
279
|
-
}
|
|
280
|
-
|
|
281
|
-
try
|
|
282
|
-
{
|
|
283
|
-
if (f == format::xml)
|
|
284
|
-
{
|
|
285
|
-
cereal::XMLOutputArchive ar(ofs);
|
|
286
|
-
ar(cereal::make_nvp(name.c_str(), t));
|
|
287
|
-
}
|
|
288
|
-
else if (f == format::json)
|
|
289
|
-
{
|
|
290
|
-
cereal::JSONOutputArchive ar(ofs);
|
|
291
|
-
ar(cereal::make_nvp(name.c_str(), t));
|
|
292
|
-
}
|
|
293
|
-
else if (f == format::binary)
|
|
294
|
-
{
|
|
295
|
-
cereal::BinaryOutputArchive ar(ofs);
|
|
296
|
-
ar(cereal::make_nvp(name.c_str(), t));
|
|
297
|
-
}
|
|
298
|
-
|
|
299
|
-
return true;
|
|
300
|
-
}
|
|
301
|
-
catch (cereal::Exception& e)
|
|
302
|
-
{
|
|
303
|
-
if (fatal)
|
|
304
|
-
Log::Fatal << e.what() << std::endl;
|
|
305
|
-
else
|
|
306
|
-
Log::Warn << e.what() << std::endl;
|
|
307
|
-
|
|
308
|
-
return false;
|
|
309
|
-
}
|
|
310
|
-
}
|
|
311
|
-
|
|
312
|
-
} // namespace data
|
|
313
128
|
} // namespace mlpack
|
|
314
129
|
|
|
315
130
|
#endif
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/data/save_matrix.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
5
|
+
*
|
|
6
|
+
* Internal implementation of matrix save function.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_DATA_SAVE_MATRIX_HPP
|
|
14
|
+
#define MLPACK_CORE_DATA_SAVE_MATRIX_HPP
|
|
15
|
+
|
|
16
|
+
namespace mlpack {
|
|
17
|
+
|
|
18
|
+
template<typename MatType, typename DataOptionsType>
|
|
19
|
+
bool SaveMatrix(const MatType& matrix,
|
|
20
|
+
const DataOptionsType& opts,
|
|
21
|
+
#ifdef ARMA_USE_HDF5
|
|
22
|
+
const std::string& filename,
|
|
23
|
+
#else
|
|
24
|
+
const std::string& /* filename */,
|
|
25
|
+
#endif
|
|
26
|
+
std::fstream& stream)
|
|
27
|
+
{
|
|
28
|
+
bool success = false;
|
|
29
|
+
if (opts.Format() == FileType::HDF5Binary)
|
|
30
|
+
{
|
|
31
|
+
#ifdef ARMA_USE_HDF5
|
|
32
|
+
// We can't save with streams for HDF5.
|
|
33
|
+
success = matrix.save(filename, opts.ArmaFormat());
|
|
34
|
+
#endif
|
|
35
|
+
}
|
|
36
|
+
else
|
|
37
|
+
{
|
|
38
|
+
success = matrix.save(stream, opts.ArmaFormat());
|
|
39
|
+
}
|
|
40
|
+
return success;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
} // namespace mlpack
|
|
44
|
+
|
|
45
|
+
#endif
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/data/save_model.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
5
|
+
*
|
|
6
|
+
* Internal implementation of model save function.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_DATA_SAVE_MODEL_HPP
|
|
14
|
+
#define MLPACK_CORE_DATA_SAVE_MODEL_HPP
|
|
15
|
+
|
|
16
|
+
#include <cereal/archives/xml.hpp>
|
|
17
|
+
#include <cereal/archives/binary.hpp>
|
|
18
|
+
#include <cereal/archives/json.hpp>
|
|
19
|
+
|
|
20
|
+
#include "text_options.hpp"
|
|
21
|
+
|
|
22
|
+
namespace mlpack {
|
|
23
|
+
|
|
24
|
+
template<typename Object>
|
|
25
|
+
bool SaveModel(Object& objectToSerialize,
|
|
26
|
+
const DataOptionsBase<PlainDataOptions>& opts,
|
|
27
|
+
std::fstream& stream)
|
|
28
|
+
{
|
|
29
|
+
try
|
|
30
|
+
{
|
|
31
|
+
if (opts.Format() == FileType::XML)
|
|
32
|
+
{
|
|
33
|
+
cereal::XMLOutputArchive ar(stream);
|
|
34
|
+
ar(cereal::make_nvp("model", objectToSerialize));
|
|
35
|
+
}
|
|
36
|
+
else if (opts.Format() == FileType::JSON)
|
|
37
|
+
{
|
|
38
|
+
cereal::JSONOutputArchive ar(stream);
|
|
39
|
+
ar(cereal::make_nvp("model", objectToSerialize));
|
|
40
|
+
}
|
|
41
|
+
else if (opts.Format() == FileType::BIN)
|
|
42
|
+
{
|
|
43
|
+
cereal::BinaryOutputArchive ar(stream);
|
|
44
|
+
ar(cereal::make_nvp("model", objectToSerialize));
|
|
45
|
+
}
|
|
46
|
+
return true;
|
|
47
|
+
}
|
|
48
|
+
catch (cereal::Exception& e)
|
|
49
|
+
{
|
|
50
|
+
if (opts.Fatal())
|
|
51
|
+
Log::Fatal << e.what() << std::endl;
|
|
52
|
+
else
|
|
53
|
+
Log::Warn << e.what() << std::endl;
|
|
54
|
+
|
|
55
|
+
return false;
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
} // namespace mlpack
|
|
60
|
+
|
|
61
|
+
#endif
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/data/save_numeric.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
5
|
+
*
|
|
6
|
+
* Internal implementation of numeric save function.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_DATA_SAVE_NUMERIC_HPP
|
|
14
|
+
#define MLPACK_CORE_DATA_SAVE_NUMERIC_HPP
|
|
15
|
+
|
|
16
|
+
#include "text_options.hpp"
|
|
17
|
+
#include "save_sparse.hpp"
|
|
18
|
+
#include "save_dense.hpp"
|
|
19
|
+
|
|
20
|
+
namespace mlpack {
|
|
21
|
+
|
|
22
|
+
template<typename ObjectType, typename DataOptionsType>
|
|
23
|
+
bool SaveNumeric(const std::string& filename,
|
|
24
|
+
const ObjectType& matrix,
|
|
25
|
+
std::fstream& stream,
|
|
26
|
+
DataOptionsBase<DataOptionsType>& opts)
|
|
27
|
+
{
|
|
28
|
+
bool success = false;
|
|
29
|
+
|
|
30
|
+
TextOptions txtOpts(std::move(opts));
|
|
31
|
+
if constexpr (IsSparseMat<ObjectType>::value)
|
|
32
|
+
{
|
|
33
|
+
success = SaveSparse(matrix, txtOpts, filename, stream);
|
|
34
|
+
}
|
|
35
|
+
else if constexpr (IsCol<ObjectType>::value)
|
|
36
|
+
{
|
|
37
|
+
const bool oldNoTranspose = txtOpts.NoTranspose();
|
|
38
|
+
txtOpts.NoTranspose() = true; // Force no transpose for a column.
|
|
39
|
+
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
40
|
+
txtOpts.NoTranspose() = oldNoTranspose;
|
|
41
|
+
}
|
|
42
|
+
else if constexpr (IsRow<ObjectType>::value)
|
|
43
|
+
{
|
|
44
|
+
const bool oldNoTranspose = txtOpts.NoTranspose();
|
|
45
|
+
txtOpts.NoTranspose() = false; // Force transpose for a row.
|
|
46
|
+
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
47
|
+
txtOpts.NoTranspose() = oldNoTranspose;
|
|
48
|
+
}
|
|
49
|
+
else if constexpr (IsDense<ObjectType>::value)
|
|
50
|
+
{
|
|
51
|
+
success = SaveDense(matrix, txtOpts, filename, stream);
|
|
52
|
+
}
|
|
53
|
+
static_cast<DataOptionsType&>(opts) = std::move(txtOpts);
|
|
54
|
+
|
|
55
|
+
return success;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
} // namespace mlpack
|
|
59
|
+
|
|
60
|
+
#endif
|