mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -17,171 +17,48 @@
|
|
|
17
17
|
// In case it hasn't already been included.
|
|
18
18
|
#include "load.hpp"
|
|
19
19
|
|
|
20
|
-
#include <algorithm>
|
|
21
|
-
#include <exception>
|
|
22
|
-
|
|
23
|
-
#include "extension.hpp"
|
|
24
|
-
#include "string_algorithms.hpp"
|
|
25
|
-
|
|
26
20
|
namespace mlpack {
|
|
27
|
-
namespace data {
|
|
28
|
-
|
|
29
|
-
// The following functions are kept for backward compatibility,
|
|
30
|
-
// Please remove them when we release mlpack 5.
|
|
31
|
-
template<typename eT>
|
|
32
|
-
bool Load(const std::string& filename,
|
|
33
|
-
arma::Mat<eT>& matrix,
|
|
34
|
-
const bool fatal,
|
|
35
|
-
const bool transpose,
|
|
36
|
-
const FileType inputLoadType)
|
|
37
|
-
{
|
|
38
|
-
MatrixOptions opts;
|
|
39
|
-
opts.Fatal() = fatal;
|
|
40
|
-
opts.NoTranspose() = !transpose;
|
|
41
|
-
opts.Format() = inputLoadType;
|
|
42
|
-
|
|
43
|
-
return Load(filename, matrix, opts);
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
// For loading data into sparse matrix
|
|
47
|
-
template <typename eT>
|
|
48
|
-
bool Load(const std::string& filename,
|
|
49
|
-
arma::SpMat<eT>& matrix,
|
|
50
|
-
const bool fatal,
|
|
51
|
-
const bool transpose,
|
|
52
|
-
const FileType inputLoadType)
|
|
53
|
-
{
|
|
54
|
-
MatrixOptions opts;
|
|
55
|
-
opts.Fatal() = fatal;
|
|
56
|
-
opts.NoTranspose() = !transpose;
|
|
57
|
-
opts.Format() = inputLoadType;
|
|
58
|
-
|
|
59
|
-
return Load(filename, matrix, opts);
|
|
60
|
-
}
|
|
61
|
-
|
|
62
|
-
// For loading data into a column vector
|
|
63
|
-
template <typename eT>
|
|
64
|
-
bool Load(const std::string& filename,
|
|
65
|
-
arma::Col<eT>& vec,
|
|
66
|
-
const bool fatal)
|
|
67
|
-
{
|
|
68
|
-
DataOptions opts;
|
|
69
|
-
opts.Fatal() = fatal;
|
|
70
|
-
return Load(filename, vec, opts);
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
// For loading data into a raw vector
|
|
74
|
-
template <typename eT>
|
|
75
|
-
bool Load(const std::string& filename,
|
|
76
|
-
arma::Row<eT>& rowvec,
|
|
77
|
-
const bool fatal)
|
|
78
|
-
{
|
|
79
|
-
DataOptions opts;
|
|
80
|
-
opts.Fatal() = fatal;
|
|
81
|
-
return Load(filename, rowvec, opts);
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
// Load with mappings. Unfortunately we have to implement this ourselves.
|
|
85
|
-
template<typename eT, typename PolicyType>
|
|
86
|
-
bool Load(const std::string& filename,
|
|
87
|
-
arma::Mat<eT>& matrix,
|
|
88
|
-
DatasetMapper<PolicyType>& info,
|
|
89
|
-
const bool fatal,
|
|
90
|
-
const bool transpose)
|
|
91
|
-
{
|
|
92
|
-
TextOptions opts;
|
|
93
|
-
opts.Fatal() = fatal;
|
|
94
|
-
opts.NoTranspose() = !transpose;
|
|
95
|
-
opts.Categorical() = true;
|
|
96
|
-
|
|
97
|
-
if constexpr (std::is_same_v<PolicyType, data::IncrementPolicy>)
|
|
98
|
-
{
|
|
99
|
-
opts.DatasetInfo() = info;
|
|
100
|
-
}
|
|
101
|
-
else if constexpr (std::is_same_v<PolicyType, data::MissingPolicy>)
|
|
102
|
-
{
|
|
103
|
-
opts.MissingPolicy() = true;
|
|
104
|
-
opts.DatasetMissingPolicy() = info;
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
bool success = Load(filename, matrix, opts);
|
|
108
|
-
|
|
109
|
-
if constexpr (std::is_same_v<PolicyType, data::IncrementPolicy>)
|
|
110
|
-
{
|
|
111
|
-
info = opts.DatasetInfo();
|
|
112
|
-
}
|
|
113
|
-
else if constexpr (std::is_same_v<PolicyType, data::MissingPolicy>)
|
|
114
|
-
{
|
|
115
|
-
info = opts.DatasetMissingPolicy();
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
return success;
|
|
119
|
-
}
|
|
120
21
|
|
|
121
22
|
template<typename MatType, typename DataOptionsType>
|
|
122
23
|
bool Load(const std::string& filename,
|
|
123
24
|
MatType& matrix,
|
|
124
25
|
const DataOptionsType& opts,
|
|
125
|
-
std::enable_if_t<
|
|
126
|
-
|
|
127
|
-
std::enable_if_t<!std::is_same_v<DataOptionsType, bool>>*)
|
|
26
|
+
const typename std::enable_if_t<
|
|
27
|
+
IsDataOptions<DataOptionsType>::value>*)
|
|
128
28
|
{
|
|
129
29
|
DataOptionsType tmpOpts(opts);
|
|
130
|
-
return Load(filename, matrix, tmpOpts);
|
|
30
|
+
return Load(filename, matrix, tmpOpts, false);
|
|
131
31
|
}
|
|
132
32
|
|
|
133
|
-
template<typename
|
|
134
|
-
bool
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
33
|
+
template<typename eT, typename DataOptionsType>
|
|
34
|
+
bool Load(const std::vector<std::string>& files,
|
|
35
|
+
arma::Mat<eT>& matrix,
|
|
36
|
+
const DataOptionsType& opts,
|
|
37
|
+
const typename std::enable_if_t<
|
|
38
|
+
IsDataOptions<DataOptionsType>::value>*)
|
|
138
39
|
{
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
{
|
|
142
|
-
success = LoadSparse(filename, matrix, txtOpts, stream);
|
|
143
|
-
}
|
|
144
|
-
else if (txtOpts.Categorical() ||
|
|
145
|
-
(txtOpts.Format() == FileType::ARFFASCII))
|
|
146
|
-
{
|
|
147
|
-
success = LoadCategorical(filename, matrix, txtOpts);
|
|
148
|
-
}
|
|
149
|
-
else if constexpr (IsCol<MatType>::value)
|
|
150
|
-
{
|
|
151
|
-
success = LoadCol(filename, matrix, txtOpts, stream);
|
|
152
|
-
}
|
|
153
|
-
else if constexpr (IsRow<MatType>::value)
|
|
154
|
-
{
|
|
155
|
-
success = LoadRow(filename, matrix, txtOpts, stream);
|
|
156
|
-
}
|
|
157
|
-
else if constexpr (IsDense<MatType>::value)
|
|
158
|
-
{
|
|
159
|
-
success = LoadDense(filename, matrix, txtOpts, stream);
|
|
160
|
-
}
|
|
161
|
-
else
|
|
162
|
-
{
|
|
163
|
-
if (txtOpts.Fatal())
|
|
164
|
-
Log::Fatal << "data::Load(): unknown matrix-like type given!"
|
|
165
|
-
<< std::endl;
|
|
166
|
-
else
|
|
167
|
-
Log::Warn << "data::Load(): unknown matrix-like type given!"
|
|
168
|
-
<< std::endl;
|
|
169
|
-
|
|
170
|
-
return false;
|
|
171
|
-
}
|
|
172
|
-
return success;
|
|
40
|
+
DataOptionsType tmpOpts(opts);
|
|
41
|
+
return Load(files, matrix, tmpOpts, false);
|
|
173
42
|
}
|
|
174
43
|
|
|
175
|
-
template<typename
|
|
44
|
+
template<typename ObjectType, typename DataOptionsType>
|
|
176
45
|
bool Load(const std::string& filename,
|
|
177
|
-
|
|
46
|
+
ObjectType& matrix,
|
|
178
47
|
DataOptionsType& opts,
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
48
|
+
const bool copyBack,
|
|
49
|
+
const typename std::enable_if_t<
|
|
50
|
+
IsDataOptions<DataOptionsType>::value>*)
|
|
182
51
|
{
|
|
183
52
|
Timer::Start("loading_data");
|
|
184
53
|
|
|
54
|
+
static_assert(!IsArma<ObjectType>::value || !IsSparseMat<ObjectType>::value
|
|
55
|
+
|| !HasSerialize<ObjectType>::value, "mlpack can load Armadillo"
|
|
56
|
+
" matrices or serialized mlpack models only; please use a known type.");
|
|
57
|
+
const bool isMatrixType = IsArma<ObjectType>::value ||
|
|
58
|
+
IsSparseMat<ObjectType>::value;
|
|
59
|
+
const bool isSerializable = HasSerialize<ObjectType>::value;
|
|
60
|
+
const bool isSparseMatrixType = IsSparseMat<ObjectType>::value;
|
|
61
|
+
|
|
185
62
|
std::fstream stream;
|
|
186
63
|
bool success = OpenFile(filename, opts, true, stream);
|
|
187
64
|
if (!success)
|
|
@@ -190,43 +67,65 @@ bool Load(const std::string& filename,
|
|
|
190
67
|
return false;
|
|
191
68
|
}
|
|
192
69
|
|
|
193
|
-
success = DetectFileType<
|
|
70
|
+
success = DetectFileType<ObjectType>(filename, opts, true, &stream);
|
|
194
71
|
if (!success)
|
|
195
72
|
{
|
|
196
73
|
Timer::Stop("loading_data");
|
|
197
74
|
return false;
|
|
198
75
|
}
|
|
76
|
+
const bool isImageFormat = (opts.Format() == FileType::PNG ||
|
|
77
|
+
opts.Format() == FileType::JPG || opts.Format() == FileType::PNM ||
|
|
78
|
+
opts.Format() == FileType::BMP || opts.Format() == FileType::GIF ||
|
|
79
|
+
opts.Format() == FileType::PSD || opts.Format() == FileType::TGA ||
|
|
80
|
+
opts.Format() == FileType::PIC || opts.Format() == FileType::ImageType);
|
|
199
81
|
|
|
200
|
-
if constexpr (
|
|
82
|
+
if constexpr (isMatrixType)
|
|
201
83
|
{
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
84
|
+
if (isImageFormat)
|
|
85
|
+
{
|
|
86
|
+
if constexpr (isSparseMatrixType)
|
|
87
|
+
{
|
|
88
|
+
return HandleError("Cannot load image data into a sparse matrix. "
|
|
89
|
+
"Please use dense matrix instead.", opts);
|
|
90
|
+
}
|
|
91
|
+
else
|
|
92
|
+
{
|
|
93
|
+
ImageOptions imgOpts(std::move(opts));
|
|
94
|
+
std::vector<std::string> files;
|
|
95
|
+
files.push_back(filename);
|
|
96
|
+
success = LoadImage(files, matrix, imgOpts);
|
|
97
|
+
if (copyBack)
|
|
98
|
+
opts = std::move(imgOpts);
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
else
|
|
102
|
+
{
|
|
103
|
+
TextOptions txtOpts(std::move(opts));
|
|
104
|
+
success = LoadNumeric(filename, matrix, stream, txtOpts);
|
|
105
|
+
if (copyBack)
|
|
106
|
+
opts = std::move(txtOpts);
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
else if constexpr (isSerializable)
|
|
110
|
+
{
|
|
111
|
+
success = LoadModel(matrix, opts, stream);
|
|
205
112
|
}
|
|
206
113
|
else
|
|
207
114
|
{
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
<< "or provide specific overloads." << std::endl;
|
|
211
|
-
else
|
|
212
|
-
Log::Warn << "DataOptionsType is unknown! Please use a known type "
|
|
213
|
-
<< "or provide specific overloads." << std::endl;
|
|
214
|
-
return false;
|
|
115
|
+
return HandleError("DataOptionsType is unknown! Please use a known type "
|
|
116
|
+
"or provide specific overloads.", opts);
|
|
215
117
|
}
|
|
216
118
|
|
|
217
119
|
if (!success)
|
|
218
120
|
{
|
|
219
121
|
Timer::Stop("loading_data");
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
Log::Warn << "Loading from '" << filename << "' failed." << std::endl;
|
|
224
|
-
|
|
225
|
-
return false;
|
|
122
|
+
std::stringstream oss;
|
|
123
|
+
oss << "Loading from '" << filename << "' failed.";
|
|
124
|
+
return HandleError(oss, opts);
|
|
226
125
|
}
|
|
227
126
|
else
|
|
228
127
|
{
|
|
229
|
-
if constexpr (IsArma<
|
|
128
|
+
if constexpr (IsArma<ObjectType>::value)
|
|
230
129
|
{
|
|
231
130
|
Log::Info << "Size is " << matrix.n_rows << " x "
|
|
232
131
|
<< matrix.n_cols << ".\n";
|
|
@@ -238,115 +137,42 @@ bool Load(const std::string& filename,
|
|
|
238
137
|
return success;
|
|
239
138
|
}
|
|
240
139
|
|
|
241
|
-
template<typename
|
|
242
|
-
bool
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
if (opts.Format() != FileType::RawBinary)
|
|
249
|
-
Log::Info << "Loading '" << filename << "' as "
|
|
250
|
-
<< opts.FileTypeToString() << ". " << std::flush;
|
|
251
|
-
|
|
252
|
-
// We can't use the stream if the type is HDF5.
|
|
253
|
-
if (opts.Format() == FileType::HDF5Binary)
|
|
254
|
-
{
|
|
255
|
-
success = LoadHDF5(filename, matrix, opts);
|
|
256
|
-
}
|
|
257
|
-
else if (opts.Format() == FileType::CSVASCII)
|
|
258
|
-
{
|
|
259
|
-
success = LoadCSVASCII(filename, matrix, opts);
|
|
260
|
-
|
|
261
|
-
if (matrix.col(0).is_zero())
|
|
262
|
-
Log::Warn << "data::Load(): the first line in '" << filename << "' was "
|
|
263
|
-
<< "loaded as all zeros; if the first row is headers, specify "
|
|
264
|
-
<< "`HasHeaders() = true` in the given DataOptions." << std::endl;
|
|
265
|
-
}
|
|
266
|
-
else
|
|
267
|
-
{
|
|
268
|
-
if (opts.Format() == FileType::RawBinary)
|
|
269
|
-
Log::Warn << "Loading '" << filename << "' as "
|
|
270
|
-
<< opts.FileTypeToString() << "; "
|
|
271
|
-
<< "but this may not be the actual filetype!" << std::endl;
|
|
272
|
-
|
|
273
|
-
success = matrix.load(stream, ToArmaFileType(opts.Format()));
|
|
274
|
-
if (!opts.NoTranspose())
|
|
275
|
-
inplace_trans(matrix);
|
|
276
|
-
}
|
|
277
|
-
return success;
|
|
278
|
-
}
|
|
279
|
-
|
|
280
|
-
template <typename eT>
|
|
281
|
-
bool LoadSparse(const std::string& filename,
|
|
282
|
-
arma::SpMat<eT>& matrix,
|
|
283
|
-
TextOptions& opts,
|
|
284
|
-
std::fstream& stream)
|
|
140
|
+
template<typename eT, typename DataOptionsType>
|
|
141
|
+
bool Load(const std::vector<std::string>& files,
|
|
142
|
+
arma::Mat<eT>& matrix,
|
|
143
|
+
DataOptionsType& opts,
|
|
144
|
+
const bool copyBack,
|
|
145
|
+
const typename std::enable_if_t<
|
|
146
|
+
IsDataOptions<DataOptionsType>::value>*)
|
|
285
147
|
{
|
|
286
|
-
bool success;
|
|
287
|
-
|
|
288
|
-
// if we got a text type, it could be a coordinate list. We will make an
|
|
289
|
-
// educated guess based on the shape of the input.
|
|
290
|
-
if (opts.Format() == FileType::RawASCII)
|
|
148
|
+
bool success = false;
|
|
149
|
+
if (files.empty())
|
|
291
150
|
{
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
const size_t cols = CountCols(stream);
|
|
295
|
-
if (cols == 3)
|
|
296
|
-
{
|
|
297
|
-
// We have the right number of columns, so assume the type is a
|
|
298
|
-
// coordinate list.
|
|
299
|
-
opts.Format() = FileType::CoordASCII;
|
|
300
|
-
}
|
|
151
|
+
return HandleError("Load(): given set of filenames is empty;"
|
|
152
|
+
" loading failed.", opts);
|
|
301
153
|
}
|
|
302
154
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
(opts.Format() == FileType::
|
|
308
|
-
|
|
309
|
-
if (opts.Fatal())
|
|
310
|
-
Log::Fatal << "Cannot load '" << filename << "' with type "
|
|
311
|
-
<< opts.FileTypeToString() << " into a sparse matrix; format is "
|
|
312
|
-
<< "only supported for dense matrices." << std::endl;
|
|
313
|
-
else
|
|
314
|
-
Log::Warn << "Cannot load '" << filename << "' with type "
|
|
315
|
-
<< opts.FileTypeToString() << " into a sparse matrix; format is "
|
|
316
|
-
<< "only supported for dense matrices; load failed." << std::endl;
|
|
155
|
+
DetectFromExtension<arma::Mat<eT>>(files.back(), opts);
|
|
156
|
+
const bool isImageFormat = (opts.Format() == FileType::PNG ||
|
|
157
|
+
opts.Format() == FileType::JPG || opts.Format() == FileType::PNM ||
|
|
158
|
+
opts.Format() == FileType::BMP || opts.Format() == FileType::GIF ||
|
|
159
|
+
opts.Format() == FileType::PSD || opts.Format() == FileType::TGA ||
|
|
160
|
+
opts.Format() == FileType::PIC || opts.Format() == FileType::ImageType);
|
|
317
161
|
|
|
318
|
-
|
|
319
|
-
}
|
|
320
|
-
else if (opts.Format() == FileType::CSVASCII)
|
|
162
|
+
if (isImageFormat)
|
|
321
163
|
{
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
success = dense.load(stream, ToArmaFileType(opts.Format()));
|
|
327
|
-
if (dense.n_cols == 3)
|
|
328
|
-
{
|
|
329
|
-
arma::umat locations = arma::conv_to<arma::umat>::from(
|
|
330
|
-
dense.cols(0, 1).t());
|
|
331
|
-
matrix = arma::SpMat<eT>(locations, dense.col(2));
|
|
332
|
-
}
|
|
333
|
-
else
|
|
334
|
-
{
|
|
335
|
-
matrix = arma::conv_to<arma::SpMat<eT>>::from(dense);
|
|
336
|
-
}
|
|
164
|
+
ImageOptions imgOpts(std::move(opts));
|
|
165
|
+
success = LoadImage(files, matrix, imgOpts);
|
|
166
|
+
if (copyBack)
|
|
167
|
+
opts = std::move(imgOpts);
|
|
337
168
|
}
|
|
338
169
|
else
|
|
339
170
|
{
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
{
|
|
345
|
-
// It seems that there is no direct way to use inplace_trans() on
|
|
346
|
-
// sparse matrices.
|
|
347
|
-
matrix = matrix.t();
|
|
171
|
+
TextOptions txtOpts(std::move(opts));
|
|
172
|
+
success = LoadNumericMultifile(files, matrix, txtOpts);
|
|
173
|
+
if (copyBack)
|
|
174
|
+
opts = std::move(txtOpts);
|
|
348
175
|
}
|
|
349
|
-
|
|
350
176
|
return success;
|
|
351
177
|
}
|
|
352
178
|
|
|
@@ -393,14 +219,10 @@ bool LoadCategorical(const std::string& filename,
|
|
|
393
219
|
{
|
|
394
220
|
// The type is unknown.
|
|
395
221
|
Timer::Stop("loading_data");
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
<< "Incorrect extension?"
|
|
399
|
-
|
|
400
|
-
Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
|
|
401
|
-
<< " Incorrect extension?" << std::endl;
|
|
402
|
-
|
|
403
|
-
return false;
|
|
222
|
+
std::stringstream oss;
|
|
223
|
+
oss << "Unable to detect type of '" << filename << "'; "
|
|
224
|
+
<< "Incorrect extension?";
|
|
225
|
+
return HandleError(oss, opts);
|
|
404
226
|
}
|
|
405
227
|
|
|
406
228
|
Log::Info << "Size is " << matrix.n_rows << " x " << matrix.n_cols << ".\n";
|
|
@@ -410,7 +232,6 @@ bool LoadCategorical(const std::string& filename,
|
|
|
410
232
|
return true;
|
|
411
233
|
}
|
|
412
234
|
|
|
413
|
-
} // namespace data
|
|
414
235
|
} // namespace mlpack
|
|
415
236
|
|
|
416
237
|
#endif
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file core/data/load_model.hpp
|
|
3
|
+
* @author Ryan Curtin
|
|
4
|
+
* @author Omar Shrit
|
|
5
|
+
*
|
|
6
|
+
* Intenal implementation of model-specific Load() function.
|
|
7
|
+
*
|
|
8
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
9
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
10
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
11
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef MLPACK_CORE_DATA_LOAD_MODEL_HPP
|
|
14
|
+
#define MLPACK_CORE_DATA_LOAD_MODEL_HPP
|
|
15
|
+
|
|
16
|
+
#include <cereal/archives/xml.hpp>
|
|
17
|
+
#include <cereal/archives/binary.hpp>
|
|
18
|
+
#include <cereal/archives/json.hpp>
|
|
19
|
+
|
|
20
|
+
#include "text_options.hpp"
|
|
21
|
+
|
|
22
|
+
namespace mlpack {
|
|
23
|
+
|
|
24
|
+
template<typename Object>
|
|
25
|
+
bool LoadModel(Object& objectToSerialize,
|
|
26
|
+
DataOptionsBase<PlainDataOptions>& opts,
|
|
27
|
+
std::fstream& stream)
|
|
28
|
+
{
|
|
29
|
+
try
|
|
30
|
+
{
|
|
31
|
+
if (opts.Format() == FileType::XML)
|
|
32
|
+
{
|
|
33
|
+
cereal::XMLInputArchive ar(stream);
|
|
34
|
+
ar(cereal::make_nvp("model", objectToSerialize));
|
|
35
|
+
}
|
|
36
|
+
else if (opts.Format() == FileType::JSON)
|
|
37
|
+
{
|
|
38
|
+
cereal::JSONInputArchive ar(stream);
|
|
39
|
+
ar(cereal::make_nvp("model", objectToSerialize));
|
|
40
|
+
}
|
|
41
|
+
else if (opts.Format() == FileType::BIN)
|
|
42
|
+
{
|
|
43
|
+
cereal::BinaryInputArchive ar(stream);
|
|
44
|
+
ar(cereal::make_nvp("model", objectToSerialize));
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
return true;
|
|
48
|
+
}
|
|
49
|
+
catch (cereal::Exception& e)
|
|
50
|
+
{
|
|
51
|
+
if (opts.Fatal())
|
|
52
|
+
Log::Fatal << e.what() << std::endl;
|
|
53
|
+
else
|
|
54
|
+
Log::Warn << e.what() << std::endl;
|
|
55
|
+
|
|
56
|
+
return false;
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
} // namespace mlpack
|
|
61
|
+
|
|
62
|
+
#endif
|