mlpack 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (415) hide show
  1. mlpack/__init__.py +4 -4
  2. mlpack/adaboost_classify.cp313-win_amd64.pyd +0 -0
  3. mlpack/adaboost_probabilities.cp313-win_amd64.pyd +0 -0
  4. mlpack/adaboost_train.cp313-win_amd64.pyd +0 -0
  5. mlpack/approx_kfn.cp313-win_amd64.pyd +0 -0
  6. mlpack/arma_numpy.cp313-win_amd64.pyd +0 -0
  7. mlpack/bayesian_linear_regression.cp313-win_amd64.pyd +0 -0
  8. mlpack/cf.cp313-win_amd64.pyd +0 -0
  9. mlpack/dbscan.cp313-win_amd64.pyd +0 -0
  10. mlpack/decision_tree.cp313-win_amd64.pyd +0 -0
  11. mlpack/det.cp313-win_amd64.pyd +0 -0
  12. mlpack/emst.cp313-win_amd64.pyd +0 -0
  13. mlpack/fastmks.cp313-win_amd64.pyd +0 -0
  14. mlpack/gmm_generate.cp313-win_amd64.pyd +0 -0
  15. mlpack/gmm_probability.cp313-win_amd64.pyd +0 -0
  16. mlpack/gmm_train.cp313-win_amd64.pyd +0 -0
  17. mlpack/hmm_generate.cp313-win_amd64.pyd +0 -0
  18. mlpack/hmm_loglik.cp313-win_amd64.pyd +0 -0
  19. mlpack/hmm_train.cp313-win_amd64.pyd +0 -0
  20. mlpack/hmm_viterbi.cp313-win_amd64.pyd +0 -0
  21. mlpack/hoeffding_tree.cp313-win_amd64.pyd +0 -0
  22. mlpack/image_converter.cp313-win_amd64.pyd +0 -0
  23. mlpack/include/mlpack/base.hpp +1 -0
  24. mlpack/include/mlpack/core/arma_extend/find_nan.hpp +63 -0
  25. mlpack/include/mlpack/core/cereal/low_precision.hpp +48 -0
  26. mlpack/include/mlpack/core/cv/cv_base.hpp +11 -11
  27. mlpack/include/mlpack/core/cv/cv_base_impl.hpp +7 -7
  28. mlpack/include/mlpack/core/cv/k_fold_cv.hpp +4 -4
  29. mlpack/include/mlpack/core/cv/k_fold_cv_impl.hpp +4 -4
  30. mlpack/include/mlpack/core/cv/meta_info_extractor.hpp +10 -10
  31. mlpack/include/mlpack/core/cv/metrics/f1_impl.hpp +1 -1
  32. mlpack/include/mlpack/core/cv/metrics/facilities.hpp +2 -1
  33. mlpack/include/mlpack/core/cv/metrics/precision_impl.hpp +1 -1
  34. mlpack/include/mlpack/core/cv/metrics/r2_score_impl.hpp +1 -1
  35. mlpack/include/mlpack/core/cv/metrics/silhouette_score_impl.hpp +1 -1
  36. mlpack/include/mlpack/core/cv/simple_cv.hpp +4 -4
  37. mlpack/include/mlpack/core/cv/simple_cv_impl.hpp +2 -2
  38. mlpack/include/mlpack/core/data/binarize.hpp +0 -2
  39. mlpack/include/mlpack/core/data/check_categorical_param.hpp +0 -2
  40. mlpack/include/mlpack/core/data/combine_options.hpp +151 -0
  41. mlpack/include/mlpack/core/data/confusion_matrix.hpp +0 -2
  42. mlpack/include/mlpack/core/data/confusion_matrix_impl.hpp +0 -2
  43. mlpack/include/mlpack/core/data/data.hpp +6 -4
  44. mlpack/include/mlpack/core/data/data_options.hpp +341 -18
  45. mlpack/include/mlpack/core/data/dataset_mapper.hpp +3 -5
  46. mlpack/include/mlpack/core/data/dataset_mapper_impl.hpp +0 -2
  47. mlpack/include/mlpack/core/data/detect_file_type.hpp +34 -5
  48. mlpack/include/mlpack/core/data/detect_file_type_impl.hpp +185 -11
  49. mlpack/include/mlpack/core/data/extension.hpp +2 -4
  50. mlpack/include/mlpack/core/data/font8x8_basic.h +152 -0
  51. mlpack/include/mlpack/core/data/has_serialize.hpp +0 -2
  52. mlpack/include/mlpack/core/data/image_bounding_box.hpp +36 -0
  53. mlpack/include/mlpack/core/data/image_bounding_box_impl.hpp +155 -0
  54. mlpack/include/mlpack/core/data/image_layout.hpp +63 -0
  55. mlpack/include/mlpack/core/data/image_layout_impl.hpp +75 -0
  56. mlpack/include/mlpack/core/data/image_letterbox.hpp +116 -0
  57. mlpack/include/mlpack/core/data/image_options.hpp +257 -0
  58. mlpack/include/mlpack/core/data/image_resize_crop.hpp +113 -48
  59. mlpack/include/mlpack/core/data/imputation_methods/custom_imputation.hpp +16 -32
  60. mlpack/include/mlpack/core/data/imputation_methods/listwise_deletion.hpp +19 -29
  61. mlpack/include/mlpack/core/data/imputation_methods/mean_imputation.hpp +113 -44
  62. mlpack/include/mlpack/core/data/imputation_methods/median_imputation.hpp +44 -43
  63. mlpack/include/mlpack/core/data/imputer.hpp +41 -49
  64. mlpack/include/mlpack/core/data/is_naninf.hpp +0 -2
  65. mlpack/include/mlpack/core/data/load.hpp +49 -233
  66. mlpack/include/mlpack/core/data/load_arff.hpp +0 -2
  67. mlpack/include/mlpack/core/data/load_arff_impl.hpp +2 -4
  68. mlpack/include/mlpack/core/data/load_categorical.hpp +1 -4
  69. mlpack/include/mlpack/core/data/load_categorical_impl.hpp +10 -26
  70. mlpack/include/mlpack/core/data/load_dense.hpp +279 -0
  71. mlpack/include/mlpack/core/data/load_deprecated.hpp +466 -0
  72. mlpack/include/mlpack/core/data/load_image.hpp +71 -43
  73. mlpack/include/mlpack/core/data/load_impl.hpp +95 -274
  74. mlpack/include/mlpack/core/data/load_model.hpp +62 -0
  75. mlpack/include/mlpack/core/data/load_numeric.hpp +124 -87
  76. mlpack/include/mlpack/core/data/load_sparse.hpp +91 -0
  77. mlpack/include/mlpack/core/data/map_policies/datatype.hpp +0 -2
  78. mlpack/include/mlpack/core/data/map_policies/increment_policy.hpp +0 -2
  79. mlpack/include/mlpack/core/data/map_policies/map_policies.hpp +0 -1
  80. mlpack/include/mlpack/core/data/matrix_options.hpp +152 -20
  81. mlpack/include/mlpack/core/data/normalize_labels.hpp +0 -2
  82. mlpack/include/mlpack/core/data/normalize_labels_impl.hpp +0 -2
  83. mlpack/include/mlpack/core/data/one_hot_encoding.hpp +2 -4
  84. mlpack/include/mlpack/core/data/one_hot_encoding_impl.hpp +3 -5
  85. mlpack/include/mlpack/core/data/save.hpp +26 -120
  86. mlpack/include/mlpack/core/data/save_dense.hpp +42 -0
  87. mlpack/include/mlpack/core/data/save_deprecated.hpp +308 -0
  88. mlpack/include/mlpack/core/data/save_image.hpp +82 -42
  89. mlpack/include/mlpack/core/data/save_impl.hpp +60 -245
  90. mlpack/include/mlpack/core/data/save_matrix.hpp +45 -0
  91. mlpack/include/mlpack/core/data/save_model.hpp +61 -0
  92. mlpack/include/mlpack/core/data/save_numeric.hpp +60 -0
  93. mlpack/include/mlpack/core/data/save_sparse.hpp +44 -0
  94. mlpack/include/mlpack/core/data/scaler_methods/max_abs_scaler.hpp +0 -2
  95. mlpack/include/mlpack/core/data/scaler_methods/mean_normalization.hpp +2 -4
  96. mlpack/include/mlpack/core/data/scaler_methods/min_max_scaler.hpp +0 -2
  97. mlpack/include/mlpack/core/data/scaler_methods/pca_whitening.hpp +1 -3
  98. mlpack/include/mlpack/core/data/scaler_methods/standard_scaler.hpp +2 -4
  99. mlpack/include/mlpack/core/data/scaler_methods/zca_whitening.hpp +0 -2
  100. mlpack/include/mlpack/core/data/split_data.hpp +6 -8
  101. mlpack/include/mlpack/core/data/string_algorithms.hpp +0 -2
  102. mlpack/include/mlpack/core/data/string_encoding.hpp +0 -2
  103. mlpack/include/mlpack/core/data/string_encoding_dictionary.hpp +0 -2
  104. mlpack/include/mlpack/core/data/string_encoding_impl.hpp +0 -2
  105. mlpack/include/mlpack/core/data/string_encoding_policies/bag_of_words_encoding_policy.hpp +0 -2
  106. mlpack/include/mlpack/core/data/string_encoding_policies/dictionary_encoding_policy.hpp +0 -2
  107. mlpack/include/mlpack/core/data/string_encoding_policies/policy_traits.hpp +0 -2
  108. mlpack/include/mlpack/core/data/string_encoding_policies/tf_idf_encoding_policy.hpp +0 -2
  109. mlpack/include/mlpack/core/data/text_options.hpp +91 -53
  110. mlpack/include/mlpack/core/data/tokenizers/char_extract.hpp +0 -2
  111. mlpack/include/mlpack/core/data/tokenizers/split_by_any_of.hpp +0 -2
  112. mlpack/include/mlpack/core/distributions/gamma_distribution_impl.hpp +4 -4
  113. mlpack/include/mlpack/core/distributions/laplace_distribution.hpp +9 -9
  114. mlpack/include/mlpack/core/distributions/laplace_distribution_impl.hpp +7 -7
  115. mlpack/include/mlpack/core/hpt/cv_function.hpp +2 -2
  116. mlpack/include/mlpack/core/hpt/cv_function_impl.hpp +2 -2
  117. mlpack/include/mlpack/core/hpt/hpt.hpp +4 -4
  118. mlpack/include/mlpack/core/hpt/hpt_impl.hpp +9 -9
  119. mlpack/include/mlpack/core/math/make_alias.hpp +7 -5
  120. mlpack/include/mlpack/core/math/random.hpp +19 -5
  121. mlpack/include/mlpack/core/math/shuffle_data.hpp +79 -245
  122. mlpack/include/mlpack/core/metrics/non_maximal_suppression_impl.hpp +9 -10
  123. mlpack/include/mlpack/core/stb/bundled/stb_image_resize2.h +291 -239
  124. mlpack/include/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp +7 -7
  125. mlpack/include/mlpack/core/tree/cellbound.hpp +2 -2
  126. mlpack/include/mlpack/core/tree/cosine_tree/cosine_tree_impl.hpp +10 -10
  127. mlpack/include/mlpack/core/tree/octree/octree.hpp +10 -0
  128. mlpack/include/mlpack/core/tree/octree/octree_impl.hpp +14 -4
  129. mlpack/include/mlpack/core/util/arma_traits.hpp +25 -38
  130. mlpack/include/mlpack/core/util/coot_traits.hpp +97 -0
  131. mlpack/include/mlpack/core/util/forward.hpp +0 -2
  132. mlpack/include/mlpack/core/util/param.hpp +4 -4
  133. mlpack/include/mlpack/core/util/params_impl.hpp +2 -2
  134. mlpack/include/mlpack/core/util/using.hpp +29 -2
  135. mlpack/include/mlpack/core/util/version.hpp +5 -3
  136. mlpack/include/mlpack/core/util/version_impl.hpp +3 -6
  137. mlpack/include/mlpack/methods/adaboost/adaboost_classify_main.cpp +1 -1
  138. mlpack/include/mlpack/methods/adaboost/adaboost_main.cpp +3 -3
  139. mlpack/include/mlpack/methods/adaboost/adaboost_train_main.cpp +2 -2
  140. mlpack/include/mlpack/methods/ann/activation_functions/activation_functions.hpp +1 -0
  141. mlpack/include/mlpack/methods/ann/activation_functions/bipolar_sigmoid_function.hpp +6 -4
  142. mlpack/include/mlpack/methods/ann/activation_functions/elish_function.hpp +17 -12
  143. mlpack/include/mlpack/methods/ann/activation_functions/elliot_function.hpp +9 -7
  144. mlpack/include/mlpack/methods/ann/activation_functions/gaussian_function.hpp +7 -6
  145. mlpack/include/mlpack/methods/ann/activation_functions/gelu_exact_function.hpp +73 -0
  146. mlpack/include/mlpack/methods/ann/activation_functions/gelu_function.hpp +27 -16
  147. mlpack/include/mlpack/methods/ann/activation_functions/hard_sigmoid_function.hpp +8 -6
  148. mlpack/include/mlpack/methods/ann/activation_functions/hard_swish_function.hpp +6 -4
  149. mlpack/include/mlpack/methods/ann/activation_functions/hyper_sinh_function.hpp +13 -8
  150. mlpack/include/mlpack/methods/ann/activation_functions/identity_function.hpp +6 -4
  151. mlpack/include/mlpack/methods/ann/activation_functions/inverse_quadratic_function.hpp +8 -6
  152. mlpack/include/mlpack/methods/ann/activation_functions/lisht_function.hpp +7 -5
  153. mlpack/include/mlpack/methods/ann/activation_functions/logistic_function.hpp +14 -12
  154. mlpack/include/mlpack/methods/ann/activation_functions/mish_function.hpp +7 -5
  155. mlpack/include/mlpack/methods/ann/activation_functions/multi_quadratic_function.hpp +6 -4
  156. mlpack/include/mlpack/methods/ann/activation_functions/poisson1_function.hpp +4 -2
  157. mlpack/include/mlpack/methods/ann/activation_functions/quadratic_function.hpp +6 -4
  158. mlpack/include/mlpack/methods/ann/activation_functions/rectifier_function.hpp +10 -10
  159. mlpack/include/mlpack/methods/ann/activation_functions/silu_function.hpp +10 -8
  160. mlpack/include/mlpack/methods/ann/activation_functions/softplus_function.hpp +12 -9
  161. mlpack/include/mlpack/methods/ann/activation_functions/softsign_function.hpp +15 -23
  162. mlpack/include/mlpack/methods/ann/activation_functions/spline_function.hpp +9 -7
  163. mlpack/include/mlpack/methods/ann/activation_functions/swish_function.hpp +11 -9
  164. mlpack/include/mlpack/methods/ann/activation_functions/tanh_exponential_function.hpp +9 -7
  165. mlpack/include/mlpack/methods/ann/activation_functions/tanh_function.hpp +10 -7
  166. mlpack/include/mlpack/methods/ann/ann.hpp +3 -0
  167. mlpack/include/mlpack/methods/ann/convolution_rules/base_convolution.hpp +197 -0
  168. mlpack/include/mlpack/methods/ann/convolution_rules/convolution_rules.hpp +1 -2
  169. mlpack/include/mlpack/methods/ann/convolution_rules/im2col_convolution.hpp +215 -0
  170. mlpack/include/mlpack/methods/ann/convolution_rules/naive_convolution.hpp +109 -154
  171. mlpack/include/mlpack/methods/ann/dag_network.hpp +728 -0
  172. mlpack/include/mlpack/methods/ann/dag_network_impl.hpp +1640 -0
  173. mlpack/include/mlpack/methods/ann/dists/bernoulli_distribution_impl.hpp +1 -1
  174. mlpack/include/mlpack/methods/ann/dists/normal_distribution_impl.hpp +7 -2
  175. mlpack/include/mlpack/methods/ann/ffn.hpp +39 -3
  176. mlpack/include/mlpack/methods/ann/ffn_impl.hpp +14 -32
  177. mlpack/include/mlpack/methods/ann/init_rules/const_init.hpp +4 -4
  178. mlpack/include/mlpack/methods/ann/init_rules/gaussian_init.hpp +6 -2
  179. mlpack/include/mlpack/methods/ann/init_rules/he_init.hpp +4 -2
  180. mlpack/include/mlpack/methods/ann/init_rules/kathirvalavakumar_subavathi_init.hpp +3 -3
  181. mlpack/include/mlpack/methods/ann/init_rules/lecun_normal_init.hpp +4 -2
  182. mlpack/include/mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp +2 -2
  183. mlpack/include/mlpack/methods/ann/init_rules/oivs_init.hpp +2 -2
  184. mlpack/include/mlpack/methods/ann/init_rules/orthogonal_init.hpp +2 -2
  185. mlpack/include/mlpack/methods/ann/init_rules/random_init.hpp +8 -4
  186. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling.hpp +21 -23
  187. mlpack/include/mlpack/methods/ann/layer/adaptive_max_pooling_impl.hpp +15 -15
  188. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling.hpp +21 -23
  189. mlpack/include/mlpack/methods/ann/layer/adaptive_mean_pooling_impl.hpp +16 -16
  190. mlpack/include/mlpack/methods/ann/layer/add.hpp +18 -18
  191. mlpack/include/mlpack/methods/ann/layer/add_impl.hpp +13 -13
  192. mlpack/include/mlpack/methods/ann/layer/add_merge.hpp +19 -18
  193. mlpack/include/mlpack/methods/ann/layer/add_merge_impl.hpp +13 -13
  194. mlpack/include/mlpack/methods/ann/layer/alpha_dropout.hpp +17 -16
  195. mlpack/include/mlpack/methods/ann/layer/alpha_dropout_impl.hpp +14 -13
  196. mlpack/include/mlpack/methods/ann/layer/base_layer.hpp +28 -51
  197. mlpack/include/mlpack/methods/ann/layer/batch_norm.hpp +16 -18
  198. mlpack/include/mlpack/methods/ann/layer/batch_norm_impl.hpp +55 -54
  199. mlpack/include/mlpack/methods/ann/layer/c_relu.hpp +18 -20
  200. mlpack/include/mlpack/methods/ann/layer/c_relu_impl.hpp +20 -25
  201. mlpack/include/mlpack/methods/ann/layer/celu.hpp +14 -19
  202. mlpack/include/mlpack/methods/ann/layer/celu_impl.hpp +25 -34
  203. mlpack/include/mlpack/methods/ann/layer/concat.hpp +18 -18
  204. mlpack/include/mlpack/methods/ann/layer/concat_impl.hpp +13 -13
  205. mlpack/include/mlpack/methods/ann/layer/concatenate.hpp +18 -18
  206. mlpack/include/mlpack/methods/ann/layer/concatenate_impl.hpp +14 -14
  207. mlpack/include/mlpack/methods/ann/layer/convolution.hpp +42 -47
  208. mlpack/include/mlpack/methods/ann/layer/convolution_impl.hpp +170 -159
  209. mlpack/include/mlpack/methods/ann/layer/dropconnect.hpp +18 -20
  210. mlpack/include/mlpack/methods/ann/layer/dropconnect_impl.hpp +20 -20
  211. mlpack/include/mlpack/methods/ann/layer/dropout.hpp +17 -19
  212. mlpack/include/mlpack/methods/ann/layer/dropout_impl.hpp +14 -21
  213. mlpack/include/mlpack/methods/ann/layer/elu.hpp +23 -25
  214. mlpack/include/mlpack/methods/ann/layer/elu_impl.hpp +20 -27
  215. mlpack/include/mlpack/methods/ann/layer/embedding.hpp +160 -0
  216. mlpack/include/mlpack/methods/ann/layer/embedding_impl.hpp +189 -0
  217. mlpack/include/mlpack/methods/ann/layer/flexible_relu.hpp +17 -19
  218. mlpack/include/mlpack/methods/ann/layer/flexible_relu_impl.hpp +20 -20
  219. mlpack/include/mlpack/methods/ann/layer/ftswish.hpp +17 -18
  220. mlpack/include/mlpack/methods/ann/layer/ftswish_impl.hpp +17 -35
  221. mlpack/include/mlpack/methods/ann/layer/grouped_convolution.hpp +27 -33
  222. mlpack/include/mlpack/methods/ann/layer/grouped_convolution_impl.hpp +170 -163
  223. mlpack/include/mlpack/methods/ann/layer/gru.hpp +195 -0
  224. mlpack/include/mlpack/methods/ann/layer/gru_impl.hpp +325 -0
  225. mlpack/include/mlpack/methods/ann/layer/hard_tanh.hpp +13 -15
  226. mlpack/include/mlpack/methods/ann/layer/hard_tanh_impl.hpp +12 -12
  227. mlpack/include/mlpack/methods/ann/layer/identity.hpp +19 -20
  228. mlpack/include/mlpack/methods/ann/layer/identity_impl.hpp +12 -12
  229. mlpack/include/mlpack/methods/ann/layer/layer.hpp +37 -33
  230. mlpack/include/mlpack/methods/ann/layer/layer_norm.hpp +11 -13
  231. mlpack/include/mlpack/methods/ann/layer/layer_norm_impl.hpp +16 -16
  232. mlpack/include/mlpack/methods/ann/layer/layer_types.hpp +4 -1
  233. mlpack/include/mlpack/methods/ann/layer/leaky_relu.hpp +20 -23
  234. mlpack/include/mlpack/methods/ann/layer/leaky_relu_impl.hpp +12 -13
  235. mlpack/include/mlpack/methods/ann/layer/linear.hpp +16 -18
  236. mlpack/include/mlpack/methods/ann/layer/linear3d.hpp +18 -18
  237. mlpack/include/mlpack/methods/ann/layer/linear3d_impl.hpp +18 -18
  238. mlpack/include/mlpack/methods/ann/layer/linear_impl.hpp +15 -15
  239. mlpack/include/mlpack/methods/ann/layer/linear_no_bias.hpp +15 -17
  240. mlpack/include/mlpack/methods/ann/layer/linear_no_bias_impl.hpp +20 -20
  241. mlpack/include/mlpack/methods/ann/layer/linear_recurrent.hpp +25 -14
  242. mlpack/include/mlpack/methods/ann/layer/linear_recurrent_impl.hpp +60 -31
  243. mlpack/include/mlpack/methods/ann/layer/log_softmax.hpp +17 -36
  244. mlpack/include/mlpack/methods/ann/layer/log_softmax_impl.hpp +58 -74
  245. mlpack/include/mlpack/methods/ann/layer/lstm.hpp +26 -29
  246. mlpack/include/mlpack/methods/ann/layer/lstm_impl.hpp +128 -124
  247. mlpack/include/mlpack/methods/ann/layer/max_pooling.hpp +19 -19
  248. mlpack/include/mlpack/methods/ann/layer/max_pooling_impl.hpp +14 -14
  249. mlpack/include/mlpack/methods/ann/layer/mean_pooling.hpp +24 -24
  250. mlpack/include/mlpack/methods/ann/layer/mean_pooling_impl.hpp +16 -16
  251. mlpack/include/mlpack/methods/ann/layer/multi_layer.hpp +36 -6
  252. mlpack/include/mlpack/methods/ann/layer/multi_layer_impl.hpp +6 -2
  253. mlpack/include/mlpack/methods/ann/layer/multihead_attention.hpp +26 -22
  254. mlpack/include/mlpack/methods/ann/layer/multihead_attention_impl.hpp +161 -64
  255. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation.hpp +28 -25
  256. mlpack/include/mlpack/methods/ann/layer/nearest_interpolation_impl.hpp +36 -37
  257. mlpack/include/mlpack/methods/ann/layer/noisylinear.hpp +39 -42
  258. mlpack/include/mlpack/methods/ann/layer/noisylinear_impl.hpp +18 -18
  259. mlpack/include/mlpack/methods/ann/layer/padding.hpp +21 -17
  260. mlpack/include/mlpack/methods/ann/layer/padding_impl.hpp +33 -19
  261. mlpack/include/mlpack/methods/ann/layer/parametric_relu.hpp +26 -28
  262. mlpack/include/mlpack/methods/ann/layer/parametric_relu_impl.hpp +18 -18
  263. mlpack/include/mlpack/methods/ann/layer/radial_basis_function.hpp +41 -28
  264. mlpack/include/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +42 -17
  265. mlpack/include/mlpack/methods/ann/layer/recurrent_layer.hpp +13 -0
  266. mlpack/include/mlpack/methods/ann/layer/relu6.hpp +19 -21
  267. mlpack/include/mlpack/methods/ann/layer/relu6_impl.hpp +14 -14
  268. mlpack/include/mlpack/methods/ann/layer/repeat.hpp +24 -25
  269. mlpack/include/mlpack/methods/ann/layer/repeat_impl.hpp +10 -10
  270. mlpack/include/mlpack/methods/ann/layer/serialization.hpp +64 -54
  271. mlpack/include/mlpack/methods/ann/layer/softmax.hpp +20 -20
  272. mlpack/include/mlpack/methods/ann/layer/softmax_impl.hpp +10 -10
  273. mlpack/include/mlpack/methods/ann/layer/softmin.hpp +20 -23
  274. mlpack/include/mlpack/methods/ann/layer/softmin_impl.hpp +10 -10
  275. mlpack/include/mlpack/methods/ann/layer/sum_reduce.hpp +103 -0
  276. mlpack/include/mlpack/methods/ann/layer/sum_reduce_impl.hpp +143 -0
  277. mlpack/include/mlpack/methods/ann/loss_functions/cosine_embedding_loss_impl.hpp +3 -3
  278. mlpack/include/mlpack/methods/ann/loss_functions/mean_bias_error_impl.hpp +1 -1
  279. mlpack/include/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp +1 -1
  280. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood.hpp +2 -2
  281. mlpack/include/mlpack/methods/ann/loss_functions/negative_log_likelihood_impl.hpp +29 -15
  282. mlpack/include/mlpack/methods/ann/loss_functions/poisson_nll_loss_impl.hpp +1 -1
  283. mlpack/include/mlpack/methods/ann/models/models.hpp +17 -0
  284. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer.hpp +151 -0
  285. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_layer_impl.hpp +265 -0
  286. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny.hpp +187 -0
  287. mlpack/include/mlpack/methods/ann/models/yolov3/yolov3_tiny_impl.hpp +206 -0
  288. mlpack/include/mlpack/methods/ann/regularizer/orthogonal_regularizer_impl.hpp +5 -3
  289. mlpack/include/mlpack/methods/ann/rnn.hpp +136 -42
  290. mlpack/include/mlpack/methods/ann/rnn_impl.hpp +230 -38
  291. mlpack/include/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +1 -1
  292. mlpack/include/mlpack/methods/bayesian_linear_regression/bayesian_linear_regression_main.cpp +1 -1
  293. mlpack/include/mlpack/methods/bias_svd/bias_svd_function_impl.hpp +1 -1
  294. mlpack/include/mlpack/methods/cf/cf_model.hpp +1 -1
  295. mlpack/include/mlpack/methods/decision_tree/decision_tree.hpp +6 -6
  296. mlpack/include/mlpack/methods/decision_tree/decision_tree_impl.hpp +12 -12
  297. mlpack/include/mlpack/methods/decision_tree/decision_tree_main.cpp +0 -1
  298. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor.hpp +6 -6
  299. mlpack/include/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp +12 -12
  300. mlpack/include/mlpack/methods/det/det_main.cpp +1 -1
  301. mlpack/include/mlpack/methods/hmm/hmm_train_main.cpp +4 -4
  302. mlpack/include/mlpack/methods/hmm/hmm_util_impl.hpp +2 -2
  303. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp +6 -6
  304. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp +31 -31
  305. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp +1 -2
  306. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp +2 -2
  307. mlpack/include/mlpack/methods/hoeffding_trees/hoeffding_tree_model_impl.hpp +1 -1
  308. mlpack/include/mlpack/methods/kde/kde_rules_impl.hpp +6 -6
  309. mlpack/include/mlpack/methods/lars/lars_impl.hpp +3 -3
  310. mlpack/include/mlpack/methods/linear_svm/linear_svm_function_impl.hpp +4 -4
  311. mlpack/include/mlpack/methods/linear_svm/linear_svm_main.cpp +3 -3
  312. mlpack/include/mlpack/methods/lmnn/lmnn_main.cpp +1 -1
  313. mlpack/include/mlpack/methods/lsh/lsh_main.cpp +1 -1
  314. mlpack/include/mlpack/methods/matrix_completion/matrix_completion_impl.hpp +1 -1
  315. mlpack/include/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp +1 -1
  316. mlpack/include/mlpack/methods/naive_bayes/nbc_main.cpp +3 -3
  317. mlpack/include/mlpack/methods/nca/nca_main.cpp +1 -1
  318. mlpack/include/mlpack/methods/neighbor_search/kfn_main.cpp +8 -8
  319. mlpack/include/mlpack/methods/neighbor_search/knn_main.cpp +8 -8
  320. mlpack/include/mlpack/methods/neighbor_search/neighbor_search.hpp +154 -34
  321. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +190 -51
  322. mlpack/include/mlpack/methods/neighbor_search/neighbor_search_stat.hpp +10 -0
  323. mlpack/include/mlpack/methods/neighbor_search/ns_model.hpp +15 -15
  324. mlpack/include/mlpack/methods/neighbor_search/ns_model_impl.hpp +55 -46
  325. mlpack/include/mlpack/methods/neighbor_search/typedef.hpp +42 -2
  326. mlpack/include/mlpack/methods/pca/pca_impl.hpp +2 -2
  327. mlpack/include/mlpack/methods/perceptron/perceptron.hpp +2 -2
  328. mlpack/include/mlpack/methods/perceptron/perceptron_impl.hpp +1 -1
  329. mlpack/include/mlpack/methods/perceptron/perceptron_main.cpp +2 -2
  330. mlpack/include/mlpack/methods/preprocess/image_converter_main.cpp +2 -3
  331. mlpack/include/mlpack/methods/preprocess/preprocess_binarize_main.cpp +2 -2
  332. mlpack/include/mlpack/methods/preprocess/preprocess_describe_main.cpp +0 -1
  333. mlpack/include/mlpack/methods/preprocess/preprocess_imputer_main.cpp +50 -129
  334. mlpack/include/mlpack/methods/preprocess/preprocess_one_hot_encoding_main.cpp +6 -6
  335. mlpack/include/mlpack/methods/preprocess/preprocess_scale_main.cpp +2 -3
  336. mlpack/include/mlpack/methods/preprocess/preprocess_split_main.cpp +3 -4
  337. mlpack/include/mlpack/methods/preprocess/scaling_model.hpp +6 -8
  338. mlpack/include/mlpack/methods/preprocess/scaling_model_impl.hpp +18 -20
  339. mlpack/include/mlpack/methods/random_forest/random_forest.hpp +5 -5
  340. mlpack/include/mlpack/methods/random_forest/random_forest_impl.hpp +9 -9
  341. mlpack/include/mlpack/methods/range_search/range_search_main.cpp +1 -1
  342. mlpack/include/mlpack/methods/rann/krann_main.cpp +1 -1
  343. mlpack/include/mlpack/methods/regularized_svd/regularized_svd_function_impl.hpp +1 -1
  344. mlpack/include/mlpack/methods/reinforcement_learning/async_learning_impl.hpp +8 -8
  345. mlpack/include/mlpack/methods/reinforcement_learning/ddpg_impl.hpp +16 -16
  346. mlpack/include/mlpack/methods/reinforcement_learning/environment/acrobot.hpp +4 -4
  347. mlpack/include/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +3 -3
  348. mlpack/include/mlpack/methods/reinforcement_learning/environment/cont_double_pole_cart.hpp +6 -5
  349. mlpack/include/mlpack/methods/reinforcement_learning/environment/pendulum.hpp +6 -5
  350. mlpack/include/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp +2 -2
  351. mlpack/include/mlpack/methods/reinforcement_learning/q_learning_impl.hpp +10 -10
  352. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/categorical_dqn.hpp +21 -17
  353. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/dueling_dqn.hpp +69 -77
  354. mlpack/include/mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp +9 -9
  355. mlpack/include/mlpack/methods/reinforcement_learning/sac_impl.hpp +14 -14
  356. mlpack/include/mlpack/methods/reinforcement_learning/td3_impl.hpp +14 -14
  357. mlpack/include/mlpack/methods/softmax_regression/softmax_regression_function_impl.hpp +1 -1
  358. mlpack/include/mlpack/methods/svdplusplus/svdplusplus_function_impl.hpp +1 -1
  359. mlpack/include/mlpack/namespace_compat.hpp +1 -0
  360. mlpack/include/mlpack/prereqs.hpp +1 -0
  361. mlpack/kde.cp313-win_amd64.pyd +0 -0
  362. mlpack/kernel_pca.cp313-win_amd64.pyd +0 -0
  363. mlpack/kfn.cp313-win_amd64.pyd +0 -0
  364. mlpack/kmeans.cp313-win_amd64.pyd +0 -0
  365. mlpack/knn.cp313-win_amd64.pyd +0 -0
  366. mlpack/krann.cp313-win_amd64.pyd +0 -0
  367. mlpack/lars.cp313-win_amd64.pyd +0 -0
  368. mlpack/linear_regression_predict.cp313-win_amd64.pyd +0 -0
  369. mlpack/linear_regression_train.cp313-win_amd64.pyd +0 -0
  370. mlpack/linear_svm.cp313-win_amd64.pyd +0 -0
  371. mlpack/lmnn.cp313-win_amd64.pyd +0 -0
  372. mlpack/local_coordinate_coding.cp313-win_amd64.pyd +0 -0
  373. mlpack/logistic_regression.cp313-win_amd64.pyd +0 -0
  374. mlpack/lsh.cp313-win_amd64.pyd +0 -0
  375. mlpack/mean_shift.cp313-win_amd64.pyd +0 -0
  376. mlpack/nbc.cp313-win_amd64.pyd +0 -0
  377. mlpack/nca.cp313-win_amd64.pyd +0 -0
  378. mlpack/nmf.cp313-win_amd64.pyd +0 -0
  379. mlpack/pca.cp313-win_amd64.pyd +0 -0
  380. mlpack/perceptron.cp313-win_amd64.pyd +0 -0
  381. mlpack/preprocess_binarize.cp313-win_amd64.pyd +0 -0
  382. mlpack/preprocess_describe.cp313-win_amd64.pyd +0 -0
  383. mlpack/preprocess_one_hot_encoding.cp313-win_amd64.pyd +0 -0
  384. mlpack/preprocess_scale.cp313-win_amd64.pyd +0 -0
  385. mlpack/preprocess_split.cp313-win_amd64.pyd +0 -0
  386. mlpack/radical.cp313-win_amd64.pyd +0 -0
  387. mlpack/random_forest.cp313-win_amd64.pyd +0 -0
  388. mlpack/softmax_regression.cp313-win_amd64.pyd +0 -0
  389. mlpack/sparse_coding.cp313-win_amd64.pyd +0 -0
  390. mlpack-4.7.0.dist-info/DELVEWHEEL +2 -0
  391. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/METADATA +2 -2
  392. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/RECORD +396 -377
  393. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/WHEEL +1 -1
  394. mlpack/include/mlpack/core/data/format.hpp +0 -31
  395. mlpack/include/mlpack/core/data/image_info.hpp +0 -102
  396. mlpack/include/mlpack/core/data/image_info_impl.hpp +0 -84
  397. mlpack/include/mlpack/core/data/load_image_impl.hpp +0 -171
  398. mlpack/include/mlpack/core/data/load_model_impl.hpp +0 -115
  399. mlpack/include/mlpack/core/data/load_vec_impl.hpp +0 -154
  400. mlpack/include/mlpack/core/data/map_policies/missing_policy.hpp +0 -148
  401. mlpack/include/mlpack/core/data/save_image_impl.hpp +0 -170
  402. mlpack/include/mlpack/core/data/types.hpp +0 -61
  403. mlpack/include/mlpack/core/data/types_impl.hpp +0 -83
  404. mlpack/include/mlpack/core/data/utilities.hpp +0 -158
  405. mlpack/include/mlpack/core/util/gitversion.hpp +0 -1
  406. mlpack/include/mlpack/methods/ann/convolution_rules/fft_convolution.hpp +0 -213
  407. mlpack/include/mlpack/methods/ann/convolution_rules/svd_convolution.hpp +0 -201
  408. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru.hpp +0 -226
  409. mlpack/include/mlpack/methods/ann/layer/not_adapted/gru_impl.hpp +0 -367
  410. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup.hpp +0 -139
  411. mlpack/include/mlpack/methods/ann/layer/not_adapted/lookup_impl.hpp +0 -98
  412. mlpack-4.6.2.dist-info/DELVEWHEEL +0 -2
  413. {mlpack-4.6.2.dist-info → mlpack-4.7.0.dist-info}/top_level.txt +0 -0
  414. /mlpack.libs/{libopenblas-9e6d070f769e6580e8c55c0cf83b80a5.dll → libopenblas-c7f521b507686ddc25bee7538a80c374.dll} +0 -0
  415. /mlpack.libs/{msvcp140-50208655e42969b9a5ab8a4e0186bbb9.dll → msvcp140-a4c2229bdc2a2a630acdc095b4d86008.dll} +0 -0
