mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlpack/__init__.py +4 -4
- mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
- mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
- mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
- mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
- mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
- mlpack/cf.cp313-win_amd64.pyd +0 -0
- mlpack/dbscan.cp313-win_amd64.pyd +0 -0
- mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
- mlpack/det.cp313-win_amd64.pyd +0 -0
- mlpack/emst.cp313-win_amd64.pyd +0 -0
- mlpack/fastmks.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
- mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
- mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
- mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
- mlpack/image_converter.cp313-win_amd64.pyd +0 -0
- mlpack/include/mlpack/base.hpp +1 -0
- mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
- mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
- mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
- mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
- mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
- mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
- mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
- mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
- mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
- mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
- mlpack/include/mlpack/core/data/binarize.hpp +0 -2
- mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
- mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
- mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
- mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/data.hpp +6 -4
- mlpack/include/mlpack/core/data/data_options.hpp +341 -18
- mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
- mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
- mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
- mlpack/include/mlpack/core/data/extension.hpp +2 -4
- mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
- mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
- mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
- mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
- mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
- mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
- mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
- mlpack/include/mlpack/core/data/image_options.hpp +257 -0
- mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
- mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
- mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
- mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
- mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
- mlpack/include/mlpack/core/data/imputer.hpp +41 -49
- mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
- mlpack/include/mlpack/core/data/load.hpp +49 -233
- mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
- mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
- mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
- mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
- mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
- mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
- mlpack/include/mlpack/core/data/load_image.hpp +71 -43
- mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
- mlpack/include/mlpack/core/data/load_model.hpp +62 -0
- mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
- mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
- mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
- mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
- mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
- mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
- mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
- mlpack/include/mlpack/core/data/save.hpp +26 -120
- mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
- mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
- mlpack/include/mlpack/core/data/save_image.hpp +82 -42
- mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
- mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
- mlpack/include/mlpack/core/data/save_model.hpp +61 -0
- mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
- mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
- mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
- mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
- mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
- mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
- mlpack/include/mlpack/core/data/split_data.hpp +6 -8
- mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
- mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
- mlpack/include/mlpack/core/data/text_options.hpp +91 -53
- mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
- mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
- mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
- mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
- mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
- mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
- mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
- mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
- mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
- mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
- mlpack/include/mlpack/core/math/random.hpp +19 -5
- mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
- mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
- mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
- mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
- mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
- mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
- mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
- mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
- mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
- mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
- mlpack/include/mlpack/core/util/forward.hpp +0 -2
- mlpack/include/mlpack/core/util/param.hpp +4 -4
- mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
- mlpack/include/mlpack/core/util/using.hpp +29 -2
- mlpack/include/mlpack/core/util/version.hpp +5 -3
- mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
- mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
- mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
- mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
- mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
- mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
- mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
- mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
- mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
- mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
- mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
- mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
- mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
- mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
- mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
- mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
- mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
- mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
- mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
- mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
- mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
- mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
- mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
- mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
- mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
- mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
- mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
- mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
- mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
- mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
- mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
- mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
- mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
- mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
- mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
- mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
- mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
- mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
- mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
- mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
- mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
- mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
- mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
- mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
- mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
- mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
- mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
- mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
- mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
- mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
- mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
- mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
- mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
- mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
- mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
- mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
- mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
- mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
- mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
- mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
- mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
- mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
- mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
- mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
- mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
- mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
- mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
- mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
- mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
- mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
- mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
- mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
- mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
- mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
- mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
- mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
- mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
- mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
- mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
- mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
- mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
- mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
- mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
- mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
- mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
- mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
- mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
- mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
- mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
- mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
- mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
- mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
- mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
- mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
- mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
- mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
- mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
- mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
- mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
- mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
- mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
- mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
- mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
- mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
- mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
- mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
- mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
- mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
- mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
- mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
- mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
- mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
- mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
- mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
- mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
- mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
- mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
- mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
- mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
- mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
- mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
- mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
- mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
- mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
- mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
- mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
- mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
- mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
- mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
- mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
- mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
- mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
- mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
- mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
- mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
- mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
- mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
- mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
- mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
- mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
- mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
- mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
- mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
- mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
- mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
- mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
- mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
- mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
- mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
- mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
- mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
- mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
- mlpack/include/mlpack/namespace_compat.hpp +1 -0
- mlpack/include/mlpack/prereqs.hpp +1 -0
- mlpack/kde.cp313-win_amd64.pyd +0 -0
- mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
- mlpack/kfn.cp313-win_amd64.pyd +0 -0
- mlpack/kmeans.cp313-win_amd64.pyd +0 -0
- mlpack/knn.cp313-win_amd64.pyd +0 -0
- mlpack/krann.cp313-win_amd64.pyd +0 -0
- mlpack/lars.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
- mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
- mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
- mlpack/lmnn.cp313-win_amd64.pyd +0 -0
- mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
- mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
- mlpack/lsh.cp313-win_amd64.pyd +0 -0
- mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
- mlpack/nbc.cp313-win_amd64.pyd +0 -0
- mlpack/nca.cp313-win_amd64.pyd +0 -0
- mlpack/nmf.cp313-win_amd64.pyd +0 -0
- mlpack/pca.cp313-win_amd64.pyd +0 -0
- mlpack/perceptron.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
- mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
- mlpack/radical.cp313-win_amd64.pyd +0 -0
- mlpack/random_forest.cp313-win_amd64.pyd +0 -0
- mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
- mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
- mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
- mlpack/include/mlpack/core/data/format.hpp +0 -31
- mlpack/include/mlpack/core/data/image_info.hpp +0 -102
- mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
- mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
- mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
- mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
- mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
- mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
- mlpack/include/mlpack/core/data/types.hpp +0 -61
- mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
- mlpack/include/mlpack/core/data/utilities.hpp +0 -158
- mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
- mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
- mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
- mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
- mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
- mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
- {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
- /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
- /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
|
@@ -34,28 +34,28 @@ DDPG<
|
|
|
34
34
|
NoiseType,
|
|
35
35
|
UpdaterType,
|
|
36
36
|
ReplayType
|
|
37
|
-
>::DDPG(TrainingConfig&
|
|
38
|
-
QNetworkType&
|
|
39
|
-
PolicyNetworkType&
|
|
40
|
-
NoiseType&
|
|
41
|
-
ReplayType&
|
|
42
|
-
UpdaterType
|
|
43
|
-
UpdaterType
|
|
44
|
-
EnvironmentType
|
|
45
|
-
config(
|
|
46
|
-
learningQNetwork(
|
|
47
|
-
policyNetwork(
|
|
48
|
-
noise(
|
|
49
|
-
replayMethod(
|
|
50
|
-
qNetworkUpdater(std::move(
|
|
37
|
+
>::DDPG(TrainingConfig& configIn,
|
|
38
|
+
QNetworkType& learningQNetworkIn,
|
|
39
|
+
PolicyNetworkType& policyNetworkIn,
|
|
40
|
+
NoiseType& noiseIn,
|
|
41
|
+
ReplayType& replayMethodIn,
|
|
42
|
+
UpdaterType qNetworkUpdaterIn,
|
|
43
|
+
UpdaterType policyNetworkUpdaterIn,
|
|
44
|
+
EnvironmentType environmentIn):
|
|
45
|
+
config(configIn),
|
|
46
|
+
learningQNetwork(learningQNetworkIn),
|
|
47
|
+
policyNetwork(policyNetworkIn),
|
|
48
|
+
noise(noiseIn),
|
|
49
|
+
replayMethod(replayMethodIn),
|
|
50
|
+
qNetworkUpdater(std::move(qNetworkUpdaterIn)),
|
|
51
51
|
#if ENS_VERSION_MAJOR >= 2
|
|
52
52
|
qNetworkUpdatePolicy(NULL),
|
|
53
53
|
#endif
|
|
54
|
-
policyNetworkUpdater(std::move(
|
|
54
|
+
policyNetworkUpdater(std::move(policyNetworkUpdaterIn)),
|
|
55
55
|
#if ENS_VERSION_MAJOR >= 2
|
|
56
56
|
policyNetworkUpdatePolicy(NULL),
|
|
57
57
|
#endif
|
|
58
|
-
environment(std::move(
|
|
58
|
+
environment(std::move(environmentIn)),
|
|
59
59
|
totalSteps(0),
|
|
60
60
|
deterministic(false)
|
|
61
61
|
{
|
|
@@ -121,7 +121,7 @@ class Acrobot
|
|
|
121
121
|
Acrobot(const size_t maxSteps = 500,
|
|
122
122
|
const double gravity = 9.81,
|
|
123
123
|
const double linkLength1 = 1.0,
|
|
124
|
-
const double linkLength2 = 1.0,
|
|
124
|
+
const double /* linkLength2 */ = 1.0,
|
|
125
125
|
const double linkMass1 = 1.0,
|
|
126
126
|
const double linkMass2 = 1.0,
|
|
127
127
|
const double linkCom1 = 0.5,
|
|
@@ -134,7 +134,7 @@ class Acrobot
|
|
|
134
134
|
maxSteps(maxSteps),
|
|
135
135
|
gravity(gravity),
|
|
136
136
|
linkLength1(linkLength1),
|
|
137
|
-
linkLength2(linkLength2),
|
|
137
|
+
//linkLength2(linkLength2),
|
|
138
138
|
linkMass1(linkMass1),
|
|
139
139
|
linkMass2(linkMass2),
|
|
140
140
|
linkCom1(linkCom1),
|
|
@@ -360,8 +360,8 @@ class Acrobot
|
|
|
360
360
|
//! Locally-stored length of link 1.
|
|
361
361
|
double linkLength1;
|
|
362
362
|
|
|
363
|
-
//! Locally-stored length of link 2.
|
|
364
|
-
double linkLength2;
|
|
363
|
+
//! Locally-stored length of link 2. (NOTE: not currently used)
|
|
364
|
+
//double linkLength2;
|
|
365
365
|
|
|
366
366
|
//! Locally-stored mass of link 1.
|
|
367
367
|
double linkMass1;
|
|
@@ -125,7 +125,7 @@ class CartPole
|
|
|
125
125
|
const double doneReward = 1.0) :
|
|
126
126
|
maxSteps(maxSteps),
|
|
127
127
|
gravity(gravity),
|
|
128
|
-
massCart(massCart),
|
|
128
|
+
//massCart(massCart),
|
|
129
129
|
massPole(massPole),
|
|
130
130
|
totalMass(massCart + massPole),
|
|
131
131
|
length(length),
|
|
@@ -247,8 +247,8 @@ class CartPole
|
|
|
247
247
|
//! Locally-stored gravity.
|
|
248
248
|
double gravity;
|
|
249
249
|
|
|
250
|
-
//! Locally-stored mass of the cart.
|
|
251
|
-
double massCart;
|
|
250
|
+
//! Locally-stored mass of the cart. NOTE: not currently used.
|
|
251
|
+
//double massCart;
|
|
252
252
|
|
|
253
253
|
//! Locally-stored mass of the pole.
|
|
254
254
|
double massPole;
|
|
@@ -104,7 +104,8 @@ class ContinuousDoublePoleCart
|
|
|
104
104
|
* @param l2 The length of the second pole.
|
|
105
105
|
* @param gravity The gravity constant.
|
|
106
106
|
* @param massCart The mass of the cart.
|
|
107
|
-
* @param forceMag The magnitude of the applied force.
|
|
107
|
+
* @param forceMag The magnitude of the applied force. NOTE: not currently
|
|
108
|
+
* used.
|
|
108
109
|
* @param tau The time interval.
|
|
109
110
|
* @param thetaThresholdRadians The maximum angle.
|
|
110
111
|
* @param xThreshold The maximum position.
|
|
@@ -118,7 +119,7 @@ class ContinuousDoublePoleCart
|
|
|
118
119
|
const double l2 = 0.05,
|
|
119
120
|
const double gravity = 9.8,
|
|
120
121
|
const double massCart = 1.0,
|
|
121
|
-
const double forceMag = 10.0,
|
|
122
|
+
const double /* forceMag */ = 10.0,
|
|
122
123
|
const double tau = 0.02,
|
|
123
124
|
const double thetaThresholdRadians = 36 * 2 *
|
|
124
125
|
3.1416 / 360,
|
|
@@ -131,7 +132,7 @@ class ContinuousDoublePoleCart
|
|
|
131
132
|
l2(l2),
|
|
132
133
|
gravity(gravity),
|
|
133
134
|
massCart(massCart),
|
|
134
|
-
forceMag(forceMag),
|
|
135
|
+
//forceMag(forceMag),
|
|
135
136
|
tau(tau),
|
|
136
137
|
thetaThresholdRadians(thetaThresholdRadians),
|
|
137
138
|
xThreshold(xThreshold),
|
|
@@ -340,8 +341,8 @@ class ContinuousDoublePoleCart
|
|
|
340
341
|
//! Locally-stored mass of the cart.
|
|
341
342
|
double massCart;
|
|
342
343
|
|
|
343
|
-
//! Locally-stored magnitude of the applied force.
|
|
344
|
-
double forceMag;
|
|
344
|
+
//! Locally-stored magnitude of the applied force. NOTE: not currently used.
|
|
345
|
+
//double forceMag;
|
|
345
346
|
|
|
346
347
|
//! Locally-stored time interval.
|
|
347
348
|
double tau;
|
|
@@ -111,18 +111,19 @@ class Pendulum
|
|
|
111
111
|
* @param maxAngularVelocity Maximum angular velocity.
|
|
112
112
|
* @param maxTorque Maximum torque.
|
|
113
113
|
* @param dt The differential value.
|
|
114
|
-
* @param doneReward The reward recieved by the agent on success.
|
|
114
|
+
* @param doneReward The reward recieved by the agent on success. NOTE: not
|
|
115
|
+
* currently used.
|
|
115
116
|
*/
|
|
116
117
|
Pendulum(const size_t maxSteps = 200,
|
|
117
118
|
const double maxAngularVelocity = 8,
|
|
118
119
|
const double maxTorque = 2.0,
|
|
119
120
|
const double dt = 0.05,
|
|
120
|
-
const double doneReward = 0.0) :
|
|
121
|
+
const double /* doneReward */ = 0.0) :
|
|
121
122
|
maxSteps(maxSteps),
|
|
122
123
|
maxAngularVelocity(maxAngularVelocity),
|
|
123
124
|
maxTorque(maxTorque),
|
|
124
125
|
dt(dt),
|
|
125
|
-
doneReward(doneReward),
|
|
126
|
+
//doneReward(doneReward),
|
|
126
127
|
stepsPerformed(0)
|
|
127
128
|
{ /* Nothing to do here */ }
|
|
128
129
|
|
|
@@ -254,8 +255,8 @@ class Pendulum
|
|
|
254
255
|
//! Locally-stored dt.
|
|
255
256
|
double dt;
|
|
256
257
|
|
|
257
|
-
//! Locally-stored done reward.
|
|
258
|
-
double doneReward;
|
|
258
|
+
//! Locally-stored done reward. NOTE: not currently used.
|
|
259
|
+
//double doneReward;
|
|
259
260
|
|
|
260
261
|
//! Locally-stored number of steps performed.
|
|
261
262
|
size_t stepsPerformed;
|
|
@@ -35,9 +35,9 @@ class AggregatedPolicy
|
|
|
35
35
|
* User should make sure its size is same as the number of policies
|
|
36
36
|
* and the sum of its element is equal to 1.
|
|
37
37
|
*/
|
|
38
|
-
AggregatedPolicy(std::vector<PolicyType>
|
|
38
|
+
AggregatedPolicy(std::vector<PolicyType> policiesIn,
|
|
39
39
|
const arma::colvec& distribution) :
|
|
40
|
-
policies(std::move(
|
|
40
|
+
policies(std::move(policiesIn)),
|
|
41
41
|
sampler({distribution})
|
|
42
42
|
{ /* Nothing to do here. */ };
|
|
43
43
|
|
|
@@ -29,21 +29,21 @@ QLearning<
|
|
|
29
29
|
UpdaterType,
|
|
30
30
|
PolicyType,
|
|
31
31
|
ReplayType
|
|
32
|
-
>::QLearning(TrainingConfig&
|
|
32
|
+
>::QLearning(TrainingConfig& configIn,
|
|
33
33
|
NetworkType& network,
|
|
34
|
-
PolicyType&
|
|
35
|
-
ReplayType&
|
|
36
|
-
UpdaterType
|
|
37
|
-
EnvironmentType
|
|
38
|
-
config(
|
|
34
|
+
PolicyType& policyIn,
|
|
35
|
+
ReplayType& replayMethodIn,
|
|
36
|
+
UpdaterType updaterIn,
|
|
37
|
+
EnvironmentType environmentIn):
|
|
38
|
+
config(configIn),
|
|
39
39
|
learningNetwork(network),
|
|
40
|
-
policy(
|
|
41
|
-
replayMethod(
|
|
42
|
-
updater(std::move(
|
|
40
|
+
policy(policyIn),
|
|
41
|
+
replayMethod(replayMethodIn),
|
|
42
|
+
updater(std::move(updaterIn)),
|
|
43
43
|
#if ENS_VERSION_MAJOR >= 2
|
|
44
44
|
updatePolicy(NULL),
|
|
45
45
|
#endif
|
|
46
|
-
environment(std::move(
|
|
46
|
+
environment(std::move(environmentIn)),
|
|
47
47
|
totalSteps(0),
|
|
48
48
|
deterministic(false)
|
|
49
49
|
{
|
|
@@ -78,21 +78,23 @@ class CategoricalDQN
|
|
|
78
78
|
vMax(config.VMax()),
|
|
79
79
|
isNoisy(isNoisy)
|
|
80
80
|
{
|
|
81
|
-
network.Add
|
|
82
|
-
network.Add
|
|
81
|
+
network.template Add<Linear>(h1);
|
|
82
|
+
network.template Add<ReLU>();
|
|
83
83
|
if (isNoisy)
|
|
84
84
|
{
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
network.Add
|
|
85
|
+
network.template Add<NoisyLinear>(h2);
|
|
86
|
+
noisyLayers.push_back(
|
|
87
|
+
dynamic_cast<NoisyLinear<>*>(network.Network().back()));
|
|
88
|
+
network.template Add<ReLU>();
|
|
89
|
+
network.template Add<NoisyLinear>(outputDim * atomSize);
|
|
90
|
+
noisyLayers.push_back(
|
|
91
|
+
dynamic_cast<NoisyLinear<>*>(network.Network().back()));
|
|
90
92
|
}
|
|
91
93
|
else
|
|
92
94
|
{
|
|
93
|
-
network.Add
|
|
94
|
-
network.Add
|
|
95
|
-
network.Add
|
|
95
|
+
network.template Add<Linear>(h2);
|
|
96
|
+
network.template Add<ReLU>();
|
|
97
|
+
network.template Add<Linear>(outputDim * atomSize);
|
|
96
98
|
}
|
|
97
99
|
}
|
|
98
100
|
|
|
@@ -104,16 +106,19 @@ class CategoricalDQN
|
|
|
104
106
|
* @param config Hyper-parameters for categorical dqn.
|
|
105
107
|
* @param isNoisy Specifies whether the network needs to be of type noisy.
|
|
106
108
|
*/
|
|
107
|
-
CategoricalDQN(NetworkType&
|
|
109
|
+
CategoricalDQN(NetworkType& networkIn,
|
|
108
110
|
TrainingConfig config,
|
|
109
111
|
const bool isNoisy = false):
|
|
110
|
-
network(std::move(
|
|
112
|
+
network(std::move(networkIn)),
|
|
111
113
|
atomSize(config.AtomSize()),
|
|
112
114
|
vMin(config.VMin()),
|
|
113
115
|
vMax(config.VMax()),
|
|
114
116
|
isNoisy(isNoisy)
|
|
115
117
|
{ /* Nothing to do here. */ }
|
|
116
118
|
|
|
119
|
+
// TODO: implement copy constructor and operator
|
|
120
|
+
CategoricalDQN(const CategoricalDQN& other) = delete;
|
|
121
|
+
|
|
117
122
|
/**
|
|
118
123
|
* Predict the responses to a given set of predictors. The responses will
|
|
119
124
|
* reflect the output of the given output layer as returned by the
|
|
@@ -176,10 +181,9 @@ class CategoricalDQN
|
|
|
176
181
|
*/
|
|
177
182
|
void ResetNoise()
|
|
178
183
|
{
|
|
179
|
-
for (size_t i = 0; i <
|
|
184
|
+
for (size_t i = 0; i < noisyLayers.size(); ++i)
|
|
180
185
|
{
|
|
181
|
-
|
|
182
|
-
(network.Network()[noisyLayerIndex[i]]))->ResetNoise();
|
|
186
|
+
noisyLayers[i]->ResetNoise();
|
|
183
187
|
}
|
|
184
188
|
}
|
|
185
189
|
|
|
@@ -228,10 +232,10 @@ class CategoricalDQN
|
|
|
228
232
|
bool isNoisy;
|
|
229
233
|
|
|
230
234
|
//! Locally-stored indexes of noisy layers in the network.
|
|
231
|
-
std::vector<
|
|
235
|
+
std::vector<NoisyLinear<>*> noisyLayers;
|
|
232
236
|
|
|
233
237
|
//! Locally-stored softmax activation function.
|
|
234
|
-
Softmax softMax;
|
|
238
|
+
Softmax<> softMax;
|
|
235
239
|
|
|
236
240
|
//! Locally-stored activations from softMax.
|
|
237
241
|
arma::mat activations;
|
|
@@ -56,17 +56,15 @@ class DuelingDQN
|
|
|
56
56
|
//! Default constructor.
|
|
57
57
|
DuelingDQN() : isNoisy(false)
|
|
58
58
|
{
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
advantageNetwork = new MultiLayer<arma::mat>();
|
|
64
|
-
concat = new Concat();
|
|
59
|
+
MultiLayer<arma::mat> featureNetwork;
|
|
60
|
+
MultiLayer<arma::mat> valueNetwork;
|
|
61
|
+
MultiLayer<arma::mat> advantageNetwork;
|
|
62
|
+
Concat concat;
|
|
65
63
|
|
|
66
|
-
concat
|
|
67
|
-
concat
|
|
64
|
+
concat.Add(std::move(valueNetwork));
|
|
65
|
+
concat.Add(std::move(advantageNetwork));
|
|
68
66
|
completeNetwork.Add(featureNetwork);
|
|
69
|
-
completeNetwork.Add(concat);
|
|
67
|
+
completeNetwork.Add(std::move(concat));
|
|
70
68
|
}
|
|
71
69
|
|
|
72
70
|
/**
|
|
@@ -88,43 +86,52 @@ class DuelingDQN
|
|
|
88
86
|
completeNetwork(outputLayer, init),
|
|
89
87
|
isNoisy(isNoisy)
|
|
90
88
|
{
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
featureNetwork
|
|
89
|
+
// TODO: this really ought to use a DAG network, but that's not implemented
|
|
90
|
+
// yet.
|
|
91
|
+
MultiLayer<arma::mat> featureNetwork;
|
|
92
|
+
featureNetwork.template Add<Linear>(h1);
|
|
93
|
+
featureNetwork.template Add<ReLU<>>();
|
|
94
94
|
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
MultiLayer<arma::mat> valueNetwork;
|
|
96
|
+
MultiLayer<arma::mat> advantageNetwork;
|
|
97
97
|
|
|
98
98
|
if (isNoisy)
|
|
99
99
|
{
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
valueNetwork
|
|
109
|
-
advantageNetwork
|
|
100
|
+
valueNetwork.Add<NoisyLinear>(h2);
|
|
101
|
+
noisyLayers.push_back(dynamic_cast<NoisyLinear<>*>(
|
|
102
|
+
valueNetwork.Network().back()));
|
|
103
|
+
|
|
104
|
+
advantageNetwork.Add<NoisyLinear>(h2);
|
|
105
|
+
noisyLayers.push_back(dynamic_cast<NoisyLinear<>*>(
|
|
106
|
+
advantageNetwork.Network().back()));
|
|
107
|
+
|
|
108
|
+
valueNetwork.template Add<ReLU>();
|
|
109
|
+
advantageNetwork.template Add<ReLU>();
|
|
110
|
+
|
|
111
|
+
valueNetwork.template Add<NoisyLinear>(1);
|
|
112
|
+
noisyLayers.push_back(dynamic_cast<NoisyLinear<>*>(
|
|
113
|
+
valueNetwork.Network().back()));
|
|
114
|
+
advantageNetwork.template Add<NoisyLinear>(outputDim);
|
|
115
|
+
noisyLayers.push_back(dynamic_cast<NoisyLinear<>*>(
|
|
116
|
+
advantageNetwork.Network().back()));
|
|
110
117
|
}
|
|
111
118
|
else
|
|
112
119
|
{
|
|
113
|
-
valueNetwork
|
|
114
|
-
valueNetwork
|
|
115
|
-
valueNetwork
|
|
120
|
+
valueNetwork.template Add<Linear>(h2);
|
|
121
|
+
valueNetwork.template Add<ReLU>();
|
|
122
|
+
valueNetwork.template Add<Linear>(1);
|
|
116
123
|
|
|
117
|
-
advantageNetwork
|
|
118
|
-
advantageNetwork
|
|
119
|
-
advantageNetwork
|
|
124
|
+
advantageNetwork.template Add<Linear>(h2);
|
|
125
|
+
advantageNetwork.template Add<ReLU>();
|
|
126
|
+
advantageNetwork.template Add<Linear>(outputDim);
|
|
120
127
|
}
|
|
121
128
|
|
|
122
|
-
concat
|
|
123
|
-
concat
|
|
124
|
-
concat
|
|
129
|
+
Concat concat;
|
|
130
|
+
concat.Add(std::move(valueNetwork));
|
|
131
|
+
concat.Add(std::move(advantageNetwork));
|
|
125
132
|
|
|
126
|
-
completeNetwork.Add(featureNetwork);
|
|
127
|
-
completeNetwork.Add(concat);
|
|
133
|
+
completeNetwork.Add(std::move(featureNetwork));
|
|
134
|
+
completeNetwork.Add(std::move(concat));
|
|
128
135
|
}
|
|
129
136
|
|
|
130
137
|
/**
|
|
@@ -135,35 +142,35 @@ class DuelingDQN
|
|
|
135
142
|
* @param valueNetwork The value network to be used by DuelingDQN class.
|
|
136
143
|
* @param isNoisy Specifies whether the network needs to be of type noisy.
|
|
137
144
|
*/
|
|
138
|
-
DuelingDQN(FeatureNetworkType
|
|
139
|
-
AdvantageNetworkType
|
|
140
|
-
ValueNetworkType
|
|
145
|
+
DuelingDQN(FeatureNetworkType&& featureNetwork,
|
|
146
|
+
AdvantageNetworkType&& advantageNetwork,
|
|
147
|
+
ValueNetworkType&& valueNetwork,
|
|
141
148
|
const bool isNoisy = false):
|
|
142
|
-
featureNetwork(featureNetwork),
|
|
143
|
-
advantageNetwork(advantageNetwork),
|
|
144
|
-
valueNetwork(valueNetwork),
|
|
145
149
|
isNoisy(isNoisy)
|
|
146
150
|
{
|
|
147
|
-
concat
|
|
148
|
-
concat
|
|
149
|
-
concat
|
|
150
|
-
completeNetwork.Add(featureNetwork);
|
|
151
|
-
completeNetwork.Add(concat);
|
|
151
|
+
Concat concat;
|
|
152
|
+
concat.Add(std::move(valueNetwork));
|
|
153
|
+
concat.Add(std::move(advantageNetwork));
|
|
154
|
+
completeNetwork.Add(std::move(featureNetwork));
|
|
155
|
+
completeNetwork.Add(std::move(concat));
|
|
152
156
|
}
|
|
153
157
|
|
|
154
|
-
|
|
155
|
-
DuelingDQN(const DuelingDQN&
|
|
156
|
-
{
|
|
158
|
+
// Copy constructor.
|
|
159
|
+
//DuelingDQN(const DuelingDQN& model) : isNoisy(false)
|
|
160
|
+
// {
|
|
161
|
+
// // Use copy operator.
|
|
162
|
+
// *this = model;
|
|
163
|
+
// }
|
|
157
164
|
|
|
158
|
-
|
|
159
|
-
void operator
|
|
160
|
-
{
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
165
|
+
// Copy assignment operator.
|
|
166
|
+
// void operator=(const DuelingDQN& model)
|
|
167
|
+
// {
|
|
168
|
+
// completeNetwork = model.completeNetwork;
|
|
169
|
+
|
|
170
|
+
// isNoisy = model.isNoisy;
|
|
171
|
+
// }
|
|
172
|
+
|
|
173
|
+
DuelingDQN(const DuelingDQN& model) = delete;
|
|
167
174
|
|
|
168
175
|
/**
|
|
169
176
|
* Predict the responses to a given set of predictors. The responses will
|
|
@@ -234,12 +241,9 @@ class DuelingDQN
|
|
|
234
241
|
*/
|
|
235
242
|
void ResetNoise()
|
|
236
243
|
{
|
|
237
|
-
for (size_t i = 0; i <
|
|
244
|
+
for (size_t i = 0; i < noisyLayers.size(); i++)
|
|
238
245
|
{
|
|
239
|
-
|
|
240
|
-
(valueNetwork->Network()[noisyLayerIndex[i]]))->ResetNoise();
|
|
241
|
-
dynamic_cast<NoisyLinear*>(
|
|
242
|
-
(advantageNetwork->Network()[noisyLayerIndex[i]]))->ResetNoise();
|
|
246
|
+
noisyLayers[i]->ResetNoise();
|
|
243
247
|
}
|
|
244
248
|
}
|
|
245
249
|
|
|
@@ -252,24 +256,12 @@ class DuelingDQN
|
|
|
252
256
|
//! Locally-stored complete network.
|
|
253
257
|
CompleteNetworkType completeNetwork;
|
|
254
258
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
//! Locally-stored feature network.
|
|
259
|
-
FeatureNetworkType* featureNetwork;
|
|
260
|
-
|
|
261
|
-
//! Locally-stored advantage network.
|
|
262
|
-
AdvantageNetworkType* advantageNetwork;
|
|
263
|
-
|
|
264
|
-
//! Locally-stored value network.
|
|
265
|
-
ValueNetworkType* valueNetwork;
|
|
259
|
+
// Pointers to noisy layers.
|
|
260
|
+
std::vector<NoisyLinear<>*> noisyLayers;
|
|
266
261
|
|
|
267
262
|
//! Locally-stored check for noisy network.
|
|
268
263
|
bool isNoisy;
|
|
269
264
|
|
|
270
|
-
//! Locally-stored indexes of noisy layers in the network.
|
|
271
|
-
std::vector<size_t> noisyLayerIndex;
|
|
272
|
-
|
|
273
265
|
//! Locally-stored actionValues of the network.
|
|
274
266
|
arma::mat actionValues;
|
|
275
267
|
|
|
@@ -58,21 +58,21 @@ class SimpleDQN
|
|
|
58
58
|
network(outputLayer, init),
|
|
59
59
|
isNoisy(isNoisy)
|
|
60
60
|
{
|
|
61
|
-
network.Add
|
|
62
|
-
network.Add
|
|
61
|
+
network.template Add<Linear>(h1);
|
|
62
|
+
network.template Add<ReLU>();
|
|
63
63
|
if (isNoisy)
|
|
64
64
|
{
|
|
65
65
|
noisyLayerIndex.push_back(network.Network().size());
|
|
66
|
-
network.Add
|
|
67
|
-
network.Add
|
|
66
|
+
network.template Add<NoisyLinear>(h2);
|
|
67
|
+
network.template Add<ReLU>();
|
|
68
68
|
noisyLayerIndex.push_back(network.Network().size());
|
|
69
|
-
network.Add
|
|
69
|
+
network.template Add<NoisyLinear>(outputDim);
|
|
70
70
|
}
|
|
71
71
|
else
|
|
72
72
|
{
|
|
73
|
-
network.Add
|
|
74
|
-
network.Add
|
|
75
|
-
network.Add
|
|
73
|
+
network.template Add<Linear>(h2);
|
|
74
|
+
network.template Add<ReLU>();
|
|
75
|
+
network.template Add<Linear>(outputDim);
|
|
76
76
|
}
|
|
77
77
|
}
|
|
78
78
|
|
|
@@ -129,7 +129,7 @@ class SimpleDQN
|
|
|
129
129
|
{
|
|
130
130
|
for (size_t i = 0; i < noisyLayerIndex.size(); i++)
|
|
131
131
|
{
|
|
132
|
-
dynamic_cast<NoisyLinear
|
|
132
|
+
dynamic_cast<NoisyLinear<>*>(
|
|
133
133
|
network.Network()[noisyLayerIndex[i]])->ResetNoise();
|
|
134
134
|
}
|
|
135
135
|
}
|
|
@@ -32,26 +32,26 @@ SAC<
|
|
|
32
32
|
PolicyNetworkType,
|
|
33
33
|
UpdaterType,
|
|
34
34
|
ReplayType
|
|
35
|
-
>::SAC(TrainingConfig&
|
|
36
|
-
QNetworkType&
|
|
37
|
-
PolicyNetworkType&
|
|
38
|
-
ReplayType&
|
|
39
|
-
UpdaterType
|
|
40
|
-
UpdaterType
|
|
41
|
-
EnvironmentType
|
|
42
|
-
config(
|
|
43
|
-
learningQ1Network(
|
|
44
|
-
policyNetwork(
|
|
45
|
-
replayMethod(
|
|
46
|
-
qNetworkUpdater(std::move(
|
|
35
|
+
>::SAC(TrainingConfig& configIn,
|
|
36
|
+
QNetworkType& learningQ1NetworkIn,
|
|
37
|
+
PolicyNetworkType& policyNetworkIn,
|
|
38
|
+
ReplayType& replayMethodIn,
|
|
39
|
+
UpdaterType qNetworkUpdaterIn,
|
|
40
|
+
UpdaterType policyNetworkUpdaterIn,
|
|
41
|
+
EnvironmentType environmentIn):
|
|
42
|
+
config(configIn),
|
|
43
|
+
learningQ1Network(learningQ1NetworkIn),
|
|
44
|
+
policyNetwork(policyNetworkIn),
|
|
45
|
+
replayMethod(replayMethodIn),
|
|
46
|
+
qNetworkUpdater(std::move(qNetworkUpdaterIn)),
|
|
47
47
|
#if ENS_VERSION_MAJOR >= 2
|
|
48
48
|
qNetworkUpdatePolicy(NULL),
|
|
49
49
|
#endif
|
|
50
|
-
policyNetworkUpdater(std::move(
|
|
50
|
+
policyNetworkUpdater(std::move(policyNetworkUpdaterIn)),
|
|
51
51
|
#if ENS_VERSION_MAJOR >= 2
|
|
52
52
|
policyNetworkUpdatePolicy(NULL),
|
|
53
53
|
#endif
|
|
54
|
-
environment(std::move(
|
|
54
|
+
environment(std::move(environmentIn)),
|
|
55
55
|
totalSteps(0),
|
|
56
56
|
deterministic(false)
|
|
57
57
|
{
|
|
@@ -32,26 +32,26 @@ TD3<
|
|
|
32
32
|
PolicyNetworkType,
|
|
33
33
|
UpdaterType,
|
|
34
34
|
ReplayType
|
|
35
|
-
>::TD3(TrainingConfig&
|
|
36
|
-
QNetworkType&
|
|
37
|
-
PolicyNetworkType&
|
|
38
|
-
ReplayType&
|
|
39
|
-
UpdaterType
|
|
40
|
-
UpdaterType
|
|
41
|
-
EnvironmentType
|
|
42
|
-
config(
|
|
43
|
-
learningQ1Network(
|
|
44
|
-
policyNetwork(
|
|
45
|
-
replayMethod(
|
|
46
|
-
qNetworkUpdater(std::move(
|
|
35
|
+
>::TD3(TrainingConfig& configIn,
|
|
36
|
+
QNetworkType& learningQ1NetworkIn,
|
|
37
|
+
PolicyNetworkType& policyNetworkIn,
|
|
38
|
+
ReplayType& replayMethodIn,
|
|
39
|
+
UpdaterType qNetworkUpdaterIn,
|
|
40
|
+
UpdaterType policyNetworkUpdaterIn,
|
|
41
|
+
EnvironmentType environmentIn):
|
|
42
|
+
config(configIn),
|
|
43
|
+
learningQ1Network(learningQ1NetworkIn),
|
|
44
|
+
policyNetwork(policyNetworkIn),
|
|
45
|
+
replayMethod(replayMethodIn),
|
|
46
|
+
qNetworkUpdater(std::move(qNetworkUpdaterIn)),
|
|
47
47
|
#if ENS_VERSION_MAJOR >= 2
|
|
48
48
|
qNetworkUpdatePolicy(NULL),
|
|
49
49
|
#endif
|
|
50
|
-
policyNetworkUpdater(std::move(
|
|
50
|
+
policyNetworkUpdater(std::move(policyNetworkUpdaterIn)),
|
|
51
51
|
#if ENS_VERSION_MAJOR >= 2
|
|
52
52
|
policyNetworkUpdatePolicy(NULL),
|
|
53
53
|
#endif
|
|
54
|
-
environment(std::move(
|
|
54
|
+
environment(std::move(environmentIn)),
|
|
55
55
|
totalSteps(0),
|
|
56
56
|
deterministic(false)
|
|
57
57
|
{
|
|
@@ -335,7 +335,7 @@ inline void SoftmaxRegressionFunction<MatType>::PartialGradient(
|
|
|
335
335
|
const size_t j,
|
|
336
336
|
GradType& gradient) const
|
|
337
337
|
{
|
|
338
|
-
gradient.zeros(
|
|
338
|
+
gradient.zeros(size(parameters));
|
|
339
339
|
|
|
340
340
|
DenseMatType probabilities;
|
|
341
341
|
GetProbabilitiesMatrix(parameters, probabilities, 0, data.n_cols);
|
|
@@ -451,7 +451,7 @@ inline double ParallelSGD<ExponentialBackoff>::Optimize(
|
|
|
451
451
|
// Get the stepsize for this iteration
|
|
452
452
|
double stepSize = decayPolicy.StepSize(i);
|
|
453
453
|
|
|
454
|
-
if (
|
|
454
|
+
if (Shuffle()) // Determine order of visitation.
|
|
455
455
|
std::shuffle(visitationOrder.begin(), visitationOrder.end(),
|
|
456
456
|
mlpack::RandGen());
|
|
457
457
|
|
|
@@ -31,6 +31,7 @@ namespace adaboost { using namespace mlpack; }
|
|
|
31
31
|
namespace amf { using namespace mlpack; }
|
|
32
32
|
namespace ann { using namespace mlpack; }
|
|
33
33
|
namespace cf { using namespace mlpack; }
|
|
34
|
+
namespace data { using namespace mlpack; }
|
|
34
35
|
namespace dbscan { using namespace mlpack; }
|
|
35
36
|
namespace det { using namespace mlpack; }
|
|
36
37
|
namespace emst { using namespace mlpack; }
|
|
@@ -32,6 +32,7 @@
|
|
|
32
32
|
#include <mlpack/core/cereal/pointer_vector_wrapper.hpp>
|
|
33
33
|
#include <mlpack/core/cereal/pointer_wrapper.hpp>
|
|
34
34
|
#include <mlpack/core/cereal/template_class_version.hpp>
|
|
35
|
+
#include <mlpack/core/cereal/low_precision.hpp>
|
|
35
36
|
#include <mlpack/core/data/has_serialize.hpp>
|
|
36
37
|
|
|
37
38
|
// Include ready to use utility function to check sizes of datasets.
|