mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -0,0 +1,728 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file methods/ann/dag_network.hpp
|
|
3
|
+
* @author Andrew Furey
|
|
4
|
+
*
|
|
5
|
+
* Definition of the DAGNetwork class, which allows uers to describe a
|
|
6
|
+
* computational graph to build arbitrary neural networks with skip
|
|
7
|
+
* connections. These skip connections can consist of concatenations or
|
|
8
|
+
* element-wise addition of the input tensors for residual connections.
|
|
9
|
+
*
|
|
10
|
+
* mlpack is free software; you may redistribute it and/or modify it under the
|
|
11
|
+
* terms of the 3-clause BSD license. You should have received a copy of the
|
|
12
|
+
* 3-clause BSD license along with mlpack. If not, see
|
|
13
|
+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
|
|
14
|
+
*/
|
|
15
|
+
#ifndef MLPACK_METHODS_ANN_DAG_NETWORK_HPP
|
|
16
|
+
#define MLPACK_METHODS_ANN_DAG_NETWORK_HPP
|
|
17
|
+
|
|
18
|
+
#include <mlpack/core.hpp>
|
|
19
|
+
|
|
20
|
+
#include "init_rules/init_rules.hpp"
|
|
21
|
+
|
|
22
|
+
#include <ensmallen.hpp>
|
|
23
|
+
|
|
24
|
+
enum ConnectionTypes
|
|
25
|
+
{
|
|
26
|
+
CONCATENATE,
|
|
27
|
+
ADDITION
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
namespace mlpack {
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* Implementation of a direct acyclic graph. Any layer that inherits
|
|
34
|
+
* from the base `Layer` class can be added to this model.
|
|
35
|
+
*
|
|
36
|
+
* A network can be created by using the `Add()` method to add
|
|
37
|
+
* layers to the network. Each layer is then linked using `Connect()`.
|
|
38
|
+
*
|
|
39
|
+
* A node with multiple parents will either concatenate the output of its
|
|
40
|
+
* parents along a specified axis, or accumulate using element-wise addition.
|
|
41
|
+
* You can specify the type of connection using `SetConnection`. If your
|
|
42
|
+
* connection is a concatenation, you can specify the axis with `SetAxis`.
|
|
43
|
+
*
|
|
44
|
+
* If no connection type or axis is set, by default the connection will be
|
|
45
|
+
* concatenation over the last axis of that layer.
|
|
46
|
+
*
|
|
47
|
+
* A DAGNetwork cannot have any cycles. Creating a network with a cycle will
|
|
48
|
+
* result in an error. A DAGNetwork can only have one input layer and one
|
|
49
|
+
* output layer.
|
|
50
|
+
*
|
|
51
|
+
* Although the actual types passed as input will be matrix objects with one
|
|
52
|
+
* data point per column, each data point can be a tensor of arbitrary shape.
|
|
53
|
+
* If data points are not 1-dimensional vectors, then set the shape of the input
|
|
54
|
+
* with `InputDimensions()` before calling `Train()`.
|
|
55
|
+
*
|
|
56
|
+
* More granular functionality is available with `Forward()`, Backward()`, and
|
|
57
|
+
* `Evaluate()`, or even by accessing the individual layers directly with
|
|
58
|
+
* `Network()`.
|
|
59
|
+
*
|
|
60
|
+
* @tparam OutputLayerType The output layer type used to evaluate the network.
|
|
61
|
+
* @tparam InitializationRuleType Rule used to initialize the weight matrix.
|
|
62
|
+
* @tparam MatType Type of matrix to be given as input to the network.
|
|
63
|
+
*/
|
|
64
|
+
template<
|
|
65
|
+
typename OutputLayerType = NegativeLogLikelihood,
|
|
66
|
+
typename InitializationRuleType = RandomInitialization,
|
|
67
|
+
typename MatType = arma::mat>
|
|
68
|
+
class DAGNetwork
|
|
69
|
+
{
|
|
70
|
+
public:
|
|
71
|
+
/**
|
|
72
|
+
* Create the DAGNetwork object.
|
|
73
|
+
*
|
|
74
|
+
* Optionally, specify which initialize rule and performance function should
|
|
75
|
+
* be used.
|
|
76
|
+
*
|
|
77
|
+
* If you want to pass in a parameter and discard the original parameter
|
|
78
|
+
* object, be sure to use std::move to avoid unnecessary copy.
|
|
79
|
+
*
|
|
80
|
+
* @param outputLayer Output layer used to evaluate the network.
|
|
81
|
+
* @param initializeRule Optional instantiated InitializationRule object
|
|
82
|
+
* for initializing the network parameter.
|
|
83
|
+
*/
|
|
84
|
+
DAGNetwork(OutputLayerType outputLayer = OutputLayerType(),
|
|
85
|
+
InitializationRuleType initializeRule = InitializationRuleType());
|
|
86
|
+
|
|
87
|
+
// Copy constructor.
|
|
88
|
+
DAGNetwork(const DAGNetwork& other);
|
|
89
|
+
// Move constructor.
|
|
90
|
+
DAGNetwork(DAGNetwork&& other);
|
|
91
|
+
// Copy operator.
|
|
92
|
+
DAGNetwork& operator=(const DAGNetwork& other);
|
|
93
|
+
// Move assignment operator.
|
|
94
|
+
DAGNetwork& operator=(DAGNetwork&& other);
|
|
95
|
+
|
|
96
|
+
// Destructor: delete all layers.
|
|
97
|
+
~DAGNetwork()
|
|
98
|
+
{
|
|
99
|
+
for (size_t i = 0; i < network.size(); i++)
|
|
100
|
+
delete network[i];
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
using CubeType = typename GetCubeType<MatType>::type;
|
|
105
|
+
|
|
106
|
+
/**
|
|
107
|
+
* Add a new layer to the model. Note that any trainable weights of this
|
|
108
|
+
* layer will be reset! (Any constant parameters are kept.) This layer
|
|
109
|
+
* should only receive input from one layer.
|
|
110
|
+
*
|
|
111
|
+
* @param layer The Layer to be added to the model.
|
|
112
|
+
*
|
|
113
|
+
* returns the index of the layer in `network`, to be used in `Connect()`
|
|
114
|
+
*/
|
|
115
|
+
template <typename LayerType, typename... Args>
|
|
116
|
+
size_t Add(Args&&... args)
|
|
117
|
+
{
|
|
118
|
+
size_t id = network.size();
|
|
119
|
+
network.push_back(new LayerType(std::forward<Args>(args)...));
|
|
120
|
+
AddLayer(id);
|
|
121
|
+
|
|
122
|
+
return id;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
template <template<typename...> typename LayerType, typename... Args>
|
|
126
|
+
size_t Add(Args&&... args)
|
|
127
|
+
{
|
|
128
|
+
size_t id = network.size();
|
|
129
|
+
network.push_back(new LayerType<MatType>(std::forward<Args>(args)...));
|
|
130
|
+
AddLayer(id);
|
|
131
|
+
|
|
132
|
+
return id;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
template <typename LayerType>
|
|
136
|
+
size_t Add(LayerType&& layer,
|
|
137
|
+
typename std::enable_if_t<
|
|
138
|
+
!std::is_pointer_v<std::remove_reference_t<LayerType>>>* = 0)
|
|
139
|
+
{
|
|
140
|
+
using NewLayerType =
|
|
141
|
+
typename std::remove_cv_t<std::remove_reference_t<LayerType>>;
|
|
142
|
+
|
|
143
|
+
size_t id = network.size();
|
|
144
|
+
network.push_back(new NewLayerType(std::forward<LayerType>(layer)));
|
|
145
|
+
AddLayer(id);
|
|
146
|
+
|
|
147
|
+
return id;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
/**
|
|
151
|
+
* Set the connection type for a layer that expects multiple parents.
|
|
152
|
+
*
|
|
153
|
+
* @param layerId The layer to be added to the model.
|
|
154
|
+
* @param connection The connection type.
|
|
155
|
+
*/
|
|
156
|
+
void SetConnection(size_t layerId, ConnectionTypes connection);
|
|
157
|
+
|
|
158
|
+
/**
|
|
159
|
+
* Set the axis to concatenate along for some layer that expects multiple
|
|
160
|
+
* parent.
|
|
161
|
+
*
|
|
162
|
+
* @param layerId The layer to be added to the model.
|
|
163
|
+
* @param concatAxis The axis to concatenate parent node outputs along.
|
|
164
|
+
*/
|
|
165
|
+
void SetAxis(size_t layerId, size_t concatAxis);
|
|
166
|
+
|
|
167
|
+
/**
|
|
168
|
+
* Create an edge between two layers. If the child node expects multiple
|
|
169
|
+
* parents, the child must have been added to the network with an axis.
|
|
170
|
+
*
|
|
171
|
+
* @param inputLayer The parent node whose output is the input to `outputLayer`
|
|
172
|
+
* @param outputLayer The child node whose input will come from `inputLayer`
|
|
173
|
+
*/
|
|
174
|
+
void Connect(size_t parentNodeId, size_t childNodeId);
|
|
175
|
+
|
|
176
|
+
// Get the layers of the network, in the order the user specified.
|
|
177
|
+
const std::vector<Layer<MatType>*>& Network() const
|
|
178
|
+
{
|
|
179
|
+
return network;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// Get the layers of the network, in topological order.
|
|
183
|
+
const std::vector<Layer<MatType>*> SortedNetwork()
|
|
184
|
+
{
|
|
185
|
+
if (!graphIsSet)
|
|
186
|
+
CheckGraph();
|
|
187
|
+
|
|
188
|
+
std::vector<Layer<MatType>*> sortedLayers;
|
|
189
|
+
for (size_t i = 0; i < sortedNetwork.size(); i++)
|
|
190
|
+
{
|
|
191
|
+
size_t layerIndex = sortedNetwork[i];
|
|
192
|
+
sortedLayers.push_back(network[layerIndex]);
|
|
193
|
+
}
|
|
194
|
+
return sortedLayers;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
template<typename OptimizerType, typename... CallbackTypes>
|
|
200
|
+
typename MatType::elem_type Train(MatType predictors,
|
|
201
|
+
MatType responses,
|
|
202
|
+
OptimizerType& optimizer,
|
|
203
|
+
CallbackTypes&&... callbacks);
|
|
204
|
+
|
|
205
|
+
template<typename OptimizerType = ens::RMSProp, typename... CallbackTypes>
|
|
206
|
+
typename MatType::elem_type Train(MatType predictors,
|
|
207
|
+
MatType responses,
|
|
208
|
+
CallbackTypes&&... callbacks);
|
|
209
|
+
|
|
210
|
+
/**
|
|
211
|
+
* Predict the responses to a given set of predictors. The responses will be
|
|
212
|
+
* the output of the output layer when `predictors` is passed through the
|
|
213
|
+
* whole network (`OutputLayerType`).
|
|
214
|
+
*
|
|
215
|
+
* @param predictors Input predictors.
|
|
216
|
+
* @param results Matrix to put output predictions of responses into.
|
|
217
|
+
* @param batchSize Batch size to use for prediction.
|
|
218
|
+
*/
|
|
219
|
+
void Predict(const MatType& predictors,
|
|
220
|
+
MatType& results,
|
|
221
|
+
const size_t batchSize = 128);
|
|
222
|
+
|
|
223
|
+
// Return the number of weights in the model.
|
|
224
|
+
size_t WeightSize();
|
|
225
|
+
|
|
226
|
+
/**
|
|
227
|
+
* Set the logical dimensions of the input. `Train()` and `Predict()` expect
|
|
228
|
+
* data to be passed such that one point corresponds to one column, but this
|
|
229
|
+
* data is allowed to be an arbitrary higher-order tensor.
|
|
230
|
+
*
|
|
231
|
+
* So, if the input is meant to be 28x28x3 images, then the
|
|
232
|
+
* input data to `Train()` or `Predict()` should have 28*28*3 = 2352 rows, and
|
|
233
|
+
* `InputDimensions()` should be set to `{ 28, 28, 3 }`. Then, the layers of
|
|
234
|
+
* the network will interpret each input point as a 3-dimensional image
|
|
235
|
+
* instead of a 1-dimensional vector.
|
|
236
|
+
*
|
|
237
|
+
* If `InputDimensions()` is left unset before training, the data will be
|
|
238
|
+
* assumed to be a 1-dimensional vector.
|
|
239
|
+
*/
|
|
240
|
+
std::vector<size_t>& InputDimensions()
|
|
241
|
+
{
|
|
242
|
+
validOutputDimensions = false;
|
|
243
|
+
graphIsSet = false;
|
|
244
|
+
layerMemoryIsSet = false;
|
|
245
|
+
|
|
246
|
+
return inputDimensions;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
// Get the logical dimensions of the input.
|
|
250
|
+
const std::vector<size_t>& InputDimensions() const { return inputDimensions; }
|
|
251
|
+
|
|
252
|
+
const std::vector<size_t>& OutputDimensions()
|
|
253
|
+
{
|
|
254
|
+
if (!graphIsSet)
|
|
255
|
+
CheckGraph();
|
|
256
|
+
|
|
257
|
+
if (!validOutputDimensions)
|
|
258
|
+
UpdateDimensions("DAGNetwork::OutputDimensions()");
|
|
259
|
+
|
|
260
|
+
size_t lastLayer = sortedNetwork.back();
|
|
261
|
+
return network[lastLayer]->OutputDimensions();
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// Return the current set of weights. These are linearized: this contains
|
|
265
|
+
// the weights of every layer.
|
|
266
|
+
const MatType& Parameters() const { return parameters; }
|
|
267
|
+
// Modify the current set of weights. These are linearized: this contains
|
|
268
|
+
// the weights of every layer. Be careful! If you change the shape of
|
|
269
|
+
// `parameters` to something incorrect, it may be re-initialized the next
|
|
270
|
+
// time a forward pass is done.
|
|
271
|
+
MatType& Parameters() { return parameters; }
|
|
272
|
+
|
|
273
|
+
/**
|
|
274
|
+
* Reset the stored data of the network entirely. This resets all weights of
|
|
275
|
+
* each layer using `InitializationRuleType`, and prepares the network to
|
|
276
|
+
* accept a (flat 1-d) input size of `inputDimensionality` (if passed), or
|
|
277
|
+
* whatever input size has been set with `InputDimensions()`.
|
|
278
|
+
*
|
|
279
|
+
* If no input size has been set with `InputDimensions()`, and
|
|
280
|
+
* `inputDimensionality` is 0, an exception will be thrown, since an empty
|
|
281
|
+
* input size is invalid.
|
|
282
|
+
*
|
|
283
|
+
* This also resets the mode of the network to prediction mode (not training
|
|
284
|
+
* mode). See `SetNetworkMode()` for more information.
|
|
285
|
+
*/
|
|
286
|
+
void Reset(const size_t inputDimensionality = 0);
|
|
287
|
+
|
|
288
|
+
/**
|
|
289
|
+
* Set all the layers in the network to training mode, if `training` is
|
|
290
|
+
* `true`, or set all the layers in the network to testing mode, if `training`
|
|
291
|
+
* is `false`.
|
|
292
|
+
*/
|
|
293
|
+
void SetNetworkMode(const bool training);
|
|
294
|
+
|
|
295
|
+
/**
|
|
296
|
+
* Perform a manual forward pass of the data.
|
|
297
|
+
*
|
|
298
|
+
* `Forward()` and `Backward()` should be used as a pair, and they are
|
|
299
|
+
* designed mainly for advanced users. You should try to use `Predict()` and
|
|
300
|
+
* `Train()`, if you can.
|
|
301
|
+
*
|
|
302
|
+
* @param inputs The input data.
|
|
303
|
+
* @param results The predicted results.
|
|
304
|
+
*/
|
|
305
|
+
void Forward(const MatType& input, MatType& output);
|
|
306
|
+
|
|
307
|
+
/**
|
|
308
|
+
* Perform a manual backward pass of the data.
|
|
309
|
+
*
|
|
310
|
+
* `Forward()` and `Backward()` should be used as a pair, and they are
|
|
311
|
+
* designed mainly for advanced users. You should try to use `Predict()` and
|
|
312
|
+
* `Train()` instead, if you can.
|
|
313
|
+
*
|
|
314
|
+
* @param input Input of the network
|
|
315
|
+
* @param output Output of the network
|
|
316
|
+
* @param error Error from loss function.
|
|
317
|
+
* @param gradients Computed gradients.
|
|
318
|
+
*/
|
|
319
|
+
void Backward(const MatType& input,
|
|
320
|
+
const MatType& output,
|
|
321
|
+
const MatType& error,
|
|
322
|
+
MatType& gradients);
|
|
323
|
+
|
|
324
|
+
/**
|
|
325
|
+
* Evaluate the network with the given predictors and responses.
|
|
326
|
+
* This functions is usually used to monitor progress while training.
|
|
327
|
+
*
|
|
328
|
+
* @param predictors Input variables.
|
|
329
|
+
* @param responses Target outputs for input variables.
|
|
330
|
+
*/
|
|
331
|
+
typename MatType::elem_type Evaluate(const MatType& predictors,
|
|
332
|
+
const MatType& responses);
|
|
333
|
+
|
|
334
|
+
//! Serialize the model.
|
|
335
|
+
template<typename Archive>
|
|
336
|
+
void serialize(Archive& ar, const uint32_t /* version */);
|
|
337
|
+
|
|
338
|
+
//
|
|
339
|
+
// Only ensmallen utility functions for training are found below here.
|
|
340
|
+
// They aren't generally useful otherwise.
|
|
341
|
+
//
|
|
342
|
+
|
|
343
|
+
/**
|
|
344
|
+
* Note: this function is implemented so that it can be used by ensmallen's
|
|
345
|
+
* optimizers. It's not generally meant to be used otherwise.
|
|
346
|
+
*
|
|
347
|
+
* Evaluate the network with the given parameters.
|
|
348
|
+
*
|
|
349
|
+
* @param parameters Matrix model parameters.
|
|
350
|
+
*/
|
|
351
|
+
typename MatType::elem_type Evaluate(const MatType& parameters);
|
|
352
|
+
|
|
353
|
+
/**
|
|
354
|
+
* Note: this function is implemented so that it can be used by ensmallen's
|
|
355
|
+
* optimizers. It's not generally meant to be used otherwise.
|
|
356
|
+
*
|
|
357
|
+
* Evaluate the network with the given parameters, but using only
|
|
358
|
+
* a number of data points. This is useful for optimizers such as SGD, which
|
|
359
|
+
* require a separable objective function.
|
|
360
|
+
*
|
|
361
|
+
* Note that the network may return different results depending on the mode it
|
|
362
|
+
* is in (see `SetNetworkMode()`).
|
|
363
|
+
*
|
|
364
|
+
* @param parameters Matrix model parameters.
|
|
365
|
+
* @param begin Index of the starting point to use for objective function
|
|
366
|
+
* evaluation.
|
|
367
|
+
* @param batchSize Number of points to be passed at a time to use for
|
|
368
|
+
* objective function evaluation.
|
|
369
|
+
*/
|
|
370
|
+
typename MatType::elem_type Evaluate(const MatType& parameters,
|
|
371
|
+
const size_t begin,
|
|
372
|
+
const size_t batchSize);
|
|
373
|
+
|
|
374
|
+
/**
|
|
375
|
+
* Note: this function is implemented so that it can be used by ensmallen's
|
|
376
|
+
* optimizers. It's not generally meant to be used otherwise.
|
|
377
|
+
*
|
|
378
|
+
* Evaluate the network with the given parameters.
|
|
379
|
+
* This function is usually called by the optimizer to train the model.
|
|
380
|
+
* This just calls the overload of EvaluateWithGradient() with batchSize = 1.
|
|
381
|
+
*
|
|
382
|
+
* @param parameters Matrix model parameters.
|
|
383
|
+
* @param gradient Matrix to output gradient into.
|
|
384
|
+
*/
|
|
385
|
+
typename MatType::elem_type EvaluateWithGradient(const MatType& parameters,
|
|
386
|
+
MatType& gradient);
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
/**
|
|
390
|
+
* Note: this function is implemented so that it can be used by ensmallen's
|
|
391
|
+
* optimizers. It's not generally meant to be used otherwise.
|
|
392
|
+
*
|
|
393
|
+
* Evaluate the network with the given parameters, but using only
|
|
394
|
+
* a number of data points. This is useful for optimizers such as SGD, which
|
|
395
|
+
* require a separable objective function.
|
|
396
|
+
*
|
|
397
|
+
* @param parameters Matrix model parameters.
|
|
398
|
+
* @param begin Index of the starting point to use for objective function
|
|
399
|
+
* evaluation.
|
|
400
|
+
* @param gradient Matrix to output gradient into.
|
|
401
|
+
* @param batchSize Number of points to be passed at a time to use for
|
|
402
|
+
* objective function evaluation.
|
|
403
|
+
*/
|
|
404
|
+
|
|
405
|
+
typename MatType::elem_type EvaluateWithGradient(const MatType& parameters,
|
|
406
|
+
const size_t begin,
|
|
407
|
+
MatType& gradient,
|
|
408
|
+
const size_t batchSize);
|
|
409
|
+
|
|
410
|
+
/**
|
|
411
|
+
* Note: this function is implemented so that it can be used by ensmallen's
|
|
412
|
+
* optimizers. It's not generally meant to be used otherwise.
|
|
413
|
+
*
|
|
414
|
+
* Evaluate the gradient of the network with the given parameters,
|
|
415
|
+
* and with respect to only a number of points in the dataset. This is useful
|
|
416
|
+
* for optimizers such as SGD, which require a separable objective function.
|
|
417
|
+
*
|
|
418
|
+
* @param parameters Matrix of the model parameters to be optimized.
|
|
419
|
+
* @param begin Index of the starting point to use for objective function
|
|
420
|
+
* gradient evaluation.
|
|
421
|
+
* @param gradient Matrix to output gradient into.
|
|
422
|
+
* @param batchSize Number of points to be processed as a batch for objective
|
|
423
|
+
* function gradient evaluation.
|
|
424
|
+
*/
|
|
425
|
+
void Gradient(const MatType& parameters,
|
|
426
|
+
const size_t begin,
|
|
427
|
+
MatType& gradient,
|
|
428
|
+
const size_t batchSize);
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
/**
|
|
432
|
+
* Note: this function is implemented so that it can be used by ensmallen's
|
|
433
|
+
* optimizers. It's not generally meant to be used otherwise.
|
|
434
|
+
*
|
|
435
|
+
* Return the number of separable functions (the number of predictor points).
|
|
436
|
+
*/
|
|
437
|
+
size_t NumFunctions() const { return responses.n_cols; }
|
|
438
|
+
|
|
439
|
+
/**
|
|
440
|
+
* Note: this function is implemented so that it can be used by ensmallen's
|
|
441
|
+
* optimizers. It's not generally meant to be used otherwise.
|
|
442
|
+
*
|
|
443
|
+
* Shuffle the order of function visitation. (This is equivalent to shuffling
|
|
444
|
+
* the dataset during training.)
|
|
445
|
+
*/
|
|
446
|
+
void Shuffle();
|
|
447
|
+
|
|
448
|
+
/**
|
|
449
|
+
* Prepare the network for training on the given data.
|
|
450
|
+
*
|
|
451
|
+
* This function won't actually trigger the training process, and is
|
|
452
|
+
* generally only useful internally.
|
|
453
|
+
*
|
|
454
|
+
* @param predictors Input data variables.
|
|
455
|
+
* @param responses Outputs results from input data variables.
|
|
456
|
+
*/
|
|
457
|
+
void ResetData(MatType predictors, MatType responses);
|
|
458
|
+
|
|
459
|
+
private:
|
|
460
|
+
// Helper functions.
|
|
461
|
+
|
|
462
|
+
void AddLayer(size_t nodeId)
|
|
463
|
+
{
|
|
464
|
+
layerGradients.push_back(MatType());
|
|
465
|
+
childrenList.insert({ nodeId, {} });
|
|
466
|
+
parentsList.insert({ nodeId, {} });
|
|
467
|
+
|
|
468
|
+
if (network.size() > 1)
|
|
469
|
+
{
|
|
470
|
+
layerOutputs.push_back(MatType());
|
|
471
|
+
layerInputs.push_back(MatType());
|
|
472
|
+
layerDeltas.push_back(MatType());
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
validOutputDimensions = false;
|
|
476
|
+
graphIsSet = false;
|
|
477
|
+
layerMemoryIsSet = false;
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
// Use the InitializationPolicy to initialize all the weights in the network.
|
|
481
|
+
void InitializeWeights();
|
|
482
|
+
|
|
483
|
+
// Call each layers `CustomInitialize`
|
|
484
|
+
void CustomInitialize(MatType& W, const size_t elements);
|
|
485
|
+
|
|
486
|
+
// Make the memory of each layer point to the right place, by calling
|
|
487
|
+
// SetWeights() on each layer.
|
|
488
|
+
void SetLayerMemory();
|
|
489
|
+
|
|
490
|
+
/**
|
|
491
|
+
* Ensure that all the there are no cycles in the graph and that the graph
|
|
492
|
+
* has one input and one output only. It will topologically sort the network
|
|
493
|
+
* for the forward and backward passes. This will set `graphIsSet` to true
|
|
494
|
+
* only if the graph is valid and topologically sorted.
|
|
495
|
+
*/
|
|
496
|
+
void CheckGraph();
|
|
497
|
+
|
|
498
|
+
/**
|
|
499
|
+
* Ensure that all the locally-cached information about the network is valid,
|
|
500
|
+
* all parameter memory is initialized, and we can make forward and backward
|
|
501
|
+
* passes.
|
|
502
|
+
*
|
|
503
|
+
* @param functionName Name of function to use if an exception is thrown.
|
|
504
|
+
* @param inputDimensionality Given dimensionality of the input data.
|
|
505
|
+
* @param setMode If true, the mode of the network will be set to the
|
|
506
|
+
* parameter given in `training`. Otherwise the mode of the network is
|
|
507
|
+
* left unmodified.
|
|
508
|
+
* @param training Mode to set the network to; `true` indicates the network
|
|
509
|
+
* should be set to training mode; `false` indicates testing mode.
|
|
510
|
+
*/
|
|
511
|
+
void CheckNetwork(const std::string& functionName,
|
|
512
|
+
const size_t inputDimensionality,
|
|
513
|
+
const bool setMode = false,
|
|
514
|
+
const bool training = false);
|
|
515
|
+
|
|
516
|
+
/**
|
|
517
|
+
* This computes the dimensions of each layer held by the network, and the
|
|
518
|
+
* output dimensions are set to the output dimensions of the last layer.
|
|
519
|
+
*
|
|
520
|
+
* Input dimensions of nodes that have multiple parents are also
|
|
521
|
+
* calculated here, based on their connection type.
|
|
522
|
+
*/
|
|
523
|
+
void ComputeOutputDimensions();
|
|
524
|
+
|
|
525
|
+
/**
|
|
526
|
+
* Compute the input dimensions of a concatenation layer with multiple
|
|
527
|
+
* parents.
|
|
528
|
+
*
|
|
529
|
+
* @param layerId The layer that has multiple parent layers.
|
|
530
|
+
*/
|
|
531
|
+
void ComputeConcatDimensions(size_t layerId);
|
|
532
|
+
|
|
533
|
+
/**
|
|
534
|
+
* Compute the input dimensions of an addition layer with multiple parents.
|
|
535
|
+
*
|
|
536
|
+
* @param layerId The layer that has multiple parent layers.
|
|
537
|
+
*/
|
|
538
|
+
void ComputeAdditionDimensions(size_t layerId);
|
|
539
|
+
|
|
540
|
+
/**
|
|
541
|
+
* Set the input and output dimensions of each layer in the network correctly.
|
|
542
|
+
* The size of the input is taken, in case `inputDimensions` has not been set
|
|
543
|
+
* otherwise (e.g. via `InputDimensions()`). If `InputDimensions()` is not
|
|
544
|
+
* empty, then `inputDimensionality` is ignored.
|
|
545
|
+
*/
|
|
546
|
+
void UpdateDimensions(const std::string& functionName,
|
|
547
|
+
const size_t inputDimensionality = 0);
|
|
548
|
+
|
|
549
|
+
/**
|
|
550
|
+
* Set the weights of the layers
|
|
551
|
+
*/
|
|
552
|
+
void SetWeights(const MatType& weightsIn);
|
|
553
|
+
|
|
554
|
+
/**
|
|
555
|
+
* Initialize memory that will be used by each layer for the forward pass,
|
|
556
|
+
* assuming that the input will have the given `batchSize`. When `Forward()`
|
|
557
|
+
* is called, `layerOutputMatrix` is allocated with enough memory to fit
|
|
558
|
+
* the outputs of each layer and to hold concatenations of output layers
|
|
559
|
+
* as inputs into layers as specified by `Add()` and `Connect()`.
|
|
560
|
+
*/
|
|
561
|
+
void InitializeForwardPassMemory(const size_t batchSize);
|
|
562
|
+
|
|
563
|
+
/**
|
|
564
|
+
* TODO: explain how the backward pass memory works.
|
|
565
|
+
*/
|
|
566
|
+
void InitializeBackwardPassMemory(const size_t batchSize);
|
|
567
|
+
|
|
568
|
+
/**
|
|
569
|
+
* Initialize memory for the gradient pass. This sets the internal aliases
|
|
570
|
+
* `layerGradients` appropriately using the memory from the given `gradient`,
|
|
571
|
+
* such that each layer will output its gradient (via its `Gradient()` method)
|
|
572
|
+
* into the appropriate member of `layerGradients`.
|
|
573
|
+
*/
|
|
574
|
+
void InitializeGradientPassMemory(MatType& gradient);
|
|
575
|
+
|
|
576
|
+
/**
|
|
577
|
+
* Compute the loss that should be added to the objective for each layer.
|
|
578
|
+
*/
|
|
579
|
+
double Loss() const;
|
|
580
|
+
|
|
581
|
+
/**
|
|
582
|
+
* Check if the optimizer has MaxIterations() parameter, if it does then check
|
|
583
|
+
* if its value is less than the number of datapoints in the dataset.
|
|
584
|
+
*
|
|
585
|
+
* @tparam OptimizerType Type of optimizer to use to train the model.
|
|
586
|
+
* @param optimizer optimizer used in the training process.
|
|
587
|
+
* @param samples Number of datapoints in the dataset.
|
|
588
|
+
*/
|
|
589
|
+
template<typename OptimizerType>
|
|
590
|
+
std::enable_if_t<
|
|
591
|
+
ens::traits::HasMaxIterationsSignature<OptimizerType>::value, void>
|
|
592
|
+
WarnMessageMaxIterations(OptimizerType& optimizer, size_t samples) const;
|
|
593
|
+
|
|
594
|
+
/**
|
|
595
|
+
* Check if the optimizer has MaxIterations() parameter; if it doesn't then
|
|
596
|
+
* simply return from the function.
|
|
597
|
+
*
|
|
598
|
+
* @tparam OptimizerType Type of optimizer to use to train the model.
|
|
599
|
+
* @param optimizer optimizer used in the training process.
|
|
600
|
+
* @param samples Number of datapoints in the dataset.
|
|
601
|
+
*/
|
|
602
|
+
template<typename OptimizerType>
|
|
603
|
+
std::enable_if_t<
|
|
604
|
+
!ens::traits::HasMaxIterationsSignature<OptimizerType>::value, void>
|
|
605
|
+
WarnMessageMaxIterations(OptimizerType& optimizer, size_t samples) const;
|
|
606
|
+
|
|
607
|
+
// Instantiated output layer used to evaluate the network.
|
|
608
|
+
OutputLayerType outputLayer;
|
|
609
|
+
|
|
610
|
+
// Instantiated InitializationRule object for initializing the network
|
|
611
|
+
// parameter.
|
|
612
|
+
InitializationRuleType initializeRule;
|
|
613
|
+
|
|
614
|
+
// The internally-held network, sorted in the order that the user
|
|
615
|
+
// specified when using `Add()`
|
|
616
|
+
std::vector<Layer<MatType>*> network;
|
|
617
|
+
|
|
618
|
+
// The internally-held network, sorted topologically when `CheckNetwork`
|
|
619
|
+
// is called if the graph is valid.
|
|
620
|
+
std::vector<size_t> sortedNetwork;
|
|
621
|
+
|
|
622
|
+
// The internally-held map of nodes that holds its edges to outgoing nodes.
|
|
623
|
+
// Uses network indices as keys.
|
|
624
|
+
std::unordered_map<size_t, std::vector<size_t>> childrenList;
|
|
625
|
+
|
|
626
|
+
// The internally-held map of nodes that holds its edges to incoming nodes.
|
|
627
|
+
// Uses network indices as keys.
|
|
628
|
+
std::unordered_map<size_t, std::vector<size_t>> parentsList;
|
|
629
|
+
|
|
630
|
+
// The internally-held map of what axes to concatenate along for each layer
|
|
631
|
+
// with multiple inputs
|
|
632
|
+
// Uses network indices as keys.
|
|
633
|
+
std::unordered_map<size_t, size_t> layerAxes;
|
|
634
|
+
|
|
635
|
+
// Connection type for some node that should have multiple parent nodes.
|
|
636
|
+
// If this exists for a layer with <= 1 parent, it gets ignored.
|
|
637
|
+
std::unordered_map<size_t, ConnectionTypes> layerConnections;
|
|
638
|
+
|
|
639
|
+
// Map layer index in network to layer index in sortedNetwork
|
|
640
|
+
// Uses network indices as keys.
|
|
641
|
+
std::unordered_map<size_t, size_t> sortedIndices;
|
|
642
|
+
|
|
643
|
+
/**
|
|
644
|
+
* Matrix of (trainable) parameters. Each weight here corresponds to a layer,
|
|
645
|
+
* and each layer's `parameters` member is an alias pointing to parameters in
|
|
646
|
+
* this matrix.
|
|
647
|
+
*
|
|
648
|
+
* Note: although each layer may have its own MatType and MatType,
|
|
649
|
+
* ensmallen optimization requires everything to be stored in one matrix
|
|
650
|
+
* object, so we have chosen MatType. This could be made more flexible
|
|
651
|
+
* with a "wrapper" class implementing the Armadillo API.
|
|
652
|
+
*/
|
|
653
|
+
MatType parameters;
|
|
654
|
+
|
|
655
|
+
// Dimensions of input data.
|
|
656
|
+
std::vector<size_t> inputDimensions;
|
|
657
|
+
|
|
658
|
+
//! The matrix of data points (predictors). This member is empty, except
|
|
659
|
+
//! during training---we must store a local copy of the training data since
|
|
660
|
+
//! the ensmallen optimizer will not provide training data.
|
|
661
|
+
MatType predictors;
|
|
662
|
+
|
|
663
|
+
//! The matrix of responses to the input data points. This member is empty,
|
|
664
|
+
//! except during training.
|
|
665
|
+
MatType responses;
|
|
666
|
+
|
|
667
|
+
// Locally-stored output of the network from a forward pass; used by the
|
|
668
|
+
// backward pass.
|
|
669
|
+
MatType networkOutput;
|
|
670
|
+
//! Locally-stored output of the backward pass; used by the gradient pass.
|
|
671
|
+
MatType error;
|
|
672
|
+
|
|
673
|
+
// This matrix stores all of the outputs of each layer when `Forward()` is
|
|
674
|
+
// called. See `InitializeForwardPassMemory()`.
|
|
675
|
+
MatType layerOutputMatrix;
|
|
676
|
+
// These are aliases of `layerOutputMatrix` for the input of each layer
|
|
677
|
+
// Ordered in the same way as `sortedNetwork`.
|
|
678
|
+
std::vector<MatType> layerInputs;
|
|
679
|
+
// These are aliases of `layerOutputMatrix` for the output of each layer.
|
|
680
|
+
// Ordered in the same way as `sortedNetwork`.
|
|
681
|
+
std::vector<MatType> layerOutputs;
|
|
682
|
+
|
|
683
|
+
// Memory for the backward pass.
|
|
684
|
+
MatType layerDeltaMatrix;
|
|
685
|
+
|
|
686
|
+
// Needed in case the first layer is a `MultiLayer` so that its
|
|
687
|
+
// gradients are calculated.
|
|
688
|
+
MatType networkDelta;
|
|
689
|
+
|
|
690
|
+
// A layers delta Loss w.r.t delta Outputs.
|
|
691
|
+
std::vector<MatType> layerDeltas;
|
|
692
|
+
|
|
693
|
+
// A layers output deltas. Uses sortedNetwork indices as keys.
|
|
694
|
+
std::unordered_map<size_t, MatType> outputDeltas;
|
|
695
|
+
|
|
696
|
+
// A layers input deltas. Uses sortedNetwork indices as keys.
|
|
697
|
+
std::unordered_map<size_t, MatType> inputDeltas;
|
|
698
|
+
|
|
699
|
+
// A layers accumulated deltas, for layers with multiple children.
|
|
700
|
+
// Uses sortedNetwork indices as keys.
|
|
701
|
+
std::unordered_map<size_t, MatType> accumulatedDeltas;
|
|
702
|
+
|
|
703
|
+
// Gradient aliases for each layer.
|
|
704
|
+
std::vector<MatType> layerGradients;
|
|
705
|
+
|
|
706
|
+
// Cache of rows for concatenation. Useful for forward / backward passes.
|
|
707
|
+
std::unordered_map<size_t, size_t> rowsCache;
|
|
708
|
+
// Cache of slices for concatenation.
|
|
709
|
+
std::unordered_map<size_t, size_t> slicesCache;
|
|
710
|
+
|
|
711
|
+
// If true, each layer has its inputDimensions properly set.
|
|
712
|
+
bool validOutputDimensions;
|
|
713
|
+
|
|
714
|
+
// If true, the graph is valid and has been topologically sorted.
|
|
715
|
+
bool graphIsSet;
|
|
716
|
+
|
|
717
|
+
// If true, each layer has its activation/gradient memory properly set
|
|
718
|
+
// for the forward/backward pass.
|
|
719
|
+
bool layerMemoryIsSet;
|
|
720
|
+
|
|
721
|
+
bool extraDeltasAllocated;
|
|
722
|
+
};
|
|
723
|
+
|
|
724
|
+
} // namespace mlpack
|
|
725
|
+
|
|
726
|
+
#include "dag_network_impl.hpp"
|
|
727
|
+
|
|
728
|
+
#endif
|