@@ -32,14 +32,16 @@ template<typename SortPolicy,
32
32
  NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
33
33
  DualTreeTraversalType, SingleTreeTraversalType>::
34
34
  NeighborSearch(MatType referenceSetIn,
35
- const NeighborSearchMode mode,
35
+ const NeighborSearchStrategy strategy,
36
36
  const double epsilon,
37
37
  const DistanceType distance) :
38
- referenceTree(mode == NAIVE_MODE ? NULL :
38
+ referenceTree(strategy == NAIVE ? NULL :
39
39
  BuildTree<Tree>(std::move(referenceSetIn), oldFromNewReferences)),
40
- referenceSet(mode == NAIVE_MODE ? new MatType(std::move(referenceSetIn)) :
41
- &referenceTree->Dataset()),
42
- searchMode(mode),
40
+ referenceSet(strategy == NAIVE ?
41
+ new MatType(std::move(referenceSetIn)) : &referenceTree->Dataset()),
42
+ searchMode(StrategyToMode(strategy)),
43
+ searchModeMod(false),
44
+ searchStrategy(strategy),
43
45
  epsilon(epsilon),
44
46
  distance(distance),
45
47
  baseCases(0),
@@ -51,6 +53,34 @@ NeighborSearch(MatType referenceSetIn,
51
53
  }
52
54
 
53
55
  // Construct the object.
56
+ template<typename SortPolicy,
57
+ typename DistanceType,
58
+ typename MatType,
59
+ template<typename TreeDistanceType,
60
+ typename TreeStatType,
61
+ typename TreeMatType> class TreeType,
62
+ template<typename> class DualTreeTraversalType,
63
+ template<typename> class SingleTreeTraversalType>
64
+ NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
65
+ DualTreeTraversalType, SingleTreeTraversalType>::
66
+ NeighborSearch(Tree referenceTree,
67
+ const NeighborSearchStrategy strategy,
68
+ const double epsilon) :
69
+ referenceTree(new Tree(std::move(referenceTree))),
70
+ referenceSet(&this->referenceTree->Dataset()),
71
+ searchMode(StrategyToMode(strategy)),
72
+ searchModeMod(false),
73
+ searchStrategy(strategy),
74
+ epsilon(epsilon),
75
+ distance(this->referenceTree->Distance()),
76
+ baseCases(0),
77
+ scores(0),
78
+ treeNeedsReset(false)
79
+ {
80
+ if (epsilon < 0)
81
+ throw std::invalid_argument("epsilon must be non-negative");
82
+ }
83
+
54
84
  template<typename SortPolicy,
55
85
  typename DistanceType,
56
86
  typename MatType,
@@ -68,6 +98,8 @@ NeighborSearch(Tree referenceTree,
68
98
  referenceTree(new Tree(std::move(referenceTree))),
69
99
  referenceSet(&this->referenceTree->Dataset()),
70
100
  searchMode(mode),
101
+ searchModeMod(false),
102
+ searchStrategy(ModeToStrategy(mode)),
71
103
  epsilon(epsilon),
72
104
  distance(distance),
73
105
  baseCases(0),
@@ -89,12 +121,14 @@ template<typename SortPolicy,
89
121
  template<typename> class SingleTreeTraversalType>
90
122
  NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
91
123
  DualTreeTraversalType, SingleTreeTraversalType>::
92
- NeighborSearch(const NeighborSearchMode mode,
124
+ NeighborSearch(const NeighborSearchStrategy strategy,
93
125
  const double epsilon,
94
126
  const DistanceType distance) :
95
127
  referenceTree(NULL),
96
- referenceSet(mode == NAIVE_MODE ? new MatType() : NULL), // Empty matrix.
97
- searchMode(mode),
128
+ referenceSet(strategy == NAIVE ? new MatType() : NULL),
129
+ searchMode(StrategyToMode(strategy)),
130
+ searchModeMod(false),
131
+ searchStrategy(strategy),
98
132
  epsilon(epsilon),
99
133
  distance(distance),
100
134
  baseCases(0),
@@ -105,7 +139,7 @@ NeighborSearch(const NeighborSearchMode mode,
105
139
  throw std::invalid_argument("epsilon must be non-negative");
106
140
 
107
141
  // Build the tree on the empty dataset, if necessary.
108
- if (mode != NAIVE_MODE)
142
+ if (strategy != NAIVE)
109
143
  {
110
144
  referenceTree = BuildTree<Tree>(std::move(MatType()),
111
145
  oldFromNewReferences);
@@ -130,6 +164,8 @@ NeighborSearch(const NeighborSearch& other) :
130
164
  referenceSet(other.referenceTree ? &referenceTree->Dataset() :
131
165
  new MatType(*other.referenceSet)),
132
166
  searchMode(other.searchMode),
167
+ searchModeMod(other.searchModeMod),
168
+ searchStrategy(other.searchStrategy),
133
169
  epsilon(other.epsilon),
134
170
  distance(other.distance),
135
171
  baseCases(other.baseCases),
@@ -155,6 +191,8 @@ NeighborSearch(NeighborSearch&& other) :
155
191
  referenceTree(other.referenceTree),
156
192
  referenceSet(other.referenceSet),
157
193
  searchMode(other.searchMode),
194
+ searchModeMod(other.searchModeMod),
195
+ searchStrategy(other.searchStrategy),
158
196
  epsilon(other.epsilon),
159
197
  distance(std::move(other.distance)),
160
198
  baseCases(other.baseCases),
@@ -165,7 +203,9 @@ NeighborSearch(NeighborSearch&& other) :
165
203
  other.referenceTree = BuildTree<Tree>(std::move(MatType()),
166
204
  other.oldFromNewReferences);
167
205
  other.referenceSet = &other.referenceTree->Dataset();
168
- other.searchMode = DUAL_TREE_MODE,
206
+ other.searchMode = DUAL_TREE_MODE;
207
+ other.searchModeMod = false;
208
+ other.searchStrategy = DUAL_TREE;
169
209
  other.epsilon = 0.0;
170
210
  other.baseCases = 0;
171
211
  other.scores = 0;
@@ -208,6 +248,8 @@ NeighborSearch<SortPolicy,
208
248
  referenceSet = other.referenceTree ? &referenceTree->Dataset() :
209
249
  new MatType(*other.referenceSet);
210
250
  searchMode = other.searchMode;
251
+ searchModeMod = other.searchModeMod;
252
+ searchStrategy = other.searchStrategy;
211
253
  epsilon = other.epsilon;
212
254
  distance = other.distance;
213
255
  baseCases = other.baseCases;
@@ -250,6 +292,8 @@ NeighborSearch<SortPolicy,
250
292
  referenceTree = other.referenceTree;
251
293
  referenceSet = other.referenceSet;
252
294
  searchMode = other.searchMode;
295
+ searchModeMod = other.searchModeMod;
296
+ searchStrategy = other.searchStrategy;
253
297
  epsilon = other.epsilon;
254
298
  distance = other.distance;
255
299
  baseCases = other.baseCases;
@@ -263,7 +307,9 @@ NeighborSearch<SortPolicy,
263
307
  other.referenceTree = BuildTree<Tree>(std::move(MatType()),
264
308
  other.oldFromNewReferences);
265
309
  other.referenceSet = &other.referenceTree->Dataset();
266
- other.searchMode = DUAL_TREE_MODE,
310
+ other.searchMode = DUAL_TREE_MODE;
311
+ other.searchModeMod = false;
312
+ other.searchStrategy = DUAL_TREE;
267
313
  other.epsilon = 0.0;
268
314
  other.baseCases = 0;
269
315
  other.scores = 0;
@@ -300,6 +346,17 @@ void NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
300
346
  DualTreeTraversalType, SingleTreeTraversalType>::
301
347
  Train(MatType referenceSetIn)
302
348
  {
349
+ // For reverse compatibility; can be removed in mlpack 5.0.0.
350
+ if (searchModeMod)
351
+ {
352
+ searchStrategy = ModeToStrategy(searchMode);
353
+ searchModeMod = false;
354
+ }
355
+ else
356
+ {
357
+ searchMode = StrategyToMode(searchStrategy);
358
+ }
359
+
303
360
  // Clean up the old tree, if we built one.
304
361
  if (referenceTree)
305
362
  {
@@ -313,7 +370,7 @@ Train(MatType referenceSetIn)
313
370
  }
314
371
 
315
372
  // We may need to rebuild the tree.
316
- if (searchMode != NAIVE_MODE)
373
+ if (searchStrategy != NAIVE)
317
374
  {
318
375
  referenceTree = BuildTree<Tree>(std::move(referenceSetIn),
319
376
  oldFromNewReferences);
@@ -336,7 +393,18 @@ template<typename SortPolicy,
336
393
  void NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
337
394
  DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree referenceTree)
338
395
  {
339
- if (searchMode == NAIVE_MODE)
396
+ // For reverse compatibility; can be removed in mlpack 5.0.0.
397
+ if (searchModeMod)
398
+ {
399
+ searchStrategy = ModeToStrategy(searchMode);
400
+ searchModeMod = false;
401
+ }
402
+ else
403
+ {
404
+ searchMode = StrategyToMode(searchStrategy);
405
+ }
406
+
407
+ if (searchStrategy == NAIVE)
340
408
  throw std::invalid_argument("cannot train on given reference tree when "
341
409
  "naive search (without trees) is desired");
342
410
 
@@ -374,6 +442,17 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
374
442
  arma::Mat<IndexType>& neighbors,
375
443
  arma::Mat<ElemType>& distances)
376
444
  {
445
+ // For reverse compatibility; can be removed in mlpack 5.0.0.
446
+ if (searchModeMod)
447
+ {
448
+ searchStrategy = ModeToStrategy(searchMode);
449
+ searchModeMod = false;
450
+ }
451
+ else
452
+ {
453
+ searchMode = StrategyToMode(searchStrategy);
454
+ }
455
+
377
456
  if (k > referenceSet->n_cols)
378
457
  {
379
458
  std::stringstream ss;
@@ -398,7 +477,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
398
477
  // Mapping is only necessary if the tree rearranges points.
399
478
  if (TreeTraits<Tree>::RearrangesDataset)
400
479
  {
401
- if (searchMode == DUAL_TREE_MODE)
480
+ if (searchStrategy == DUAL_TREE)
402
481
  {
403
482
  distancePtr = new arma::Mat<ElemType>; // Query indices need to be mapped.
404
483
  neighborPtr = new arma::Mat<IndexType>;
@@ -413,9 +492,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
413
492
 
414
493
  using RuleType = NeighborSearchRules<SortPolicy, DistanceType, Tree>;
415
494
 
416
- switch (searchMode)
495
+ switch (searchStrategy)
417
496
  {
418
- case NAIVE_MODE:
497
+ case NAIVE:
419
498
  {
420
499
  // Create the helper object for the tree traversal.
421
500
  RuleType rules(*referenceSet, querySet, k, distance, epsilon);
@@ -430,7 +509,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
430
509
  rules.GetResults(*neighborPtr, *distancePtr);
431
510
  break;
432
511
  }
433
- case SINGLE_TREE_MODE:
512
+ case SINGLE_TREE:
434
513
  {
435
514
  // Create the helper object for the tree traversal.
436
515
  RuleType rules(*referenceSet, querySet, k, distance, epsilon);
@@ -453,7 +532,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
453
532
  rules.GetResults(*neighborPtr, *distancePtr);
454
533
  break;
455
534
  }
456
- case DUAL_TREE_MODE:
535
+ case DUAL_TREE:
457
536
  {
458
537
  // Build the query tree.
459
538
  Tree* queryTree = BuildTree<Tree>(querySet, oldFromNewQueries);
@@ -479,7 +558,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
479
558
  delete queryTree;
480
559
  break;
481
560
  }
482
- case GREEDY_SINGLE_TREE_MODE:
561
+ case GREEDY_SINGLE_TREE:
483
562
  {
484
563
  // Create the helper object for the tree traversal.
485
564
  RuleType rules(*referenceSet, querySet, k, distance);
@@ -507,7 +586,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
507
586
  // Map points back to original indices, if necessary.
508
587
  if (TreeTraits<Tree>::RearrangesDataset)
509
588
  {
510
- if (searchMode == DUAL_TREE_MODE && !oldFromNewReferences.empty())
589
+ if (searchStrategy == DUAL_TREE && !oldFromNewReferences.empty())
511
590
  {
512
591
  // We must map both query and reference indices.
513
592
  neighbors.set_size(k, querySet.n_cols);
@@ -530,7 +609,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
530
609
  delete neighborPtr;
531
610
  delete distancePtr;
532
611
  }
533
- else if (searchMode == DUAL_TREE_MODE)
612
+ else if (searchStrategy == DUAL_TREE)
534
613
  {
535
614
  // We must map query indices only.
536
615
  neighbors.set_size(k, querySet.n_cols);
@@ -581,6 +660,17 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
581
660
  arma::Mat<ElemType>& distances,
582
661
  bool sameSet)
583
662
  {
663
+ // For reverse compatibility; can be removed in mlpack 5.0.0.
664
+ if (searchModeMod)
665
+ {
666
+ searchStrategy = ModeToStrategy(searchMode);
667
+ searchModeMod = false;
668
+ }
669
+ else
670
+ {
671
+ searchMode = StrategyToMode(searchStrategy);
672
+ }
673
+
584
674
  if (k > referenceSet->n_cols)
585
675
  {
586
676
  std::stringstream ss;
@@ -590,9 +680,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
590
680
  }
591
681
 
592
682
  // Make sure we are in dual-tree mode.
593
- if (searchMode != DUAL_TREE_MODE)
594
- throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
595
- "query tree when naive or singleMode are set to true");
683
+ if (searchStrategy != DUAL_TREE)
684
+ throw std::invalid_argument("Cannot call NeighborSearch::Search() with a "
685
+ "query tree when search strategy is not DUAL_TREE!");
596
686
 
597
687
  baseCases = 0;
598
688
  scores = 0;
@@ -659,6 +749,17 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
659
749
  arma::Mat<IndexType>& neighbors,
660
750
  arma::Mat<ElemType>& distances)
661
751
  {
752
+ // For reverse compatibility; can be removed in mlpack 5.0.0.
753
+ if (searchModeMod)
754
+ {
755
+ searchStrategy = ModeToStrategy(searchMode);
756
+ searchModeMod = false;
757
+ }
758
+ else
759
+ {
760
+ searchMode = StrategyToMode(searchStrategy);
761
+ }
762
+
662
763
  if (k > referenceSet->n_cols)
663
764
  {
664
765
  std::stringstream ss;
@@ -697,9 +798,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
697
798
  RuleType rules(*referenceSet, *referenceSet, k, distance, epsilon,
698
799
  true /* don't return the same point as nearest neighbor */);
699
800
 
700
- switch (searchMode)
801
+ switch (searchStrategy)
701
802
  {
702
- case NAIVE_MODE:
803
+ case NAIVE:
703
804
  {
704
805
  // The naive brute-force solution.
705
806
  for (size_t i = 0; i < referenceSet->n_cols; ++i)
@@ -709,7 +810,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
709
810
  baseCases += referenceSet->n_cols * referenceSet->n_cols;
710
811
  break;
711
812
  }
712
- case SINGLE_TREE_MODE:
813
+ case SINGLE_TREE:
713
814
  {
714
815
  // Create the traverser.
715
816
  SingleTreeTraversalType<RuleType> traverser(rules);
@@ -727,27 +828,12 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
727
828
  << std::endl;
728
829
  break;
729
830
  }
730
- case DUAL_TREE_MODE:
831
+ case DUAL_TREE:
731
832
  {
732
833
  // The dual-tree monochromatic search case may require resetting the
733
834
  // bounds in the tree.
734
835
  if (treeNeedsReset)
735
- {
736
- std::stack<Tree*> nodes;
737
- nodes.push(referenceTree);
738
- while (!nodes.empty())
739
- {
740
- Tree* node = nodes.top();
741
- nodes.pop();
742
-
743
- // Reset bounds of this node.
744
- node->Stat().Reset();
745
-
746
- // Then add the children.
747
- for (size_t i = 0; i < node->NumChildren(); ++i)
748
- nodes.push(&node->Child(i));
749
- }
750
- }
836
+ ResetTree(*referenceTree);
751
837
 
752
838
  // Create the traverser.
753
839
  DualTreeTraversalType<RuleType> traverser(rules);
@@ -762,8 +848,6 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
762
848
  else
763
849
  {
764
850
  traverser.Traverse(*referenceTree, *referenceTree);
765
- // Next time we perform this search, we'll need to reset the tree.
766
- treeNeedsReset = true;
767
851
  }
768
852
 
769
853
  scores += rules.Scores();
@@ -778,7 +862,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
778
862
  treeNeedsReset = true;
779
863
  break;
780
864
  }
781
- case GREEDY_SINGLE_TREE_MODE:
865
+ case GREEDY_SINGLE_TREE:
782
866
  {
783
867
  // Create the traverser.
784
868
  GreedySingleTreeTraverser<Tree, RuleType> traverser(rules);
@@ -855,7 +939,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::EffectiveError(
855
939
  }
856
940
  }
857
941
 
858
- if (numCases)
942
+ if (numCases > 0)
859
943
  effectiveError /= numCases;
860
944
 
861
945
  return effectiveError;
@@ -893,6 +977,35 @@ DualTreeTraversalType, SingleTreeTraversalType>::Recall(
893
977
  return ((double) found) / realNeighbors.n_elem;
894
978
  }
895
979
 
980
+ template<typename SortPolicy,
981
+ typename DistanceType,
982
+ typename MatType,
983
+ template<typename TreeDistanceType,
984
+ typename TreeStatType,
985
+ typename TreeMatType> class TreeType,
986
+ template<typename> class DualTreeTraversalType,
987
+ template<typename> class SingleTreeTraversalType>
988
+ void NeighborSearch<
989
+ SortPolicy, DistanceType, MatType, TreeType,
990
+ DualTreeTraversalType, SingleTreeTraversalType
991
+ >::ResetTree(Tree& tree)
992
+ {
993
+ std::stack<Tree*> nodes;
994
+ nodes.push(&tree);
995
+ while (!nodes.empty())
996
+ {
997
+ Tree* node = nodes.top();
998
+ nodes.pop();
999
+
1000
+ // Reset bounds of this node.
1001
+ node->Stat().Reset();
1002
+
1003
+ // Then add the children.
1004
+ for (size_t i = 0; i < node->NumChildren(); ++i)
1005
+ nodes.push(&node->Child(i));
1006
+ }
1007
+ }
1008
+
896
1009
  //! Serialize the NeighborSearch model.
897
1010
  template<typename SortPolicy,
898
1011
  typename DistanceType,
@@ -905,15 +1018,41 @@ template<typename SortPolicy,
905
1018
  template<typename Archive>
906
1019
  void NeighborSearch<SortPolicy, DistanceType, MatType, TreeType,
907
1020
  DualTreeTraversalType, SingleTreeTraversalType>::serialize(
908
- Archive& ar, const uint32_t /* version */)
1021
+ Archive& ar, const uint32_t version)
909
1022
  {
1023
+ if (cereal::is_saving<Archive>())
1024
+ {
1025
+ // For reverse compatibility; can be removed in mlpack 5.0.0.
1026
+ if (searchModeMod)
1027
+ {
1028
+ searchStrategy = ModeToStrategy(searchMode);
1029
+ searchModeMod = false;
1030
+ }
1031
+ else
1032
+ {
1033
+ searchMode = StrategyToMode(searchStrategy);
1034
+ }
1035
+ }
1036
+
910
1037
  // Serialize preferences for search.
911
- ar(CEREAL_NVP(searchMode));
1038
+ if (version == 0)
1039
+ {
1040
+ ar(CEREAL_NVP(searchMode));
1041
+ searchModeMod = false;
1042
+ searchStrategy = ModeToStrategy(searchMode);
1043
+ }
1044
+ else
1045
+ {
1046
+ ar(CEREAL_NVP(searchStrategy));
1047
+ searchModeMod = false;
1048
+ searchMode = StrategyToMode(searchStrategy);
1049
+ }
1050
+
912
1051
  ar(CEREAL_NVP(treeNeedsReset));
913
1052
 
914
1053
  // If we are doing naive search, we serialize the dataset. Otherwise we
915
1054
  // serialize the tree.
916
- if (searchMode == NAIVE_MODE)
1055
+ if (searchStrategy == NAIVE)
917
1056
  {
918
1057
  // Delete the current reference set, if necessary and if we are loading.
919
1058
  if (cereal::is_loading<Archive>() && referenceSet)
@@ -14,6 +14,8 @@
14
14
  #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_STAT_HPP
15
15
 
16
16
  #include <mlpack/prereqs.hpp>
17
+ #include "sort_policies/nearest_neighbor_sort.hpp"
18
+ #include "sort_policies/furthest_neighbor_sort.hpp"
17
19
 
18
20
  namespace mlpack {
19
21
 
@@ -100,6 +102,14 @@ class NeighborSearchStat
100
102
  }
101
103
  };
102
104
 
105
+ // This is the type that must be used as the StatisticType for
106
+ // k-nearest-neighbor search (e.g. the KNN or KNNType<> class).
107
+ using NearestNeighborStat = NeighborSearchStat<NearestNeighborSort>;
108
+
109
+ // This is the type that must be used as the StatisticType for
110
+ // k-furthest-neighbor search (e.g. the KFN or KFNType<> class).
111
+ using FurthestNeighborStat = NeighborSearchStat<FurthestNeighborSort>;
112
+
103
113
  } // namespace mlpack
104
114
 
105
115
  #endif
@@ -49,9 +49,9 @@ class NSWrapperBase
49
49
  virtual const arma::mat& Dataset() const = 0;
50
50
 
51
51
  //! Get the search mode.
52
- virtual NeighborSearchMode SearchMode() const = 0;
52
+ virtual NeighborSearchStrategy SearchStrategy() const = 0;
53
53
  //! Modify the search modem
54
- virtual NeighborSearchMode& SearchMode() = 0;
54
+ virtual NeighborSearchStrategy& SearchStrategy() = 0;
55
55
 
56
56
  //! Get the approximation parameter epsilon.
57
57
  virtual double Epsilon() const = 0;
@@ -103,9 +103,9 @@ class NSWrapper : public NSWrapperBase
103
103
  public:
104
104
  //! Construct the NSWrapper object, initializing the internally-held
105
105
  //! NeighborSearch object.
106
- NSWrapper(const NeighborSearchMode searchMode,
106
+ NSWrapper(const NeighborSearchStrategy searchStrategy,
107
107
  const double epsilon) :
108
- ns(searchMode, epsilon)
108
+ ns(searchStrategy, epsilon)
109
109
  {
110
110
  // Nothing else to do.
111
111
  }
@@ -121,9 +121,9 @@ class NSWrapper : public NSWrapperBase
121
121
  const arma::mat& Dataset() const { return ns.ReferenceSet(); }
122
122
 
123
123
  //! Get the search mode.
124
- NeighborSearchMode SearchMode() const { return ns.SearchMode(); }
124
+ NeighborSearchStrategy SearchStrategy() const { return ns.SearchStrategy(); }
125
125
  //! Modify the search mode.
126
- NeighborSearchMode& SearchMode() { return ns.SearchMode(); }
126
+ NeighborSearchStrategy& SearchStrategy() { return ns.SearchStrategy(); }
127
127
 
128
128
  //! Get epsilon, the approximation parameter.
129
129
  double Epsilon() const { return ns.Epsilon(); }
@@ -201,12 +201,12 @@ class LeafSizeNSWrapper :
201
201
  public:
202
202
  //! Construct the LeafSizeNSWrapper by delegating to the NSWrapper
203
203
  //! constructor.
204
- LeafSizeNSWrapper(const NeighborSearchMode searchMode,
204
+ LeafSizeNSWrapper(const NeighborSearchStrategy searchStrategy,
205
205
  const double epsilon) :
206
206
  NSWrapper<SortPolicy,
207
207
  TreeType,
208
208
  DualTreeTraversalType,
209
- SingleTreeTraversalType>(searchMode, epsilon)
209
+ SingleTreeTraversalType>(searchStrategy, epsilon)
210
210
  {
211
211
  // Nothing to do.
212
212
  }
@@ -270,7 +270,7 @@ class SpillNSWrapper :
270
270
  {
271
271
  public:
272
272
  //! Construct the SpillNSWrapper.
273
- SpillNSWrapper(const NeighborSearchMode searchMode,
273
+ SpillNSWrapper(const NeighborSearchStrategy searchStrategy,
274
274
  const double epsilon) :
275
275
  NSWrapper<
276
276
  SortPolicy,
@@ -281,7 +281,7 @@ class SpillNSWrapper :
281
281
  SPTree<EuclideanDistance,
282
282
  NeighborSearchStat<SortPolicy>,
283
283
  arma::mat>::template DefeatistSingleTreeTraverser>(
284
- searchMode, epsilon)
284
+ searchStrategy, epsilon)
285
285
  {
286
286
  // Nothing to do.
287
287
  }
@@ -430,9 +430,9 @@ class NSModel
430
430
  //! Expose the dataset.
431
431
  const arma::mat& Dataset() const;
432
432
 
433
- //! Expose SearchMode.
434
- NeighborSearchMode SearchMode() const;
435
- NeighborSearchMode& SearchMode();
433
+ //! Expose search strategy..
434
+ NeighborSearchStrategy SearchStrategy() const;
435
+ NeighborSearchStrategy& SearchStrategy();
436
436
 
437
437
  //! Expose LeafSize.
438
438
  size_t LeafSize() const { return leafSize; }
@@ -459,13 +459,13 @@ class NSModel
459
459
  bool& RandomBasis() { return randomBasis; }
460
460
 
461
461
  //! Initialize the model type. (This does not perform any training.)
462
- void InitializeModel(const NeighborSearchMode searchMode,
462
+ void InitializeModel(const NeighborSearchStrategy searchStrategy,
463
463
  const double epsilon);
464
464
 
465
465
  //! Build the reference tree.
466
466
  void BuildModel(util::Timers& timers,
467
467
  arma::mat&& referenceSet,
468
- const NeighborSearchMode searchMode,
468
+ const NeighborSearchStrategy searchStrategy,
469
469
  const double epsilon = 0);
470
470
 
471
471
  //! Perform neighbor search. The query set will be reordered.