isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -1,324 +0,0 @@
1
- /* Isolation forests and variations thereof, with adjustments for incorporation
2
- * of categorical variables and missing values.
3
- * Writen for C++11 standard and aimed at being used in R and Python.
4
- *
5
- * This library is based on the following works:
6
- * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
- * "Isolation forest."
8
- * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
- * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
- * "Isolation-based anomaly detection."
11
- * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
- * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
- * "Extended Isolation Forest."
14
- * arXiv preprint arXiv:1811.02141 (2018).
15
- * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
- * "On detecting clustered anomalies using SCiForest."
17
- * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
- * [5] https://sourceforge.net/projects/iforest/
19
- * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
- * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
- * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
- * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
- *
24
- * BSD 2-Clause License
25
- * Copyright (c) 2020, David Cortes
26
- * All rights reserved.
27
- * Redistribution and use in source and binary forms, with or without
28
- * modification, are permitted provided that the following conditions are met:
29
- * * Redistributions of source code must retain the above copyright notice, this
30
- * list of conditions and the following disclaimer.
31
- * * Redistributions in binary form must reproduce the above copyright notice,
32
- * this list of conditions and the following disclaimer in the documentation
33
- * and/or other materials provided with the distribution.
34
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
- */
45
- #include "isotree.hpp"
46
-
47
- void decide_column(size_t ncols_numeric, size_t ncols_categ, size_t &col_chosen, ColType &col_type,
48
- RNG_engine &rnd_generator, std::uniform_int_distribution<size_t> &runif,
49
- std::discrete_distribution<size_t> &col_sampler)
50
- {
51
- if (!col_sampler.max())
52
- col_chosen = runif(rnd_generator);
53
- else
54
- col_chosen = col_sampler(rnd_generator);
55
-
56
- if (col_chosen >= ncols_numeric)
57
- {
58
- col_chosen -= ncols_numeric;
59
- col_type = Categorical;
60
- }
61
-
62
- else { col_type = Numeric; }
63
- }
64
-
65
- void add_unsplittable_col(WorkerMemory &workspace, IsoTree &tree, InputData &input_data)
66
- {
67
- if (tree.col_type == Numeric)
68
- workspace.cols_possible[tree.col_num] = false;
69
- else
70
- workspace.cols_possible[tree.col_num + input_data.ncols_numeric] = false;
71
- }
72
-
73
- void add_unsplittable_col(WorkerMemory &workspace, InputData &input_data)
74
- {
75
- if (workspace.col_type == Numeric)
76
- workspace.cols_possible[workspace.col_chosen] = false;
77
- else
78
- workspace.cols_possible[workspace.col_chosen + input_data.ncols_numeric] = false;
79
- }
80
-
81
- bool check_is_not_unsplittable_col(WorkerMemory &workspace, IsoTree &tree, InputData &input_data)
82
- {
83
- if (tree.col_type == Numeric)
84
- return workspace.cols_possible[tree.col_num];
85
- else
86
- return workspace.cols_possible[tree.col_num + input_data.ncols_numeric];
87
- }
88
-
89
- /* for use in regular model */
90
- void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params, IsoTree &tree)
91
- {
92
- if (tree.col_type == Numeric)
93
- {
94
- if (input_data.Xc_indptr == NULL)
95
- get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * tree.col_num,
96
- workspace.st, workspace.end, model_params.missing_action,
97
- workspace.xmin, workspace.xmax, workspace.unsplittable);
98
- else
99
- get_range(workspace.ix_arr.data(), workspace.st, workspace.end, tree.col_num,
100
- input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
101
- model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable);
102
- }
103
-
104
- else
105
- {
106
- get_categs(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * tree.col_num,
107
- workspace.st, workspace.end, input_data.ncat[tree.col_num],
108
- model_params.missing_action, workspace.categs.data(), workspace.npresent, workspace.unsplittable);
109
- }
110
- }
111
-
112
- /* for use in extended model */
113
- void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
114
- {
115
- if (workspace.col_type == Numeric)
116
- {
117
- if (input_data.Xc_indptr == NULL)
118
- get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * workspace.col_chosen,
119
- workspace.st, workspace.end, model_params.missing_action,
120
- workspace.xmin, workspace.xmax, workspace.unsplittable);
121
- else
122
- get_range(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
123
- input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
124
- model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable);
125
- }
126
-
127
- else
128
- {
129
- get_categs(workspace.ix_arr.data(), input_data.categ_data + input_data.nrows * workspace.col_chosen,
130
- workspace.st, workspace.end, input_data.ncat[workspace.col_chosen],
131
- model_params.missing_action, workspace.categs.data(), workspace.npresent, workspace.unsplittable);
132
- }
133
- }
134
-
135
- int choose_cat_from_present(WorkerMemory &workspace, InputData &input_data, size_t col_num)
136
- {
137
- int chosen_cat = std::uniform_int_distribution<int>
138
- (0, workspace.npresent - 1)
139
- (workspace.rnd_generator);
140
- workspace.ncat_tried = 0;
141
- for (int cat = 0; cat < input_data.ncat[col_num]; cat++)
142
- {
143
- if (workspace.categs[cat] > 0)
144
- {
145
- if (workspace.ncat_tried == chosen_cat)
146
- return cat;
147
- else
148
- workspace.ncat_tried++;
149
- }
150
- }
151
-
152
- return -1; /* this will never be reached, but CRAN complains otherwise */
153
- }
154
-
155
- void update_col_sampler(WorkerMemory &workspace, InputData &input_data)
156
- {
157
- if (!workspace.col_sampler.max())
158
- return;
159
-
160
- std::vector<double> col_weights = workspace.col_sampler.probabilities();
161
- for (size_t col = 0; col < input_data.ncols_numeric; col++)
162
- if (!workspace.cols_possible[col])
163
- col_weights[col] = 0;
164
- for (size_t col = 0; col < input_data.ncols_categ; col++)
165
- if (!workspace.cols_possible[col + input_data.ncols_numeric])
166
- col_weights[col + input_data.ncols_numeric] = 0;
167
- workspace.col_sampler = std::discrete_distribution<size_t>(col_weights.begin(), col_weights.end());
168
- }
169
-
170
- bool is_col_taken(std::vector<bool> &col_is_taken, std::unordered_set<size_t> &col_is_taken_s,
171
- InputData &input_data, size_t col_num, ColType col_type)
172
- {
173
- col_num += ((col_type == Categorical)? 0 : input_data.ncols_categ);
174
- if (col_is_taken.size())
175
- return col_is_taken[col_num];
176
- else
177
- return col_is_taken_s.find(col_num) != col_is_taken_s.end();
178
- }
179
-
180
- void set_col_as_taken(std::vector<bool> &col_is_taken, std::unordered_set<size_t> &col_is_taken_s,
181
- InputData &input_data, size_t col_num, ColType col_type)
182
- {
183
- col_num += ((col_type == Categorical)? 0 : input_data.ncols_categ);
184
- if (col_is_taken.size())
185
- col_is_taken[col_num] = true;
186
- else
187
- col_is_taken_s.insert(col_num);
188
- }
189
-
190
- void add_separation_step(WorkerMemory &workspace, InputData &input_data, double remainder)
191
- {
192
- if (workspace.weights_arr.size())
193
- increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
194
- input_data.nrows, workspace.tmat_sep.data(), workspace.weights_arr.data(), remainder);
195
- else if (workspace.weights_map.size())
196
- increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
197
- input_data.nrows, workspace.tmat_sep.data(), workspace.weights_map, remainder);
198
- else
199
- increase_comb_counter(workspace.ix_arr.data(), workspace.st, workspace.end,
200
- input_data.nrows, workspace.tmat_sep.data(), remainder);
201
- }
202
-
203
- void add_remainder_separation_steps(WorkerMemory &workspace, InputData &input_data, long double sum_weight)
204
- {
205
- if (
206
- ((workspace.end - workspace.st) > 0 && !workspace.weights_arr.size() && !workspace.weights_map.size()) ||
207
- (sum_weight > 1 && (workspace.weights_arr.size() || workspace.weights_map.size()))
208
- )
209
- {
210
- double expected_dsep;
211
- if (!workspace.weights_arr.size() && !workspace.weights_map.size())
212
- expected_dsep = expected_separation_depth(workspace.end - workspace.st + 1);
213
- else
214
- expected_dsep = expected_separation_depth(sum_weight);
215
-
216
- add_separation_step(workspace, input_data, expected_dsep + 1);
217
- }
218
- }
219
-
220
- void remap_terminal_trees(IsoForest *model_outputs, ExtIsoForest *model_outputs_ext,
221
- PredictionData &prediction_data, sparse_ix *restrict tree_num, int nthreads)
222
- {
223
- size_t ntrees = (model_outputs != NULL)? model_outputs->trees.size() : model_outputs_ext->hplanes.size();
224
- size_t max_tree, curr_term;
225
- std::vector<sparse_ix> tree_mapping;
226
- if (model_outputs != NULL)
227
- {
228
- max_tree = std::accumulate(model_outputs->trees.begin(),
229
- model_outputs->trees.end(),
230
- (size_t)0,
231
- [](const size_t curr_max, const std::vector<IsoTree> &tr)
232
- {return std::max(curr_max, tr.size());});
233
- tree_mapping.resize(max_tree);
234
- for (size_t tree = 0; tree < ntrees; tree++)
235
- {
236
- std::fill(tree_mapping.begin(), tree_mapping.end(), (size_t)0);
237
- curr_term = 0;
238
- for (size_t node = 0; node < model_outputs->trees[tree].size(); node++)
239
- if (model_outputs->trees[tree][node].score >= 0)
240
- tree_mapping[node] = curr_term++;
241
-
242
- #pragma omp parallel for schedule(static) num_threads(nthreads) shared(tree_num, tree_mapping, tree, prediction_data)
243
- for (size_t_for row = 0; row < prediction_data.nrows; row++)
244
- tree_num[row + tree * prediction_data.nrows] = tree_mapping[tree_num[row + tree * prediction_data.nrows]];
245
- }
246
- }
247
-
248
- else
249
- {
250
- max_tree = std::accumulate(model_outputs_ext->hplanes.begin(),
251
- model_outputs_ext->hplanes.end(),
252
- (size_t)0,
253
- [](const size_t curr_max, const std::vector<IsoHPlane> &tr)
254
- {return std::max(curr_max, tr.size());});
255
- tree_mapping.resize(max_tree);
256
- for (size_t tree = 0; tree < ntrees; tree++)
257
- {
258
- std::fill(tree_mapping.begin(), tree_mapping.end(), (size_t)0);
259
- curr_term = 0;
260
- for (size_t node = 0; node < model_outputs_ext->hplanes[tree].size(); node++)
261
- if (model_outputs_ext->hplanes[tree][node].score >= 0)
262
- tree_mapping[node] = curr_term++;
263
-
264
- #pragma omp parallel for schedule(static) num_threads(nthreads) shared(tree_num, tree_mapping, tree, prediction_data)
265
- for (size_t_for row = 0; row < prediction_data.nrows; row++)
266
- tree_num[row + tree * prediction_data.nrows] = tree_mapping[tree_num[row + tree * prediction_data.nrows]];
267
- }
268
- }
269
- }
270
-
271
- void backup_recursion_state(WorkerMemory &workspace, RecursionState &recursion_state)
272
- {
273
- recursion_state.st = workspace.st;
274
- recursion_state.st_NA = workspace.st_NA;
275
- recursion_state.end_NA = workspace.end_NA;
276
- recursion_state.split_ix = workspace.split_ix;
277
- recursion_state.end = workspace.end;
278
- recursion_state.cols_possible = workspace.cols_possible;
279
- recursion_state.col_sampler = workspace.col_sampler;
280
-
281
- /* for the extended model, it's not necessary to copy everything */
282
- if (!workspace.comb_val.size())
283
- {
284
- recursion_state.ix_arr = std::vector<size_t>(workspace.ix_arr.begin() + workspace.st_NA,
285
- workspace.ix_arr.begin() + workspace.end + 1);
286
- size_t tot = workspace.end - workspace.st_NA + 1;
287
- if (workspace.weights_arr.size() || workspace.weights_map.size())
288
- recursion_state.weights_arr = std::unique_ptr<double[]>(new double[tot]);
289
- if (workspace.weights_arr.size())
290
- for (size_t ix = 0; ix < tot; ix++)
291
- recursion_state.weights_arr[ix] = workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]];
292
- else if (workspace.weights_map.size())
293
- for (size_t ix = 0; ix < tot; ix++)
294
- recursion_state.weights_arr[ix] = workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]];
295
-
296
-
297
- }
298
- }
299
-
300
-
301
- void restore_recursion_state(WorkerMemory &workspace, RecursionState &recursion_state)
302
- {
303
- workspace.st = recursion_state.st;
304
- workspace.st_NA = recursion_state.st_NA;
305
- workspace.end_NA = recursion_state.end_NA;
306
- workspace.split_ix = recursion_state.split_ix;
307
- workspace.end = recursion_state.end;
308
- workspace.cols_possible = std::move(recursion_state.cols_possible);
309
- workspace.col_sampler = std::move(recursion_state.col_sampler);
310
-
311
- if (!workspace.comb_val.size())
312
- {
313
- std::copy(recursion_state.ix_arr.begin(),
314
- recursion_state.ix_arr.end(),
315
- workspace.ix_arr.begin() + recursion_state.st_NA);
316
- size_t tot = workspace.end - workspace.st_NA + 1;
317
- if (workspace.weights_arr.size())
318
- for (size_t ix = 0; ix < tot; ix++)
319
- workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]] = recursion_state.weights_arr[ix];
320
- else if (workspace.weights_map.size())
321
- for (size_t ix = 0; ix < tot; ix++)
322
- workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]] = recursion_state.weights_arr[ix];
323
- }
324
- }