isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -1,771 +0,0 @@
1
- /* Isolation forests and variations thereof, with adjustments for incorporation
2
- * of categorical variables and missing values.
3
- * Writen for C++11 standard and aimed at being used in R and Python.
4
- *
5
- * This library is based on the following works:
6
- * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
- * "Isolation forest."
8
- * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
- * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
- * "Isolation-based anomaly detection."
11
- * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
- * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
- * "Extended Isolation Forest."
14
- * arXiv preprint arXiv:1811.02141 (2018).
15
- * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
- * "On detecting clustered anomalies using SCiForest."
17
- * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
- * [5] https://sourceforge.net/projects/iforest/
19
- * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
- * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
- * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
- * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
- *
24
- * BSD 2-Clause License
25
- * Copyright (c) 2020, David Cortes
26
- * All rights reserved.
27
- * Redistribution and use in source and binary forms, with or without
28
- * modification, are permitted provided that the following conditions are met:
29
- * * Redistributions of source code must retain the above copyright notice, this
30
- * list of conditions and the following disclaimer.
31
- * * Redistributions in binary form must reproduce the above copyright notice,
32
- * this list of conditions and the following disclaimer in the documentation
33
- * and/or other materials provided with the distribution.
34
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
- */
45
- #include "isotree.hpp"
46
-
47
- void split_itree_recursive(std::vector<IsoTree> &trees,
48
- WorkerMemory &workspace,
49
- InputData &input_data,
50
- ModelParams &model_params,
51
- std::vector<ImputeNode> *impute_nodes,
52
- size_t curr_depth)
53
- {
54
- long double sum_weight = -HUGE_VAL;
55
-
56
- /* calculate imputation statistics if desired */
57
- if (impute_nodes != NULL)
58
- {
59
- if (input_data.Xc_indptr != NULL)
60
- std::sort(workspace.ix_arr.begin() + workspace.st,
61
- workspace.ix_arr.begin() + workspace.end + 1);
62
- build_impute_node(impute_nodes->back(), workspace,
63
- input_data, model_params,
64
- *impute_nodes, curr_depth,
65
- model_params.min_imp_obs);
66
- }
67
-
68
- /* check for potential isolated leafs */
69
- if (workspace.end == workspace.st || curr_depth >= model_params.max_depth)
70
- goto terminal_statistics;
71
-
72
- /* with 2 observations and no weights, there's only 1 potential or assumed split */
73
- if ((workspace.end - workspace.st) == 1 && !workspace.weights_arr.size() && !workspace.weights_map.size())
74
- goto terminal_statistics;
75
-
76
- /* when using weights, the split should stop when the sum of weights is <= 2 */
77
- sum_weight = calculate_sum_weights(workspace.ix_arr, workspace.st, workspace.end, curr_depth,
78
- workspace.weights_arr, workspace.weights_map);
79
-
80
- if (curr_depth > 0 && (workspace.weights_arr.size() || workspace.weights_map.size()) && sum_weight < 2.5)
81
- goto terminal_statistics;
82
-
83
- /* for sparse matrices, need to sort the indices */
84
- if (input_data.Xc_indptr != NULL && impute_nodes == NULL)
85
- std::sort(workspace.ix_arr.begin() + workspace.st, workspace.ix_arr.begin() + workspace.end + 1);
86
-
87
- /* pick column to split according to criteria */
88
- workspace.prob_split_type = workspace.rbin(workspace.rnd_generator);
89
-
90
- /* case1: guided, pick column and point with best gain */
91
- if (
92
- workspace.prob_split_type
93
- < (
94
- model_params.prob_pick_by_gain_avg +
95
- model_params.prob_pick_by_gain_pl
96
- )
97
- )
98
- {
99
- workspace.determine_split = false;
100
-
101
- /* case 1.1: column is decided by averaged gain */
102
- if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg)
103
- workspace.criterion = Averaged;
104
-
105
- /* case 1.2: column is decided by pooled gain */
106
- else
107
- workspace.criterion = Pooled;
108
-
109
- /* evaluate gain for all columns */
110
- trees.back().score = -HUGE_VAL; /* this is used to track the best gain */
111
- if (input_data.Xc_indptr == NULL)
112
- {
113
- for (size_t col = 0; col < input_data.ncols_numeric; col++)
114
- {
115
- workspace.this_gain = eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
116
- input_data.numeric_data + col * input_data.nrows,
117
- workspace.split_ix, workspace.this_split_point,
118
- workspace.xmin, workspace.xmax,
119
- workspace.criterion, model_params.min_gain,
120
- model_params.missing_action);
121
- if (workspace.this_gain <= -HUGE_VAL)
122
- {
123
- workspace.cols_possible[col] = false;
124
- }
125
-
126
- else if (workspace.this_gain > trees.back().score)
127
- {
128
- trees.back().score = workspace.this_gain;
129
- trees.back().col_num = col;
130
- trees.back().num_split = workspace.this_split_point;
131
- if (model_params.penalize_range)
132
- {
133
- trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
134
- trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
135
- }
136
- }
137
- }
138
-
139
- }
140
-
141
- else
142
- {
143
- for (size_t col = 0; col < input_data.ncols_numeric; col++)
144
- {
145
- workspace.this_gain = eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
146
- col, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
147
- workspace.buffer_dbl.data(), workspace.buffer_szt.data(),
148
- workspace.this_split_point, workspace.xmin, workspace.xmax,
149
- workspace.criterion, model_params.min_gain, model_params.missing_action);
150
- if (workspace.this_gain <= -HUGE_VAL)
151
- {
152
- workspace.cols_possible[col] = false;
153
- }
154
-
155
- else if (workspace.this_gain > trees.back().score)
156
- {
157
- trees.back().score = workspace.this_gain;
158
- trees.back().col_num = col;
159
- trees.back().num_split = workspace.this_split_point;
160
- if (model_params.penalize_range)
161
- {
162
- trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
163
- trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
164
- }
165
- }
166
- }
167
- }
168
-
169
- for (size_t col = 0; col < input_data.ncols_categ; col++)
170
- {
171
- workspace.this_gain = eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
172
- input_data.categ_data + col * input_data.nrows, input_data.ncat[col],
173
- workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
174
- workspace.buffer_dbl.data(), workspace.this_categ, workspace.this_split_categ.data(),
175
- workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
176
- model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
177
- if (workspace.this_gain <= -HUGE_VAL)
178
- {
179
- workspace.cols_possible[col + input_data.ncols_numeric] = false;
180
- }
181
-
182
- else if (workspace.this_gain > trees.back().score)
183
- {
184
- trees.back().score = workspace.this_gain;
185
- trees.back().col_num = col + input_data.ncols_numeric;
186
- switch(model_params.cat_split_type)
187
- {
188
- case SingleCateg:
189
- {
190
- trees.back().chosen_cat = workspace.this_categ;
191
- break;
192
- }
193
-
194
- case SubSet:
195
- {
196
- trees.back().cat_split.assign(workspace.this_split_categ.begin(),
197
- workspace.this_split_categ.begin() + input_data.ncat[col]);
198
- break;
199
- }
200
- }
201
- }
202
- }
203
-
204
-
205
- if (trees.back().score <= 0.)
206
- goto terminal_statistics;
207
- else
208
- trees.back().score = 0.;
209
-
210
- if (trees.back().col_num < input_data.ncols_numeric)
211
- {
212
- trees.back().col_type = Numeric;
213
- }
214
-
215
- else
216
- {
217
- trees.back().col_type = Categorical;
218
- trees.back().col_num -= input_data.ncols_numeric;
219
- }
220
- }
221
-
222
- /* case2: column is chosen at random */
223
- else
224
- {
225
- workspace.determine_split = true;
226
-
227
- /* case 2.1: split point is chosen according to gain (averaged) */
228
- if (
229
- workspace.prob_split_type
230
- < (
231
- model_params.prob_pick_by_gain_avg +
232
- model_params.prob_pick_by_gain_pl +
233
- model_params.prob_split_by_gain_avg
234
- )
235
- )
236
- workspace.criterion = Averaged;
237
-
238
- /* case 2.2: split point is chosen according to gain (pooled) */
239
- else if (
240
- workspace.prob_split_type
241
- < (
242
- model_params.prob_pick_by_gain_avg +
243
- model_params.prob_pick_by_gain_pl +
244
- model_params.prob_split_by_gain_avg +
245
- model_params.prob_split_by_gain_pl
246
- )
247
- )
248
- workspace.criterion = Pooled;
249
-
250
- /* case 2.3: split point is chosen randomly (like in the original paper) */
251
- else
252
- workspace.criterion = NoCrit;
253
-
254
-
255
- /* pick column at random */
256
- decide_column(input_data.ncols_numeric, input_data.ncols_categ,
257
- trees.back().col_num, trees.back().col_type,
258
- workspace.rnd_generator, workspace.runif,
259
- workspace.col_sampler);
260
-
261
- /* get the range of possible splits */
262
- get_split_range(workspace, input_data, model_params, trees.back());
263
-
264
- /* if it's not possible to split, will have to try more */
265
- if (workspace.unsplittable)
266
- {
267
- /* keep track of which columns are tried */
268
- add_unsplittable_col(workspace, trees.back(), input_data);
269
-
270
- /* try more random columns for {(1/2) * ncols} times */
271
- workspace.ncols_tried = 1;
272
- do
273
- {
274
- decide_column(input_data.ncols_numeric, input_data.ncols_categ,
275
- trees.back().col_num, trees.back().col_type,
276
- workspace.rnd_generator, workspace.runif,
277
- workspace.col_sampler);
278
- if (!check_is_not_unsplittable_col(workspace, trees.back(), input_data))
279
- {
280
- get_split_range(workspace, input_data, model_params, trees.back());
281
- if (!workspace.unsplittable)
282
- break;
283
- else
284
- add_unsplittable_col(workspace, trees.back(), input_data);
285
- }
286
- workspace.ncols_tried++;
287
- }
288
- while (workspace.ncols_tried < input_data.ncols_tot / 2);
289
-
290
- /* if that didn't work, then check all the columns that are still splittable */
291
- if (workspace.unsplittable)
292
- {
293
- workspace.ncols_tried = 0; /* note: this is used here as a counter for the number of still splittable columns */
294
- if (input_data.Xc_indptr == NULL)
295
- {
296
- for (size_t col = 0; col < input_data.ncols_numeric; col++)
297
- {
298
- if (!workspace.cols_possible[col]) continue;
299
- get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * col,
300
- workspace.st, workspace.end, model_params.missing_action,
301
- workspace.xmin, workspace.xmax, workspace.unsplittable);
302
- workspace.cols_possible[col] = !workspace.unsplittable;
303
- workspace.ncols_tried += !workspace.unsplittable;
304
- }
305
- }
306
-
307
- else
308
- {
309
- for (size_t col = 0; col < input_data.ncols_numeric; col++)
310
- {
311
- if (!workspace.cols_possible[col]) continue;
312
- get_range(workspace.ix_arr.data(), workspace.st, workspace.end, col,
313
- input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
314
- model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable);
315
- workspace.cols_possible[col] = !workspace.unsplittable;
316
- workspace.ncols_tried += !workspace.unsplittable;
317
- }
318
- }
319
-
320
- for (size_t col = 0; col < input_data.ncols_categ; col++)
321
- {
322
- if (!workspace.cols_possible[col + input_data.ncols_numeric]) continue;
323
- get_categs(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * col,
324
- workspace.st, workspace.end, input_data.ncat[col],
325
- model_params.missing_action, workspace.categs.data(), workspace.npresent, workspace.unsplittable);
326
- workspace.cols_possible[col + input_data.ncols_numeric] = !workspace.unsplittable;
327
- workspace.ncols_tried += !workspace.unsplittable;
328
- }
329
-
330
-
331
- /* if no further splits are possible, end the procedure here */
332
- workspace.npresent = workspace.ncols_tried;
333
- if (!workspace.npresent) goto terminal_statistics;
334
-
335
- /* otherwise, pick a column at random from the possible ones */
336
- if (!workspace.col_sampler.max())
337
- {
338
- /* no weights by columns */
339
- trees.back().col_num = std::uniform_int_distribution<size_t>
340
- (0, workspace.npresent - 1)
341
- (workspace.rnd_generator);
342
- workspace.ncols_tried = 0;
343
- for (size_t col = 0; col < input_data.ncols_tot; col++)
344
- {
345
-
346
- if (workspace.cols_possible[col])
347
- {
348
- if (workspace.ncols_tried == trees.back().col_num)
349
- {
350
- if (col < input_data.ncols_numeric)
351
- {
352
- trees.back().col_num = col;
353
- trees.back().col_type = Numeric;
354
- }
355
-
356
- else
357
- {
358
- trees.back().col_num = col - input_data.ncols_numeric;
359
- trees.back().col_type = Categorical;
360
- }
361
- break;
362
- }
363
- workspace.ncols_tried++;
364
- }
365
-
366
- }
367
- }
368
-
369
- else
370
- {
371
- /* weights by columns */
372
- std::vector<double> col_weights = workspace.col_sampler.probabilities();
373
- update_col_sampler(workspace, input_data);
374
-
375
- decide_column(input_data.ncols_numeric, input_data.ncols_categ,
376
- trees.back().col_num, trees.back().col_type,
377
- workspace.rnd_generator, workspace.runif,
378
- workspace.col_sampler);
379
- }
380
-
381
- }
382
-
383
- /* finally, check the range if needed, and later decide on the split point */
384
- if (workspace.criterion == NoCrit)
385
- get_split_range(workspace, input_data, model_params, trees.back());
386
-
387
- }
388
-
389
- }
390
-
391
-
392
- /* for numeric, choose a random point, or pick the best point as determined earlier */
393
- if (trees.back().col_type == Numeric)
394
- {
395
- if (workspace.determine_split)
396
- {
397
- switch(workspace.criterion)
398
- {
399
- case NoCrit:
400
- {
401
- trees.back().num_split = std::uniform_real_distribution<double>
402
- (workspace.xmin, workspace.xmax)
403
- (workspace.rnd_generator);
404
- break;
405
- }
406
-
407
- default:
408
- {
409
- if (input_data.Xc_indptr == NULL)
410
- {
411
- eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
412
- input_data.numeric_data + trees.back().col_num * input_data.nrows,
413
- workspace.split_ix, trees.back().num_split,
414
- workspace.xmin, workspace.xmax,
415
- workspace.criterion, model_params.min_gain,
416
- model_params.missing_action);
417
- if (model_params.missing_action == Fail) /* data is already split */
418
- {
419
- workspace.split_ix++;
420
- goto follow_branches;
421
- }
422
- }
423
-
424
- else
425
- {
426
- eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
427
- trees.back().col_num, input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
428
- workspace.buffer_dbl.data(), workspace.buffer_szt.data(),
429
- trees.back().num_split, workspace.xmin, workspace.xmax,
430
- workspace.criterion, model_params.min_gain,
431
- model_params.missing_action);
432
- }
433
- break;
434
- }
435
- }
436
-
437
- if (model_params.penalize_range)
438
- {
439
- trees.back().range_low = workspace.xmin - workspace.xmax + trees.back().num_split;
440
- trees.back().range_high = workspace.xmax - workspace.xmin + trees.back().num_split;
441
- }
442
- }
443
-
444
- if (input_data.Xc_indptr == NULL)
445
- divide_subset_split(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * trees.back().col_num,
446
- workspace.st, workspace.end, trees.back().num_split, model_params.missing_action,
447
- workspace.st_NA, workspace.end_NA, workspace.split_ix);
448
- else
449
- divide_subset_split(workspace.ix_arr.data(), workspace.st, workspace.end, trees.back().col_num,
450
- input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr, trees.back().num_split,
451
- model_params.missing_action, workspace.st_NA, workspace.end_NA, workspace.split_ix);
452
- }
453
-
454
- /* for categorical, there are different ways of splitting */
455
- else
456
- {
457
- /* if the columns is binary, there's only one possible split */
458
- if (input_data.ncat[trees.back().col_num] <= 2)
459
- {
460
- trees.back().chosen_cat = 0;
461
- divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
462
- workspace.st, workspace.end, (int)0, model_params.missing_action,
463
- workspace.st_NA, workspace.end_NA, workspace.split_ix);
464
- trees.back().cat_split.clear();
465
- trees.back().cat_split.shrink_to_fit();
466
- }
467
-
468
- /* otherwise, split according to desired type (single/subset) */
469
- /* TODO: refactor this */
470
- else
471
- {
472
-
473
- switch(model_params.cat_split_type)
474
- {
475
-
476
- case SingleCateg:
477
- {
478
-
479
- if (workspace.determine_split)
480
- {
481
- switch(workspace.criterion)
482
- {
483
- case NoCrit:
484
- {
485
- trees.back().chosen_cat = choose_cat_from_present(workspace, input_data, trees.back().col_num);
486
- break;
487
- }
488
-
489
- default:
490
- {
491
- eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
492
- input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
493
- workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
494
- workspace.buffer_dbl.data(), trees.back().chosen_cat, workspace.this_split_categ.data(),
495
- workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
496
- model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
497
- break;
498
- }
499
- }
500
- }
501
-
502
-
503
- divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
504
- workspace.st, workspace.end, trees.back().chosen_cat, model_params.missing_action,
505
- workspace.st_NA, workspace.end_NA, workspace.split_ix);
506
- break;
507
- }
508
-
509
-
510
- case SubSet:
511
- {
512
-
513
- if (workspace.determine_split)
514
- {
515
- switch(workspace.criterion)
516
- {
517
- case NoCrit:
518
- {
519
- workspace.unsplittable = true;
520
- while(workspace.unsplittable)
521
- {
522
- workspace.npresent = 0;
523
- workspace.ncols_tried = 0;
524
- for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
525
- {
526
- if (workspace.categs[cat] >= 0)
527
- {
528
- workspace.categs[cat] = workspace.rbin(workspace.rnd_generator) < 0.5;
529
- workspace.npresent += workspace.categs[cat];
530
- workspace.ncols_tried += !workspace.categs[cat];
531
- }
532
- workspace.unsplittable = !(workspace.npresent && workspace.ncols_tried);
533
- }
534
- }
535
-
536
- trees.back().cat_split.assign(workspace.categs.begin(), workspace.categs.begin() + input_data.ncat[trees.back().col_num]);
537
- break; /* NoCrit */
538
- }
539
-
540
- default:
541
- {
542
- trees.back().cat_split.resize(input_data.ncat[trees.back().col_num]);
543
- eval_guided_crit(workspace.ix_arr.data(), workspace.st, workspace.end,
544
- input_data.categ_data + trees.back().col_num * input_data.nrows, input_data.ncat[trees.back().col_num],
545
- workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ,
546
- workspace.buffer_dbl.data(), trees.back().chosen_cat, trees.back().cat_split.data(),
547
- workspace.buffer_chr.data(), workspace.criterion, model_params.min_gain,
548
- model_params.all_perm, model_params.missing_action, model_params.cat_split_type);
549
- break;
550
- }
551
- }
552
- }
553
-
554
- if (model_params.new_cat_action == Random)
555
- for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
556
- if (trees.back().cat_split[cat] < 0)
557
- trees.back().cat_split[cat] = workspace.rbin(workspace.rnd_generator) < 0.5;
558
-
559
- divide_subset_split(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * trees.back().col_num,
560
- workspace.st, workspace.end, trees.back().cat_split.data(), model_params.missing_action,
561
- workspace.st_NA, workspace.end_NA, workspace.split_ix);
562
- }
563
-
564
- }
565
-
566
- }
567
-
568
- }
569
-
570
-
571
- /* if it hasn't reached the limit, continue splitting from here */
572
- follow_branches:
573
- {
574
- /* add another round of separation depth for distance */
575
- if (model_params.calc_dist && curr_depth > 0)
576
- add_separation_step(workspace, input_data, (double)(-1));
577
-
578
- size_t tree_from = trees.size() - 1;
579
- size_t ix2, ix3;
580
- std::unique_ptr<std::vector<bool>> cols_possible_ptr;
581
- std::unique_ptr<std::discrete_distribution<size_t>> col_sampler_ptr;
582
- trees.back().score = -1;
583
-
584
- /* compute statistics for NAs and remember recursion indices/weights */
585
- std::unique_ptr<RecursionState> recursion_state;
586
- if (model_params.missing_action != Fail)
587
- {
588
- recursion_state = std::unique_ptr<RecursionState>(new RecursionState);
589
- backup_recursion_state(workspace, *recursion_state);
590
-
591
- trees.back().pct_tree_left = (long double)(workspace.st_NA - workspace.st)
592
- /
593
- (long double)(workspace.end - workspace.st + 1 - (workspace.end_NA - workspace.st_NA));
594
-
595
- switch(model_params.missing_action)
596
- {
597
- case Impute:
598
- {
599
- if (trees.back().pct_tree_left >= .5)
600
- workspace.end = workspace.end_NA - 1;
601
- else
602
- workspace.end = workspace.st_NA - 1;
603
- break;
604
- }
605
-
606
-
607
- case Divide:
608
- {
609
- if (workspace.weights_map.size())
610
- for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
611
- workspace.weights_map[workspace.ix_arr[row]] *= trees.back().pct_tree_left;
612
- else
613
- for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
614
- workspace.weights_arr[workspace.ix_arr[row]] *= trees.back().pct_tree_left;
615
- workspace.end = workspace.end_NA - 1;
616
- break;
617
- }
618
- }
619
- }
620
-
621
- else
622
- {
623
- trees.back().pct_tree_left = (long double) (workspace.split_ix - workspace.st)
624
- /
625
- (long double) (workspace.end - workspace.st + 1);
626
-
627
- ix2 = workspace.split_ix;
628
- ix3 = workspace.end;
629
- cols_possible_ptr = std::unique_ptr<std::vector<bool>>(new std::vector<bool>);
630
- *cols_possible_ptr = workspace.cols_possible;
631
- if (workspace.col_sampler.max())
632
- {
633
- col_sampler_ptr = std::unique_ptr<std::discrete_distribution<size_t>>(new std::discrete_distribution<size_t>);
634
- *col_sampler_ptr = workspace.col_sampler;
635
- }
636
- workspace.end = workspace.split_ix - 1;
637
- }
638
-
639
- /* Branch where to assign new categories can be pre-determined in this case */
640
- if (
641
- trees.back().col_type == Categorical &&
642
- model_params.cat_split_type == SubSet &&
643
- input_data.ncat[trees.back().col_num] > 2 &&
644
- model_params.new_cat_action == Smallest
645
- )
646
- {
647
- bool new_to_left = trees.back().pct_tree_left < 0.5;
648
- for (int cat = 0; cat < input_data.ncat[trees.back().col_num]; cat++)
649
- if (trees.back().cat_split[cat] < 0)
650
- trees.back().cat_split[cat] = new_to_left;
651
- }
652
-
653
- /* left branch */
654
- trees.back().tree_left = trees.size();
655
- trees.emplace_back();
656
- if (impute_nodes != NULL) impute_nodes->emplace_back(tree_from);
657
- split_itree_recursive(trees,
658
- workspace,
659
- input_data,
660
- model_params,
661
- impute_nodes,
662
- curr_depth + 1);
663
-
664
-
665
- /* right branch */
666
- if (model_params.missing_action != Fail)
667
- {
668
- restore_recursion_state(workspace, *recursion_state);
669
-
670
- switch(model_params.missing_action)
671
- {
672
- case Impute:
673
- {
674
- if (trees[tree_from].pct_tree_left >= .5)
675
- workspace.st = workspace.end_NA;
676
- else
677
- workspace.st = workspace.st_NA;
678
- break;
679
- }
680
-
681
- case Divide:
682
- {
683
- if (workspace.weights_map.size())
684
- for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
685
- workspace.weights_map[workspace.ix_arr[row]] *= (1 - trees[tree_from].pct_tree_left);
686
- else
687
- for (size_t row = workspace.st_NA; row < workspace.end_NA; row++)
688
- workspace.weights_arr[workspace.ix_arr[row]] *= (1 - trees[tree_from].pct_tree_left);
689
- workspace.st = workspace.st_NA;
690
- break;
691
- }
692
- }
693
- }
694
-
695
- else
696
- {
697
- workspace.st = ix2;
698
- workspace.end = ix3;
699
- workspace.cols_possible = std::move(*cols_possible_ptr);
700
- if (col_sampler_ptr)
701
- workspace.col_sampler = std::move(*col_sampler_ptr);
702
- }
703
-
704
- trees[tree_from].tree_right = trees.size();
705
- trees.emplace_back();
706
- if (impute_nodes != NULL) impute_nodes->emplace_back(tree_from);
707
- split_itree_recursive(trees,
708
- workspace,
709
- input_data,
710
- model_params,
711
- impute_nodes,
712
- curr_depth + 1);
713
- }
714
- return;
715
-
716
- /* if it reached the limit, calculate terminal statistics */
717
- terminal_statistics:
718
- {
719
- if (!workspace.weights_arr.size() && !workspace.weights_map.size())
720
- {
721
- trees.back().score = (double)(curr_depth + expected_avg_depth(workspace.end - workspace.st + 1));
722
- }
723
-
724
- else
725
- {
726
- if (sum_weight == -HUGE_VAL)
727
- sum_weight = calculate_sum_weights(workspace.ix_arr, workspace.st, workspace.end, curr_depth,
728
- workspace.weights_arr, workspace.weights_map);
729
- trees.back().score = (double)(curr_depth + expected_avg_depth(sum_weight));
730
- }
731
-
732
- trees.back().cat_split.clear();
733
- trees.back().cat_split.shrink_to_fit();
734
-
735
- trees.back().remainder = workspace.weights_arr.size()?
736
- sum_weight : (workspace.weights_map.size()?
737
- sum_weight : ((double)(workspace.end - workspace.st + 1))
738
- );
739
-
740
- /* for distance, assume also the elements keep being split */
741
- if (model_params.calc_dist)
742
- add_remainder_separation_steps(workspace, input_data, sum_weight);
743
-
744
- /* add this depth right away if requested */
745
- if (workspace.row_depths.size())
746
- {
747
- if (!workspace.weights_arr.size() && !workspace.weights_map.size())
748
- {
749
- for (size_t row = workspace.st; row <= workspace.end; row++)
750
- workspace.row_depths[workspace.ix_arr[row]] += trees.back().score;
751
- }
752
-
753
- else if (workspace.weights_arr.size())
754
- {
755
- for (size_t row = workspace.st; row <= workspace.end; row++)
756
- workspace.row_depths[workspace.ix_arr[row]] += workspace.weights_arr[workspace.ix_arr[row]] * trees.back().score;
757
- }
758
-
759
- else
760
- {
761
- for (size_t row = workspace.st; row <= workspace.end; row++)
762
- workspace.row_depths[workspace.ix_arr[row]] += workspace.weights_map[workspace.ix_arr[row]] * trees.back().score;
763
- }
764
- }
765
-
766
- /* add imputations from node if requested */
767
- if (model_params.impute_at_fit)
768
- add_from_impute_node(impute_nodes->back(), workspace, input_data);
769
- }
770
-
771
- }