isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,813 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+
66
+ /* for use in regular model */
67
+ template <class InputData, class WorkerMemory>
68
+ void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params, IsoTree &tree)
69
+ {
70
+ if (tree.col_num < input_data.ncols_numeric)
71
+ {
72
+ tree.col_type = Numeric;
73
+
74
+ if (input_data.Xc_indptr == NULL)
75
+ get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * tree.col_num,
76
+ workspace.st, workspace.end, model_params.missing_action,
77
+ workspace.xmin, workspace.xmax, workspace.unsplittable);
78
+ else
79
+ get_range(workspace.ix_arr.data(), workspace.st, workspace.end, tree.col_num,
80
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
81
+ model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable);
82
+ }
83
+
84
+ else
85
+ {
86
+ tree.col_num -= input_data.ncols_numeric;
87
+ tree.col_type = Categorical;
88
+
89
+ get_categs(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * tree.col_num,
90
+ workspace.st, workspace.end, input_data.ncat[tree.col_num],
91
+ model_params.missing_action, workspace.categs.data(), workspace.npresent, workspace.unsplittable);
92
+ }
93
+ }
94
+
95
+ /* for use in extended model */
96
+ template <class InputData, class WorkerMemory>
97
+ void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
98
+ {
99
+ if (workspace.col_chosen < input_data.ncols_numeric)
100
+ {
101
+ workspace.col_type = Numeric;
102
+
103
+ if (input_data.Xc_indptr == NULL)
104
+ get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * workspace.col_chosen,
105
+ workspace.st, workspace.end, model_params.missing_action,
106
+ workspace.xmin, workspace.xmax, workspace.unsplittable);
107
+ else
108
+ get_range(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
109
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
110
+ model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable);
111
+ }
112
+
113
+ else
114
+ {
115
+ workspace.col_type = Categorical;
116
+ workspace.col_chosen -= input_data.ncols_numeric;
117
+
118
+ get_categs(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * workspace.col_chosen,
119
+ workspace.st, workspace.end, input_data.ncat[workspace.col_chosen],
120
+ model_params.missing_action, workspace.categs.data(), workspace.npresent, workspace.unsplittable);
121
+ }
122
+ }
123
+
124
+ /* for use in regular model with ntry>1 */
125
+ template <class InputData, class WorkerMemory>
126
+ void get_split_range_v2(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
127
+ {
128
+ get_split_range(workspace, input_data, model_params);
129
+ if (workspace.col_type == Categorical)
130
+ workspace.col_chosen += input_data.ncols_numeric;
131
+ }
132
+
133
+ template <class InputData, class WorkerMemory>
134
+ int choose_cat_from_present(WorkerMemory &workspace, InputData &input_data, size_t col_num)
135
+ {
136
+ int chosen_cat = std::uniform_int_distribution<int>
137
+ (0, workspace.npresent - 1)
138
+ (workspace.rnd_generator);
139
+ workspace.ncat_tried = 0;
140
+ for (int cat = 0; cat < input_data.ncat[col_num]; cat++)
141
+ {
142
+ if (workspace.categs[cat] > 0)
143
+ {
144
+ if (workspace.ncat_tried == chosen_cat)
145
+ return cat;
146
+ else
147
+ workspace.ncat_tried++;
148
+ }
149
+ }
150
+
151
+ return -1; /* this will never be reached, but CRAN complains otherwise */
152
+ }
153
+
154
+ bool is_col_taken(std::vector<bool> &col_is_taken, hashed_set<size_t> &col_is_taken_s,
155
+ size_t col_num)
156
+ {
157
+ if (!col_is_taken.empty())
158
+ return col_is_taken[col_num];
159
+ else
160
+ return col_is_taken_s.find(col_num) != col_is_taken_s.end();
161
+ }
162
+
163
+ template <class InputData>
164
+ void set_col_as_taken(std::vector<bool> &col_is_taken, hashed_set<size_t> &col_is_taken_s,
165
+ InputData &input_data, size_t col_num, ColType col_type)
166
+ {
167
+ col_num += ((col_type == Numeric)? 0 : input_data.ncols_numeric);
168
+ if (!col_is_taken.empty())
169
+ col_is_taken[col_num] = true;
170
+ else
171
+ col_is_taken_s.insert(col_num);
172
+ }
173
+
174
+ template <class InputData>
175
+ void set_col_as_taken(std::vector<bool> &col_is_taken, hashed_set<size_t> &col_is_taken_s,
176
+ InputData &input_data, size_t col_num)
177
+ {
178
+ if (!col_is_taken.empty())
179
+ col_is_taken[col_num] = true;
180
+ else
181
+ col_is_taken_s.insert(col_num);
182
+ }
183
+
184
+ template <class InputData, class WorkerMemory>
185
+ void add_separation_step(WorkerMemory &workspace, InputData &input_data, double remainder)
186
+ {
187
+ if (!workspace.changed_weights)
188
+ increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
189
+ input_data.nrows, workspace.tmat_sep.data(), remainder);
190
+ else if (!workspace.weights_arr.empty())
191
+ increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
192
+ input_data.nrows, workspace.tmat_sep.data(), workspace.weights_arr.data(), remainder);
193
+ else
194
+ increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
195
+ input_data.nrows, workspace.tmat_sep.data(), workspace.weights_map, remainder);
196
+ }
197
+
198
+ template <class InputData, class WorkerMemory, class ldouble_safe>
199
+ void add_remainder_separation_steps(WorkerMemory &workspace, InputData &input_data, ldouble_safe sum_weight)
200
+ {
201
+ if ((workspace.end - workspace.st) > 0 && (!workspace.changed_weights || sum_weight > 0))
202
+ {
203
+ double expected_dsep;
204
+ if (!workspace.changed_weights)
205
+ expected_dsep = expected_separation_depth(workspace.end - workspace.st + 1);
206
+ else
207
+ expected_dsep = expected_separation_depth(sum_weight);
208
+
209
+ add_separation_step(workspace, input_data, expected_dsep + 1);
210
+ }
211
+ }
212
+
213
+ template <class PredictionData, class sparse_ix>
214
+ void remap_terminal_trees(IsoForest *model_outputs, ExtIsoForest *model_outputs_ext,
215
+ PredictionData &prediction_data, sparse_ix *restrict tree_num, int nthreads)
216
+ {
217
+ size_t ntrees = (model_outputs != NULL)? model_outputs->trees.size() : model_outputs_ext->hplanes.size();
218
+ size_t max_tree, curr_term;
219
+ std::vector<sparse_ix> tree_mapping;
220
+ if (model_outputs != NULL)
221
+ {
222
+ max_tree = std::accumulate(model_outputs->trees.begin(),
223
+ model_outputs->trees.end(),
224
+ (size_t)0,
225
+ [](const size_t curr_max, const std::vector<IsoTree> &tr)
226
+ {return std::max(curr_max, tr.size());});
227
+ tree_mapping.resize(max_tree);
228
+ for (size_t tree = 0; tree < ntrees; tree++)
229
+ {
230
+ std::fill(tree_mapping.begin(), tree_mapping.end(), (size_t)0);
231
+ curr_term = 0;
232
+ for (size_t node = 0; node < model_outputs->trees[tree].size(); node++)
233
+ if (model_outputs->trees[tree][node].tree_left == 0)
234
+ tree_mapping[node] = curr_term++;
235
+
236
+ #pragma omp parallel for schedule(static) num_threads(nthreads) shared(tree_num, tree_mapping, tree, prediction_data)
237
+ for (size_t_for row = 0; row < (decltype(row))prediction_data.nrows; row++)
238
+ tree_num[row + tree * prediction_data.nrows] = tree_mapping[tree_num[row + tree * prediction_data.nrows]];
239
+ }
240
+ }
241
+
242
+ else
243
+ {
244
+ max_tree = std::accumulate(model_outputs_ext->hplanes.begin(),
245
+ model_outputs_ext->hplanes.end(),
246
+ (size_t)0,
247
+ [](const size_t curr_max, const std::vector<IsoHPlane> &tr)
248
+ {return std::max(curr_max, tr.size());});
249
+ tree_mapping.resize(max_tree);
250
+ for (size_t tree = 0; tree < ntrees; tree++)
251
+ {
252
+ std::fill(tree_mapping.begin(), tree_mapping.end(), (size_t)0);
253
+ curr_term = 0;
254
+ for (size_t node = 0; node < model_outputs_ext->hplanes[tree].size(); node++)
255
+ if (model_outputs_ext->hplanes[tree][node].hplane_left == 0)
256
+ tree_mapping[node] = curr_term++;
257
+
258
+ #pragma omp parallel for schedule(static) num_threads(nthreads) shared(tree_num, tree_mapping, tree, prediction_data)
259
+ for (size_t_for row = 0; row < (decltype(row))prediction_data.nrows; row++)
260
+ tree_num[row + tree * prediction_data.nrows] = tree_mapping[tree_num[row + tree * prediction_data.nrows]];
261
+ }
262
+ }
263
+ }
264
+
265
+
266
+ template <class WorkerMemory>
267
+ RecursionState::RecursionState(WorkerMemory &workspace, bool full_state)
268
+ {
269
+ this->full_state = full_state;
270
+
271
+ this->split_ix = workspace.split_ix;
272
+ this->end = workspace.end;
273
+ if (!workspace.col_sampler.has_weights())
274
+ this->sampler_pos = workspace.col_sampler.curr_pos;
275
+ else {
276
+ this->col_sampler_weights = workspace.col_sampler.tree_weights;
277
+ this->n_dropped = workspace.col_sampler.n_dropped;
278
+ }
279
+
280
+ if (this->full_state)
281
+ {
282
+ this->st = workspace.st;
283
+ this->st_NA = workspace.st_NA;
284
+ this->end_NA = workspace.end_NA;
285
+
286
+ this->changed_weights = workspace.changed_weights;
287
+
288
+ /* for the extended model, it's not necessary to copy everything */
289
+ if (workspace.comb_val.empty() && workspace.st_NA < workspace.end_NA)
290
+ {
291
+ this->ix_arr = std::vector<size_t>(workspace.ix_arr.begin() + workspace.st_NA,
292
+ workspace.ix_arr.begin() + workspace.end_NA);
293
+ if (this->changed_weights)
294
+ {
295
+ size_t tot = workspace.end_NA - workspace.st_NA;
296
+ this->weights_arr = std::unique_ptr<double[]>(new double[tot]);
297
+ if (!workspace.weights_arr.empty())
298
+ for (size_t ix = 0; ix < tot; ix++)
299
+ this->weights_arr[ix] = workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]];
300
+ else
301
+ for (size_t ix = 0; ix < tot; ix++)
302
+ this->weights_arr[ix] = workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]];
303
+ }
304
+ }
305
+ }
306
+ }
307
+
308
+
309
+ template <class WorkerMemory>
310
+ void RecursionState::restore_state(WorkerMemory &workspace)
311
+ {
312
+ workspace.split_ix = this->split_ix;
313
+ workspace.end = this->end;
314
+ if (!workspace.col_sampler.has_weights())
315
+ workspace.col_sampler.curr_pos = this->sampler_pos;
316
+ else {
317
+ workspace.col_sampler.tree_weights = std::move(this->col_sampler_weights);
318
+ workspace.col_sampler.n_dropped = this->n_dropped;
319
+ }
320
+
321
+ if (this->full_state)
322
+ {
323
+ workspace.st = this->st;
324
+ workspace.st_NA = this->st_NA;
325
+ workspace.end_NA = this->end_NA;
326
+
327
+ workspace.changed_weights = this->changed_weights;
328
+
329
+ if (workspace.comb_val.empty() && !this->ix_arr.empty())
330
+ {
331
+ std::copy(this->ix_arr.begin(),
332
+ this->ix_arr.end(),
333
+ workspace.ix_arr.begin() + this->st_NA);
334
+ if (this->changed_weights)
335
+ {
336
+ size_t tot = workspace.end_NA - workspace.st_NA;
337
+ if (!workspace.weights_arr.empty())
338
+ for (size_t ix = 0; ix < tot; ix++)
339
+ workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]] = this->weights_arr[ix];
340
+ else
341
+ for (size_t ix = 0; ix < tot; ix++)
342
+ workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]] = this->weights_arr[ix];
343
+ }
344
+ }
345
+ }
346
+ }
347
+
348
+ template <class InputData, class ldouble_safe>
349
+ std::vector<double> calc_kurtosis_all_data(InputData &input_data, ModelParams &model_params, RNG_engine &rnd_generator)
350
+ {
351
+ std::unique_ptr<double[]> buffer_double;
352
+ std::unique_ptr<size_t[]> buffer_size_t;
353
+ if (input_data.ncols_categ)
354
+ {
355
+ buffer_double = std::unique_ptr<double[]>(new double[input_data.max_categ]);
356
+ if (!(input_data.sample_weights != NULL && !input_data.weight_as_sample))
357
+ buffer_size_t = std::unique_ptr<size_t[]>(new size_t[input_data.max_categ + 1]);
358
+ }
359
+
360
+
361
+ std::vector<double> kurt_weights(input_data.ncols_numeric + input_data.ncols_categ);
362
+ for (size_t col = 0; col < input_data.ncols_tot; col++)
363
+ {
364
+ if (col < input_data.ncols_numeric)
365
+ {
366
+ if (input_data.Xc_indptr == NULL)
367
+ {
368
+ if (!(input_data.sample_weights != NULL && !input_data.weight_as_sample))
369
+ {
370
+ kurt_weights[col]
371
+ = calc_kurtosis<typename std::remove_pointer<decltype(input_data.numeric_data)>::type, ldouble_safe>(
372
+ input_data.numeric_data + col * input_data.nrows,
373
+ input_data.nrows, model_params.missing_action);
374
+ }
375
+
376
+ else
377
+ {
378
+ kurt_weights[col]
379
+ = calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
380
+ ldouble_safe>(
381
+ input_data.numeric_data + col * input_data.nrows, input_data.nrows,
382
+ model_params.missing_action, input_data.sample_weights);
383
+ }
384
+ }
385
+
386
+ else
387
+ {
388
+ if (!(input_data.sample_weights != NULL && !input_data.weight_as_sample))
389
+ {
390
+ kurt_weights[col]
391
+ = calc_kurtosis<typename std::remove_pointer<decltype(input_data.Xc)>::type,
392
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
393
+ ldouble_safe>(
394
+ col, input_data.nrows,
395
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
396
+ model_params.missing_action);
397
+ }
398
+
399
+ else
400
+ {
401
+ kurt_weights[col]
402
+ = calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
403
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
404
+ ldouble_safe>(
405
+ col, input_data.nrows,
406
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
407
+ model_params.missing_action, input_data.sample_weights);
408
+ }
409
+ }
410
+ }
411
+
412
+ else
413
+ {
414
+ if (!(input_data.sample_weights != NULL && !input_data.weight_as_sample))
415
+ {
416
+ kurt_weights[col]
417
+ = calc_kurtosis<ldouble_safe>(input_data.nrows,
418
+ input_data.categ_data + (col- input_data.ncols_numeric) * input_data.nrows,
419
+ input_data.ncat[col - input_data.ncols_numeric],
420
+ buffer_size_t.get(), buffer_double.get(),
421
+ model_params.missing_action, model_params.cat_split_type, rnd_generator);
422
+ }
423
+
424
+ else
425
+ {
426
+ kurt_weights[col]
427
+ = calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.sample_weights)>::type,
428
+ ldouble_safe>(
429
+ input_data.nrows,
430
+ input_data.categ_data + (col- input_data.ncols_numeric) * input_data.nrows,
431
+ input_data.ncat[col - input_data.ncols_numeric],
432
+ buffer_double.get(),
433
+ model_params.missing_action, model_params.cat_split_type,
434
+ rnd_generator, input_data.sample_weights);
435
+ }
436
+ }
437
+ }
438
+
439
+ for (auto &w : kurt_weights) w = (w == -HUGE_VAL)? 0. : std::fmax(1e-8, -1. + w);
440
+ if (input_data.col_weights != NULL)
441
+ {
442
+ for (size_t col = 0; col < input_data.ncols_tot; col++)
443
+ {
444
+ if (kurt_weights[col] <= 0) continue;
445
+ kurt_weights[col] *= input_data.col_weights[col];
446
+ kurt_weights[col] = std::fmax(kurt_weights[col], 1e-100);
447
+ }
448
+ }
449
+
450
+ return kurt_weights;
451
+ }
452
+
453
+ template <class InputData, class WorkerMemory>
454
+ void calc_ranges_all_cols(InputData &input_data, WorkerMemory &workspace, ModelParams &model_params,
455
+ double *restrict ranges, double *restrict saved_xmin, double *restrict saved_xmax)
456
+ {
457
+ workspace.col_sampler.prepare_full_pass();
458
+ while (workspace.col_sampler.sample_col(workspace.col_chosen))
459
+ {
460
+ get_split_range(workspace, input_data, model_params);
461
+
462
+ if (workspace.unsplittable) {
463
+ workspace.col_sampler.drop_col(workspace.col_chosen);
464
+ ranges[workspace.col_chosen] = 0;
465
+ if (saved_xmin != NULL) {
466
+ saved_xmin[workspace.col_chosen] = 0;
467
+ saved_xmax[workspace.col_chosen] = 0;
468
+ }
469
+ }
470
+ else {
471
+ ranges[workspace.col_chosen] = workspace.xmax - workspace.xmin;
472
+ if (workspace.tree_kurtoses != NULL) {
473
+ ranges[workspace.col_chosen] *= workspace.tree_kurtoses[workspace.col_chosen];
474
+ ranges[workspace.col_chosen] = std::fmax(ranges[workspace.col_chosen], 1e-100);
475
+ }
476
+ else if (input_data.col_weights != NULL) {
477
+ ranges[workspace.col_chosen] *= input_data.col_weights[workspace.col_chosen];
478
+ ranges[workspace.col_chosen] = std::fmax(ranges[workspace.col_chosen], 1e-100);
479
+ }
480
+ if (saved_xmin != NULL) {
481
+ saved_xmin[workspace.col_chosen] = workspace.xmin;
482
+ saved_xmax[workspace.col_chosen] = workspace.xmax;
483
+ }
484
+ }
485
+ }
486
+ }
487
+
488
+ template <class InputData, class WorkerMemory, class ldouble_safe>
489
+ void calc_var_all_cols(InputData &input_data, WorkerMemory &workspace, ModelParams &model_params,
490
+ double *restrict variances, double *restrict saved_xmin, double *restrict saved_xmax,
491
+ double *restrict saved_means, double *restrict saved_sds)
492
+ {
493
+ double xmean, xsd;
494
+ if (saved_means != NULL)
495
+ workspace.has_saved_stats = true;
496
+
497
+ workspace.col_sampler.prepare_full_pass();
498
+ while (workspace.col_sampler.sample_col(workspace.col_chosen))
499
+ {
500
+ if (workspace.col_chosen < input_data.ncols_numeric)
501
+ {
502
+ get_split_range(workspace, input_data, model_params);
503
+ if (workspace.unsplittable)
504
+ {
505
+ workspace.col_sampler.drop_col(workspace.col_chosen);
506
+ variances[workspace.col_chosen] = 0;
507
+ if (saved_xmin != NULL)
508
+ {
509
+ saved_xmin[workspace.col_chosen] = 0;
510
+ saved_xmax[workspace.col_chosen] = 0;
511
+ }
512
+ continue;
513
+ }
514
+
515
+ if (saved_xmin != NULL)
516
+ {
517
+ saved_xmin[workspace.col_chosen] = workspace.xmin;
518
+ saved_xmax[workspace.col_chosen] = workspace.xmax;
519
+ }
520
+
521
+
522
+ if (input_data.Xc_indptr == NULL)
523
+ {
524
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
525
+ {
526
+ calc_mean_and_sd<typename std::remove_pointer<decltype(input_data.numeric_data)>::type, ldouble_safe>(
527
+ workspace.ix_arr.data(), workspace.st, workspace.end,
528
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
529
+ model_params.missing_action, xsd, xmean);
530
+ }
531
+
532
+ else if (!workspace.weights_arr.empty())
533
+ {
534
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
535
+ decltype(workspace.weights_arr), ldouble_safe>(
536
+ workspace.ix_arr.data(), workspace.st, workspace.end,
537
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
538
+ workspace.weights_arr,
539
+ model_params.missing_action, xsd, xmean);
540
+ }
541
+
542
+ else
543
+ {
544
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
545
+ decltype(workspace.weights_map), ldouble_safe>(
546
+ workspace.ix_arr.data(), workspace.st, workspace.end,
547
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
548
+ workspace.weights_map,
549
+ model_params.missing_action, xsd, xmean);
550
+ }
551
+ }
552
+
553
+ else
554
+ {
555
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
556
+ {
557
+ calc_mean_and_sd<typename std::remove_pointer<decltype(input_data.Xc)>::type,
558
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
559
+ ldouble_safe>(
560
+ workspace.ix_arr.data(), workspace.st, workspace.end,
561
+ workspace.col_chosen,
562
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
563
+ xsd, xmean);
564
+ }
565
+
566
+ else if (!workspace.weights_arr.empty())
567
+ {
568
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
569
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
570
+ decltype(workspace.weights_arr), ldouble_safe>(
571
+ workspace.ix_arr.data(), workspace.st, workspace.end,
572
+ workspace.col_chosen,
573
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
574
+ xsd, xmean, workspace.weights_arr);
575
+ }
576
+
577
+ else
578
+ {
579
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
580
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
581
+ decltype(workspace.weights_map), ldouble_safe>(
582
+ workspace.ix_arr.data(), workspace.st, workspace.end,
583
+ workspace.col_chosen,
584
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
585
+ xsd, xmean, workspace.weights_map);
586
+ }
587
+ }
588
+
589
+ if (saved_means != NULL) saved_means[workspace.col_chosen] = xmean;
590
+ if (saved_sds != NULL) saved_sds[workspace.col_chosen] = xsd;
591
+ }
592
+
593
+ else
594
+ {
595
+ size_t col = workspace.col_chosen - input_data.ncols_numeric;
596
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
597
+ {
598
+ if (workspace.buffer_szt.size() < (size_t)2 * (size_t)input_data.ncat[col] + 1)
599
+ workspace.buffer_szt.resize((size_t)2 * (size_t)input_data.ncat[col] + 1);
600
+ xsd = expected_sd_cat<size_t, ldouble_safe>(
601
+ workspace.ix_arr.data(), workspace.st, workspace.end,
602
+ input_data.categ_data + col * input_data.nrows,
603
+ input_data.ncat[col],
604
+ model_params.missing_action,
605
+ workspace.buffer_szt.data(),
606
+ workspace.buffer_szt.data() + input_data.ncat[col] + 1,
607
+ workspace.buffer_dbl.data());
608
+ }
609
+
610
+ else if (!workspace.weights_arr.empty())
611
+ {
612
+ if (workspace.buffer_dbl.size() < (size_t)2 * (size_t)input_data.ncat[col] + 1)
613
+ workspace.buffer_dbl.resize((size_t)2 * (size_t)input_data.ncat[col] + 1);
614
+ xsd = expected_sd_cat_weighted<decltype(workspace.weights_arr), size_t, ldouble_safe>(
615
+ workspace.ix_arr.data(), workspace.st, workspace.end,
616
+ input_data.categ_data + col * input_data.nrows,
617
+ input_data.ncat[col],
618
+ model_params.missing_action, workspace.weights_arr,
619
+ workspace.buffer_dbl.data(),
620
+ workspace.buffer_szt.data(),
621
+ workspace.buffer_dbl.data() + input_data.ncat[col] + 1);
622
+ }
623
+
624
+ else
625
+ {
626
+ if (workspace.buffer_dbl.size() < (size_t)2 * (size_t)input_data.ncat[col] + 1)
627
+ workspace.buffer_dbl.resize((size_t)2 * (size_t)input_data.ncat[col] + 1);
628
+ xsd = expected_sd_cat_weighted<decltype(workspace.weights_map), size_t, ldouble_safe>(
629
+ workspace.ix_arr.data(), workspace.st, workspace.end,
630
+ input_data.categ_data + col * input_data.nrows,
631
+ input_data.ncat[col],
632
+ model_params.missing_action, workspace.weights_map,
633
+ workspace.buffer_dbl.data(),
634
+ workspace.buffer_szt.data(),
635
+ workspace.buffer_dbl.data() + input_data.ncat[col] + 1);
636
+ }
637
+ }
638
+
639
+ if (xsd)
640
+ {
641
+ variances[workspace.col_chosen] = square(xsd);
642
+ if (workspace.tree_kurtoses != NULL)
643
+ variances[workspace.col_chosen] *= workspace.tree_kurtoses[workspace.col_chosen];
644
+ else if (input_data.col_weights != NULL)
645
+ variances[workspace.col_chosen] *= input_data.col_weights[workspace.col_chosen];
646
+ variances[workspace.col_chosen] = std::fmax(variances[workspace.col_chosen], 1e-100);
647
+ }
648
+
649
+ else
650
+ {
651
+ workspace.col_sampler.drop_col(workspace.col_chosen);
652
+ variances[workspace.col_chosen] = 0;
653
+ }
654
+ }
655
+ }
656
+
657
+ template <class InputData, class WorkerMemory, class ldouble_safe>
658
+ void calc_kurt_all_cols(InputData &input_data, WorkerMemory &workspace, ModelParams &model_params,
659
+ double *restrict kurtosis, double *restrict saved_xmin, double *restrict saved_xmax)
660
+ {
661
+ workspace.col_sampler.prepare_full_pass();
662
+ while (workspace.col_sampler.sample_col(workspace.col_chosen))
663
+ {
664
+ if (saved_xmin != NULL)
665
+ {
666
+ get_split_range(workspace, input_data, model_params);
667
+ if (workspace.unsplittable)
668
+ {
669
+ workspace.col_sampler.drop_col(workspace.col_chosen);
670
+ continue;
671
+ }
672
+
673
+ if (saved_xmin != NULL)
674
+ {
675
+ saved_xmin[workspace.col_chosen] = workspace.xmin;
676
+ saved_xmax[workspace.col_chosen] = workspace.xmax;
677
+ }
678
+ }
679
+
680
+ if (workspace.col_chosen < input_data.ncols_numeric)
681
+ {
682
+ if (input_data.Xc_indptr == NULL)
683
+ {
684
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
685
+ {
686
+ kurtosis[workspace.col_chosen] =
687
+ calc_kurtosis<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
688
+ ldouble_safe>(
689
+ workspace.ix_arr.data(), workspace.st, workspace.end,
690
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
691
+ model_params.missing_action);
692
+ }
693
+
694
+ else if (!workspace.weights_arr.empty())
695
+ {
696
+ kurtosis[workspace.col_chosen] =
697
+ calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
698
+ decltype(workspace.weights_arr), ldouble_safe>(
699
+ workspace.ix_arr.data(), workspace.st, workspace.end,
700
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
701
+ model_params.missing_action, workspace.weights_arr);
702
+ }
703
+
704
+ else
705
+ {
706
+ kurtosis[workspace.col_chosen] =
707
+ calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
708
+ decltype(workspace.weights_map), ldouble_safe>(
709
+ workspace.ix_arr.data(), workspace.st, workspace.end,
710
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
711
+ model_params.missing_action, workspace.weights_map);
712
+ }
713
+ }
714
+
715
+ else
716
+ {
717
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
718
+ {
719
+ kurtosis[workspace.col_chosen] =
720
+ calc_kurtosis<typename std::remove_pointer<decltype(input_data.Xc)>::type,
721
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
722
+ ldouble_safe>(
723
+ workspace.ix_arr.data(), workspace.st, workspace.end,
724
+ workspace.col_chosen,
725
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
726
+ model_params.missing_action);
727
+ }
728
+
729
+ else if (!workspace.weights_arr.empty())
730
+ {
731
+ kurtosis[workspace.col_chosen] =
732
+ calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
733
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
734
+ decltype(workspace.weights_arr), ldouble_safe>(
735
+ workspace.ix_arr.data(), workspace.st, workspace.end,
736
+ workspace.col_chosen,
737
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
738
+ model_params.missing_action, workspace.weights_arr);
739
+ }
740
+
741
+ else
742
+ {
743
+ kurtosis[workspace.col_chosen] =
744
+ calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
745
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
746
+ decltype(workspace.weights_map), ldouble_safe>(
747
+ workspace.ix_arr.data(), workspace.st, workspace.end,
748
+ workspace.col_chosen,
749
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
750
+ model_params.missing_action, workspace.weights_map);
751
+ }
752
+ }
753
+ }
754
+
755
+ else
756
+ {
757
+ size_t col = workspace.col_chosen - input_data.ncols_numeric;
758
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
759
+ {
760
+ kurtosis[workspace.col_chosen] =
761
+ calc_kurtosis<ldouble_safe>(
762
+ workspace.ix_arr.data(), workspace.st, workspace.end,
763
+ input_data.categ_data + col * input_data.nrows,
764
+ input_data.ncat[col],
765
+ workspace.buffer_szt.data(), workspace.buffer_dbl.data(),
766
+ model_params.missing_action, model_params.cat_split_type,
767
+ workspace.rnd_generator);
768
+ }
769
+
770
+ else if (!workspace.weights_arr.empty())
771
+ {
772
+ kurtosis[workspace.col_chosen] =
773
+ calc_kurtosis_weighted<decltype(workspace.weights_arr), ldouble_safe>(
774
+ workspace.ix_arr.data(), workspace.st, workspace.end,
775
+ input_data.categ_data + col * input_data.nrows,
776
+ input_data.ncat[col],
777
+ workspace.buffer_dbl.data(),
778
+ model_params.missing_action, model_params.cat_split_type,
779
+ workspace.rnd_generator, workspace.weights_arr);
780
+ }
781
+
782
+ else
783
+ {
784
+ kurtosis[workspace.col_chosen] =
785
+ calc_kurtosis_weighted<decltype(workspace.weights_map), ldouble_safe>(
786
+ workspace.ix_arr.data(), workspace.st, workspace.end,
787
+ input_data.categ_data + col * input_data.nrows,
788
+ input_data.ncat[col],
789
+ workspace.buffer_dbl.data(),
790
+ model_params.missing_action, model_params.cat_split_type,
791
+ workspace.rnd_generator, workspace.weights_map);
792
+ }
793
+ }
794
+
795
+ if (kurtosis[workspace.col_chosen] == -HUGE_VAL)
796
+ workspace.col_sampler.drop_col(workspace.col_chosen);
797
+
798
+ kurtosis[workspace.col_chosen] = (kurtosis[workspace.col_chosen] == -HUGE_VAL)?
799
+ 0. : std::fmax(1e-8, -1. + kurtosis[workspace.col_chosen]);
800
+ if (input_data.col_weights != NULL && kurtosis[workspace.col_chosen] > 0)
801
+ {
802
+ kurtosis[workspace.col_chosen] *= input_data.col_weights[workspace.col_chosen];
803
+ kurtosis[workspace.col_chosen] = std::fmax(kurtosis[workspace.col_chosen], 1e-100);
804
+ }
805
+ }
806
+ }
807
+
808
+ bool is_boxed_metric(const ScoringMetric scoring_metric)
809
+ {
810
+ return scoring_metric == BoxedDensity ||
811
+ scoring_metric == BoxedDensity2 ||
812
+ scoring_metric == BoxedRatio;
813
+ }