isotree 0.2.2 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (152) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2116 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +132 -0
  14. data/vendor/isotree/src/RcppExports.cpp +594 -57
  15. data/vendor/isotree/src/Rwrapper.cpp +2452 -304
  16. data/vendor/isotree/src/c_interface.cpp +958 -0
  17. data/vendor/isotree/src/crit.hpp +4236 -0
  18. data/vendor/isotree/src/digamma.hpp +184 -0
  19. data/vendor/isotree/src/dist.hpp +1886 -0
  20. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  21. data/vendor/isotree/src/extended.hpp +1444 -0
  22. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  23. data/vendor/isotree/src/fit_model.hpp +2401 -0
  24. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  25. data/vendor/isotree/src/helpers_iforest.hpp +814 -0
  26. data/vendor/isotree/src/{impute.cpp → impute.hpp} +382 -123
  27. data/vendor/isotree/src/indexer.cpp +515 -0
  28. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  29. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  30. data/vendor/isotree/src/isoforest.hpp +1659 -0
  31. data/vendor/isotree/src/isotree.hpp +1815 -394
  32. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  33. data/vendor/isotree/src/merge_models.cpp +159 -16
  34. data/vendor/isotree/src/mult.hpp +1321 -0
  35. data/vendor/isotree/src/oop_interface.cpp +844 -0
  36. data/vendor/isotree/src/oop_interface.hpp +278 -0
  37. data/vendor/isotree/src/other_helpers.hpp +219 -0
  38. data/vendor/isotree/src/predict.hpp +1932 -0
  39. data/vendor/isotree/src/python_helpers.hpp +114 -0
  40. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  41. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  42. data/vendor/isotree/src/robinmap/README.md +483 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1639 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  46. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  47. data/vendor/isotree/src/serialize.cpp +4316 -139
  48. data/vendor/isotree/src/sql.cpp +143 -61
  49. data/vendor/isotree/src/subset_models.cpp +174 -0
  50. data/vendor/isotree/src/utils.hpp +3786 -0
  51. data/vendor/isotree/src/xoshiro.hpp +463 -0
  52. data/vendor/isotree/src/ziggurat.hpp +405 -0
  53. metadata +40 -105
  54. data/vendor/cereal/LICENSE +0 -24
  55. data/vendor/cereal/README.md +0 -85
  56. data/vendor/cereal/include/cereal/access.hpp +0 -351
  57. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  58. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  59. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  60. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  61. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  62. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  63. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  65. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  66. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  67. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  68. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  69. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  70. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  71. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  72. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  74. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  76. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  77. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  78. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  79. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  91. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  92. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  94. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  96. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  97. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  98. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  99. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  100. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  101. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  102. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  103. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  104. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  105. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  106. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  107. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  111. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  112. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  113. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  114. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  115. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  116. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  117. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  118. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  119. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  120. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  121. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  122. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  123. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  124. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  125. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  126. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  127. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  128. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  129. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  130. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  131. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  132. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  133. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  134. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  135. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  136. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  137. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  138. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  139. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  140. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  141. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  142. data/vendor/cereal/include/cereal/version.hpp +0 -52
  143. data/vendor/isotree/src/Makevars +0 -4
  144. data/vendor/isotree/src/crit.cpp +0 -912
  145. data/vendor/isotree/src/dist.cpp +0 -749
  146. data/vendor/isotree/src/extended.cpp +0 -790
  147. data/vendor/isotree/src/fit_model.cpp +0 -1090
  148. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  149. data/vendor/isotree/src/isoforest.cpp +0 -771
  150. data/vendor/isotree/src/mult.cpp +0 -607
  151. data/vendor/isotree/src/predict.cpp +0 -853
  152. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,814 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+
