isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,1444 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+ template <class InputData, class WorkerMemory, class ldouble_safe>
66
+ void split_hplane_recursive(std::vector<IsoHPlane> &hplanes,
67
+ WorkerMemory &workspace,
68
+ InputData &input_data,
69
+ ModelParams &model_params,
70
+ std::vector<ImputeNode> *impute_nodes,
71
+ size_t curr_depth)
72
+ {
73
+ if (interrupt_switch) return;
74
+ ldouble_safe sum_weight = -HUGE_VAL;
75
+ size_t hplane_from = hplanes.size() - 1;
76
+ std::unique_ptr<RecursionState> recursion_state;
77
+ std::vector<bool> col_is_taken;
78
+ hashed_set<size_t> col_is_taken_s;
79
+
80
+ /* calculate imputation statistics if desired */
81
+ if (impute_nodes != NULL)
82
+ {
83
+ if (input_data.Xc_indptr != NULL)
84
+ std::sort(workspace.ix_arr.begin() + workspace.st,
85
+ workspace.ix_arr.begin() + workspace.end + 1);
86
+ build_impute_node<decltype(input_data), decltype(workspace), ldouble_safe>(
87
+ impute_nodes->back(), workspace,
88
+ input_data, model_params,
89
+ *impute_nodes, curr_depth,
90
+ model_params.min_imp_obs);
91
+ }
92
+
93
+ /* check for potential isolated leafs or unique splits */
94
+ if (workspace.end == workspace.st || (workspace.end - workspace.st) == 1 || curr_depth >= model_params.max_depth)
95
+ goto terminal_statistics;
96
+
97
+ /* when using weights, the split should stop when the sum of weights is <= 1 */
98
+ sum_weight = calculate_sum_weights<ldouble_safe>(
99
+ workspace.ix_arr, workspace.st, workspace.end, curr_depth,
100
+ workspace.weights_arr, workspace.weights_map);
101
+
102
+ if (curr_depth > 0 && (!workspace.weights_arr.empty() || !workspace.weights_map.empty()) && sum_weight <= 1)
103
+ goto terminal_statistics;
104
+
105
+ /* for sparse matrices, need to sort the indices */
106
+ if (input_data.Xc_indptr != NULL && impute_nodes == NULL)
107
+ std::sort(workspace.ix_arr.begin() + workspace.st, workspace.ix_arr.begin() + workspace.end + 1);
108
+
109
+ /* pick column to split according to criteria */
110
+ workspace.prob_split_type = workspace.rbin(workspace.rnd_generator);
111
+
112
+ if (
113
+ workspace.prob_split_type
114
+ < (
115
+ model_params.prob_pick_by_gain_avg +
116
+ model_params.prob_pick_by_gain_pl +
117
+ model_params.prob_pick_by_full_gain +
118
+ model_params.prob_pick_by_dens
119
+ )
120
+ )
121
+ {
122
+ workspace.ntry = model_params.ntry;
123
+ hplanes.back().score = -HUGE_VAL; /* this keeps track of the gain */
124
+ if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg)
125
+ workspace.criterion = Averaged;
126
+ else if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg +
127
+ model_params.prob_pick_by_gain_pl)
128
+ workspace.criterion = Pooled;
129
+ else if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg +
130
+ model_params.prob_pick_by_gain_pl +
131
+ model_params.prob_pick_by_full_gain)
132
+ workspace.criterion = FullGain;
133
+ else
134
+ workspace.criterion = DensityCrit;
135
+ }
136
+
137
+ else
138
+ {
139
+ workspace.criterion = NoCrit;
140
+ workspace.ntry = 1;
141
+ }
142
+
143
+ /* pick column selection method also according to criteria */
144
+ if (
145
+ (workspace.criterion != NoCrit &&
146
+ std::max(workspace.ntry, (size_t)1) >= workspace.col_sampler.get_remaining_cols())
147
+ ||
148
+ (workspace.col_sampler.get_remaining_cols() <= model_params.ndim)
149
+ ) {
150
+ workspace.prob_split_type = 0;
151
+ }
152
+ else {
153
+ workspace.prob_split_type = workspace.rbin(workspace.rnd_generator);
154
+ }
155
+
156
+ if (
157
+ workspace.prob_split_type
158
+ < model_params.prob_pick_col_by_range
159
+ )
160
+ {
161
+ workspace.col_criterion = ByRange;
162
+ if (curr_depth == 0 && is_boxed_metric(model_params.scoring_metric))
163
+ {
164
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
165
+ workspace.node_col_weights[col] = workspace.density_calculator.box_high[col]
166
+ - workspace.density_calculator.box_low[col];
167
+ add_col_weights_to_ranges:
168
+ if (workspace.tree_kurtoses != NULL)
169
+ {
170
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
171
+ {
172
+ if (workspace.node_col_weights[col] <= 0) continue;
173
+ workspace.node_col_weights[col] *= workspace.tree_kurtoses[col];
174
+ workspace.node_col_weights[col] = std::fmax(workspace.node_col_weights[col], 1e-100);
175
+ }
176
+ }
177
+ else if (input_data.col_weights != NULL)
178
+ {
179
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
180
+ {
181
+ if (workspace.node_col_weights[col] <= 0) continue;
182
+ workspace.node_col_weights[col] *= input_data.col_weights[col];
183
+ workspace.node_col_weights[col] = std::fmax(workspace.node_col_weights[col], 1e-100);
184
+ }
185
+ }
186
+ }
187
+
188
+ else if (curr_depth == 0 &&
189
+ model_params.sample_size == input_data.nrows &&
190
+ !model_params.with_replacement &&
191
+ input_data.range_low != NULL &&
192
+ model_params.ncols_per_tree == input_data.ncols_tot)
193
+ {
194
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
195
+ workspace.node_col_weights[col] = input_data.range_high[col]
196
+ - input_data.range_low[col];
197
+ goto add_col_weights_to_ranges;
198
+ }
199
+
200
+ else
201
+ {
202
+ calc_ranges_all_cols(input_data, workspace, model_params, workspace.node_col_weights.data(),
203
+ NULL,
204
+ NULL);
205
+ }
206
+
207
+ workspace.has_saved_stats = false;
208
+ }
209
+
210
+ else if (
211
+ workspace.prob_split_type
212
+ < (model_params.prob_pick_col_by_range +
213
+ model_params.prob_pick_col_by_var)
214
+ )
215
+ {
216
+ workspace.col_criterion = ByVar;
217
+ workspace.has_saved_stats = model_params.standardize_data || model_params.missing_action != Fail;
218
+ calc_var_all_cols<InputData, WorkerMemory, ldouble_safe>(
219
+ input_data, workspace, model_params,
220
+ workspace.node_col_weights.data(),
221
+ NULL, NULL,
222
+ workspace.has_saved_stats? workspace.saved_stat1.data() : NULL,
223
+ workspace.has_saved_stats? workspace.saved_stat2.data() : NULL);
224
+ }
225
+
226
+ else if (
227
+ workspace.prob_split_type
228
+ < (model_params.prob_pick_col_by_range +
229
+ model_params.prob_pick_col_by_var +
230
+ model_params.prob_pick_col_by_kurt)
231
+ )
232
+ {
233
+ workspace.col_criterion = ByKurt;
234
+ calc_kurt_all_cols<decltype(input_data), decltype(workspace), ldouble_safe>(
235
+ input_data, workspace, model_params, workspace.node_col_weights.data(),
236
+ NULL,
237
+ NULL);
238
+ workspace.has_saved_stats = false;
239
+ }
240
+
241
+ else
242
+ {
243
+ workspace.col_criterion = Uniformly;
244
+ workspace.has_saved_stats = false;
245
+ }
246
+
247
+ if (workspace.col_criterion != Uniformly)
248
+ {
249
+ if (!workspace.node_col_sampler.initialize(workspace.node_col_weights.data(),
250
+ &workspace.col_sampler.col_indices,
251
+ workspace.col_sampler.curr_pos,
252
+ model_params.ndim,
253
+ model_params.ntry > 1))
254
+ {
255
+ goto terminal_statistics;
256
+ }
257
+
258
+ if (model_params.ntry > 1)
259
+ {
260
+ workspace.node_col_sampler.backup(workspace.node_col_sampler_backup, input_data.ncols_tot);
261
+ }
262
+ }
263
+
264
+
265
+ if (workspace.criterion != NoCrit && (!workspace.weights_arr.empty() || !workspace.weights_map.empty()))
266
+ {
267
+ if (!workspace.weights_arr.empty())
268
+ {
269
+ for (size_t row = workspace.st; row <= workspace.end; row++)
270
+ workspace.sample_weights[row-workspace.st] = workspace.weights_arr[workspace.ix_arr[row]];
271
+ }
272
+
273
+ else
274
+ {
275
+ for (size_t row = workspace.st; row <= workspace.end; row++)
276
+ workspace.sample_weights[row-workspace.st] = workspace.weights_map[workspace.ix_arr[row]];
277
+ }
278
+ }
279
+
280
+ if (workspace.criterion == FullGain)
281
+ {
282
+ workspace.col_sampler.get_array_remaining_cols(workspace.col_indices);
283
+ }
284
+
285
+ workspace.ntaken_best = 0;
286
+
287
+ for (size_t attempt = 0; attempt < workspace.ntry; attempt++)
288
+ {
289
+ if (attempt > 0 && workspace.col_criterion != Uniformly)
290
+ {
291
+ workspace.node_col_sampler.restore(workspace.node_col_sampler_backup);
292
+ }
293
+
294
+ if (workspace.col_criterion == Uniformly)
295
+ {
296
+ if (input_data.ncols_tot < 1e5 ||
297
+ ((ldouble_safe)model_params.ndim / (ldouble_safe)workspace.col_sampler.get_remaining_cols()) > .25
298
+ )
299
+ {
300
+ if (!col_is_taken.size())
301
+ col_is_taken.resize(input_data.ncols_tot, false);
302
+ else
303
+ col_is_taken.assign(input_data.ncols_tot, false);
304
+ }
305
+ else {
306
+ col_is_taken_s.clear();
307
+ col_is_taken_s.reserve(model_params.ndim);
308
+ }
309
+ }
310
+
311
+ workspace.ntaken = 0;
312
+ workspace.ntried = 0;
313
+ std::fill(workspace.comb_val.begin(),
314
+ workspace.comb_val.begin() + (workspace.end - workspace.st + 1),
315
+ (double)0);
316
+
317
+ if (model_params.ndim >= input_data.ncols_tot)
318
+ workspace.col_sampler.prepare_full_pass();
319
+ else if (workspace.try_all && workspace.col_criterion == Uniformly)
320
+ workspace.col_sampler.shuffle_remainder(workspace.rnd_generator);
321
+ size_t threshold_shuffle = (workspace.col_sampler.get_remaining_cols() + 1) / 2;
322
+
323
+ while (
324
+ (workspace.col_criterion != Uniformly)?
325
+ workspace.node_col_sampler.sample_col(workspace.col_chosen, workspace.rnd_generator)
326
+ :
327
+ (workspace.try_all?
328
+ workspace.col_sampler.sample_col(workspace.col_chosen)
329
+ :
330
+ workspace.col_sampler.sample_col(workspace.col_chosen, workspace.rnd_generator))
331
+ )
332
+ {
333
+ if (interrupt_switch) return;
334
+
335
+ if (workspace.col_criterion != Uniformly) goto add_this_col;
336
+
337
+ workspace.ntried++;
338
+ if (!workspace.try_all && workspace.ntried >= threshold_shuffle)
339
+ {
340
+ workspace.try_all = true;
341
+ workspace.col_sampler.shuffle_remainder(workspace.rnd_generator);
342
+ }
343
+
344
+ if (is_col_taken(col_is_taken, col_is_taken_s, workspace.col_chosen))
345
+ continue;
346
+
347
+ get_split_range(workspace, input_data, model_params);
348
+ if (workspace.unsplittable)
349
+ {
350
+ if (workspace.col_criterion != Uniformly) /* <- used 'node_col_sampler' */
351
+ unexpected_error();
352
+ workspace.col_sampler.drop_col(workspace.col_chosen
353
+ +
354
+ ((workspace.col_type == Numeric)?
355
+ (size_t)0 : input_data.ncols_numeric));
356
+ }
357
+
358
+ else
359
+ {
360
+ add_this_col:
361
+ add_chosen_column<decltype(input_data), decltype(workspace), ldouble_safe>(
362
+ workspace, input_data, model_params, col_is_taken, col_is_taken_s
363
+ );
364
+ if (++workspace.ntaken >= model_params.ndim)
365
+ break;
366
+ }
367
+ }
368
+
369
+ if (!workspace.ntaken && !workspace.ntaken_best)
370
+ goto terminal_statistics;
371
+ else if (!workspace.ntaken)
372
+ break;
373
+
374
+
375
+ /* evaluate gain if necessary */
376
+ if (workspace.criterion != NoCrit)
377
+ {
378
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
379
+ workspace.this_gain = eval_guided_crit<ldouble_safe>(
380
+ workspace.comb_val.data(), workspace.end - workspace.st + 1,
381
+ workspace.criterion, model_params.min_gain, workspace.ntry == 1,
382
+ workspace.buffer_dbl.data(), workspace.this_split_point,
383
+ workspace.xmin, workspace.xmax,
384
+ workspace.ix_arr.data() + workspace.st,
385
+ workspace.col_indices.data(),
386
+ workspace.col_sampler.get_remaining_cols(),
387
+ model_params.ncols_per_tree < input_data.ncols_numeric,
388
+ input_data.X_row_major.data(),
389
+ input_data.ncols_numeric,
390
+ input_data.Xr.data(),
391
+ input_data.Xr_ind.data(),
392
+ input_data.Xr_indptr.data());
393
+ else if (!workspace.weights_arr.empty())
394
+ workspace.this_gain = eval_guided_crit_weighted<ldouble_safe>(
395
+ workspace.comb_val.data(), workspace.end - workspace.st + 1,
396
+ workspace.criterion, model_params.min_gain, workspace.ntry == 1,
397
+ workspace.buffer_dbl.data(), workspace.this_split_point,
398
+ workspace.xmin, workspace.xmax,
399
+ workspace.sample_weights.data(), workspace.buffer_szt.data(),
400
+ workspace.ix_arr.data() + workspace.st,
401
+ workspace.col_indices.data(),
402
+ workspace.col_sampler.get_remaining_cols(),
403
+ model_params.ncols_per_tree < input_data.ncols_numeric,
404
+ input_data.X_row_major.data(),
405
+ input_data.ncols_numeric,
406
+ input_data.Xr.data(),
407
+ input_data.Xr_ind.data(),
408
+ input_data.Xr_indptr.data());
409
+ else
410
+ workspace.this_gain = eval_guided_crit_weighted<ldouble_safe>(
411
+ workspace.comb_val.data(), workspace.end - workspace.st + 1,
412
+ workspace.criterion, model_params.min_gain, workspace.ntry == 1,
413
+ workspace.buffer_dbl.data(), workspace.this_split_point,
414
+ workspace.xmin, workspace.xmax,
415
+ workspace.sample_weights.data(), workspace.buffer_szt.data(),
416
+ workspace.ix_arr.data() + workspace.st,
417
+ workspace.col_indices.data(),
418
+ workspace.col_sampler.get_remaining_cols(),
419
+ model_params.ncols_per_tree < input_data.ncols_numeric,
420
+ input_data.X_row_major.data(),
421
+ input_data.ncols_numeric,
422
+ input_data.Xr.data(),
423
+ input_data.Xr_ind.data(),
424
+ input_data.Xr_indptr.data());
425
+ }
426
+
427
+ /* pass to the output object */
428
+ if (workspace.ntry == 1 || workspace.this_gain > hplanes.back().score)
429
+ {
430
+ /* these should be shrunk later according to what ends up used */
431
+ hplanes.back().score = workspace.this_gain;
432
+ workspace.ntaken_best = workspace.ntaken;
433
+ if (workspace.criterion != NoCrit)
434
+ {
435
+ hplanes.back().split_point = workspace.this_split_point;
436
+ if (model_params.penalize_range)
437
+ {
438
+ hplanes.back().range_low = workspace.xmin - workspace.xmax + hplanes.back().split_point;
439
+ hplanes.back().range_high = workspace.xmax - workspace.xmin + hplanes.back().split_point;
440
+ }
441
+ }
442
+ hplanes.back().col_num.assign(workspace.col_take.begin(), workspace.col_take.begin() + workspace.ntaken);
443
+ hplanes.back().col_type.assign(workspace.col_take_type.begin(), workspace.col_take_type.begin() + workspace.ntaken);
444
+ if (input_data.ncols_numeric)
445
+ {
446
+ hplanes.back().coef.assign(workspace.ext_coef.begin(), workspace.ext_coef.begin() + workspace.ntaken);
447
+ hplanes.back().mean.assign(workspace.ext_mean.begin(), workspace.ext_mean.begin() + workspace.ntaken);
448
+ }
449
+
450
+ if (model_params.missing_action != Fail)
451
+ hplanes.back().fill_val.assign(workspace.ext_fill_val.begin(), workspace.ext_fill_val.begin() + workspace.ntaken);
452
+
453
+ if (model_params.scoring_metric != Depth && !is_boxed_metric(model_params.scoring_metric))
454
+ {
455
+ workspace.density_calculator.save_range(workspace.xmin, workspace.xmax);
456
+ }
457
+
458
+ if (input_data.ncols_categ)
459
+ {
460
+ hplanes.back().fill_new.assign(workspace.ext_fill_new.begin(), workspace.ext_fill_new.begin() + workspace.ntaken);
461
+ switch(model_params.cat_split_type)
462
+ {
463
+ case SingleCateg:
464
+ {
465
+ hplanes.back().chosen_cat.assign(workspace.chosen_cat.begin(),
466
+ workspace.chosen_cat.begin() + workspace.ntaken);
467
+ break;
468
+ }
469
+
470
+ case SubSet:
471
+ {
472
+ if (hplanes.back().cat_coef.size() < workspace.ntaken)
473
+ hplanes.back().cat_coef.assign(workspace.ext_cat_coef.begin(),
474
+ workspace.ext_cat_coef.begin() + workspace.ntaken);
475
+ else
476
+ for (size_t col = 0; col < workspace.ntaken_best; col++)
477
+ std::copy(workspace.ext_cat_coef[col].begin(),
478
+ workspace.ext_cat_coef[col].end(),
479
+ hplanes.back().cat_coef[col].begin());
480
+ break;
481
+ }
482
+ }
483
+ }
484
+ }
485
+
486
+ }
487
+
488
+ col_is_taken.clear();
489
+ col_is_taken.shrink_to_fit();
490
+ col_is_taken_s.clear();
491
+
492
+ /* if the best split is not good enough, don't split any further */
493
+ if (workspace.criterion != NoCrit && hplanes.back().score <= 0)
494
+ goto terminal_statistics;
495
+
496
+ /* now need to reproduce the same split from before */
497
+ if (workspace.criterion != NoCrit)
498
+ {
499
+ std::fill(workspace.comb_val.begin(),
500
+ workspace.comb_val.begin() + (workspace.end - workspace.st + 1),
501
+ (double)0);
502
+ for (size_t col = 0; col < workspace.ntaken_best; col++)
503
+ {
504
+ switch(hplanes.back().col_type[col])
505
+ {
506
+ case Numeric:
507
+ {
508
+ if (input_data.Xc_indptr == NULL)
509
+ {
510
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
511
+ input_data.numeric_data + hplanes.back().col_num[col] * input_data.nrows,
512
+ hplanes.back().coef[col], (double)0, hplanes.back().mean[col],
513
+ hplanes.back().fill_val.size()? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
514
+ model_params.missing_action, NULL, NULL, false);
515
+ }
516
+
517
+ else
518
+ {
519
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end,
520
+ hplanes.back().col_num[col], workspace.comb_val.data(),
521
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
522
+ hplanes.back().coef[col], (double)0, hplanes.back().mean[col],
523
+ hplanes.back().fill_val.size()? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
524
+ model_params.missing_action, NULL, NULL, false);
525
+ }
526
+
527
+ break;
528
+ }
529
+
530
+ case Categorical:
531
+ {
532
+ add_linear_comb<ldouble_safe>(
533
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
534
+ input_data.categ_data + hplanes.back().col_num[col] * input_data.nrows,
535
+ input_data.ncat[hplanes.back().col_num[col]],
536
+ (model_params.cat_split_type == SubSet)? hplanes.back().cat_coef[col].data() : NULL,
537
+ (model_params.cat_split_type == SingleCateg)? hplanes.back().fill_new[col] : (double)0,
538
+ (model_params.cat_split_type == SingleCateg)? hplanes.back().chosen_cat[col] : 0,
539
+ (hplanes.back().fill_val.size())? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
540
+ (model_params.cat_split_type == SubSet)? hplanes.back().fill_new[col] : workspace.this_split_point, /* second case is not used */
541
+ NULL, NULL, model_params.new_cat_action, model_params.missing_action,
542
+ model_params.cat_split_type, false);
543
+ break;
544
+ }
545
+
546
+ default:
547
+ {
548
+ unexpected_error();
549
+ break;
550
+ }
551
+ }
552
+ }
553
+ }
554
+
555
+ /* get the range */
556
+ if (workspace.criterion == NoCrit)
557
+ {
558
+ workspace.xmin = HUGE_VAL;
559
+ workspace.xmax = -HUGE_VAL;
560
+ for (size_t row = 0; row < (workspace.end - workspace.st + 1); row++)
561
+ {
562
+ workspace.xmin = (workspace.xmin > workspace.comb_val[row])? workspace.comb_val[row] : workspace.xmin;
563
+ workspace.xmax = (workspace.xmax < workspace.comb_val[row])? workspace.comb_val[row] : workspace.xmax;
564
+ }
565
+ if (workspace.xmin == workspace.xmax)
566
+ goto terminal_statistics; /* in theory, could try again too, this could just be an unlucky case */
567
+
568
+ hplanes.back().split_point = sample_random_uniform(workspace.xmin, workspace.xmax, workspace.rnd_generator);
569
+
570
+ /* determine acceptable range */
571
+ if (model_params.penalize_range)
572
+ {
573
+ hplanes.back().range_low = workspace.xmin - workspace.xmax + hplanes.back().split_point;
574
+ hplanes.back().range_high = workspace.xmax - workspace.xmin + hplanes.back().split_point;
575
+ }
576
+ }
577
+
578
+ if (model_params.missing_action == Fail && is_na_or_inf(hplanes.back().split_point))
579
+ throw std::runtime_error("Data has missing values. Try using a different value for 'missing_action'.\n");
580
+
581
+ /* divide */
582
+ workspace.split_ix = divide_subset_split(workspace.ix_arr.data(), workspace.comb_val.data(),
583
+ workspace.st, workspace.end, hplanes.back().split_point);
584
+
585
+ /* set as non-terminal */
586
+ hplanes.back().score = -1;
587
+
588
+ /* add another round of separation depth for distance */
589
+ if (model_params.calc_dist && curr_depth > 0)
590
+ add_separation_step(workspace, input_data, (double)(-1));
591
+
592
+ /* simplify vectors according to what ends up used */
593
+ if (input_data.ncols_categ || workspace.ntaken_best < model_params.ndim)
594
+ simplify_hplane(hplanes.back(), workspace, input_data, model_params);
595
+
596
+ shrink_to_fit_hplane(hplanes.back(), false);
597
+
598
+ /* if using a custom scoring metric, need to calculate it now */
599
+ if (model_params.scoring_metric != Depth)
600
+ {
601
+ if (workspace.criterion != NoCrit)
602
+ workspace.density_calculator.restore_range(workspace.xmin, workspace.xmax);
603
+
604
+ if (model_params.scoring_metric == Density)
605
+ {
606
+ workspace.density_calculator.push_density(workspace.xmin, workspace.xmax, hplanes.back().split_point);
607
+ }
608
+
609
+ else if (is_boxed_metric(model_params.scoring_metric))
610
+ {
611
+ workspace.density_calculator.push_bdens_ext(hplanes.back(), model_params);
612
+ }
613
+
614
+ else
615
+ {
616
+ double pct_tree_left;
617
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
618
+ {
619
+ pct_tree_left = (ldouble_safe)(workspace.split_ix - workspace.st)
620
+ / (ldouble_safe)(workspace.end - workspace.st + 1);
621
+ }
622
+
623
+ else
624
+ {
625
+ ldouble_safe wtot = 0;
626
+ ldouble_safe wleft = 0;
627
+ if (!workspace.weights_arr.empty())
628
+ {
629
+ for (size_t ix = workspace.st; ix < workspace.split_ix; ix++)
630
+ wtot += workspace.weights_arr[workspace.ix_arr[ix]];
631
+ wleft = wtot;
632
+ for (size_t ix = workspace.split_ix; ix <= workspace.end; ix++)
633
+ wtot += workspace.weights_arr[workspace.ix_arr[ix]];
634
+ }
635
+
636
+ else
637
+ {
638
+ for (size_t ix = workspace.st; ix < workspace.split_ix; ix++)
639
+ wtot += workspace.weights_map[workspace.ix_arr[ix]];
640
+ wleft = wtot;
641
+ for (size_t ix = workspace.split_ix; ix <= workspace.end; ix++)
642
+ wtot += workspace.weights_map[workspace.ix_arr[ix]];
643
+ }
644
+
645
+ pct_tree_left = wleft / wtot;
646
+ }
647
+
648
+ workspace.density_calculator.push_adj(workspace.xmin, workspace.xmax,
649
+ hplanes.back().split_point, pct_tree_left,
650
+ model_params.scoring_metric);
651
+ }
652
+ }
653
+
654
+ /* now split */
655
+
656
+ /* back-up where it was */
657
+ recursion_state = std::unique_ptr<RecursionState>(new RecursionState(workspace, true));
658
+
659
+ /* follow left branch */
660
+ hplanes[hplane_from].hplane_left = hplanes.size();
661
+ hplanes.emplace_back();
662
+ if (impute_nodes != NULL) impute_nodes->emplace_back(hplane_from);
663
+ workspace.end = workspace.split_ix - 1;
664
+ split_hplane_recursive<InputData, WorkerMemory, ldouble_safe>(
665
+ hplanes,
666
+ workspace,
667
+ input_data,
668
+ model_params,
669
+ impute_nodes,
670
+ curr_depth + 1);
671
+
672
+
673
+ /* follow right branch */
674
+ hplanes[hplane_from].hplane_right = hplanes.size();
675
+ recursion_state->restore_state(workspace);
676
+ hplanes.emplace_back();
677
+
678
+ if (impute_nodes != NULL) impute_nodes->emplace_back(hplane_from);
679
+ if (is_boxed_metric(model_params.scoring_metric)) {
680
+ workspace.density_calculator.pop_bdens_ext();
681
+ }
682
+ else if (model_params.scoring_metric != Depth) {
683
+ workspace.density_calculator.pop();
684
+ }
685
+ workspace.st = workspace.split_ix;
686
+ split_hplane_recursive<InputData, WorkerMemory, ldouble_safe>(
687
+ hplanes,
688
+ workspace,
689
+ input_data,
690
+ model_params,
691
+ impute_nodes,
692
+ curr_depth + 1);
693
+ if (is_boxed_metric(model_params.scoring_metric)) {
694
+ workspace.density_calculator.pop_bdens_ext_right();
695
+ }
696
+ else if (model_params.scoring_metric != Depth) {
697
+ workspace.density_calculator.pop_right();
698
+ }
699
+
700
+ return;
701
+
702
+ terminal_statistics:
703
+ {
704
+ hplanes.back().hplane_left = 0;
705
+
706
+ bool has_weights = !workspace.weights_arr.empty() || !workspace.weights_map.empty();
707
+ if (has_weights)
708
+ {
709
+ if (sum_weight == -HUGE_VAL)
710
+ sum_weight = calculate_sum_weights<ldouble_safe>(
711
+ workspace.ix_arr, workspace.st, workspace.end, curr_depth,
712
+ workspace.weights_arr, workspace.weights_map);
713
+ }
714
+
715
+ switch (model_params.scoring_metric)
716
+ {
717
+ case Depth:
718
+ {
719
+ if (!has_weights)
720
+ hplanes.back().score = curr_depth + expected_avg_depth<ldouble_safe>(workspace.end - workspace.st + 1);
721
+ else
722
+ hplanes.back().score = curr_depth + expected_avg_depth<ldouble_safe>(sum_weight);
723
+ break;
724
+ }
725
+
726
+ case AdjDepth:
727
+ {
728
+ if (!has_weights)
729
+ hplanes.back().score = workspace.density_calculator.calc_adj_depth() + expected_avg_depth<ldouble_safe>(workspace.end - workspace.st + 1);
730
+ else
731
+ hplanes.back().score = workspace.density_calculator.calc_adj_depth() + expected_avg_depth<ldouble_safe>(sum_weight);
732
+ break;
733
+ }
734
+
735
+ case Density:
736
+ {
737
+ if (!has_weights)
738
+ hplanes.back().score = workspace.density_calculator.calc_density(workspace.end - workspace.st + 1, model_params.sample_size);
739
+ else
740
+ hplanes.back().score = workspace.density_calculator.calc_density(sum_weight, model_params.sample_size);
741
+ break;
742
+ }
743
+
744
+ case AdjDensity:
745
+ {
746
+ hplanes.back().score = workspace.density_calculator.calc_adj_density();
747
+ break;
748
+ }
749
+
750
+ case BoxedRatio:
751
+ {
752
+ hplanes.back().score = workspace.density_calculator.calc_bratio_ext();
753
+ break;
754
+ }
755
+
756
+ case BoxedDensity:
757
+ {
758
+ if (!has_weights)
759
+ hplanes.back().score = workspace.density_calculator.calc_bdens_ext(workspace.end - workspace.st + 1, model_params.sample_size);
760
+ else
761
+ hplanes.back().score = workspace.density_calculator.calc_bdens_ext(sum_weight, model_params.sample_size);
762
+ break;
763
+ }
764
+
765
+ case BoxedDensity2:
766
+ {
767
+ if (!has_weights)
768
+ hplanes.back().score = workspace.density_calculator.calc_bdens_ext(workspace.end - workspace.st + 1, model_params.sample_size);
769
+ else
770
+ hplanes.back().score = workspace.density_calculator.calc_bdens_ext(sum_weight, model_params.sample_size);
771
+ break;
772
+ }
773
+
774
+ }
775
+
776
+ /* don't leave any vector initialized */
777
+ shrink_to_fit_hplane(hplanes.back(), true);
778
+
779
+ hplanes.back().remainder = (!workspace.weights_arr.empty())?
780
+ sum_weight : ((!workspace.weights_map.empty())?
781
+ sum_weight : ((double)(workspace.end - workspace.st + 1))
782
+ );
783
+
784
+ /* for distance, assume also the elements keep being split */
785
+ if (model_params.calc_dist)
786
+ add_remainder_separation_steps<InputData, WorkerMemory, ldouble_safe>(workspace, input_data, sum_weight);
787
+
788
+ /* add this depth right away if requested */
789
+ if (!workspace.row_depths.empty())
790
+ for (size_t row = workspace.st; row <= workspace.end; row++)
791
+ workspace.row_depths[workspace.ix_arr[row]] += hplanes.back().score;
792
+
793
+ /* add imputations from node if requested */
794
+ if (model_params.impute_at_fit)
795
+ add_from_impute_node(impute_nodes->back(), workspace, input_data);
796
+ }
797
+ }
798
+
799
+
800
+ template <class InputData, class WorkerMemory, class ldouble_safe>
801
+ void add_chosen_column(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params,
802
+ std::vector<bool> &col_is_taken, hashed_set<size_t> &col_is_taken_s)
803
+ {
804
+ if (workspace.col_criterion == Uniformly) {
805
+ set_col_as_taken(col_is_taken, col_is_taken_s, input_data, workspace.col_chosen, workspace.col_type);
806
+ }
807
+ else {
808
+ if (workspace.col_chosen < input_data.ncols_numeric) {
809
+ workspace.col_type = Numeric;
810
+ }
811
+ else {
812
+ workspace.col_chosen -= input_data.ncols_numeric;
813
+ workspace.col_type = Categorical;
814
+ }
815
+ }
816
+ workspace.col_take[workspace.ntaken] = workspace.col_chosen;
817
+ workspace.col_take_type[workspace.ntaken] = workspace.col_type;
818
+
819
+ switch(workspace.col_type)
820
+ {
821
+ case Numeric:
822
+ {
823
+ switch(model_params.coef_type)
824
+ {
825
+ case Uniform:
826
+ {
827
+ workspace.ext_coef[workspace.ntaken] = workspace.coef_unif(workspace.rnd_generator);
828
+ break;
829
+ }
830
+
831
+ case Normal:
832
+ {
833
+ workspace.ext_coef[workspace.ntaken] = workspace.coef_norm(workspace.rnd_generator);
834
+ break;
835
+ }
836
+ }
837
+
838
+ if (input_data.Xc_indptr == NULL)
839
+ {
840
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
841
+ {
842
+ if (model_params.missing_action == Fail && !model_params.standardize_data)
843
+ {
844
+ workspace.ext_mean[workspace.ntaken] = 0;
845
+ workspace.ext_sd = 1;
846
+ }
847
+
848
+ else if (!model_params.standardize_data)
849
+ {
850
+ workspace.ext_sd = 1;
851
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
852
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
853
+ else
854
+ {
855
+ workspace.ext_mean[workspace.ntaken]
856
+ =
857
+ calc_mean_only(workspace.ix_arr.data(), workspace.st, workspace.end,
858
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows);
859
+ }
860
+ }
861
+
862
+ else
863
+ {
864
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
865
+ {
866
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
867
+ workspace.ext_sd = workspace.saved_stat2[workspace.col_chosen];
868
+ }
869
+
870
+ else
871
+ {
872
+ calc_mean_and_sd<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
873
+ ldouble_safe>(
874
+ workspace.ix_arr.data(), workspace.st, workspace.end,
875
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
876
+ model_params.missing_action, workspace.ext_sd, workspace.ext_mean[workspace.ntaken]);
877
+ }
878
+ }
879
+
880
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
881
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
882
+ workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
883
+ workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
884
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true);
885
+ }
886
+ else if (!workspace.weights_arr.empty())
887
+ {
888
+ if (model_params.missing_action == Fail && !model_params.standardize_data)
889
+ {
890
+ workspace.ext_mean[workspace.ntaken] = 0;
891
+ workspace.ext_sd = 1;
892
+ }
893
+
894
+ else if (!model_params.standardize_data)
895
+ {
896
+ workspace.ext_sd = 1;
897
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
898
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
899
+ else
900
+ {
901
+ workspace.ext_mean[workspace.ntaken]
902
+ =
903
+ calc_mean_only_weighted(workspace.ix_arr.data(), workspace.st, workspace.end,
904
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
905
+ workspace.weights_arr);
906
+ }
907
+ }
908
+
909
+ else
910
+ {
911
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
912
+ {
913
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
914
+ workspace.ext_sd = workspace.saved_stat2[workspace.col_chosen];
915
+ }
916
+
917
+ else
918
+ {
919
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
920
+ decltype(workspace.weights_arr), ldouble_safe>(
921
+ workspace.ix_arr.data(), workspace.st, workspace.end,
922
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
923
+ workspace.weights_arr,
924
+ model_params.missing_action, workspace.ext_sd,
925
+ workspace.ext_mean[workspace.ntaken]);
926
+ }
927
+ }
928
+
929
+ add_linear_comb_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
930
+ decltype(workspace.weights_arr), ldouble_safe>(
931
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
932
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
933
+ workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
934
+ workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
935
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true,
936
+ workspace.weights_arr);
937
+ }
938
+
939
+ else
940
+ {
941
+ if (model_params.missing_action == Fail && !model_params.standardize_data)
942
+ {
943
+ workspace.ext_mean[workspace.ntaken] = 0;
944
+ workspace.ext_sd = 1;
945
+ }
946
+
947
+ else if (!model_params.standardize_data)
948
+ {
949
+ workspace.ext_sd = 1;
950
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
951
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
952
+ else
953
+ {
954
+ workspace.ext_mean[workspace.ntaken]
955
+ =
956
+ calc_mean_only_weighted(workspace.ix_arr.data(), workspace.st, workspace.end,
957
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
958
+ workspace.weights_map);
959
+ }
960
+ }
961
+
962
+ else
963
+ {
964
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
965
+ {
966
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
967
+ workspace.ext_sd = workspace.saved_stat2[workspace.col_chosen];
968
+ }
969
+
970
+ else
971
+ {
972
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
973
+ decltype(workspace.weights_map), ldouble_safe>(
974
+ workspace.ix_arr.data(), workspace.st, workspace.end,
975
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
976
+ workspace.weights_map,
977
+ model_params.missing_action, workspace.ext_sd,
978
+ workspace.ext_mean[workspace.ntaken]);
979
+ }
980
+ }
981
+
982
+ add_linear_comb_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
983
+ decltype(workspace.weights_map), ldouble_safe>(
984
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
985
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
986
+ workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
987
+ workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
988
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true,
989
+ workspace.weights_map);
990
+ }
991
+ }
992
+
993
+ else
994
+ {
995
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
996
+ {
997
+ if (model_params.missing_action == Fail && !model_params.standardize_data)
998
+ {
999
+ workspace.ext_mean[workspace.ntaken] = 0;
1000
+ workspace.ext_sd = 1;
1001
+ }
1002
+
1003
+ else if (!model_params.standardize_data)
1004
+ {
1005
+ workspace.ext_sd = 1;
1006
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
1007
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
1008
+ else
1009
+ {
1010
+ workspace.ext_mean[workspace.ntaken]
1011
+ =
1012
+ calc_mean_only<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
1013
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
1014
+ ldouble_safe>(
1015
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
1016
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr);
1017
+ }
1018
+ }
1019
+
1020
+ else
1021
+ {
1022
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
1023
+ {
1024
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
1025
+ workspace.ext_sd = workspace.saved_stat2[workspace.col_chosen];
1026
+ }
1027
+
1028
+ else
1029
+ {
1030
+ calc_mean_and_sd<typename std::remove_pointer<decltype(input_data.Xc)>::type,
1031
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
1032
+ ldouble_safe>(
1033
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
1034
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1035
+ workspace.ext_sd, workspace.ext_mean[workspace.ntaken]);
1036
+ }
1037
+ }
1038
+
1039
+ add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end,
1040
+ workspace.col_chosen, workspace.comb_val.data(),
1041
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1042
+ workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
1043
+ workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
1044
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true);
1045
+ }
1046
+
1047
+ else if (!workspace.weights_arr.empty())
1048
+ {
1049
+ if (model_params.missing_action == Fail && !model_params.standardize_data)
1050
+ {
1051
+ workspace.ext_mean[workspace.ntaken] = 0;
1052
+ workspace.ext_sd = 1;
1053
+ }
1054
+
1055
+ else if (!model_params.standardize_data)
1056
+ {
1057
+ workspace.ext_sd = 1;
1058
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
1059
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
1060
+ else
1061
+ {
1062
+ workspace.ext_mean[workspace.ntaken]
1063
+ =
1064
+ calc_mean_only_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
1065
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
1066
+ decltype(workspace.weights_arr), ldouble_safe>(
1067
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
1068
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1069
+ workspace.weights_arr);
1070
+ }
1071
+ }
1072
+
1073
+ else
1074
+ {
1075
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
1076
+ {
1077
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
1078
+ workspace.ext_sd = workspace.saved_stat2[workspace.col_chosen];
1079
+ }
1080
+
1081
+ else
1082
+ {
1083
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
1084
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
1085
+ decltype(workspace.weights_arr), ldouble_safe>(
1086
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
1087
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1088
+ workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
1089
+ workspace.weights_arr);
1090
+ }
1091
+ }
1092
+
1093
+ add_linear_comb_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
1094
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
1095
+ decltype(workspace.weights_arr), ldouble_safe>(
1096
+ workspace.ix_arr.data(), workspace.st, workspace.end,
1097
+ workspace.col_chosen, workspace.comb_val.data(),
1098
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1099
+ workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
1100
+ workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
1101
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true,
1102
+ workspace.weights_arr);
1103
+ }
1104
+
1105
+ else
1106
+ {
1107
+ if (model_params.missing_action == Fail && !model_params.standardize_data)
1108
+ {
1109
+ workspace.ext_mean[workspace.ntaken] = 0;
1110
+ workspace.ext_sd = 1;
1111
+ }
1112
+
1113
+ else if (!model_params.standardize_data)
1114
+ {
1115
+ workspace.ext_sd = 1;
1116
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
1117
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
1118
+ else
1119
+ {
1120
+ workspace.ext_mean[workspace.ntaken]
1121
+ =
1122
+ calc_mean_only_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
1123
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
1124
+ decltype(workspace.weights_map), ldouble_safe>(
1125
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
1126
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1127
+ workspace.weights_map);
1128
+ }
1129
+ }
1130
+
1131
+ else
1132
+ {
1133
+ if (workspace.col_criterion != Uniformly && workspace.has_saved_stats)
1134
+ {
1135
+ workspace.ext_mean[workspace.ntaken] = workspace.saved_stat1[workspace.col_chosen];
1136
+ workspace.ext_sd = workspace.saved_stat2[workspace.col_chosen];
1137
+ }
1138
+
1139
+ else
1140
+ {
1141
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
1142
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
1143
+ decltype(workspace.weights_map), ldouble_safe>(
1144
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
1145
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1146
+ workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
1147
+ workspace.weights_map);
1148
+ }
1149
+ }
1150
+
1151
+ add_linear_comb_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
1152
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
1153
+ decltype(workspace.weights_map), ldouble_safe>(
1154
+ workspace.ix_arr.data(), workspace.st, workspace.end,
1155
+ workspace.col_chosen, workspace.comb_val.data(),
1156
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1157
+ workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
1158
+ workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
1159
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true,
1160
+ workspace.weights_map);
1161
+ }
1162
+
1163
+
1164
+ }
1165
+ break;
1166
+ }
1167
+
1168
+ case Categorical:
1169
+ {
1170
+ switch(model_params.cat_split_type)
1171
+ {
1172
+ case SingleCateg:
1173
+ {
1174
+ workspace.chosen_cat[workspace.ntaken] = choose_cat_from_present(workspace, input_data, workspace.col_chosen);
1175
+ workspace.ext_fill_new[workspace.ntaken] = workspace.coef_norm(workspace.rnd_generator);
1176
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
1177
+ {
1178
+ add_linear_comb<ldouble_safe>(
1179
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
1180
+ input_data.categ_data + workspace.col_chosen * input_data.nrows,
1181
+ input_data.ncat[workspace.col_chosen],
1182
+ NULL, workspace.ext_fill_new[workspace.ntaken],
1183
+ workspace.chosen_cat[workspace.ntaken],
1184
+ workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
1185
+ NULL, NULL, model_params.new_cat_action, model_params.missing_action, SingleCateg, true);
1186
+ }
1187
+
1188
+ else if (!workspace.weights_arr.empty())
1189
+ {
1190
+ add_linear_comb_weighted<decltype(workspace.weights_arr), ldouble_safe>(
1191
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
1192
+ input_data.categ_data + workspace.col_chosen * input_data.nrows,
1193
+ input_data.ncat[workspace.col_chosen],
1194
+ NULL, workspace.ext_fill_new[workspace.ntaken],
1195
+ workspace.chosen_cat[workspace.ntaken],
1196
+ workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
1197
+ NULL, model_params.new_cat_action, model_params.missing_action, SingleCateg, true,
1198
+ workspace.weights_arr);
1199
+ }
1200
+
1201
+ else
1202
+ {
1203
+ add_linear_comb_weighted<decltype(workspace.weights_map), ldouble_safe>(
1204
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
1205
+ input_data.categ_data + workspace.col_chosen * input_data.nrows,
1206
+ input_data.ncat[workspace.col_chosen],
1207
+ NULL, workspace.ext_fill_new[workspace.ntaken],
1208
+ workspace.chosen_cat[workspace.ntaken],
1209
+ workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
1210
+ NULL, model_params.new_cat_action, model_params.missing_action, SingleCateg, true,
1211
+ workspace.weights_map);
1212
+ }
1213
+
1214
+ break;
1215
+ }
1216
+
1217
+ case SubSet:
1218
+ {
1219
+ for (int cat = 0; cat < input_data.ncat[workspace.col_chosen]; cat++)
1220
+ workspace.ext_cat_coef[workspace.ntaken][cat] = workspace.coef_norm(workspace.rnd_generator);
1221
+
1222
+ if (model_params.coef_by_prop)
1223
+ {
1224
+ int ncat = input_data.ncat[workspace.col_chosen];
1225
+ size_t *restrict counts = workspace.buffer_szt.data();
1226
+ size_t *restrict sorted_ix = workspace.buffer_szt.data() + ncat;
1227
+ /* calculate counts and sort by them */
1228
+ std::fill(counts, counts + ncat, (size_t)0);
1229
+ for (size_t ix = workspace.st; ix <= workspace.end; ix++)
1230
+ if (input_data.categ_data[workspace.col_chosen * input_data.nrows + ix] >= 0)
1231
+ counts[input_data.categ_data[workspace.col_chosen * input_data.nrows + ix]]++;
1232
+ std::iota(sorted_ix, sorted_ix + ncat, (size_t)0);
1233
+ std::sort(sorted_ix, sorted_ix + ncat,
1234
+ [&counts](const size_t a, const size_t b){return counts[a] < counts[b];});
1235
+ /* now re-order the coefficients accordingly */
1236
+ std::sort(workspace.ext_cat_coef[workspace.ntaken].begin(),
1237
+ workspace.ext_cat_coef[workspace.ntaken].begin() + ncat);
1238
+ std::copy(workspace.ext_cat_coef[workspace.ntaken].begin(),
1239
+ workspace.ext_cat_coef[workspace.ntaken].begin() + ncat,
1240
+ workspace.buffer_dbl.begin());
1241
+ for (int ix = 0; ix < ncat; ix++)
1242
+ workspace.ext_cat_coef[workspace.ntaken][ix] = workspace.buffer_dbl[sorted_ix[ix]];
1243
+ }
1244
+
1245
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
1246
+ {
1247
+ add_linear_comb<ldouble_safe>(
1248
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
1249
+ input_data.categ_data + workspace.col_chosen * input_data.nrows,
1250
+ input_data.ncat[workspace.col_chosen],
1251
+ workspace.ext_cat_coef[workspace.ntaken].data(), (double)0, (int)0,
1252
+ workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
1253
+ workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ + 1,
1254
+ model_params.new_cat_action, model_params.missing_action, SubSet, true);
1255
+ }
1256
+
1257
+ else if (!workspace.weights_arr.empty())
1258
+ {
1259
+ add_linear_comb_weighted<decltype(workspace.weights_arr), ldouble_safe>(
1260
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
1261
+ input_data.categ_data + workspace.col_chosen * input_data.nrows,
1262
+ input_data.ncat[workspace.col_chosen],
1263
+ workspace.ext_cat_coef[workspace.ntaken].data(), (double)0, (int)0,
1264
+ workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
1265
+ workspace.buffer_szt.data(),
1266
+ model_params.new_cat_action, model_params.missing_action, SubSet, true,
1267
+ workspace.weights_arr);
1268
+ }
1269
+
1270
+ else
1271
+ {
1272
+ add_linear_comb_weighted<decltype(workspace.weights_map), ldouble_safe>(
1273
+ workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
1274
+ input_data.categ_data + workspace.col_chosen * input_data.nrows,
1275
+ input_data.ncat[workspace.col_chosen],
1276
+ workspace.ext_cat_coef[workspace.ntaken].data(), (double)0, (int)0,
1277
+ workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
1278
+ workspace.buffer_szt.data(),
1279
+ model_params.new_cat_action, model_params.missing_action, SubSet, true,
1280
+ workspace.weights_map);
1281
+ }
1282
+
1283
+ break;
1284
+ }
1285
+ }
1286
+ break;
1287
+ }
1288
+
1289
+ default:
1290
+ {
1291
+ unexpected_error();
1292
+ break;
1293
+ }
1294
+ }
1295
+ }
1296
+
1297
+ void shrink_to_fit_hplane(IsoHPlane &hplane, bool clear_vectors)
1298
+ {
1299
+ if (clear_vectors)
1300
+ {
1301
+ hplane.col_num.clear();
1302
+ hplane.col_type.clear();
1303
+ hplane.coef.clear();
1304
+ hplane.mean.clear();
1305
+ hplane.cat_coef.clear();
1306
+ hplane.chosen_cat.clear();
1307
+ hplane.fill_val.clear();
1308
+ hplane.fill_new.clear();
1309
+ }
1310
+
1311
+ hplane.col_num.shrink_to_fit();
1312
+ hplane.col_type.shrink_to_fit();
1313
+ hplane.coef.shrink_to_fit();
1314
+ hplane.mean.shrink_to_fit();
1315
+ hplane.cat_coef.shrink_to_fit();
1316
+ hplane.chosen_cat.shrink_to_fit();
1317
+ hplane.fill_val.shrink_to_fit();
1318
+ hplane.fill_new.shrink_to_fit();
1319
+ }
1320
+
1321
+ template <class InputData, class WorkerMemory>
1322
+ void simplify_hplane(IsoHPlane &hplane, WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
1323
+ {
1324
+ if (workspace.ntaken_best < model_params.ndim)
1325
+ {
1326
+ hplane.col_num.resize(workspace.ntaken_best);
1327
+ hplane.col_type.resize(workspace.ntaken_best);
1328
+ if (model_params.missing_action != Fail)
1329
+ hplane.fill_val.resize(workspace.ntaken_best);
1330
+ }
1331
+
1332
+ size_t ncols_numeric = 0;
1333
+ size_t ncols_categ = 0;
1334
+
1335
+ if (input_data.ncols_categ)
1336
+ {
1337
+ for (size_t col = 0; col < workspace.ntaken_best; col++)
1338
+ {
1339
+ switch(hplane.col_type[col])
1340
+ {
1341
+ case Numeric:
1342
+ {
1343
+ workspace.ext_coef[ncols_numeric] = hplane.coef[col];
1344
+ workspace.ext_mean[ncols_numeric] = hplane.mean[col];
1345
+ ncols_numeric++;
1346
+ break;
1347
+ }
1348
+
1349
+ case Categorical:
1350
+ {
1351
+ workspace.ext_fill_new[ncols_categ] = hplane.fill_new[col];
1352
+ switch(model_params.cat_split_type)
1353
+ {
1354
+ case SingleCateg:
1355
+ {
1356
+ workspace.chosen_cat[ncols_categ] = hplane.chosen_cat[col];
1357
+ break;
1358
+ }
1359
+
1360
+ case SubSet:
1361
+ {
1362
+ std::copy(hplane.cat_coef[col].begin(),
1363
+ hplane.cat_coef[col].begin() + input_data.ncat[hplane.col_num[col]],
1364
+ workspace.ext_cat_coef[ncols_categ].begin());
1365
+ break;
1366
+ }
1367
+ }
1368
+ ncols_categ++;
1369
+ break;
1370
+ }
1371
+
1372
+ default:
1373
+ {
1374
+ unexpected_error();
1375
+ break;
1376
+ }
1377
+ }
1378
+ }
1379
+ }
1380
+
1381
+ else
1382
+ {
1383
+ ncols_numeric = workspace.ntaken_best;
1384
+ }
1385
+
1386
+
1387
+ hplane.coef.resize(ncols_numeric);
1388
+ hplane.mean.resize(ncols_numeric);
1389
+ if (input_data.ncols_numeric)
1390
+ {
1391
+ std::copy(workspace.ext_coef.begin(), workspace.ext_coef.begin() + ncols_numeric, hplane.coef.begin());
1392
+ std::copy(workspace.ext_mean.begin(), workspace.ext_mean.begin() + ncols_numeric, hplane.mean.begin());
1393
+ }
1394
+
1395
+ /* If there are no categorical columns, all of them will be numerical and there is no need to reorder */
1396
+ if (ncols_categ)
1397
+ {
1398
+ hplane.fill_new.resize(ncols_categ);
1399
+ std::copy(workspace.ext_fill_new.begin(),
1400
+ workspace.ext_fill_new.begin() + ncols_categ,
1401
+ hplane.fill_new.begin());
1402
+
1403
+ hplane.cat_coef.resize(ncols_categ);
1404
+ switch(model_params.cat_split_type)
1405
+ {
1406
+ case SingleCateg:
1407
+ {
1408
+ hplane.chosen_cat.resize(ncols_categ);
1409
+ std::copy(workspace.chosen_cat.begin(),
1410
+ workspace.chosen_cat.begin() + ncols_categ,
1411
+ hplane.chosen_cat.begin());
1412
+ hplane.cat_coef.clear();
1413
+ break;
1414
+ }
1415
+
1416
+ case SubSet:
1417
+ {
1418
+ hplane.chosen_cat.clear();
1419
+ ncols_categ = 0;
1420
+ for (size_t col = 0; col < workspace.ntaken_best; col++)
1421
+ {
1422
+ if (hplane.col_type[col] == Categorical)
1423
+ {
1424
+ hplane.cat_coef[ncols_categ].resize(input_data.ncat[hplane.col_num[col]]);
1425
+ std::copy(workspace.ext_cat_coef[ncols_categ].begin(),
1426
+ workspace.ext_cat_coef[ncols_categ].begin()
1427
+ + input_data.ncat[hplane.col_num[col]],
1428
+ hplane.cat_coef[ncols_categ].begin());
1429
+ hplane.cat_coef[ncols_categ].shrink_to_fit();
1430
+ ncols_categ++;
1431
+ }
1432
+ }
1433
+ break;
1434
+ }
1435
+ }
1436
+ }
1437
+
1438
+ else
1439
+ {
1440
+ hplane.cat_coef.clear();
1441
+ hplane.chosen_cat.clear();
1442
+ hplane.fill_new.clear();
1443
+ }
1444
+ }