isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,1659 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+ template <class InputData, class WorkerMemory, class ldouble_safe>
66
+ void split_itree_recursive(std::vector<IsoTree> &trees,
67
+ WorkerMemory &workspace,
68
+ InputData &input_data,
69
+ ModelParams &model_params,
70
+ std::vector<ImputeNode> *impute_nodes,
71
+ size_t curr_depth)
72
+ {
73
+ if (interrupt_switch) return;
74
+ ldouble_safe sum_weight = -HUGE_VAL;
75
+
76
+ /* calculate imputation statistics if desired */
77
+ if (impute_nodes != NULL)
78
+ {
79
+ if (input_data.Xc_indptr != NULL)
80
+ std::sort(workspace.ix_arr.begin() + workspace.st,
81
+ workspace.ix_arr.begin() + workspace.end + 1);
82
+ build_impute_node<decltype(input_data), decltype(workspace), ldouble_safe>(
83
+ impute_nodes->back(), workspace,
84
+ input_data, model_params,
85
+ *impute_nodes, curr_depth,
86
+ model_params.min_imp_obs);
87
+ }
88
+
89
+ /* check for potential isolated leafs or unique splits */
90
+ if (workspace.end == workspace.st || (workspace.end - workspace.st) == 1 || curr_depth >= model_params.max_depth)
91
+ goto terminal_statistics;
92
+
93
+ /* when using weights, the split should stop when the sum of weights is <= 1 */
94
+ if (workspace.changed_weights)
95
+ {
96
+ sum_weight = calculate_sum_weights<ldouble_safe>(
97
+ workspace.ix_arr, workspace.st, workspace.end, curr_depth,
98
+ workspace.weights_arr, workspace.weights_map);
99
+ if (curr_depth > 0 && sum_weight <= 1)
100
+ goto terminal_statistics;
101
+ }
102
+
103
+ /* if there's no columns left to split, can end here */
104
+ if (!workspace.col_sampler.get_remaining_cols())
105
+ goto terminal_statistics;
106
+
107
+ /* for sparse matrices, need to sort the indices */
108
+ if (input_data.Xc_indptr != NULL && impute_nodes == NULL)
109
+ std::sort(workspace.ix_arr.begin() + workspace.st, workspace.ix_arr.begin() + workspace.end + 1);
110
+
111
+ /* pick column to split according to criteria */
112
+ workspace.prob_split_type = workspace.rbin(workspace.rnd_generator);
113
+
114
+
115
+ /* case1: guided, pick column and/or point with best gain */
116
+ if (
117
+ workspace.prob_split_type
118
+ < (
119
+ model_params.prob_pick_by_gain_avg +
120
+ model_params.prob_pick_by_gain_pl +
121
+ model_params.prob_pick_by_full_gain +
122
+ model_params.prob_pick_by_dens
123
+ )
124
+ )
125
+ {
126
+ /* case 1.1: column and/or threshold is/are decided by averaged gain */
127
+ if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg)
128
+ workspace.criterion = Averaged;
129
+
130
+ /* case 1.2: column and/or threshold is/are decided by pooled gain */
131
+ else if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg +
132
+ model_params.prob_pick_by_gain_pl)
133
+ workspace.criterion = Pooled;
134
+ /* case 1.3: column and/or threshold is/are decided by full gain (pooled gain in all columns) */
135
+ else if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg +
136
+ model_params.prob_pick_by_gain_pl +
137
+ model_params.prob_pick_by_full_gain)
138
+ workspace.criterion = FullGain;
139
+ /* case 1.4: column and/or threshold is/are decided by density pooled gain */
140
+ else
141
+ workspace.criterion = DensityCrit;
142
+
143
+ workspace.determine_split = model_params.ntry <= 1 || workspace.col_sampler.get_remaining_cols() == 1;
144
+
145
+ if (workspace.criterion == FullGain)
146
+ {
147
+ workspace.col_sampler.get_array_remaining_cols(workspace.col_indices);
148
+ }
149
+ }
150
+
151
+ /* case2: column and split point is decided at random */
152
+ else
153
+ {
154
+ workspace.criterion = NoCrit;
155
+ workspace.determine_split = true;
156
+ }
157
+
158
+ /* pick column selection method also according to criteria */
159
+ if (
160
+ (workspace.criterion != NoCrit &&
161
+ std::max(workspace.ntry, (size_t)1) >= workspace.col_sampler.get_remaining_cols())
162
+ ||
163
+ (workspace.col_sampler.get_remaining_cols() == 1)
164
+ ) {
165
+ workspace.prob_split_type = 0;
166
+ }
167
+ else {
168
+ workspace.prob_split_type = workspace.rbin(workspace.rnd_generator);
169
+ }
170
+
171
+ if (
172
+ workspace.prob_split_type
173
+ < model_params.prob_pick_col_by_range
174
+ )
175
+ {
176
+ workspace.col_criterion = ByRange;
177
+ if (curr_depth == 0 && is_boxed_metric(model_params.scoring_metric))
178
+ {
179
+ workspace.has_saved_stats = false;
180
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
181
+ workspace.node_col_weights[col] = workspace.density_calculator.box_high[col]
182
+ - workspace.density_calculator.box_low[col];
183
+
184
+ add_col_weights_to_ranges:
185
+ if (workspace.tree_kurtoses != NULL)
186
+ {
187
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
188
+ {
189
+ if (workspace.node_col_weights[col] <= 0) continue;
190
+ workspace.node_col_weights[col] *= workspace.tree_kurtoses[col];
191
+ workspace.node_col_weights[col] = std::fmax(workspace.node_col_weights[col], 1e-100);
192
+ }
193
+ }
194
+ else if (input_data.col_weights != NULL)
195
+ {
196
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
197
+ {
198
+ if (workspace.node_col_weights[col] <= 0) continue;
199
+ workspace.node_col_weights[col] *= input_data.col_weights[col];
200
+ workspace.node_col_weights[col] = std::fmax(workspace.node_col_weights[col], 1e-100);
201
+ }
202
+ }
203
+ }
204
+
205
+ else if (curr_depth == 0 &&
206
+ model_params.sample_size == input_data.nrows &&
207
+ !model_params.with_replacement &&
208
+ input_data.range_low != NULL &&
209
+ model_params.ncols_per_tree == input_data.ncols_tot)
210
+ {
211
+ workspace.has_saved_stats = false;
212
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
213
+ workspace.node_col_weights[col] = input_data.range_high[col]
214
+ - input_data.range_low[col];
215
+ goto add_col_weights_to_ranges;
216
+ }
217
+
218
+ else
219
+ {
220
+ workspace.has_saved_stats = workspace.criterion == NoCrit;
221
+ calc_ranges_all_cols(input_data, workspace, model_params, workspace.node_col_weights.data(),
222
+ workspace.has_saved_stats? workspace.saved_stat1.data() : NULL,
223
+ workspace.has_saved_stats? workspace.saved_stat2.data() : NULL);
224
+ }
225
+ }
226
+
227
+ else if (
228
+ workspace.prob_split_type
229
+ < (model_params.prob_pick_col_by_range +
230
+ model_params.prob_pick_col_by_var)
231
+ )
232
+ {
233
+ workspace.col_criterion = ByVar;
234
+ workspace.has_saved_stats = workspace.criterion == NoCrit;
235
+ calc_var_all_cols<InputData, WorkerMemory, ldouble_safe>(
236
+ input_data, workspace, model_params,
237
+ workspace.node_col_weights.data(),
238
+ workspace.has_saved_stats? workspace.saved_stat1.data() : NULL,
239
+ workspace.has_saved_stats? workspace.saved_stat2.data() : NULL,
240
+ NULL, NULL);
241
+ }
242
+
243
+ else if (
244
+ workspace.prob_split_type
245
+ < (model_params.prob_pick_col_by_range +
246
+ model_params.prob_pick_col_by_var +
247
+ model_params.prob_pick_col_by_kurt)
248
+ )
249
+ {
250
+ workspace.col_criterion = ByKurt;
251
+ workspace.has_saved_stats = workspace.criterion == NoCrit;
252
+ calc_kurt_all_cols<decltype(input_data), decltype(workspace), ldouble_safe>(
253
+ input_data, workspace, model_params, workspace.node_col_weights.data(),
254
+ workspace.has_saved_stats? workspace.saved_stat1.data() : NULL,
255
+ workspace.has_saved_stats? workspace.saved_stat2.data() : NULL);
256
+ }
257
+
258
+ else
259
+ {
260
+ workspace.col_criterion = Uniformly;
261
+ }
262
+
263
+ if (workspace.col_criterion != Uniformly)
264
+ {
265
+ if (!workspace.node_col_sampler.initialize(workspace.node_col_weights.data(),
266
+ &workspace.col_sampler.col_indices,
267
+ workspace.col_sampler.curr_pos,
268
+ (workspace.criterion == NoCrit)? 1 : model_params.ntry,
269
+ false))
270
+ {
271
+ goto terminal_statistics;
272
+ }
273
+ }
274
+
275
+ /* when column is chosen at random */
276
+ if (workspace.determine_split)
277
+ {
278
+ if (workspace.col_criterion != Uniformly)
279
+ {
280
+ if (!workspace.node_col_sampler.sample_col(trees.back().col_num, workspace.rnd_generator))
281
+ {
282
+ goto terminal_statistics;
283
+ }
284
+
285
+ if (trees.back().col_num < input_data.ncols_numeric)
286
+ {
287
+ trees.back().col_type = Numeric;
288
+ if (workspace.has_saved_stats)
289
+ {
290
+ workspace.xmin = workspace.saved_stat1[trees.back().col_num];
291
+ workspace.xmax = workspace.saved_stat2[trees.back().col_num];
292
+ }
293
+
294
+ else
295
+ {
296
+ get_split_range(workspace, input_data, model_params, trees.back());
297
+ if (workspace.unsplittable)
298
+ unexpected_error();
299
+ }
300
+ }
301
+
302
+ else
303
+ {
304
+ get_split_range(workspace, input_data, model_params, trees.back());
305
+ if (workspace.unsplittable)
306
+ unexpected_error();
307
+ }
308
+
309
+ goto produce_split;
310
+ }
311
+
312
+ if (!workspace.col_sampler.has_weights())
313
+ {
314
+ while (workspace.col_sampler.sample_col(trees.back().col_num, workspace.rnd_generator))
315
+ {
316
+ if (interrupt_switch) return;
317
+
318
+ get_split_range(workspace, input_data, model_params, trees.back());
319
+ if (workspace.unsplittable)
320
+ workspace.col_sampler.drop_col(trees.back().col_num + ((trees.back().col_type == Numeric)? (size_t)0 : input_data.ncols_numeric));
321
+ else
322
+ goto produce_split;
323
+ }
324
+ goto terminal_statistics;
325
+ }
326
+
327
+ else
328
+ {
329
+ if (workspace.try_all)
330
+ workspace.col_sampler.shuffle_remainder(workspace.rnd_generator);
331
+ workspace.ntried = 0;
332
+ size_t threshold_shuffle = (workspace.col_sampler.get_remaining_cols() + 1) / 2;
333
+
334
+ while (
335
+ workspace.try_all?
336
+ workspace.col_sampler.sample_col(trees.back().col_num)
337
+ :
338
+ workspace.col_sampler.sample_col(trees.back().col_num, workspace.rnd_generator)
339
+ )
340
+ {
341
+ if (interrupt_switch) return;
342
+
343
+ get_split_range(workspace, input_data, model_params, trees.back());
344
+ if (workspace.unsplittable)
345
+ {
346
+ workspace.col_sampler.drop_col(trees.back().col_num + ((trees.back().col_type == Numeric)? (size_t)0 : input_data.ncols_numeric));
347
+ workspace.ntried++;
348
+ if (!workspace.try_all && workspace.ntried >= threshold_shuffle)
349
+ {
350
+ workspace.try_all = true;
351
+ workspace.col_sampler.shuffle_remainder(workspace.rnd_generator);
352
+ }
353
+ }
354
+
355
+ else
356
+ {
357
+ goto produce_split;
358
+ }
359
+ }
360
+ goto terminal_statistics;
361
+ }
362
+ }
363
+
364
+
365
+ /* when choosing both column and threshold */
366
+ else
367
+ {
368
+ if (model_params.ntry >= workspace.col_sampler.get_remaining_cols())
369
+ workspace.col_sampler.prepare_full_pass();
370
+ else if (workspace.try_all && workspace.col_criterion == Uniformly)
371
+ workspace.col_sampler.shuffle_remainder(workspace.rnd_generator);
372
+
373
+ std::vector<bool> col_is_taken;
374
+ hashed_set<size_t> col_is_taken_s;
375
+ if (model_params.ntry < workspace.col_sampler.get_remaining_cols() && workspace.col_criterion == Uniformly)
376
+ {
377
+ if (input_data.ncols_tot < 1e5 ||
378
+ ((ldouble_safe)model_params.ntry / (ldouble_safe)workspace.col_sampler.get_remaining_cols()) > .25
379
+ )
380
+ {
381
+ col_is_taken.resize(input_data.ncols_tot, false);
382
+ }
383
+ else {
384
+ col_is_taken_s.reserve(model_params.ntry);
385
+ }
386
+ }
387
+
388
+ size_t threshold_shuffle = (workspace.col_sampler.get_remaining_cols() + 1) / 2;
389
+ workspace.ntried = 0; /* <- used to determine when to shuffle the remainder */
390
+ workspace.ntaken = 0; /* <- used to count how many columns have been evaluated */
391
+ trees.back().score = -HUGE_VAL; /* this is used to track the best gain */
392
+
393
+ while (
394
+ (workspace.col_criterion != Uniformly)?
395
+ workspace.node_col_sampler.sample_col(workspace.col_chosen, workspace.rnd_generator)
396
+ :
397
+ (workspace.try_all?
398
+ workspace.col_sampler.sample_col(workspace.col_chosen)
399
+ :
400
+ workspace.col_sampler.sample_col(workspace.col_chosen, workspace.rnd_generator))
401
+ )
402
+ {
403
+ if (interrupt_switch) return;
404
+
405
+ if (workspace.col_criterion != Uniformly)
406
+ {
407
+ workspace.ntaken++;
408
+ goto probe_this_col;
409
+ }
410
+
411
+ workspace.ntried++;
412
+ if (!workspace.try_all && workspace.ntried >= threshold_shuffle)
413
+ {
414
+ workspace.try_all = true;
415
+ workspace.col_sampler.shuffle_remainder(workspace.rnd_generator);
416
+ }
417
+
418
+ if ((col_is_taken.size() || col_is_taken_s.size()) && !workspace.try_all)
419
+ {
420
+ if (is_col_taken(col_is_taken, col_is_taken_s, workspace.col_chosen))
421
+ continue;
422
+ set_col_as_taken(col_is_taken, col_is_taken_s, input_data, workspace.col_chosen);
423
+ }
424
+
425
+ get_split_range_v2(workspace, input_data, model_params);
426
+ if (workspace.unsplittable)
427
+ {
428
+ workspace.col_sampler.drop_col(workspace.col_chosen);
429
+ continue;
430
+ }
431
+
432
+ else
433
+ {
434
+ probe_this_col:
435
+ if (workspace.col_chosen < input_data.ncols_numeric)
436
+ {
437
+ if (input_data.Xc_indptr == NULL)
438
+ {
439
+ if (!workspace.changed_weights)
440
+ workspace.this_gain = eval_guided_crit<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
441
+ ldouble_safe>(
442
+ workspace.ix_arr.data(), workspace.st, workspace.end,
443
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
444
+ workspace.buffer_dbl.data(), false,
445
+ workspace.imputed_x_buffer.data(),
446
+ &workspace.saved_xmedian,
447
+ workspace.split_ix, workspace.this_split_point,
448
+ workspace.xmin, workspace.xmax,
449
+ workspace.criterion, model_params.min_gain,
450
+ model_params.missing_action,
451
+ workspace.col_indices.data(),
452
+ workspace.col_sampler.get_remaining_cols(),
453
+ model_params.ncols_per_tree < input_data.ncols_tot,
454
+ input_data.X_row_major.data(),
455
+ input_data.ncols_numeric,
456
+ input_data.Xr.data(),
457
+ input_data.Xr_ind.data(),
458
+ input_data.Xr_indptr.data());
459
+ else if (!workspace.weights_arr.empty())
460
+ workspace.this_gain = eval_guided_crit_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
461
+ decltype(workspace.weights_arr), ldouble_safe>(
462
+ workspace.ix_arr.data(), workspace.st, workspace.end,
463
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
464
+ workspace.buffer_dbl.data(), false,
465
+ workspace.imputed_x_buffer.data(),
466
+ &workspace.saved_xmedian,
467
+ workspace.split_ix, workspace.this_split_point,
468
+ workspace.xmin, workspace.xmax,
469
+ workspace.criterion, model_params.min_gain,
470
+ model_params.missing_action,
471
+ workspace.col_indices.data(),
472
+ workspace.col_sampler.get_remaining_cols(),
473
+ model_params.ncols_per_tree < input_data.ncols_tot,
474
+ input_data.X_row_major.data(),
475
+ input_data.ncols_numeric,
476
+ input_data.Xr.data(),
477
+ input_data.Xr_ind.data(),
478
+ input_data.Xr_indptr.data(),
479
+ workspace.weights_arr);
480
+ else
481
+ workspace.this_gain = eval_guided_crit_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
482
+ decltype(workspace.weights_map), ldouble_safe>(
483
+ workspace.ix_arr.data(), workspace.st, workspace.end,
484
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
485
+ workspace.buffer_dbl.data(), false,
486
+ workspace.imputed_x_buffer.data(),
487
+ &workspace.saved_xmedian,
488
+ workspace.split_ix, workspace.this_split_point,
489
+ workspace.xmin, workspace.xmax,
490
+ workspace.criterion, model_params.min_gain,
491
+ model_params.missing_action,
492
+ workspace.col_indices.data(),
493
+ workspace.col_sampler.get_remaining_cols(),
494
+ model_params.ncols_per_tree < input_data.ncols_tot,
495
+ input_data.X_row_major.data(),
496
+ input_data.ncols_numeric,
497
+ input_data.Xr.data(),
498
+ input_data.Xr_ind.data(),
499
+ input_data.Xr_indptr.data(),
500
+ workspace.weights_map);
501
+ }
502
+
503
+ else
504
+ {
505
+ if (!workspace.changed_weights)
506
+ workspace.this_gain = eval_guided_crit<typename std::remove_pointer<decltype(input_data.Xc)>::type,
507
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
508
+ ldouble_safe>(
509
+ workspace.ix_arr.data(), workspace.st, workspace.end,
510
+ workspace.col_chosen, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
511
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), false,
512
+ &workspace.saved_xmedian,
513
+ workspace.this_split_point, workspace.xmin, workspace.xmax,
514
+ workspace.criterion, model_params.min_gain, model_params.missing_action,
515
+ workspace.col_indices.data(),
516
+ workspace.col_sampler.get_remaining_cols(),
517
+ model_params.ncols_per_tree < input_data.ncols_tot,
518
+ input_data.X_row_major.data(),
519
+ input_data.ncols_numeric,
520
+ input_data.Xr.data(),
521
+ input_data.Xr_ind.data(),
522
+ input_data.Xr_indptr.data());
523
+ else if (!workspace.weights_arr.empty())
524
+ workspace.this_gain = eval_guided_crit_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
525
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
526
+ decltype(workspace.weights_arr), ldouble_safe>(
527
+ workspace.ix_arr.data(), workspace.st, workspace.end,
528
+ workspace.col_chosen, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
529
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), false,
530
+ &workspace.saved_xmedian,
531
+ workspace.this_split_point, workspace.xmin, workspace.xmax,
532
+ workspace.criterion, model_params.min_gain, model_params.missing_action,
533
+ workspace.col_indices.data(),
534
+ workspace.col_sampler.get_remaining_cols(),
535
+ model_params.ncols_per_tree < input_data.ncols_tot,
536
+ input_data.X_row_major.data(),
537
+ input_data.ncols_numeric,
538
+ input_data.Xr.data(),
539
+ input_data.Xr_ind.data(),
540
+ input_data.Xr_indptr.data(),
541
+ workspace.weights_arr);
542
+ else
543
+ workspace.this_gain = eval_guided_crit_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
544
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
545
+ decltype(workspace.weights_map),
546
+ ldouble_safe>(
547
+ workspace.ix_arr.data(), workspace.st, workspace.end,
548
+ workspace.col_chosen, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
549
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), false,
550
+ &workspace.saved_xmedian,
551
+ workspace.this_split_point, workspace.xmin, workspace.xmax,
552
+ workspace.criterion, model_params.min_gain, model_params.missing_action,
553
+ workspace.col_indices.data(),
554
+ workspace.col_sampler.get_remaining_cols(),
555
+ model_params.ncols_per_tree < input_data.ncols_tot,
556
+ input_data.X_row_major.data(),
557
+ input_data.ncols_numeric,
558
+ input_data.Xr.data(),
559
+ input_data.Xr_ind.data(),
560
+ input_data.Xr_indptr.data(),
561
+ workspace.weights_map);
562
+ }
563
+ }
564
+
565
+ else
566
+ {
567
+ if (!workspace.changed_weights)
568
+ workspace.this_gain = eval_guided_crit<ldouble_safe>(
569
+ workspace.ix_arr.data(), workspace.st, workspace.end,
570
+ input_data.categ_data + (workspace.col_chosen - input_data.ncols_numeric) * input_data.nrows,
571
+ input_data.ncat[workspace.col_chosen - input_data.ncols_numeric],
572
+ &workspace.saved_cat_mode,
573
+ workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
574
+ workspace.buffer_dbl.data(), workspace.this_categ, workspace.this_split_categ.data(),
575
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
576
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
577
+ else if (!workspace.weights_arr.empty())
578
+ workspace.this_gain = eval_guided_crit_weighted<decltype(workspace.weights_arr), ldouble_safe>(
579
+ workspace.ix_arr.data(), workspace.st, workspace.end,
580
+ input_data.categ_data + (workspace.col_chosen - input_data.ncols_numeric) * input_data.nrows,
581
+ input_data.ncat[workspace.col_chosen - input_data.ncols_numeric],
582
+ &workspace.saved_cat_mode,
583
+ workspace.buffer_szt.data(),
584
+ workspace.buffer_dbl.data(), workspace.this_categ, workspace.this_split_categ.data(),
585
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
586
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type,
587
+ workspace.weights_arr);
588
+ else
589
+ workspace.this_gain = eval_guided_crit_weighted<decltype(workspace.weights_map), ldouble_safe>(
590
+ workspace.ix_arr.data(), workspace.st, workspace.end,
591
+ input_data.categ_data + (workspace.col_chosen - input_data.ncols_numeric) * input_data.nrows,
592
+ input_data.ncat[workspace.col_chosen - input_data.ncols_numeric],
593
+ &workspace.saved_cat_mode,
594
+ workspace.buffer_szt.data(),
595
+ workspace.buffer_dbl.data(), workspace.this_categ, workspace.this_split_categ.data(),
596
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
597
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type,
598
+ workspace.weights_map);
599
+ }
600
+
601
+ if (std::isnan(workspace.this_gain) || workspace.this_gain <= -HUGE_VAL)
602
+ continue;
603
+
604
+
605
+ if (workspace.this_gain > trees.back().score)
606
+ {
607
+ if (workspace.col_chosen < input_data.ncols_numeric)
608
+ {
609
+ trees.back().score = workspace.this_gain;
610
+ trees.back().col_num = workspace.col_chosen;
611
+ trees.back().col_type = Numeric;
612
+ trees.back().num_split = workspace.this_split_point;
613
+ if (model_params.penalize_range)
614
+ {
615
+ trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
616
+ trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
617
+ }
618
+
619
+ if (model_params.scoring_metric != Depth && !is_boxed_metric(model_params.scoring_metric))
620
+ {
621
+ workspace.density_calculator.save_range(workspace.xmin, workspace.xmax);
622
+ }
623
+
624
+ workspace.best_xmedian = workspace.saved_xmedian;
625
+ }
626
+
627
+ else
628
+ {
629
+ trees.back().score = workspace.this_gain;
630
+ trees.back().col_num = workspace.col_chosen - input_data.ncols_numeric;
631
+ trees.back().col_type = Categorical;
632
+ switch (model_params.cat_split_type)
633
+ {
634
+ case SingleCateg:
635
+ {
636
+ trees.back().chosen_cat = workspace.this_categ;
637
+ break;
638
+ }
639
+
640
+ case SubSet:
641
+ {
642
+ trees.back().cat_split.assign(workspace.this_split_categ.begin(),
643
+ workspace.this_split_categ.begin()
644
+ + input_data.ncat[trees.back().col_num]);
645
+ break;
646
+ }
647
+ }
648
+
649
+ workspace.best_cat_mode = workspace.saved_cat_mode;
650
+
651
+ if (model_params.scoring_metric != Depth && !is_boxed_metric(model_params.scoring_metric))
652
+ {
653
+ if (model_params.scoring_metric == Density)
654
+ {
655
+ switch (model_params.cat_split_type)
656
+ {
657
+ case SingleCateg:
658
+ {
659
+ workspace.density_calculator.save_n_present(workspace.buffer_szt.data(),
660
+ input_data.ncat[trees.back().col_num]);
661
+ break;
662
+ }
663
+
664
+ case SubSet:
665
+ {
666
+ workspace.density_calculator.save_n_present_and_left(
667
+ workspace.this_split_categ.data(),
668
+ input_data.ncat[trees.back().col_num]
669
+ );
670
+ break;
671
+ }
672
+ }
673
+ }
674
+
675
+ else
676
+ {
677
+ workspace.density_calculator.save_counts(workspace.buffer_szt.data(),
678
+ input_data.ncat[trees.back().col_num]);
679
+ }
680
+ }
681
+ }
682
+ }
683
+
684
+ if (++workspace.ntaken >= model_params.ntry)
685
+ break;
686
+ }
687
+ }
688
+
689
+ if (!workspace.ntaken)
690
+ goto terminal_statistics;
691
+
692
+ if (trees.back().score <= 0.)
693
+ goto terminal_statistics;
694
+ else
695
+ trees.back().score = 0.;
696
+ }
697
+
698
+
699
+ /* for numeric, choose a random point, or pick the best point as determined earlier */
700
+ produce_split:
701
+ if (trees.back().col_type == Numeric)
702
+ {
703
+ if (workspace.determine_split)
704
+ {
705
+ switch(workspace.criterion)
706
+ {
707
+ case NoCrit:
708
+ {
709
+ trees.back().num_split = sample_random_uniform(workspace.xmin, workspace.xmax, workspace.rnd_generator);
710
+ break;
711
+ }
712
+
713
+ default:
714
+ {
715
+ if (input_data.Xc_indptr == NULL)
716
+ {
717
+ if (!workspace.changed_weights)
718
+ workspace.this_gain =
719
+ eval_guided_crit<typename std::remove_pointer<decltype(input_data.numeric_data)>::type, ldouble_safe>(
720
+ workspace.ix_arr.data(), workspace.st, workspace.end,
721
+ input_data.numeric_data + trees.back().col_num * input_data.nrows,
722
+ workspace.buffer_dbl.data(), true,
723
+ workspace.imputed_x_buffer.data(),
724
+ &workspace.best_xmedian,
725
+ workspace.split_ix, trees.back().num_split,
726
+ workspace.xmin, workspace.xmax,
727
+ workspace.criterion, model_params.min_gain,
728
+ model_params.missing_action,
729
+ workspace.col_indices.data(),
730
+ workspace.col_sampler.get_remaining_cols(),
731
+ model_params.ncols_per_tree < input_data.ncols_tot,
732
+ input_data.X_row_major.data(),
733
+ input_data.ncols_numeric,
734
+ input_data.Xr.data(),
735
+ input_data.Xr_ind.data(),
736
+ input_data.Xr_indptr.data());
737
+ else if (!workspace.weights_arr.empty())
738
+ workspace.this_gain =
739
+ eval_guided_crit_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type, decltype(workspace.weights_arr), ldouble_safe>(
740
+ workspace.ix_arr.data(), workspace.st, workspace.end,
741
+ input_data.numeric_data + trees.back().col_num * input_data.nrows,
742
+ workspace.buffer_dbl.data(), true,
743
+ workspace.imputed_x_buffer.data(),
744
+ &workspace.best_xmedian,
745
+ workspace.split_ix, trees.back().num_split,
746
+ workspace.xmin, workspace.xmax,
747
+ workspace.criterion, model_params.min_gain,
748
+ model_params.missing_action,
749
+ workspace.col_indices.data(),
750
+ workspace.col_sampler.get_remaining_cols(),
751
+ model_params.ncols_per_tree < input_data.ncols_tot,
752
+ input_data.X_row_major.data(),
753
+ input_data.ncols_numeric,
754
+ input_data.Xr.data(),
755
+ input_data.Xr_ind.data(),
756
+ input_data.Xr_indptr.data(),
757
+ workspace.weights_arr);
758
+ else
759
+ workspace.this_gain =
760
+ eval_guided_crit_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type, decltype(workspace.weights_map), ldouble_safe>(
761
+ workspace.ix_arr.data(), workspace.st, workspace.end,
762
+ input_data.numeric_data + trees.back().col_num * input_data.nrows,
763
+ workspace.buffer_dbl.data(), true,
764
+ workspace.imputed_x_buffer.data(),
765
+ &workspace.best_xmedian,
766
+ workspace.split_ix, trees.back().num_split,
767
+ workspace.xmin, workspace.xmax,
768
+ workspace.criterion, model_params.min_gain,
769
+ model_params.missing_action,
770
+ workspace.col_indices.data(),
771
+ workspace.col_sampler.get_remaining_cols(),
772
+ model_params.ncols_per_tree < input_data.ncols_tot,
773
+ input_data.X_row_major.data(),
774
+ input_data.ncols_numeric,
775
+ input_data.Xr.data(),
776
+ input_data.Xr_ind.data(),
777
+ input_data.Xr_indptr.data(),
778
+ workspace.weights_map);
779
+
780
+ if (std::isnan(workspace.this_gain) || workspace.this_gain <= -HUGE_VAL)
781
+ goto terminal_statistics;
782
+
783
+ if (
784
+ model_params.missing_action == Fail
785
+ ||
786
+ (model_params.missing_action == Impute && input_data.Xc_indptr == NULL)
787
+ ) /* data is already split in this case */
788
+ {
789
+ if (model_params.missing_action == Impute)
790
+ {
791
+ workspace.st_NA = workspace.split_ix + 1;
792
+ workspace.end_NA = workspace.st_NA;
793
+ }
794
+
795
+ workspace.split_ix++;
796
+ if (model_params.penalize_range)
797
+ {
798
+ trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
799
+ trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
800
+ }
801
+ goto follow_branches;
802
+ }
803
+ }
804
+
805
+ else
806
+ {
807
+ if (!workspace.changed_weights)
808
+ workspace.this_gain =
809
+ eval_guided_crit<typename std::remove_pointer<decltype(input_data.Xc)>::type,
810
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
811
+ ldouble_safe>(
812
+ workspace.ix_arr.data(), workspace.st, workspace.end,
813
+ trees.back().col_num, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
814
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true,
815
+ &workspace.best_xmedian,
816
+ trees.back().num_split, workspace.xmin, workspace.xmax,
817
+ workspace.criterion, model_params.min_gain,
818
+ model_params.missing_action,
819
+ workspace.col_indices.data(),
820
+ workspace.col_sampler.get_remaining_cols(),
821
+ model_params.ncols_per_tree < input_data.ncols_tot,
822
+ input_data.X_row_major.data(),
823
+ input_data.ncols_numeric,
824
+ input_data.Xr.data(),
825
+ input_data.Xr_ind.data(),
826
+ input_data.Xr_indptr.data());
827
+ else if (!workspace.weights_arr.empty())
828
+ workspace.this_gain =
829
+ eval_guided_crit_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
830
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
831
+ decltype(workspace.weights_arr),
832
+ ldouble_safe>(
833
+ workspace.ix_arr.data(), workspace.st, workspace.end,
834
+ trees.back().col_num, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
835
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true,
836
+ &workspace.best_xmedian,
837
+ trees.back().num_split, workspace.xmin, workspace.xmax,
838
+ workspace.criterion, model_params.min_gain,
839
+ model_params.missing_action,
840
+ workspace.col_indices.data(),
841
+ workspace.col_sampler.get_remaining_cols(),
842
+ model_params.ncols_per_tree < input_data.ncols_tot,
843
+ input_data.X_row_major.data(),
844
+ input_data.ncols_numeric,
845
+ input_data.Xr.data(),
846
+ input_data.Xr_ind.data(),
847
+ input_data.Xr_indptr.data(),
848
+ workspace.weights_arr);
849
+ else
850
+ workspace.this_gain =
851
+ eval_guided_crit_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
852
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
853
+ decltype(workspace.weights_map),
854
+ ldouble_safe>(
855
+ workspace.ix_arr.data(), workspace.st, workspace.end,
856
+ trees.back().col_num, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
857
+ workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true,
858
+ &workspace.best_xmedian,
859
+ trees.back().num_split, workspace.xmin, workspace.xmax,
860
+ workspace.criterion, model_params.min_gain,
861
+ model_params.missing_action,
862
+ workspace.col_indices.data(),
863
+ workspace.col_sampler.get_remaining_cols(),
864
+ model_params.ncols_per_tree < input_data.ncols_tot,
865
+ input_data.X_row_major.data(),
866
+ input_data.ncols_numeric,
867
+ input_data.Xr.data(),
868
+ input_data.Xr_ind.data(),
869
+ input_data.Xr_indptr.data(),
870
+ workspace.weights_map);
871
+ }
872
+
873
+ if (std::isnan(workspace.this_gain) || workspace.this_gain <= -HUGE_VAL)
874
+ goto terminal_statistics;
875
+
876
+ break;
877
+ }
878
+ }
879
+
880
+ if (model_params.penalize_range)
881
+ {
882
+ trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
883
+ trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
884
+ }
885
+ }
886
+
887
+ if (model_params.missing_action == Fail && std::isnan(trees.back().num_split))
888
+ throw std::runtime_error("Data has missing values. Try using a different value for 'missing_action'.\n");
889
+
890
+ /* TODO: make this work, can end up messing with the start and end indices somehow */
891
+ /* It should also consider that 'split_ix' might not match when using missing_action == Impute */
892
+ // if (input_data.Xc_indptr == NULL && model_params.missing_action == Fail && workspace.ntaken == 1)
893
+ // goto follow_branches;
894
+
895
+ if (input_data.Xc_indptr == NULL)
896
+ divide_subset_split(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * trees.back().col_num,
897
+ workspace.st, workspace.end, trees.back().num_split, model_params.missing_action,
898
+ workspace.st_NA, workspace.end_NA, workspace.split_ix);
899
+ else
900
+ divide_subset_split(workspace.ix_arr.data(), workspace.st, workspace.end, trees.back().col_num,
901
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr, trees.back().num_split,
902
+ model_params.missing_action, workspace.st_NA, workspace.end_NA, workspace.split_ix);
903
+ }
904
+
905
+ /* for categorical, there are different ways of splitting */
906
+ else
907
+ {
908
+ /* if the columns is binary, there's only one possible split */
909
+ if (input_data.ncat[trees.back().col_num] <= 2)
910
+ {
911
+ trees.back().chosen_cat = 0;
912
+ divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
913
+ workspace.st, workspace.end, (int)0, model_params.missing_action,
914
+ workspace.st_NA, workspace.end_NA, workspace.split_ix);
915
+ trees.back().cat_split.clear();
916
+ trees.back().cat_split.shrink_to_fit();
917
+ }
918
+
919
+ /* otherwise, split according to desired type (single/subset) */
920
+ /* TODO: refactor this */
921
+ else
922
+ {
923
+
924
+ switch (model_params.cat_split_type)
925
+ {
926
+
927
+ case SingleCateg:
928
+ {
929
+
930
+ if (workspace.determine_split)
931
+ {
932
+ switch (workspace.criterion)
933
+ {
934
+ case NoCrit:
935
+ {
936
+ trees.back().chosen_cat = choose_cat_from_present(workspace, input_data, trees.back().col_num);
937
+ break;
938
+ }
939
+
940
+ default:
941
+ {
942
+ if (!workspace.changed_weights)
943
+ workspace.this_gain =
944
+ eval_guided_crit<ldouble_safe>(
945
+ workspace.ix_arr.data(), workspace.st, workspace.end,
946
+ input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
947
+ &workspace.best_cat_mode,
948
+ workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
949
+ workspace.buffer_dbl.data(), trees.back().chosen_cat, workspace.this_split_categ.data(),
950
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
951
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
952
+ else if (!workspace.weights_arr.empty())
953
+ workspace.this_gain =
954
+ eval_guided_crit_weighted<decltype(workspace.weights_arr), ldouble_safe>(
955
+ workspace.ix_arr.data(), workspace.st, workspace.end,
956
+ input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
957
+ &workspace.best_cat_mode,
958
+ workspace.buffer_szt.data(),
959
+ workspace.buffer_dbl.data(), trees.back().chosen_cat, workspace.this_split_categ.data(),
960
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
961
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type,
962
+ workspace.weights_arr);
963
+ else
964
+ workspace.this_gain =
965
+ eval_guided_crit_weighted<decltype(workspace.weights_map), ldouble_safe>(
966
+ workspace.ix_arr.data(), workspace.st, workspace.end,
967
+ input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
968
+ &workspace.best_cat_mode,
969
+ workspace.buffer_szt.data(),
970
+ workspace.buffer_dbl.data(), trees.back().chosen_cat, workspace.this_split_categ.data(),
971
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
972
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type,
973
+ workspace.weights_map);
974
+
975
+ if (std::isnan(workspace.this_gain) || workspace.this_gain <= -HUGE_VAL)
976
+ goto terminal_statistics;
977
+
978
+ break;
979
+ }
980
+ }
981
+ }
982
+
983
+
984
+ divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
985
+ workspace.st, workspace.end, trees.back().chosen_cat, model_params.missing_action,
986
+ workspace.st_NA, workspace.end_NA, workspace.split_ix);
987
+ break;
988
+ }
989
+
990
+
991
+ case SubSet:
992
+ {
993
+
994
+ if (workspace.determine_split)
995
+ {
996
+ switch(workspace.criterion)
997
+ {
998
+ case NoCrit:
999
+ {
1000
+ workspace.unsplittable = true;
1001
+ while(workspace.unsplittable)
1002
+ {
1003
+ workspace.npresent = 0;
1004
+ workspace.ncols_tried = 0;
1005
+ for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
1006
+ {
1007
+ if (workspace.categs[cat] >= 0)
1008
+ {
1009
+ workspace.categs[cat] = workspace.rbin(workspace.rnd_generator) < 0.5;
1010
+ workspace.npresent += workspace.categs[cat];
1011
+ workspace.ncols_tried += !workspace.categs[cat];
1012
+ }
1013
+ workspace.unsplittable = !(workspace.npresent && workspace.ncols_tried);
1014
+ }
1015
+ }
1016
+
1017
+ trees.back().cat_split.assign(workspace.categs.begin(), workspace.categs.begin() + input_data.ncat[trees.back().col_num]);
1018
+ break; /* NoCrit */
1019
+ }
1020
+
1021
+ default:
1022
+ {
1023
+ trees.back().cat_split.resize(input_data.ncat[trees.back().col_num]);
1024
+ if (!workspace.changed_weights)
1025
+ workspace.this_gain =
1026
+ eval_guided_crit<ldouble_safe>(
1027
+ workspace.ix_arr.data(), workspace.st, workspace.end,
1028
+ input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
1029
+ &workspace.best_cat_mode,
1030
+ workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
1031
+ workspace.buffer_dbl.data(), trees.back().chosen_cat, trees.back().cat_split.data(),
1032
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
1033
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
1034
+ else if (!workspace.weights_arr.empty())
1035
+ workspace.this_gain =
1036
+ eval_guided_crit_weighted<decltype(workspace.weights_arr), ldouble_safe>(
1037
+ workspace.ix_arr.data(), workspace.st, workspace.end,
1038
+ input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
1039
+ &workspace.best_cat_mode,
1040
+ workspace.buffer_szt.data(),
1041
+ workspace.buffer_dbl.data(), trees.back().chosen_cat, trees.back().cat_split.data(),
1042
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
1043
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type,
1044
+ workspace.weights_arr);
1045
+ else
1046
+ workspace.this_gain =
1047
+ eval_guided_crit_weighted<decltype(workspace.weights_map), ldouble_safe>(
1048
+ workspace.ix_arr.data(), workspace.st, workspace.end,
1049
+ input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
1050
+ &workspace.best_cat_mode,
1051
+ workspace.buffer_szt.data(),
1052
+ workspace.buffer_dbl.data(), trees.back().chosen_cat, trees.back().cat_split.data(),
1053
+ workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
1054
+ model_params.all_perm, model_params.missing_action, model_params.cat_split_type,
1055
+ workspace.weights_map);
1056
+
1057
+ if (std::isnan(workspace.this_gain) || workspace.this_gain <= -HUGE_VAL)
1058
+ goto terminal_statistics;
1059
+ break;
1060
+ }
1061
+ }
1062
+ }
1063
+
1064
+ if (model_params.new_cat_action == Random)
1065
+ {
1066
+ if (model_params.scoring_metric == Density)
1067
+ {
1068
+ workspace.density_calculator.save_n_present_and_left(trees.back().cat_split.data(), input_data.ncat[trees.back().col_num]);
1069
+ }
1070
+
1071
+ for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
1072
+ if (trees.back().cat_split[cat] < 0)
1073
+ trees.back().cat_split[cat] = workspace.rbin(workspace.rnd_generator) < 0.5;
1074
+ }
1075
+
1076
+ divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
1077
+ workspace.st, workspace.end, trees.back().cat_split.data(), model_params.missing_action,
1078
+ workspace.st_NA, workspace.end_NA, workspace.split_ix);
1079
+ }
1080
+
1081
+ }
1082
+
1083
+ }
1084
+
1085
+ }
1086
+
1087
+
1088
+ /* if it hasn't reached the limit, continue splitting from here */
1089
+ follow_branches:
1090
+ {
1091
+ /* add another round of separation depth for distance */
1092
+ if (model_params.calc_dist && curr_depth > 0)
1093
+ add_separation_step(workspace, input_data, (double)(-1));
1094
+
1095
+ /* if it split by a categorical variable with only 2 values,
1096
+ the column will no longer be splittable in either branch */
1097
+ if (trees.back().col_type == Categorical &&
1098
+ ((model_params.cat_split_type == SubSet && trees.back().cat_split.empty()) ||
1099
+ (model_params.cat_split_type == SingleCateg && input_data.ncat[trees.back().col_num] == 2)))
1100
+ {
1101
+ workspace.col_sampler.drop_col(trees.back().col_num + input_data.ncols_numeric,
1102
+ workspace.end - workspace.st + 1);
1103
+ }
1104
+
1105
+ size_t tree_from = trees.size() - 1;
1106
+ std::unique_ptr<RecursionState>
1107
+ recursion_state(new RecursionState(workspace, model_params.missing_action != Fail));
1108
+ trees.back().score = -1;
1109
+
1110
+ /* compute statistics for NAs and remember recursion indices/weights */
1111
+ if (model_params.missing_action != Fail)
1112
+ {
1113
+ if (
1114
+ model_params.missing_action == Impute &&
1115
+ workspace.criterion != NoCrit &&
1116
+ workspace.st_NA < workspace.end_NA
1117
+ ) {
1118
+ bool move_NAs_left;
1119
+ if (trees.back().col_type == Numeric)
1120
+ {
1121
+ move_NAs_left = workspace.best_xmedian <= trees.back().num_split;
1122
+ }
1123
+
1124
+ else
1125
+ {
1126
+ if (trees.back().cat_split.empty())
1127
+ move_NAs_left = workspace.best_cat_mode == trees.back().chosen_cat;
1128
+ else
1129
+ move_NAs_left = trees.back().cat_split[workspace.best_cat_mode] == 1;
1130
+ }
1131
+
1132
+ if (move_NAs_left)
1133
+ workspace.st_NA = workspace.end_NA;
1134
+ else
1135
+ workspace.end_NA = workspace.st_NA;
1136
+ }
1137
+
1138
+ if (!workspace.changed_weights)
1139
+ {
1140
+ trees.back().pct_tree_left = (ldouble_safe)(workspace.st_NA - workspace.st)
1141
+ /
1142
+ (ldouble_safe)(workspace.end - workspace.st + 1 - (workspace.end_NA - workspace.st_NA));
1143
+
1144
+ if (model_params.missing_action == Divide && workspace.st_NA < workspace.end_NA)
1145
+ {
1146
+ workspace.changed_weights = true;
1147
+
1148
+ if (input_data.Xc_indptr != NULL && model_params.sample_size < input_data.nrows / 20) {
1149
+ workspace.weights_arr.clear();
1150
+ workspace.weights_map.reserve(workspace.end - workspace.st + 1);
1151
+ for (size_t row = workspace.st; row < workspace.end_NA; row++)
1152
+ workspace.weights_map[workspace.ix_arr[row]] = 1;
1153
+ }
1154
+
1155
+ else {
1156
+ workspace.weights_arr.resize(input_data.nrows);
1157
+ for (size_t row = workspace.st; row < workspace.end_NA; row++)
1158
+ workspace.weights_arr[workspace.ix_arr[row]] = 1;
1159
+ }
1160
+ }
1161
+ }
1162
+
1163
+ else
1164
+ {
1165
+ ldouble_safe sum_weight_left = 0;
1166
+ ldouble_safe sum_weight_right = 0;
1167
+
1168
+ if (!workspace.weights_arr.empty()) {
1169
+ for (size_t row = workspace.st; row < workspace.st_NA; row++)
1170
+ sum_weight_left += workspace.weights_arr[workspace.ix_arr[row]];
1171
+ for (size_t row = workspace.end_NA; row <= workspace.end; row++)
1172
+ sum_weight_right += workspace.weights_arr[workspace.ix_arr[row]];
1173
+ }
1174
+
1175
+ else {
1176
+ for (size_t row = workspace.st; row < workspace.st_NA; row++)
1177
+ sum_weight_left += workspace.weights_map[workspace.ix_arr[row]];
1178
+ for (size_t row = workspace.end_NA; row <= workspace.end; row++)
1179
+ sum_weight_right += workspace.weights_map[workspace.ix_arr[row]];
1180
+ }
1181
+
1182
+ trees.back().pct_tree_left = sum_weight_left / (sum_weight_left + sum_weight_right);
1183
+ }
1184
+
1185
+ switch(model_params.missing_action)
1186
+ {
1187
+ case Impute:
1188
+ {
1189
+ if (trees.back().pct_tree_left >= .5)
1190
+ workspace.end = workspace.end_NA - 1;
1191
+ else
1192
+ workspace.end = workspace.st_NA - 1;
1193
+ break;
1194
+ }
1195
+
1196
+
1197
+ case Divide:
1198
+ {
1199
+ if (!workspace.weights_arr.empty())
1200
+ for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
1201
+ workspace.weights_arr[workspace.ix_arr[row]] *= trees.back().pct_tree_left;
1202
+ else
1203
+ for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
1204
+ workspace.weights_map[workspace.ix_arr[row]] *= trees.back().pct_tree_left;
1205
+ workspace.end = workspace.end_NA - 1;
1206
+ break;
1207
+ }
1208
+
1209
+ default:
1210
+ {
1211
+ unexpected_error();
1212
+ break;
1213
+ }
1214
+ }
1215
+ }
1216
+
1217
+ else
1218
+ {
1219
+ trees.back().pct_tree_left = (ldouble_safe) (workspace.split_ix - workspace.st)
1220
+ /
1221
+ (ldouble_safe) (workspace.end - workspace.st + 1);
1222
+ workspace.end = workspace.split_ix - 1;
1223
+ }
1224
+
1225
+ /* Depending on the scoring metric, might need to calculate fractions of data and volume */
1226
+ if (model_params.scoring_metric != Depth && !is_boxed_metric(model_params.scoring_metric))
1227
+ {
1228
+ switch (trees.back().col_type)
1229
+ {
1230
+ case Numeric:
1231
+ {
1232
+ if (!workspace.determine_split)
1233
+ workspace.density_calculator.restore_range(workspace.xmin, workspace.xmax);
1234
+
1235
+ if (model_params.scoring_metric == Density)
1236
+ workspace.density_calculator.push_density(workspace.xmin, workspace.xmax, trees.back().num_split);
1237
+ else
1238
+ workspace.density_calculator.push_adj(workspace.xmin, workspace.xmax,
1239
+ trees.back().num_split, trees.back().pct_tree_left,
1240
+ model_params.scoring_metric);
1241
+ break;
1242
+ }
1243
+
1244
+ case Categorical:
1245
+ {
1246
+ switch (model_params.cat_split_type)
1247
+ {
1248
+ case SingleCateg:
1249
+ {
1250
+ if (model_params.scoring_metric == Density)
1251
+ {
1252
+ if (workspace.determine_split)
1253
+ {
1254
+ if (workspace.criterion == NoCrit)
1255
+ workspace.density_calculator.push_density(workspace.npresent);
1256
+ else
1257
+ workspace.density_calculator.push_density(workspace.buffer_szt.data(),
1258
+ input_data.ncat[trees.back().col_num]);
1259
+ }
1260
+
1261
+ else
1262
+ {
1263
+ workspace.density_calculator.push_density(workspace.density_calculator.counts.data(),
1264
+ input_data.ncat[trees.back().col_num]);
1265
+ }
1266
+ }
1267
+
1268
+ else
1269
+ {
1270
+ if (workspace.determine_split)
1271
+ {
1272
+ if (workspace.criterion == NoCrit)
1273
+ {
1274
+ count_categs(workspace.ix_arr.data(), workspace.st, workspace.end,
1275
+ input_data.categ_data + trees.back().col_num * input_data.nrows,
1276
+ input_data.ncat[trees.back().col_num],
1277
+ workspace.density_calculator.counts.data());
1278
+ workspace.density_calculator.push_adj(workspace.density_calculator.counts.data(),
1279
+ input_data.ncat[trees.back().col_num],
1280
+ trees.back().chosen_cat,
1281
+ model_params.scoring_metric);
1282
+ }
1283
+
1284
+ else
1285
+ {
1286
+ workspace.density_calculator.push_adj(workspace.buffer_szt.data(),
1287
+ input_data.ncat[trees.back().col_num],
1288
+ trees.back().chosen_cat,
1289
+ model_params.scoring_metric);
1290
+ }
1291
+ }
1292
+
1293
+ else
1294
+ {
1295
+
1296
+ workspace.density_calculator.push_adj(workspace.density_calculator.counts.data(),
1297
+ input_data.ncat[trees.back().col_num],
1298
+ trees.back().chosen_cat,
1299
+ model_params.scoring_metric);
1300
+ }
1301
+ }
1302
+ break;
1303
+ }
1304
+
1305
+ case SubSet:
1306
+ {
1307
+ if (model_params.scoring_metric == Density)
1308
+ {
1309
+ if (!trees.back().cat_split.size())
1310
+ {
1311
+ workspace.density_calculator.push_density();
1312
+ }
1313
+
1314
+ else
1315
+ {
1316
+ workspace.density_calculator.push_density(workspace.density_calculator.n_left,
1317
+ workspace.density_calculator.n_present);
1318
+ }
1319
+
1320
+ }
1321
+
1322
+ else
1323
+ {
1324
+ if (!trees.back().cat_split.size())
1325
+ {
1326
+ workspace.density_calculator.push_adj(trees.back().pct_tree_left,
1327
+ model_params.scoring_metric);
1328
+ }
1329
+
1330
+ else
1331
+ {
1332
+ if (workspace.determine_split)
1333
+ {
1334
+ if (workspace.criterion == NoCrit)
1335
+ {
1336
+ count_categs(workspace.ix_arr.data(), workspace.st, workspace.end,
1337
+ input_data.categ_data + trees.back().col_num * input_data.nrows,
1338
+ input_data.ncat[trees.back().col_num],
1339
+ workspace.density_calculator.counts.data());
1340
+ workspace.density_calculator.push_adj(trees.back().cat_split.data(),
1341
+ workspace.density_calculator.counts.data(),
1342
+ input_data.ncat[trees.back().col_num],
1343
+ model_params.scoring_metric);
1344
+ }
1345
+
1346
+ else
1347
+ {
1348
+ workspace.density_calculator.push_adj(trees.back().cat_split.data(),
1349
+ workspace.buffer_szt.data(),
1350
+ input_data.ncat[trees.back().col_num],
1351
+ model_params.scoring_metric);
1352
+ }
1353
+ }
1354
+
1355
+ else
1356
+ {
1357
+ workspace.density_calculator.push_adj(trees.back().cat_split.data(),
1358
+ workspace.density_calculator.counts.data(),
1359
+ input_data.ncat[trees.back().col_num],
1360
+ model_params.scoring_metric);
1361
+ }
1362
+ }
1363
+ }
1364
+ break;
1365
+ }
1366
+ }
1367
+ break;
1368
+ }
1369
+
1370
+ default:
1371
+ {
1372
+ assert(0);
1373
+ }
1374
+ }
1375
+ }
1376
+
1377
+ else if (is_boxed_metric(model_params.scoring_metric))
1378
+ {
1379
+ switch (trees.back().col_type)
1380
+ {
1381
+ case Numeric:
1382
+ {
1383
+ workspace.density_calculator.push_bdens(trees.back().num_split, trees.back().col_num);
1384
+ break;
1385
+ }
1386
+
1387
+ case Categorical:
1388
+ {
1389
+ switch (model_params.cat_split_type)
1390
+ {
1391
+ case SingleCateg:
1392
+ {
1393
+ workspace.density_calculator.push_bdens((int)1, trees.back().col_num);
1394
+ break;
1395
+ }
1396
+
1397
+ case SubSet:
1398
+ {
1399
+ if (trees.back().cat_split.empty())
1400
+ {
1401
+ workspace.density_calculator.push_bdens((int)1, trees.back().col_num);
1402
+ }
1403
+
1404
+ else
1405
+ {
1406
+ workspace.density_calculator.push_bdens(trees.back().cat_split, trees.back().col_num);
1407
+ }
1408
+ break;
1409
+ }
1410
+ }
1411
+ break;
1412
+ }
1413
+
1414
+ default:
1415
+ {
1416
+ assert(0);
1417
+ }
1418
+ }
1419
+ }
1420
+
1421
+ /* Branch where to assign new categories can be pre-determined in this case */
1422
+ if (
1423
+ trees.back().col_type == Categorical &&
1424
+ model_params.cat_split_type == SubSet &&
1425
+ input_data.ncat[trees.back().col_num] > 2 &&
1426
+ model_params.new_cat_action == Smallest
1427
+ )
1428
+ {
1429
+ bool new_to_left = trees.back().pct_tree_left < 0.5;
1430
+ for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
1431
+ if (trees.back().cat_split[cat] < 0)
1432
+ trees.back().cat_split[cat] = new_to_left;
1433
+ }
1434
+
1435
+ /* If doing single-category splits, the branch that got only one category will not
1436
+ be splittable anymore, so it can be dropped for the remainder of that branch */
1437
+ if (trees.back().col_type == Categorical &&
1438
+ model_params.cat_split_type == SingleCateg &&
1439
+ input_data.ncat[trees.back().col_num] > 2 /* <- in this case, would have been dropped earlier */
1440
+ )
1441
+ {
1442
+ workspace.col_sampler.drop_col(trees.back().col_num + input_data.ncols_numeric,
1443
+ workspace.end - workspace.st + 1);
1444
+ }
1445
+
1446
+ /* left branch */
1447
+ trees.back().tree_left = trees.size();
1448
+ trees.emplace_back();
1449
+ if (impute_nodes != NULL) impute_nodes->emplace_back(tree_from);
1450
+ split_itree_recursive<InputData, WorkerMemory, ldouble_safe>(
1451
+ trees,
1452
+ workspace,
1453
+ input_data,
1454
+ model_params,
1455
+ impute_nodes,
1456
+ curr_depth + 1);
1457
+
1458
+
1459
+ /* right branch */
1460
+ recursion_state->restore_state(workspace);
1461
+ if (is_boxed_metric(model_params.scoring_metric))
1462
+ {
1463
+ if (trees[tree_from].col_type == Numeric)
1464
+ workspace.density_calculator.pop_bdens(trees[tree_from].col_num);
1465
+ else
1466
+ workspace.density_calculator.pop_bdens_cat(trees[tree_from].col_num);
1467
+ }
1468
+ else if (model_params.scoring_metric != Depth)
1469
+ {
1470
+ workspace.density_calculator.pop();
1471
+ }
1472
+ if (model_params.missing_action != Fail)
1473
+ {
1474
+ switch(model_params.missing_action)
1475
+ {
1476
+ case Impute:
1477
+ {
1478
+ if (trees[tree_from].pct_tree_left >= .5)
1479
+ workspace.st = workspace.end_NA;
1480
+ else
1481
+ workspace.st = workspace.st_NA;
1482
+ break;
1483
+ }
1484
+
1485
+ case Divide:
1486
+ {
1487
+ if (!workspace.changed_weights && workspace.st_NA < workspace.end_NA)
1488
+ {
1489
+ workspace.changed_weights = true;
1490
+
1491
+ if (!workspace.weights_arr.empty()) {
1492
+ for (size_t row = workspace.st_NA; row <= workspace.end; row++)
1493
+ workspace.weights_arr[workspace.ix_arr[row]] = 1;
1494
+ }
1495
+
1496
+ else {
1497
+ for (size_t row = workspace.st_NA; row <= workspace.end; row++)
1498
+ workspace.weights_map[workspace.ix_arr[row]] = 1;
1499
+ }
1500
+ }
1501
+
1502
+ if (!workspace.weights_arr.empty())
1503
+ for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
1504
+ workspace.weights_arr[workspace.ix_arr[row]] *= (1. - trees[tree_from].pct_tree_left);
1505
+ else
1506
+ for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
1507
+ workspace.weights_map[workspace.ix_arr[row]] *= (1. - trees[tree_from].pct_tree_left);
1508
+ workspace.st = workspace.st_NA;
1509
+ break;
1510
+ }
1511
+
1512
+ default:
1513
+ {
1514
+ unexpected_error();
1515
+ break;
1516
+ }
1517
+ }
1518
+ }
1519
+
1520
+ else
1521
+ {
1522
+ workspace.st = workspace.split_ix;
1523
+ }
1524
+
1525
+ trees[tree_from].tree_right = trees.size();
1526
+ trees.emplace_back();
1527
+ if (impute_nodes != NULL) impute_nodes->emplace_back(tree_from);
1528
+ split_itree_recursive<InputData, WorkerMemory, ldouble_safe>(
1529
+ trees,
1530
+ workspace,
1531
+ input_data,
1532
+ model_params,
1533
+ impute_nodes,
1534
+ curr_depth + 1);
1535
+ if (is_boxed_metric(model_params.scoring_metric))
1536
+ {
1537
+ if (trees[tree_from].col_type == Numeric)
1538
+ workspace.density_calculator.pop_bdens_right(trees[tree_from].col_num);
1539
+ else
1540
+ workspace.density_calculator.pop_bdens_cat_right(trees[tree_from].col_num);
1541
+ }
1542
+ else if (model_params.scoring_metric != Depth)
1543
+ {
1544
+ workspace.density_calculator.pop_right();
1545
+ }
1546
+ }
1547
+ return;
1548
+
1549
+ /* if it reached the limit, calculate terminal statistics */
1550
+ terminal_statistics:
1551
+ {
1552
+ trees.back().tree_left = 0;
1553
+
1554
+ if (workspace.changed_weights)
1555
+ {
1556
+ if (sum_weight <= -HUGE_VAL)
1557
+ sum_weight = calculate_sum_weights<ldouble_safe>(
1558
+ workspace.ix_arr, workspace.st, workspace.end, curr_depth,
1559
+ workspace.weights_arr, workspace.weights_map);
1560
+ }
1561
+
1562
+ switch (model_params.scoring_metric)
1563
+ {
1564
+ case Depth:
1565
+ {
1566
+ if (!workspace.changed_weights)
1567
+ trees.back().score = curr_depth + expected_avg_depth<ldouble_safe>(workspace.end - workspace.st + 1);
1568
+ else
1569
+ trees.back().score = curr_depth + expected_avg_depth<ldouble_safe>(sum_weight);
1570
+ break;
1571
+ }
1572
+
1573
+ case AdjDepth:
1574
+ {
1575
+ if (!workspace.changed_weights)
1576
+ trees.back().score = workspace.density_calculator.calc_adj_depth() + expected_avg_depth<ldouble_safe>(workspace.end - workspace.st + 1);
1577
+ else
1578
+ trees.back().score = workspace.density_calculator.calc_adj_depth() + expected_avg_depth<ldouble_safe>(sum_weight);
1579
+ break;
1580
+ }
1581
+
1582
+ case Density:
1583
+ {
1584
+ if (!workspace.changed_weights)
1585
+ trees.back().score = workspace.density_calculator.calc_density(workspace.end - workspace.st + 1, model_params.sample_size);
1586
+ else
1587
+ trees.back().score = workspace.density_calculator.calc_density(sum_weight, model_params.sample_size);
1588
+ break;
1589
+ }
1590
+
1591
+ case AdjDensity:
1592
+ {
1593
+ trees.back().score = workspace.density_calculator.calc_adj_density();
1594
+ break;
1595
+ }
1596
+
1597
+ case BoxedRatio:
1598
+ {
1599
+ trees.back().score = workspace.density_calculator.calc_bratio();
1600
+ break;
1601
+ }
1602
+
1603
+ case BoxedDensity:
1604
+ {
1605
+ if (!workspace.changed_weights)
1606
+ trees.back().score = workspace.density_calculator.calc_bdens(workspace.end - workspace.st + 1, model_params.sample_size);
1607
+ else
1608
+ trees.back().score = workspace.density_calculator.calc_bdens(sum_weight, model_params.sample_size);
1609
+ break;
1610
+ }
1611
+
1612
+ case BoxedDensity2:
1613
+ {
1614
+ if (!workspace.changed_weights)
1615
+ trees.back().score = workspace.density_calculator.calc_bdens2(workspace.end - workspace.st + 1, model_params.sample_size);
1616
+ else
1617
+ trees.back().score = workspace.density_calculator.calc_bdens2(sum_weight, model_params.sample_size);
1618
+ break;
1619
+ }
1620
+ }
1621
+
1622
+ trees.back().cat_split.clear();
1623
+ trees.back().cat_split.shrink_to_fit();
1624
+
1625
+ trees.back().remainder = workspace.changed_weights?
1626
+ (double)sum_weight : (double)(workspace.end - workspace.st + 1);
1627
+
1628
+ /* for distance, assume also the elements keep being split */
1629
+ if (model_params.calc_dist)
1630
+ add_remainder_separation_steps<InputData, WorkerMemory, ldouble_safe>(workspace, input_data, sum_weight);
1631
+
1632
+ /* add this depth right away if requested */
1633
+ if (workspace.row_depths.size())
1634
+ {
1635
+ if (!workspace.changed_weights)
1636
+ {
1637
+ for (size_t row = workspace.st; row <= workspace.end; row++)
1638
+ workspace.row_depths[workspace.ix_arr[row]] += trees.back().score;
1639
+ }
1640
+
1641
+ else if (!workspace.weights_arr.empty())
1642
+ {
1643
+ for (size_t row = workspace.st; row <= workspace.end; row++)
1644
+ workspace.row_depths[workspace.ix_arr[row]] += workspace.weights_arr[workspace.ix_arr[row]] * trees.back().score;
1645
+ }
1646
+
1647
+ else
1648
+ {
1649
+ for (size_t row = workspace.st; row <= workspace.end; row++)
1650
+ workspace.row_depths[workspace.ix_arr[row]] += workspace.weights_map[workspace.ix_arr[row]] * trees.back().score;
1651
+ }
1652
+ }
1653
+
1654
+ /* add imputations from node if requested */
1655
+ if (model_params.impute_at_fit)
1656
+ add_from_impute_node(impute_nodes->back(), workspace, input_data);
1657
+ }
1658
+
1659
+ }