66
+ /* for use in regular model */
67
+ template <class InputData, class WorkerMemory>
68
+ void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params, IsoTree &tree)
69
+ {
70
+ if (tree.col_num < input_data.ncols_numeric)
71
+ {
72
+ tree.col_type = Numeric;
73
+
74
+ if (input_data.Xc_indptr == NULL)
75
+ get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * tree.col_num,
76
+ workspace.st, workspace.end, model_params.missing_action,
77
+ workspace.xmin, workspace.xmax, workspace.unsplittable);
78
+ else
79
+ get_range(workspace.ix_arr.data(), workspace.st, workspace.end, tree.col_num,
80
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
81
+ model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable);
82
+ }
83
+
84
+ else
85
+ {
86
+ tree.col_num -= input_data.ncols_numeric;
87
+ tree.col_type = Categorical;
88
+
89
+ get_categs(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * tree.col_num,
90
+ workspace.st, workspace.end, input_data.ncat[tree.col_num],
91
+ model_params.missing_action, workspace.categs.data(), workspace.npresent, workspace.unsplittable);
92
+ }
93
+ }
94
+
95
+ /* for use in extended model */
96
+ template <class InputData, class WorkerMemory>
97
+ void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
98
+ {
99
+ if (workspace.col_chosen < input_data.ncols_numeric)
100
+ {
101
+ workspace.col_type = Numeric;
102
+
103
+ if (input_data.Xc_indptr == NULL)
104
+ get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * workspace.col_chosen,
105
+ workspace.st, workspace.end, model_params.missing_action,
106
+ workspace.xmin, workspace.xmax, workspace.unsplittable);
107
+ else
108
+ get_range(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
109
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
110
+ model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable);
111
+ }
112
+
113
+ else
114
+ {
115
+ workspace.col_type = Categorical;
116
+ workspace.col_chosen -= input_data.ncols_numeric;
117
+
118
+ get_categs(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * workspace.col_chosen,
119
+ workspace.st, workspace.end, input_data.ncat[workspace.col_chosen],
120
+ model_params.missing_action, workspace.categs.data(), workspace.npresent, workspace.unsplittable);
121
+ }
122
+ }
123
+
124
+ /* for use in regular model with ntry>1 */
125
+ template <class InputData, class WorkerMemory>
126
+ void get_split_range_v2(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
127
+ {
128
+ get_split_range(workspace, input_data, model_params);
129
+ if (workspace.col_type == Categorical)
130
+ workspace.col_chosen += input_data.ncols_numeric;
131
+ }
132
+
133
+ template <class InputData, class WorkerMemory>
134
+ int choose_cat_from_present(WorkerMemory &workspace, InputData &input_data, size_t col_num)
135
+ {
136
+ int chosen_cat = std::uniform_int_distribution<int>
137
+ (0, workspace.npresent - 1)
138
+ (workspace.rnd_generator);
139
+ workspace.ncat_tried = 0;
140
+ for (int cat = 0; cat < input_data.ncat[col_num]; cat++)
141
+ {
142
+ if (workspace.categs[cat] > 0)
143
+ {
144
+ if (workspace.ncat_tried == chosen_cat)
145
+ return cat;
146
+ else
147
+ workspace.ncat_tried++;
148
+ }
149
+ }
150
+
151
+ unreachable();
152
+ return -1; /* this will never be reached, but CRAN complains otherwise */
153
+ }
154
+
155
+ bool is_col_taken(std::vector<bool> &col_is_taken, hashed_set<size_t> &col_is_taken_s,
156
+ size_t col_num)
157
+ {
158
+ if (!col_is_taken.empty())
159
+ return col_is_taken[col_num];
160
+ else
161
+ return col_is_taken_s.find(col_num) != col_is_taken_s.end();
162
+ }
163
+
164
+ template <class InputData>
165
+ void set_col_as_taken(std::vector<bool> &col_is_taken, hashed_set<size_t> &col_is_taken_s,
166
+ InputData &input_data, size_t col_num, ColType col_type)
167
+ {
168
+ col_num += ((col_type == Numeric)? 0 : input_data.ncols_numeric);
169
+ if (!col_is_taken.empty())
170
+ col_is_taken[col_num] = true;
171
+ else
172
+ col_is_taken_s.insert(col_num);
173
+ }
174
+
175
+ template <class InputData>
176
+ void set_col_as_taken(std::vector<bool> &col_is_taken, hashed_set<size_t> &col_is_taken_s,
177
+ InputData &input_data, size_t col_num)
178
+ {
179
+ if (!col_is_taken.empty())
180
+ col_is_taken[col_num] = true;
181
+ else
182
+ col_is_taken_s.insert(col_num);
183
+ }
184
+
185
+ template <class InputData, class WorkerMemory>
186
+ void add_separation_step(WorkerMemory &workspace, InputData &input_data, double remainder)
187
+ {
188
+ if (!workspace.changed_weights)
189
+ increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
190
+ input_data.nrows, workspace.tmat_sep.data(), remainder);
191
+ else if (!workspace.weights_arr.empty())
192
+ increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
193
+ input_data.nrows, workspace.tmat_sep.data(), workspace.weights_arr.data(), remainder);
194
+ else
195
+ increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
196
+ input_data.nrows, workspace.tmat_sep.data(), workspace.weights_map, remainder);
197
+ }
198
+
199
+ template <class InputData, class WorkerMemory, class ldouble_safe>
200
+ void add_remainder_separation_steps(WorkerMemory &workspace, InputData &input_data, ldouble_safe sum_weight)
201
+ {
202
+ if ((workspace.end - workspace.st) > 0 && (!workspace.changed_weights || sum_weight > 0))
203
+ {
204
+ double expected_dsep;
205
+ if (!workspace.changed_weights)
206
+ expected_dsep = expected_separation_depth(workspace.end - workspace.st + 1);
207
+ else
208
+ expected_dsep = expected_separation_depth(sum_weight);
209
+
210
+ add_separation_step(workspace, input_data, expected_dsep + 1);
211
+ }
212
+ }
213
+
214
+ template <class PredictionData, class sparse_ix>
215
+ void remap_terminal_trees(IsoForest *model_outputs, ExtIsoForest *model_outputs_ext,
216
+ PredictionData &prediction_data, sparse_ix *restrict tree_num, int nthreads)
217
+ {
218
+ size_t ntrees = (model_outputs != NULL)? model_outputs->trees.size() : model_outputs_ext->hplanes.size();
219
+ size_t max_tree, curr_term;
220
+ std::vector<sparse_ix> tree_mapping;
221
+ if (model_outputs != NULL)
222
+ {
223
+ max_tree = std::accumulate(model_outputs->trees.begin(),
224
+ model_outputs->trees.end(),
225
+ (size_t)0,
226
+ [](const size_t curr_max, const std::vector<IsoTree> &tr)
227
+ {return std::max(curr_max, tr.size());});
228
+ tree_mapping.resize(max_tree);
229
+ for (size_t tree = 0; tree < ntrees; tree++)
230
+ {
231
+ std::fill(tree_mapping.begin(), tree_mapping.end(), (size_t)0);
232
+ curr_term = 0;
233
+ for (size_t node = 0; node < model_outputs->trees[tree].size(); node++)
234
+ if (model_outputs->trees[tree][node].tree_left == 0)
235
+ tree_mapping[node] = curr_term++;
236
+
237
+ #pragma omp parallel for schedule(static) num_threads(nthreads) shared(tree_num, tree_mapping, tree, prediction_data)
238
+ for (size_t_for row = 0; row < (decltype(row))prediction_data.nrows; row++)
239
+ tree_num[row + tree * prediction_data.nrows] = tree_mapping[tree_num[row + tree * prediction_data.nrows]];
240
+ }
241
+ }
242
+
243
+ else
244
+ {
245
+ max_tree = std::accumulate(model_outputs_ext->hplanes.begin(),
246
+ model_outputs_ext->hplanes.end(),
247
+ (size_t)0,
248
+ [](const size_t curr_max, const std::vector<IsoHPlane> &tr)
249
+ {return std::max(curr_max, tr.size());});
250
+ tree_mapping.resize(max_tree);
251
+ for (size_t tree = 0; tree < ntrees; tree++)
252
+ {
253
+ std::fill(tree_mapping.begin(), tree_mapping.end(), (size_t)0);
254
+ curr_term = 0;
255
+ for (size_t node = 0; node < model_outputs_ext->hplanes[tree].size(); node++)
256
+ if (model_outputs_ext->hplanes[tree][node].hplane_left == 0)
257
+ tree_mapping[node] = curr_term++;
258
+
259
+ #pragma omp parallel for schedule(static) num_threads(nthreads) shared(tree_num, tree_mapping, tree, prediction_data)
260
+ for (size_t_for row = 0; row < (decltype(row))prediction_data.nrows; row++)
261
+ tree_num[row + tree * prediction_data.nrows] = tree_mapping[tree_num[row + tree * prediction_data.nrows]];
262
+ }
263
+ }
264
+ }
265
+
266
+
267
+ template <class WorkerMemory>
268
+ RecursionState::RecursionState(WorkerMemory &workspace, bool full_state)
269
+ {
270
+ this->full_state = full_state;
271
+
272
+ this->split_ix = workspace.split_ix;
273
+ this->end = workspace.end;
274
+ if (!workspace.col_sampler.has_weights())
275
+ this->sampler_pos = workspace.col_sampler.curr_pos;
276
+ else {
277
+ this->col_sampler_weights = workspace.col_sampler.tree_weights;
278
+ this->n_dropped = workspace.col_sampler.n_dropped;
279
+ }
280
+
281
+ if (this->full_state)
282
+ {
283
+ this->st = workspace.st;
284
+ this->st_NA = workspace.st_NA;
285
+ this->end_NA = workspace.end_NA;
286
+
287
+ this->changed_weights = workspace.changed_weights;
288
+
289
+ /* for the extended model, it's not necessary to copy everything */
290
+ if (workspace.comb_val.empty() && workspace.st_NA < workspace.end_NA)
291
+ {
292
+ this->ix_arr = std::vector<size_t>(workspace.ix_arr.begin() + workspace.st_NA,
293
+ workspace.ix_arr.begin() + workspace.end_NA);
294
+ if (this->changed_weights)
295
+ {
296
+ size_t tot = workspace.end_NA - workspace.st_NA;
297
+ this->weights_arr = std::unique_ptr<double[]>(new double[tot]);
298
+ if (!workspace.weights_arr.empty())
299
+ for (size_t ix = 0; ix < tot; ix++)
300
+ this->weights_arr[ix] = workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]];
301
+ else
302
+ for (size_t ix = 0; ix < tot; ix++)
303
+ this->weights_arr[ix] = workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]];
304
+ }
305
+ }
306
+ }
307
+ }
308
+
309
+
310
+ template <class WorkerMemory>
311
+ void RecursionState::restore_state(WorkerMemory &workspace)
312
+ {
313
+ workspace.split_ix = this->split_ix;
314
+ workspace.end = this->end;
315
+ if (!workspace.col_sampler.has_weights())
316
+ workspace.col_sampler.curr_pos = this->sampler_pos;
317
+ else {
318
+ workspace.col_sampler.tree_weights = std::move(this->col_sampler_weights);
319
+ workspace.col_sampler.n_dropped = this->n_dropped;
320
+ }
321
+
322
+ if (this->full_state)
323
+ {
324
+ workspace.st = this->st;
325
+ workspace.st_NA = this->st_NA;
326
+ workspace.end_NA = this->end_NA;
327
+
328
+ workspace.changed_weights = this->changed_weights;
329
+
330
+ if (workspace.comb_val.empty() && !this->ix_arr.empty())
331
+ {
332
+ std::copy(this->ix_arr.begin(),
333
+ this->ix_arr.end(),
334
+ workspace.ix_arr.begin() + this->st_NA);
335
+ if (this->changed_weights)
336
+ {
337
+ size_t tot = workspace.end_NA - workspace.st_NA;
338
+ if (!workspace.weights_arr.empty())
339
+ for (size_t ix = 0; ix < tot; ix++)
340
+ workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]] = this->weights_arr[ix];
341
+ else
342
+ for (size_t ix = 0; ix < tot; ix++)
343
+ workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]] = this->weights_arr[ix];
344
+ }
345
+ }
346
+ }
347
+ }
348
+
349
+ template <class InputData, class ldouble_safe>
350
+ std::vector<double> calc_kurtosis_all_data(InputData &input_data, ModelParams &model_params, RNG_engine &rnd_generator)
351
+ {
352
+ std::unique_ptr<double[]> buffer_double;
353
+ std::unique_ptr<size_t[]> buffer_size_t;
354
+ if (input_data.ncols_categ)
355
+ {
356
+ buffer_double = std::unique_ptr<double[]>(new double[input_data.max_categ]);
357
+ if (!(input_data.sample_weights != NULL && !input_data.weight_as_sample))
358
+ buffer_size_t = std::unique_ptr<size_t[]>(new size_t[input_data.max_categ + 1]);
359
+ }
360
+
361
+
362
+ std::vector<double> kurt_weights(input_data.ncols_numeric + input_data.ncols_categ);
363
+ for (size_t col = 0; col < input_data.ncols_tot; col++)
364
+ {
365
+ if (col < input_data.ncols_numeric)
366
+ {
367
+ if (input_data.Xc_indptr == NULL)
368
+ {
369
+ if (!(input_data.sample_weights != NULL && !input_data.weight_as_sample))
370
+ {
371
+ kurt_weights[col]
372
+ = calc_kurtosis<typename std::remove_pointer<decltype(input_data.numeric_data)>::type, ldouble_safe>(
373
+ input_data.numeric_data + col * input_data.nrows,
374
+ input_data.nrows, model_params.missing_action);
375
+ }
376
+
377
+ else
378
+ {
379
+ kurt_weights[col]
380
+ = calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
381
+ ldouble_safe>(
382
+ input_data.numeric_data + col * input_data.nrows, input_data.nrows,
383
+ model_params.missing_action, input_data.sample_weights);
384
+ }
385
+ }
386
+
387
+ else
388
+ {
389
+ if (!(input_data.sample_weights != NULL && !input_data.weight_as_sample))
390
+ {
391
+ kurt_weights[col]
392
+ = calc_kurtosis<typename std::remove_pointer<decltype(input_data.Xc)>::type,
393
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
394
+ ldouble_safe>(
395
+ col, input_data.nrows,
396
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
397
+ model_params.missing_action);
398
+ }
399
+
400
+ else
401
+ {
402
+ kurt_weights[col]
403
+ = calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
404
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
405
+ ldouble_safe>(
406
+ col, input_data.nrows,
407
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
408
+ model_params.missing_action, input_data.sample_weights);
409
+ }
410
+ }
411
+ }
412
+
413
+ else
414
+ {
415
+ if (!(input_data.sample_weights != NULL && !input_data.weight_as_sample))
416
+ {
417
+ kurt_weights[col]
418
+ = calc_kurtosis<ldouble_safe>(input_data.nrows,
419
+ input_data.categ_data + (col- input_data.ncols_numeric) * input_data.nrows,
420
+ input_data.ncat[col - input_data.ncols_numeric],
421
+ buffer_size_t.get(), buffer_double.get(),
422
+ model_params.missing_action, model_params.cat_split_type, rnd_generator);
423
+ }
424
+
425
+ else
426
+ {
427
+ kurt_weights[col]
428
+ = calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.sample_weights)>::type,
429
+ ldouble_safe>(
430
+ input_data.nrows,
431
+ input_data.categ_data + (col- input_data.ncols_numeric) * input_data.nrows,
432
+ input_data.ncat[col - input_data.ncols_numeric],
433
+ buffer_double.get(),
434
+ model_params.missing_action, model_params.cat_split_type,
435
+ rnd_generator, input_data.sample_weights);
436
+ }
437
+ }
438
+ }
439
+
440
+ for (auto &w : kurt_weights) w = (w == -HUGE_VAL)? 0. : std::fmax(1e-8, -1. + w);
441
+ if (input_data.col_weights != NULL)
442
+ {
443
+ for (size_t col = 0; col < input_data.ncols_tot; col++)
444
+ {
445
+ if (kurt_weights[col] <= 0) continue;
446
+ kurt_weights[col] *= input_data.col_weights[col];
447
+ kurt_weights[col] = std::fmax(kurt_weights[col], 1e-100);
448
+ }
449
+ }
450
+
451
+ return kurt_weights;
452
+ }
453
+
454
+ template <class InputData, class WorkerMemory>
455
+ void calc_ranges_all_cols(InputData &input_data, WorkerMemory &workspace, ModelParams &model_params,
456
+ double *restrict ranges, double *restrict saved_xmin, double *restrict saved_xmax)
457
+ {
458
+ workspace.col_sampler.prepare_full_pass();
459
+ while (workspace.col_sampler.sample_col(workspace.col_chosen))
460
+ {
461
+ get_split_range(workspace, input_data, model_params);
462
+
463
+ if (workspace.unsplittable) {
464
+ workspace.col_sampler.drop_col(workspace.col_chosen);
465
+ ranges[workspace.col_chosen] = 0;
466
+ if (saved_xmin != NULL) {
467
+ saved_xmin[workspace.col_chosen] = 0;
468
+ saved_xmax[workspace.col_chosen] = 0;
469
+ }
470
+ }
471
+ else {
472
+ ranges[workspace.col_chosen] = workspace.xmax - workspace.xmin;
473
+ if (workspace.tree_kurtoses != NULL) {
474
+ ranges[workspace.col_chosen] *= workspace.tree_kurtoses[workspace.col_chosen];
475
+ ranges[workspace.col_chosen] = std::fmax(ranges[workspace.col_chosen], 1e-100);
476
+ }
477
+ else if (input_data.col_weights != NULL) {
478
+ ranges[workspace.col_chosen] *= input_data.col_weights[workspace.col_chosen];
479
+ ranges[workspace.col_chosen] = std::fmax(ranges[workspace.col_chosen], 1e-100);
480
+ }
481
+ if (saved_xmin != NULL) {
482
+ saved_xmin[workspace.col_chosen] = workspace.xmin;
483
+ saved_xmax[workspace.col_chosen] = workspace.xmax;
484
+ }
485
+ }
486
+ }
487
+ }
488
+
489
+ template <class InputData, class WorkerMemory, class ldouble_safe>
490
+ void calc_var_all_cols(InputData &input_data, WorkerMemory &workspace, ModelParams &model_params,
491
+ double *restrict variances, double *restrict saved_xmin, double *restrict saved_xmax,
492
+ double *restrict saved_means, double *restrict saved_sds)
493
+ {
494
+ double xmean, xsd;
495
+ if (saved_means != NULL)
496
+ workspace.has_saved_stats = true;
497
+
498
+ workspace.col_sampler.prepare_full_pass();
499
+ while (workspace.col_sampler.sample_col(workspace.col_chosen))
500
+ {
501
+ if (workspace.col_chosen < input_data.ncols_numeric)
502
+ {
503
+ get_split_range(workspace, input_data, model_params);
504
+ if (workspace.unsplittable)
505
+ {
506
+ workspace.col_sampler.drop_col(workspace.col_chosen);
507
+ variances[workspace.col_chosen] = 0;
508
+ if (saved_xmin != NULL)
509
+ {
510
+ saved_xmin[workspace.col_chosen] = 0;
511
+ saved_xmax[workspace.col_chosen] = 0;
512
+ }
513
+ continue;
514
+ }
515
+
516
+ if (saved_xmin != NULL)
517
+ {
518
+ saved_xmin[workspace.col_chosen] = workspace.xmin;
519
+ saved_xmax[workspace.col_chosen] = workspace.xmax;
520
+ }
521
+
522
+
523
+ if (input_data.Xc_indptr == NULL)
524
+ {
525
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
526
+ {
527
+ calc_mean_and_sd<typename std::remove_pointer<decltype(input_data.numeric_data)>::type, ldouble_safe>(
528
+ workspace.ix_arr.data(), workspace.st, workspace.end,
529
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
530
+ model_params.missing_action, xsd, xmean);
531
+ }
532
+
533
+ else if (!workspace.weights_arr.empty())
534
+ {
535
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
536
+ decltype(workspace.weights_arr), ldouble_safe>(
537
+ workspace.ix_arr.data(), workspace.st, workspace.end,
538
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
539
+ workspace.weights_arr,
540
+ model_params.missing_action, xsd, xmean);
541
+ }
542
+
543
+ else
544
+ {
545
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
546
+ decltype(workspace.weights_map), ldouble_safe>(
547
+ workspace.ix_arr.data(), workspace.st, workspace.end,
548
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
549
+ workspace.weights_map,
550
+ model_params.missing_action, xsd, xmean);
551
+ }
552
+ }
553
+
554
+ else
555
+ {
556
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
557
+ {
558
+ calc_mean_and_sd<typename std::remove_pointer<decltype(input_data.Xc)>::type,
559
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
560
+ ldouble_safe>(
561
+ workspace.ix_arr.data(), workspace.st, workspace.end,
562
+ workspace.col_chosen,
563
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
564
+ xsd, xmean);
565
+ }
566
+
567
+ else if (!workspace.weights_arr.empty())
568
+ {
569
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
570
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
571
+ decltype(workspace.weights_arr), ldouble_safe>(
572
+ workspace.ix_arr.data(), workspace.st, workspace.end,
573
+ workspace.col_chosen,
574
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
575
+ xsd, xmean, workspace.weights_arr);
576
+ }
577
+
578
+ else
579
+ {
580
+ calc_mean_and_sd_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
581
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
582
+ decltype(workspace.weights_map), ldouble_safe>(
583
+ workspace.ix_arr.data(), workspace.st, workspace.end,
584
+ workspace.col_chosen,
585
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
586
+ xsd, xmean, workspace.weights_map);
587
+ }
588
+ }
589
+
590
+ if (saved_means != NULL) saved_means[workspace.col_chosen] = xmean;
591
+ if (saved_sds != NULL) saved_sds[workspace.col_chosen] = xsd;
592
+ }
593
+
594
+ else
595
+ {
596
+ size_t col = workspace.col_chosen - input_data.ncols_numeric;
597
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
598
+ {
599
+ if (workspace.buffer_szt.size() < (size_t)2 * (size_t)input_data.ncat[col] + 1)
600
+ workspace.buffer_szt.resize((size_t)2 * (size_t)input_data.ncat[col] + 1);
601
+ xsd = expected_sd_cat<size_t, ldouble_safe>(
602
+ workspace.ix_arr.data(), workspace.st, workspace.end,
603
+ input_data.categ_data + col * input_data.nrows,
604
+ input_data.ncat[col],
605
+ model_params.missing_action,
606
+ workspace.buffer_szt.data(),
607
+ workspace.buffer_szt.data() + input_data.ncat[col] + 1,
608
+ workspace.buffer_dbl.data());
609
+ }
610
+
611
+ else if (!workspace.weights_arr.empty())
612
+ {
613
+ if (workspace.buffer_dbl.size() < (size_t)2 * (size_t)input_data.ncat[col] + 1)
614
+ workspace.buffer_dbl.resize((size_t)2 * (size_t)input_data.ncat[col] + 1);
615
+ xsd = expected_sd_cat_weighted<decltype(workspace.weights_arr), size_t, ldouble_safe>(
616
+ workspace.ix_arr.data(), workspace.st, workspace.end,
617
+ input_data.categ_data + col * input_data.nrows,
618
+ input_data.ncat[col],
619
+ model_params.missing_action, workspace.weights_arr,
620
+ workspace.buffer_dbl.data(),
621
+ workspace.buffer_szt.data(),
622
+ workspace.buffer_dbl.data() + input_data.ncat[col] + 1);
623
+ }
624
+
625
+ else
626
+ {
627
+ if (workspace.buffer_dbl.size() < (size_t)2 * (size_t)input_data.ncat[col] + 1)
628
+ workspace.buffer_dbl.resize((size_t)2 * (size_t)input_data.ncat[col] + 1);
629
+ xsd = expected_sd_cat_weighted<decltype(workspace.weights_map), size_t, ldouble_safe>(
630
+ workspace.ix_arr.data(), workspace.st, workspace.end,
631
+ input_data.categ_data + col * input_data.nrows,
632
+ input_data.ncat[col],
633
+ model_params.missing_action, workspace.weights_map,
634
+ workspace.buffer_dbl.data(),
635
+ workspace.buffer_szt.data(),
636
+ workspace.buffer_dbl.data() + input_data.ncat[col] + 1);
637
+ }
638
+ }
639
+
640
+ if (xsd)
641
+ {
642
+ variances[workspace.col_chosen] = square(xsd);
643
+ if (workspace.tree_kurtoses != NULL)
644
+ variances[workspace.col_chosen] *= workspace.tree_kurtoses[workspace.col_chosen];
645
+ else if (input_data.col_weights != NULL)
646
+ variances[workspace.col_chosen] *= input_data.col_weights[workspace.col_chosen];
647
+ variances[workspace.col_chosen] = std::fmax(variances[workspace.col_chosen], 1e-100);
648
+ }
649
+
650
+ else
651
+ {
652
+ workspace.col_sampler.drop_col(workspace.col_chosen);
653
+ variances[workspace.col_chosen] = 0;
654
+ }
655
+ }
656
+ }
657
+
658
+ template <class InputData, class WorkerMemory, class ldouble_safe>
659
+ void calc_kurt_all_cols(InputData &input_data, WorkerMemory &workspace, ModelParams &model_params,
660
+ double *restrict kurtosis, double *restrict saved_xmin, double *restrict saved_xmax)
661
+ {
662
+ workspace.col_sampler.prepare_full_pass();
663
+ while (workspace.col_sampler.sample_col(workspace.col_chosen))
664
+ {
665
+ if (saved_xmin != NULL)
666
+ {
667
+ get_split_range(workspace, input_data, model_params);
668
+ if (workspace.unsplittable)
669
+ {
670
+ workspace.col_sampler.drop_col(workspace.col_chosen);
671
+ continue;
672
+ }
673
+
674
+ if (saved_xmin != NULL)
675
+ {
676
+ saved_xmin[workspace.col_chosen] = workspace.xmin;
677
+ saved_xmax[workspace.col_chosen] = workspace.xmax;
678
+ }
679
+ }
680
+
681
+ if (workspace.col_chosen < input_data.ncols_numeric)
682
+ {
683
+ if (input_data.Xc_indptr == NULL)
684
+ {
685
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
686
+ {
687
+ kurtosis[workspace.col_chosen] =
688
+ calc_kurtosis<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
689
+ ldouble_safe>(
690
+ workspace.ix_arr.data(), workspace.st, workspace.end,
691
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
692
+ model_params.missing_action);
693
+ }
694
+
695
+ else if (!workspace.weights_arr.empty())
696
+ {
697
+ kurtosis[workspace.col_chosen] =
698
+ calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
699
+ decltype(workspace.weights_arr), ldouble_safe>(
700
+ workspace.ix_arr.data(), workspace.st, workspace.end,
701
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
702
+ model_params.missing_action, workspace.weights_arr);
703
+ }
704
+
705
+ else
706
+ {
707
+ kurtosis[workspace.col_chosen] =
708
+ calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.numeric_data)>::type,
709
+ decltype(workspace.weights_map), ldouble_safe>(
710
+ workspace.ix_arr.data(), workspace.st, workspace.end,
711
+ input_data.numeric_data + workspace.col_chosen * input_data.nrows,
712
+ model_params.missing_action, workspace.weights_map);
713
+ }
714
+ }
715
+
716
+ else
717
+ {
718
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
719
+ {
720
+ kurtosis[workspace.col_chosen] =
721
+ calc_kurtosis<typename std::remove_pointer<decltype(input_data.Xc)>::type,
722
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
723
+ ldouble_safe>(
724
+ workspace.ix_arr.data(), workspace.st, workspace.end,
725
+ workspace.col_chosen,
726
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
727
+ model_params.missing_action);
728
+ }
729
+
730
+ else if (!workspace.weights_arr.empty())
731
+ {
732
+ kurtosis[workspace.col_chosen] =
733
+ calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
734
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
735
+ decltype(workspace.weights_arr), ldouble_safe>(
736
+ workspace.ix_arr.data(), workspace.st, workspace.end,
737
+ workspace.col_chosen,
738
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
739
+ model_params.missing_action, workspace.weights_arr);
740
+ }
741
+
742
+ else
743
+ {
744
+ kurtosis[workspace.col_chosen] =
745
+ calc_kurtosis_weighted<typename std::remove_pointer<decltype(input_data.Xc)>::type,
746
+ typename std::remove_pointer<decltype(input_data.Xc_indptr)>::type,
747
+ decltype(workspace.weights_map), ldouble_safe>(
748
+ workspace.ix_arr.data(), workspace.st, workspace.end,
749
+ workspace.col_chosen,
750
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
751
+ model_params.missing_action, workspace.weights_map);
752
+ }
753
+ }
754
+ }
755
+
756
+ else
757
+ {
758
+ size_t col = workspace.col_chosen - input_data.ncols_numeric;
759
+ if (workspace.weights_arr.empty() && workspace.weights_map.empty())
760
+ {
761
+ kurtosis[workspace.col_chosen] =
762
+ calc_kurtosis<ldouble_safe>(
763
+ workspace.ix_arr.data(), workspace.st, workspace.end,
764
+ input_data.categ_data + col * input_data.nrows,
765
+ input_data.ncat[col],
766
+ workspace.buffer_szt.data(), workspace.buffer_dbl.data(),
767
+ model_params.missing_action, model_params.cat_split_type,
768
+ workspace.rnd_generator);
769
+ }
770
+
771
+ else if (!workspace.weights_arr.empty())
772
+ {
773
+ kurtosis[workspace.col_chosen] =
774
+ calc_kurtosis_weighted<decltype(workspace.weights_arr), ldouble_safe>(
775
+ workspace.ix_arr.data(), workspace.st, workspace.end,
776
+ input_data.categ_data + col * input_data.nrows,
777
+ input_data.ncat[col],
778
+ workspace.buffer_dbl.data(),
779
+ model_params.missing_action, model_params.cat_split_type,
780
+ workspace.rnd_generator, workspace.weights_arr);
781
+ }
782
+
783
+ else
784
+ {
785
+ kurtosis[workspace.col_chosen] =
786
+ calc_kurtosis_weighted<decltype(workspace.weights_map), ldouble_safe>(
787
+ workspace.ix_arr.data(), workspace.st, workspace.end,
788
+ input_data.categ_data + col * input_data.nrows,
789
+ input_data.ncat[col],
790
+ workspace.buffer_dbl.data(),
791
+ model_params.missing_action, model_params.cat_split_type,
792
+ workspace.rnd_generator, workspace.weights_map);
793
+ }
794
+ }
795
+
796
+ if (kurtosis[workspace.col_chosen] == -HUGE_VAL)
797
+ workspace.col_sampler.drop_col(workspace.col_chosen);
798
+
799
+ kurtosis[workspace.col_chosen] = (kurtosis[workspace.col_chosen] == -HUGE_VAL)?
800
+ 0. : std::fmax(1e-8, -1. + kurtosis[workspace.col_chosen]);
801
+ if (input_data.col_weights != NULL && kurtosis[workspace.col_chosen] > 0)
802
+ {
803
+ kurtosis[workspace.col_chosen] *= input_data.col_weights[workspace.col_chosen];
804
+ kurtosis[workspace.col_chosen] = std::fmax(kurtosis[workspace.col_chosen], 1e-100);
805
+ }
806
+ }
807
+ }
808
+
809
+ bool is_boxed_metric(const ScoringMetric scoring_metric)
810
+ {
811
+ return scoring_metric == BoxedDensity ||
812
+ scoring_metric == BoxedDensity2 ||
813
+ scoring_metric == BoxedRatio;
814
+ }