isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -1,790 +0,0 @@
1
- /* Isolation forests and variations thereof, with adjustments for incorporation
2
- * of categorical variables and missing values.
3
- * Writen for C++11 standard and aimed at being used in R and Python.
4
- *
5
- * This library is based on the following works:
6
- * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
- * "Isolation forest."
8
- * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
- * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
- * "Isolation-based anomaly detection."
11
- * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
- * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
- * "Extended Isolation Forest."
14
- * arXiv preprint arXiv:1811.02141 (2018).
15
- * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
- * "On detecting clustered anomalies using SCiForest."
17
- * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
- * [5] https://sourceforge.net/projects/iforest/
19
- * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
- * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
- * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
- * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
- *
24
- * BSD 2-Clause License
25
- * Copyright (c) 2020, David Cortes
26
- * All rights reserved.
27
- * Redistribution and use in source and binary forms, with or without
28
- * modification, are permitted provided that the following conditions are met:
29
- * * Redistributions of source code must retain the above copyright notice, this
30
- * list of conditions and the following disclaimer.
31
- * * Redistributions in binary form must reproduce the above copyright notice,
32
- * this list of conditions and the following disclaimer in the documentation
33
- * and/or other materials provided with the distribution.
34
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
- */
45
- #include "isotree.hpp"
46
-
47
- void split_hplane_recursive(std::vector<IsoHPlane> &hplanes,
48
- WorkerMemory &workspace,
49
- InputData &input_data,
50
- ModelParams &model_params,
51
- std::vector<ImputeNode> *impute_nodes,
52
- size_t curr_depth)
53
- {
54
- long double sum_weight = -HUGE_VAL;
55
- size_t hplane_from = hplanes.size() - 1;
56
- std::unique_ptr<RecursionState> recursion_state;
57
- std::vector<bool> col_is_taken;
58
- std::unordered_set<size_t> col_is_taken_s;
59
-
60
- /* calculate imputation statistics if desired */
61
- if (impute_nodes != NULL)
62
- {
63
- if (input_data.Xc_indptr != NULL)
64
- std::sort(workspace.ix_arr.begin() + workspace.st,
65
- workspace.ix_arr.begin() + workspace.end + 1);
66
- build_impute_node(impute_nodes->back(), workspace,
67
- input_data, model_params,
68
- *impute_nodes, curr_depth,
69
- model_params.min_imp_obs);
70
- }
71
-
72
- /* check for potential isolated leafs */
73
- if (workspace.end == workspace.st || curr_depth >= model_params.max_depth)
74
- goto terminal_statistics;
75
-
76
- /* with 2 observations and no weights, there's only 1 potential or assumed split */
77
- if ((workspace.end - workspace.st) == 1 && !workspace.weights_arr.size() && !workspace.weights_map.size())
78
- goto terminal_statistics;
79
-
80
- /* when using weights, the split should stop when the sum of weights is <= 2 */
81
- sum_weight = calculate_sum_weights(workspace.ix_arr, workspace.st, workspace.end, curr_depth,
82
- workspace.weights_arr, workspace.weights_map);
83
-
84
- if (curr_depth > 0 && (workspace.weights_arr.size() || workspace.weights_map.size()) && sum_weight < 2.5)
85
- goto terminal_statistics;
86
-
87
- /* for sparse matrices, need to sort the indices */
88
- if (input_data.Xc_indptr != NULL && impute_nodes == NULL)
89
- std::sort(workspace.ix_arr.begin() + workspace.st, workspace.ix_arr.begin() + workspace.end + 1);
90
-
91
- /* pick column to split according to criteria */
92
- workspace.prob_split_type = workspace.rbin(workspace.rnd_generator);
93
-
94
- if (
95
- workspace.prob_split_type
96
- < (
97
- model_params.prob_pick_by_gain_avg +
98
- model_params.prob_pick_by_gain_pl
99
- )
100
- )
101
- {
102
- workspace.ntry = model_params.ntry;
103
- hplanes.back().score = -HUGE_VAL; /* this keeps track of the gain */
104
- if (workspace.prob_split_type < model_params.prob_pick_by_gain_avg)
105
- workspace.criterion = Averaged;
106
- else
107
- workspace.criterion = Pooled;
108
- }
109
-
110
- else
111
- {
112
- workspace.criterion = NoCrit;
113
- workspace.ntry = 1;
114
- }
115
-
116
- workspace.ntaken_best = 0;
117
-
118
- for (size_t attempt = 0; attempt < workspace.ntry; attempt++)
119
- {
120
- if (input_data.ncols_tot < 1e3)
121
- {
122
- if (!col_is_taken.size())
123
- col_is_taken.resize(input_data.ncols_tot, false);
124
- else
125
- col_is_taken.assign(input_data.ncols_tot, false);
126
- }
127
- else
128
- col_is_taken_s.clear();
129
- workspace.ntaken = 0;
130
- workspace.ncols_tried = 0;
131
- std::fill(workspace.comb_val.begin(),
132
- workspace.comb_val.begin() + (workspace.end - workspace.st + 1),
133
- (double)0);
134
-
135
- workspace.tried_all = false;
136
- if (model_params.ndim < input_data.ncols_tot / 2 || workspace.col_sampler.max())
137
- {
138
- while(workspace.ncols_tried < std::max(input_data.ncols_tot / 2, model_params.ndim))
139
- {
140
- workspace.ncols_tried++;
141
- decide_column(input_data.ncols_numeric, input_data.ncols_categ,
142
- workspace.col_chosen, workspace.col_type,
143
- workspace.rnd_generator, workspace.runif,
144
- workspace.col_sampler);
145
-
146
- if (
147
- (workspace.col_type == Numeric && !workspace.cols_possible[workspace.col_chosen])
148
- ||
149
- (workspace.col_type == Categorical && !workspace.cols_possible[workspace.col_chosen + input_data.ncols_numeric])
150
- ||
151
- is_col_taken(col_is_taken, col_is_taken_s, input_data, workspace.col_chosen, workspace.col_type)
152
- )
153
- continue;
154
-
155
-
156
- get_split_range(workspace, input_data, model_params);
157
- if (workspace.unsplittable)
158
- {
159
- add_unsplittable_col(workspace, input_data);
160
- }
161
-
162
- else
163
- {
164
- add_chosen_column(workspace, input_data, model_params, col_is_taken, col_is_taken_s);
165
- if (++workspace.ntaken >= model_params.ndim)
166
- break;
167
- }
168
-
169
- }
170
-
171
- if (workspace.ntaken < model_params.ndim)
172
- {
173
- update_col_sampler(workspace, input_data);
174
- goto probe_all;
175
- }
176
- }
177
-
178
- else /* probe all columns */
179
- {
180
- probe_all:
181
- workspace.tried_all = true;
182
- std::iota(workspace.cols_shuffled.begin(), workspace.cols_shuffled.end(), (size_t)0);
183
- if (model_params.ndim < input_data.ncols_tot)
184
- {
185
-
186
- if (!workspace.col_sampler.max())
187
- {
188
- std::shuffle(workspace.cols_shuffled.begin(),
189
- workspace.cols_shuffled.end(),
190
- workspace.rnd_generator);
191
- }
192
-
193
- else
194
- {
195
- if (!model_params.weigh_by_kurt)
196
- {
197
- weighted_shuffle(workspace.cols_shuffled.data(), input_data.ncols_tot, input_data.col_weights,
198
- workspace.buffer_dbl.data(), workspace.rnd_generator);
199
- }
200
-
201
- else
202
- {
203
- std::vector<double> col_weights = workspace.col_sampler.probabilities();
204
- /* sampler will fail if passed weights of zero, so need to discard those first and then remap */
205
- std::iota(workspace.buffer_szt.begin(), workspace.buffer_szt.begin() + input_data.ncols_tot, (size_t)0);
206
- long st = input_data.ncols_tot - 1;
207
- for (long col = st; col >= 0; col--)
208
- {
209
- if (col_weights[col] <= 0)
210
- {
211
- std::swap(col_weights[st], col_weights[col]);
212
- std::swap(workspace.buffer_szt[st], workspace.buffer_szt[col]);
213
- st--;
214
- }
215
- }
216
-
217
- if ((size_t)st == input_data.ncols_tot - 1)
218
- {
219
- weighted_shuffle(workspace.cols_shuffled.data(), input_data.ncols_tot, col_weights.data(),
220
- workspace.buffer_dbl.data(), workspace.rnd_generator);
221
- }
222
-
223
- else if (st < 0)
224
- {
225
- goto terminal_statistics;
226
- }
227
-
228
- else if (st == 0)
229
- {
230
- std::copy(workspace.buffer_szt.begin(),
231
- workspace.buffer_szt.begin() + input_data.ncols_tot,
232
- workspace.cols_shuffled.begin());
233
- }
234
-
235
- else
236
- {
237
- weighted_shuffle(workspace.buffer_szt.data(), (size_t) ++st, col_weights.data(),
238
- workspace.buffer_dbl.data(), workspace.rnd_generator);
239
- std::copy(workspace.buffer_szt.begin(),
240
- workspace.buffer_szt.begin() + input_data.ncols_tot,
241
- workspace.cols_shuffled.begin());
242
- }
243
- }
244
- }
245
- }
246
-
247
- for (size_t col : workspace.cols_shuffled)
248
- {
249
- if (
250
- !workspace.cols_possible[col]
251
- ||
252
- (workspace.ntaken
253
- &&
254
- is_col_taken(col_is_taken, col_is_taken_s, input_data,
255
- (col < input_data.ncols_numeric)? col : col - input_data.ncols_numeric,
256
- (col < input_data.ncols_numeric)? Numeric : Categorical)
257
- )
258
- )
259
- continue;
260
-
261
- if (col < input_data.ncols_numeric)
262
- {
263
- workspace.col_chosen = col;
264
- workspace.col_type = Numeric;
265
- }
266
-
267
- else
268
- {
269
- workspace.col_chosen = col - input_data.ncols_numeric;
270
- workspace.col_type = Categorical;
271
- }
272
-
273
- get_split_range(workspace, input_data, model_params);
274
- if (workspace.unsplittable)
275
- {
276
- add_unsplittable_col(workspace, input_data);
277
- }
278
-
279
- else
280
- {
281
- add_chosen_column(workspace, input_data, model_params, col_is_taken, col_is_taken_s);
282
- if (++workspace.ntaken >= model_params.ndim)
283
- break;
284
- }
285
- }
286
-
287
- if (model_params.weigh_by_kurt)
288
- update_col_sampler(workspace, input_data);
289
- }
290
-
291
- /* evaluate gain if necessary */
292
- if (workspace.criterion != NoCrit)
293
- workspace.this_gain = eval_guided_crit(workspace.comb_val.data(), workspace.end - workspace.st + 1,
294
- workspace.criterion, model_params.min_gain, workspace.this_split_point,
295
- workspace.xmin, workspace.xmax);
296
-
297
- /* pass to the output object */
298
- if (workspace.ntry == 1 || workspace.this_gain > hplanes.back().score)
299
- {
300
- /* these should be shrunk later according to what ends up used */
301
- hplanes.back().score = workspace.this_gain;
302
- workspace.ntaken_best = workspace.ntaken;
303
- if (workspace.criterion != NoCrit)
304
- {
305
- hplanes.back().split_point = workspace.this_split_point;
306
- if (model_params.penalize_range)
307
- {
308
- hplanes.back().range_low = workspace.xmin - workspace.xmax + hplanes.back().split_point;
309
- hplanes.back().range_high = workspace.xmax - workspace.xmin + hplanes.back().split_point;
310
- }
311
- }
312
- hplanes.back().col_num.assign(workspace.col_take.begin(), workspace.col_take.begin() + workspace.ntaken);
313
- hplanes.back().col_type.assign(workspace.col_take_type.begin(), workspace.col_take_type.begin() + workspace.ntaken);
314
- if (input_data.ncols_numeric)
315
- {
316
- hplanes.back().coef.assign(workspace.ext_coef.begin(), workspace.ext_coef.begin() + workspace.ntaken);
317
- hplanes.back().mean.assign(workspace.ext_mean.begin(), workspace.ext_mean.begin() + workspace.ntaken);
318
- }
319
-
320
- if (model_params.missing_action != Fail)
321
- hplanes.back().fill_val.assign(workspace.ext_fill_val.begin(), workspace.ext_fill_val.begin() + workspace.ntaken);
322
-
323
- if (input_data.ncols_categ)
324
- {
325
- hplanes.back().fill_new.assign(workspace.ext_fill_new.begin(), workspace.ext_fill_new.begin() + workspace.ntaken);
326
- switch(model_params.cat_split_type)
327
- {
328
- case SingleCateg:
329
- {
330
- hplanes.back().chosen_cat.assign(workspace.chosen_cat.begin(),
331
- workspace.chosen_cat.begin() + workspace.ntaken);
332
- break;
333
- }
334
-
335
- case SubSet:
336
- {
337
- if (hplanes.back().cat_coef.size() < workspace.ntaken)
338
- hplanes.back().cat_coef.assign(workspace.ext_cat_coef.begin(),
339
- workspace.ext_cat_coef.begin() + workspace.ntaken);
340
- else
341
- for (size_t col = 0; col < workspace.ntaken_best; col++)
342
- std::copy(workspace.ext_cat_coef[col].begin(),
343
- workspace.ext_cat_coef[col].end(),
344
- hplanes.back().cat_coef[col].begin());
345
- break;
346
- }
347
- }
348
- }
349
- }
350
-
351
- }
352
-
353
- /* if there isn't a single splittable column, end here */
354
- if (!workspace.ntaken_best && !workspace.ntaken && workspace.tried_all)
355
- goto terminal_statistics;
356
-
357
- /* if the best split is not good enough, don't split any further */
358
- if (workspace.criterion != NoCrit && hplanes.back().score <= 0)
359
- goto terminal_statistics;
360
-
361
- /* now need to reproduce the same split from before */
362
- if (workspace.criterion != NoCrit && workspace.ntry > 1)
363
- {
364
- std::fill(workspace.comb_val.begin(),
365
- workspace.comb_val.begin() + (workspace.end - workspace.st + 1),
366
- (double)0);
367
- for (size_t col = 0; col < workspace.ntaken_best; col++)
368
- {
369
- switch(hplanes.back().col_type[col])
370
- {
371
- case Numeric:
372
- {
373
- if (input_data.Xc_indptr == NULL)
374
- {
375
- add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
376
- input_data.numeric_data + hplanes.back().col_num[col] * input_data.nrows,
377
- hplanes.back().coef[col], (double)0, hplanes.back().mean[col],
378
- hplanes.back().fill_val.size()? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
379
- model_params.missing_action, NULL, NULL, false);
380
- }
381
-
382
- else
383
- {
384
- add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end,
385
- hplanes.back().col_num[col], workspace.comb_val.data(),
386
- input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
387
- hplanes.back().coef[col], (double)0, hplanes.back().mean[col],
388
- hplanes.back().fill_val.size()? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
389
- model_params.missing_action, NULL, NULL, false);
390
- }
391
-
392
- break;
393
- }
394
-
395
- case Categorical:
396
- {
397
- add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
398
- input_data.categ_data + hplanes.back().col_num[col] * input_data.nrows,
399
- input_data.ncat[hplanes.back().col_num[col]],
400
- (model_params.cat_split_type == SubSet)? hplanes.back().cat_coef[col].data() : NULL,
401
- (model_params.cat_split_type == SingleCateg)? hplanes.back().fill_new[col] : (double)0,
402
- (model_params.cat_split_type == SingleCateg)? hplanes.back().chosen_cat[col] : 0,
403
- (hplanes.back().fill_val.size())? hplanes.back().fill_val[col] : workspace.this_split_point, /* second case is not used */
404
- (model_params.cat_split_type == SubSet)? hplanes.back().fill_new[col] : workspace.this_split_point, /* second case is not used */
405
- NULL, NULL, model_params.new_cat_action, model_params.missing_action,
406
- model_params.cat_split_type, false);
407
- break;
408
- }
409
- }
410
- }
411
- }
412
-
413
- /* get the range */
414
- if (workspace.criterion == NoCrit)
415
- {
416
- workspace.xmin = HUGE_VAL;
417
- workspace.xmax = -HUGE_VAL;
418
- for (size_t row = 0; row < (workspace.end - workspace.st + 1); row++)
419
- {
420
- workspace.xmin = (workspace.xmin > workspace.comb_val[row])? workspace.comb_val[row] : workspace.xmin;
421
- workspace.xmax = (workspace.xmax < workspace.comb_val[row])? workspace.comb_val[row] : workspace.xmax;
422
- }
423
- if (workspace.xmin == workspace.xmax)
424
- goto terminal_statistics; /* in theory, could try again too, this could just be an unlucky case */
425
-
426
- hplanes.back().split_point =
427
- std::uniform_real_distribution<double>(workspace.xmin, workspace.xmax)
428
- (workspace.rnd_generator);
429
-
430
- /* determine acceptable range */
431
- if (model_params.penalize_range)
432
- {
433
- hplanes.back().range_low = workspace.xmin - workspace.xmax + hplanes.back().split_point;
434
- hplanes.back().range_high = workspace.xmax - workspace.xmin + hplanes.back().split_point;
435
- }
436
- }
437
-
438
- /* divide */
439
- workspace.split_ix = divide_subset_split(workspace.ix_arr.data(), workspace.comb_val.data(),
440
- workspace.st, workspace.end, hplanes.back().split_point);
441
-
442
- /* set as non-terminal */
443
- hplanes.back().score = -1;
444
-
445
- /* add another round of separation depth for distance */
446
- if (model_params.calc_dist && curr_depth > 0)
447
- add_separation_step(workspace, input_data, (double)(-1));
448
-
449
- /* simplify vectors according to what ends up used */
450
- if (input_data.ncols_categ || workspace.ntaken_best < model_params.ndim)
451
- simplify_hplane(hplanes.back(), workspace, input_data, model_params);
452
-
453
- shrink_to_fit_hplane(hplanes.back(), false);
454
-
455
- /* now split */
456
-
457
- /* back-up where it was */
458
- recursion_state = std::unique_ptr<RecursionState>(new RecursionState);
459
- backup_recursion_state(workspace, *recursion_state);
460
-
461
- /* follow left branch */
462
- hplanes[hplane_from].hplane_left = hplanes.size();
463
- hplanes.emplace_back();
464
- if (impute_nodes != NULL) impute_nodes->emplace_back(hplane_from);
465
- workspace.end = workspace.split_ix - 1;
466
- split_hplane_recursive(hplanes,
467
- workspace,
468
- input_data,
469
- model_params,
470
- impute_nodes,
471
- curr_depth + 1);
472
-
473
-
474
- /* follow right branch */
475
- hplanes[hplane_from].hplane_right = hplanes.size();
476
- restore_recursion_state(workspace, *recursion_state);
477
- hplanes.emplace_back();
478
- if (impute_nodes != NULL) impute_nodes->emplace_back(hplane_from);
479
- workspace.st = workspace.split_ix;
480
- split_hplane_recursive(hplanes,
481
- workspace,
482
- input_data,
483
- model_params,
484
- impute_nodes,
485
- curr_depth + 1);
486
-
487
- return;
488
-
489
- terminal_statistics:
490
- {
491
- if (!workspace.weights_arr.size() && !workspace.weights_map.size())
492
- {
493
- hplanes.back().score = (double)(curr_depth + expected_avg_depth(workspace.end - workspace.st + 1));
494
- }
495
-
496
- else
497
- {
498
- if (sum_weight == -HUGE_VAL)
499
- sum_weight = calculate_sum_weights(workspace.ix_arr, workspace.st, workspace.end, curr_depth,
500
- workspace.weights_arr, workspace.weights_map);
501
- hplanes.back().score = (double)(curr_depth + expected_avg_depth(sum_weight));
502
- }
503
-
504
- /* don't leave any vector initialized */
505
- shrink_to_fit_hplane(hplanes.back(), true);
506
-
507
- hplanes.back().remainder = workspace.weights_arr.size()?
508
- sum_weight : (workspace.weights_map.size()?
509
- sum_weight : ((double)(workspace.end - workspace.st + 1))
510
- );
511
-
512
- /* for distance, assume also the elements keep being split */
513
- if (model_params.calc_dist)
514
- add_remainder_separation_steps(workspace, input_data, sum_weight);
515
-
516
- /* add this depth right away if requested */
517
- if (workspace.row_depths.size())
518
- for (size_t row = workspace.st; row <= workspace.end; row++)
519
- workspace.row_depths[workspace.ix_arr[row]] += hplanes.back().score;
520
-
521
- /* add imputations from node if requested */
522
- if (model_params.impute_at_fit)
523
- add_from_impute_node(impute_nodes->back(), workspace, input_data);
524
- }
525
- }
526
-
527
-
528
- void add_chosen_column(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params,
529
- std::vector<bool> &col_is_taken, std::unordered_set<size_t> &col_is_taken_s)
530
- {
531
- set_col_as_taken(col_is_taken, col_is_taken_s, input_data, workspace.col_chosen, workspace.col_type);
532
- workspace.col_take[workspace.ntaken] = workspace.col_chosen;
533
- workspace.col_take_type[workspace.ntaken] = workspace.col_type;
534
-
535
- switch(workspace.col_type)
536
- {
537
- case Numeric:
538
- {
539
- switch(model_params.coef_type)
540
- {
541
- case Uniform:
542
- {
543
- workspace.ext_coef[workspace.ntaken] = workspace.coef_unif(workspace.rnd_generator);
544
- break;
545
- }
546
-
547
- case Normal:
548
- {
549
- workspace.ext_coef[workspace.ntaken] = workspace.coef_norm(workspace.rnd_generator);
550
- break;
551
- }
552
- }
553
-
554
- if (input_data.Xc_indptr == NULL)
555
- {
556
- calc_mean_and_sd(workspace.ix_arr.data(), workspace.st, workspace.end,
557
- input_data.numeric_data + workspace.col_chosen * input_data.nrows,
558
- model_params.missing_action, workspace.ext_sd, workspace.ext_mean[workspace.ntaken]);
559
- add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
560
- input_data.numeric_data + workspace.col_chosen * input_data.nrows,
561
- workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
562
- workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
563
- workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true);
564
- }
565
-
566
- else
567
- {
568
- calc_mean_and_sd(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
569
- input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
570
- workspace.ext_sd, workspace.ext_mean[workspace.ntaken]);
571
- add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end,
572
- workspace.col_chosen, workspace.comb_val.data(),
573
- input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
574
- workspace.ext_coef[workspace.ntaken], workspace.ext_sd, workspace.ext_mean[workspace.ntaken],
575
- workspace.ext_fill_val[workspace.ntaken], model_params.missing_action,
576
- workspace.buffer_dbl.data(), workspace.buffer_szt.data(), true);
577
- }
578
- break;
579
- }
580
-
581
- case Categorical:
582
- {
583
- switch(model_params.cat_split_type)
584
- {
585
- case SingleCateg:
586
- {
587
- workspace.chosen_cat[workspace.ntaken] = choose_cat_from_present(workspace, input_data, workspace.col_chosen);
588
- workspace.ext_fill_new[workspace.ntaken] = workspace.coef_norm(workspace.rnd_generator);
589
- add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
590
- input_data.categ_data + workspace.col_chosen * input_data.nrows,
591
- input_data.ncat[workspace.col_chosen],
592
- NULL, workspace.ext_fill_new[workspace.ntaken],
593
- workspace.chosen_cat[workspace.ntaken],
594
- workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
595
- NULL, NULL, model_params.new_cat_action, model_params.missing_action, SingleCateg, true);
596
-
597
- break;
598
- }
599
-
600
- case SubSet:
601
- {
602
- for (int cat = 0; cat < input_data.ncat[workspace.col_chosen]; cat++)
603
- workspace.ext_cat_coef[workspace.ntaken][cat] = workspace.coef_norm(workspace.rnd_generator);
604
-
605
- if (model_params.coef_by_prop)
606
- {
607
- int ncat = input_data.ncat[workspace.col_chosen];
608
- size_t *restrict counts = workspace.buffer_szt.data();
609
- size_t *restrict sorted_ix = workspace.buffer_szt.data() + ncat;
610
- /* calculate counts and sort by them */
611
- std::fill(counts, counts + ncat, (size_t)0);
612
- for (size_t ix = workspace.st; ix <= workspace.end; ix++)
613
- if (input_data.categ_data[workspace.col_chosen * input_data.nrows + ix] >= 0)
614
- counts[input_data.categ_data[workspace.col_chosen * input_data.nrows + ix]]++;
615
- std::iota(sorted_ix, sorted_ix + ncat, (size_t)0);
616
- std::sort(sorted_ix, sorted_ix + ncat,
617
- [&counts](const size_t a, const size_t b){return counts[a] < counts[b];});
618
- /* now re-order the coefficients accordingly */
619
- std::sort(workspace.ext_cat_coef[workspace.ntaken].begin(),
620
- workspace.ext_cat_coef[workspace.ntaken].begin() + ncat);
621
- std::copy(workspace.ext_cat_coef[workspace.ntaken].begin(),
622
- workspace.ext_cat_coef[workspace.ntaken].begin() + ncat,
623
- workspace.buffer_dbl.begin());
624
- for (size_t ix = 0; ix < ncat; ix++)
625
- workspace.ext_cat_coef[workspace.ntaken][ix] = workspace.buffer_dbl[sorted_ix[ix]];
626
- }
627
-
628
- add_linear_comb(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.comb_val.data(),
629
- input_data.categ_data + workspace.col_chosen * input_data.nrows,
630
- input_data.ncat[workspace.col_chosen],
631
- workspace.ext_cat_coef[workspace.ntaken].data(), (double)0, (int)0,
632
- workspace.ext_fill_val[workspace.ntaken], workspace.ext_fill_new[workspace.ntaken],
633
- workspace.buffer_szt.data(), workspace.buffer_szt.data() + input_data.max_categ + 1,
634
- model_params.new_cat_action, model_params.missing_action, SubSet, true);
635
- break;
636
- }
637
- }
638
- break;
639
- }
640
- }
641
-
642
- double xmin = HUGE_VAL, xmax = -HUGE_VAL;
643
- for (size_t row = workspace.st; row <= workspace.end; row++)
644
- {
645
- xmin = fmin(xmin, workspace.comb_val[row - workspace.st]);
646
- xmax = fmax(xmax, workspace.comb_val[row - workspace.st]);
647
- }
648
- }
649
-
650
- void shrink_to_fit_hplane(IsoHPlane &hplane, bool clear_vectors)
651
- {
652
- if (clear_vectors)
653
- {
654
- hplane.col_num.clear();
655
- hplane.col_type.clear();
656
- hplane.coef.clear();
657
- hplane.mean.clear();
658
- hplane.cat_coef.clear();
659
- hplane.chosen_cat.clear();
660
- hplane.fill_val.clear();
661
- hplane.fill_new.clear();
662
- }
663
-
664
- hplane.col_num.shrink_to_fit();
665
- hplane.col_type.shrink_to_fit();
666
- hplane.coef.shrink_to_fit();
667
- hplane.mean.shrink_to_fit();
668
- hplane.cat_coef.shrink_to_fit();
669
- hplane.chosen_cat.shrink_to_fit();
670
- hplane.fill_val.shrink_to_fit();
671
- hplane.fill_new.shrink_to_fit();
672
- }
673
-
674
- void simplify_hplane(IsoHPlane &hplane, WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
675
- {
676
- if (workspace.ntaken_best < model_params.ndim)
677
- {
678
- hplane.col_num.resize(workspace.ntaken_best);
679
- hplane.col_type.resize(workspace.ntaken_best);
680
- if (model_params.missing_action != Fail)
681
- hplane.fill_val.resize(workspace.ntaken_best);
682
- }
683
-
684
- size_t ncols_numeric = 0;
685
- size_t ncols_categ = 0;
686
-
687
- if (input_data.ncols_categ)
688
- {
689
- for (size_t col = 0; col < workspace.ntaken_best; col++)
690
- {
691
- switch(hplane.col_type[col])
692
- {
693
- case Numeric:
694
- {
695
- workspace.ext_coef[ncols_numeric] = hplane.coef[col];
696
- workspace.ext_mean[ncols_numeric] = hplane.mean[col];
697
- ncols_numeric++;
698
- break;
699
- }
700
-
701
- case Categorical:
702
- {
703
- workspace.ext_fill_new[ncols_categ] = hplane.fill_new[col];
704
- switch(model_params.cat_split_type)
705
- {
706
- case SingleCateg:
707
- {
708
- workspace.chosen_cat[ncols_categ] = hplane.chosen_cat[col];
709
- break;
710
- }
711
-
712
- case SubSet:
713
- {
714
- std::copy(hplane.cat_coef[col].begin(),
715
- hplane.cat_coef[col].begin() + input_data.ncat[hplane.col_num[col]],
716
- workspace.ext_cat_coef[ncols_categ].begin());
717
- break;
718
- }
719
- }
720
- ncols_categ++;
721
- break;
722
- }
723
- }
724
- }
725
- }
726
-
727
- else
728
- {
729
- ncols_numeric = workspace.ntaken_best;
730
- }
731
-
732
-
733
- hplane.coef.resize(ncols_numeric);
734
- hplane.mean.resize(ncols_numeric);
735
- if (input_data.ncols_numeric)
736
- {
737
- std::copy(workspace.ext_coef.begin(), workspace.ext_coef.begin() + ncols_numeric, hplane.coef.begin());
738
- std::copy(workspace.ext_mean.begin(), workspace.ext_mean.begin() + ncols_numeric, hplane.mean.begin());
739
- }
740
-
741
- /* If there are no categorical columns, all of them will be numerical and there is no need to reorder */
742
- if (ncols_categ)
743
- {
744
- hplane.fill_new.resize(ncols_categ);
745
- std::copy(workspace.ext_fill_new.begin(),
746
- workspace.ext_fill_new.begin() + ncols_categ,
747
- hplane.fill_new.begin());
748
-
749
- hplane.cat_coef.resize(ncols_categ);
750
- switch(model_params.cat_split_type)
751
- {
752
- case SingleCateg:
753
- {
754
- hplane.chosen_cat.resize(ncols_categ);
755
- std::copy(workspace.chosen_cat.begin(),
756
- workspace.chosen_cat.begin() + ncols_categ,
757
- hplane.chosen_cat.begin());
758
- hplane.cat_coef.clear();
759
- break;
760
- }
761
-
762
- case SubSet:
763
- {
764
- hplane.chosen_cat.clear();
765
- ncols_categ = 0;
766
- for (size_t col = 0; col < workspace.ntaken_best; col++)
767
- {
768
- if (hplane.col_type[col] == Categorical)
769
- {
770
- hplane.cat_coef[ncols_categ].resize(input_data.ncat[hplane.col_num[col]]);
771
- std::copy(workspace.ext_cat_coef[ncols_categ].begin(),
772
- workspace.ext_cat_coef[ncols_categ].begin()
773
- + input_data.ncat[hplane.col_num[col]],
774
- hplane.cat_coef[ncols_categ].begin());
775
- hplane.cat_coef[ncols_categ].shrink_to_fit();
776
- ncols_categ++;
777
- }
778
- }
779
- break;
780
- }
781
- }
782
- }
783
-
784
- else
785
- {
786
- hplane.cat_coef.clear();
787
- hplane.chosen_cat.clear();
788
- hplane.fill_new.clear();
789
- }
790
- }