isotree 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,262 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
+ * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
+ *
24
+ * BSD 2-Clause License
25
+ * Copyright (c) 2019, David Cortes
26
+ * All rights reserved.
27
+ * Redistribution and use in source and binary forms, with or without
28
+ * modification, are permitted provided that the following conditions are met:
29
+ * * Redistributions of source code must retain the above copyright notice, this
30
+ * list of conditions and the following disclaimer.
31
+ * * Redistributions in binary form must reproduce the above copyright notice,
32
+ * this list of conditions and the following disclaimer in the documentation
33
+ * and/or other materials provided with the distribution.
34
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
+ */
45
+ #include "isotree.hpp"
46
+
47
+ #ifdef _ENABLE_CEREAL
48
+
49
+
50
+ template <class T>
51
+ void serialize_obj(T &obj, std::ostream &output)
52
+ {
53
+ cereal::BinaryOutputArchive archive(output);
54
+ archive(obj);
55
+ }
56
+ template <class T>
57
+ std::string serialize_obj(T &obj)
58
+ {
59
+ std::stringstream ss;
60
+ {
61
+ cereal::BinaryOutputArchive archive(ss);
62
+ archive(obj);
63
+ }
64
+ return ss.str();
65
+ }
66
+ template <class T, class I>
67
+ void deserialize_obj(T &output, I &serialized)
68
+ {
69
+ cereal::BinaryInputArchive archive(serialized);
70
+ archive(output);
71
+ }
72
+ template <class T>
73
+ void deserialize_obj(T &output, std::string &serialized, bool move_str)
74
+ {
75
+ std::stringstream ss;
76
+ if (move_str)
77
+ ss.str(std::move(serialized));
78
+ else
79
+ /* Bug with GCC4 not implementing the move method for stringsreams
80
+ https://stackoverflow.com/questions/50926506/deleted-function-std-basic-stringstream-in-linux-with-g
81
+ https://github.com/david-cortes/isotree/issues/7 */
82
+ // ss = std::stringstream(serialized); /* <- fails with GCC4, CRAN complains */
83
+ {
84
+ std::string str_copy = serialized;
85
+ ss.str(str_copy);
86
+ }
87
+ deserialize_obj(output, ss);
88
+ }
89
+
90
+
91
+ /* Serialization and de-serialization functions using Cereal
92
+ *
93
+ * Parameters
94
+ * ==========
95
+ * - model (in)
96
+ * A model object to serialize, after being fitted through function 'fit_iforest'.
97
+ * - imputer (in)
98
+ * An imputer object to serialize, after being fitted through function 'fit_iforest'
99
+ * with 'build_imputer=true'.
100
+ * - output_obj (out)
101
+ * An already-allocated object into which a serialized object of the same class will
102
+ * be de-serialized. The contents of this object will be overwritten. Should be initialized
103
+ * through the default constructor (e.g. 'new ExtIsoForest' or 'ExtIsoForest()').
104
+ * - output (out)
105
+ * An output stream (any type will do) in which to save/persist/serialize the
106
+ * model or imputer object using the cereal library. In the functions that do not
107
+ * take this parameter, it will be returned as a string containing the raw bytes.
108
+ * - serialized (in)
109
+ * The input stream which contains the serialized/saved/persisted model or imputer object,
110
+ * which will be de-serialized into 'output'.
111
+ * - output_file_path
112
+ * File name into which to write the serialized model or imputer object as raw bytes.
113
+ * Note that, on Windows, passing non-ASCII characters will fail, and in such case,
114
+ * you might instead want to use instead the versions that take 'wchar_t', which are
115
+ * only available in the MSVC compiler (it uses 'std::ofstream' internally, which as
116
+ * of C++20, is not required by the standard to accept 'wchar_t' in its constructor).
117
+ * Be aware that it will only write raw bytes, thus metadata such as CPU endianness
118
+ * will be lost. If you need to transfer files berween e.g. an x86 computer and a SPARC
119
+ * server, you'll have to use other methods.
120
+ * This functionality is intended for being easily wrapper into scripting languages
121
+ * without having to copy the contents to to some intermediate language.
122
+ * - input_file_path
123
+ * File name from which to read a serialized model or imputer object as raw bytes.
124
+ * See the description for 'output_file_path' for more details.
125
+ * - move_str
126
+ * Whether to move ('std::move') the contents of the string passed as input in order
127
+ * to speed things up and avoid making a redundant copy of the raw bytes. If passing
128
+ * 'true', the input string will be rendered empty afterwards.
129
+ */
130
+ void serialize_isoforest(IsoForest &model, std::ostream &output)
131
+ {
132
+ serialize_obj(model, output);
133
+ }
134
+ void serialize_isoforest(IsoForest &model, const char *output_file_path)
135
+ {
136
+ std::ofstream output(output_file_path);
137
+ serialize_obj(model, output);
138
+ }
139
+ std::string serialize_isoforest(IsoForest &model)
140
+ {
141
+ return serialize_obj(model);
142
+ }
143
+ void deserialize_isoforest(IsoForest &output_obj, std::istream &serialized)
144
+ {
145
+ deserialize_obj(output_obj, serialized);
146
+ }
147
+ void deserialize_isoforest(IsoForest &output_obj, const char *input_file_path)
148
+ {
149
+ std::ifstream serialized(input_file_path);
150
+ deserialize_obj(output_obj, serialized);
151
+ }
152
+ void deserialize_isoforest(IsoForest &output_obj, std::string &serialized, bool move_str)
153
+ {
154
+ deserialize_obj(output_obj, serialized, move_str);
155
+ }
156
+
157
+
158
+
159
+ void serialize_ext_isoforest(ExtIsoForest &model, std::ostream &output)
160
+ {
161
+ serialize_obj(model, output);
162
+ }
163
+ void serialize_ext_isoforest(ExtIsoForest &model, const char *output_file_path)
164
+ {
165
+ std::ofstream output(output_file_path);
166
+ serialize_obj(model, output);
167
+ }
168
+ std::string serialize_ext_isoforest(ExtIsoForest &model)
169
+ {
170
+ return serialize_obj(model);
171
+ }
172
+ void deserialize_ext_isoforest(ExtIsoForest &output_obj, std::istream &serialized)
173
+ {
174
+ deserialize_obj(output_obj, serialized);
175
+ }
176
+ void deserialize_ext_isoforest(ExtIsoForest &output_obj, const char *input_file_path)
177
+ {
178
+ std::ifstream serialized(input_file_path);
179
+ deserialize_obj(output_obj, serialized);
180
+ }
181
+ void deserialize_ext_isoforest(ExtIsoForest &output_obj, std::string &serialized, bool move_str)
182
+ {
183
+ deserialize_obj(output_obj, serialized, move_str);
184
+ }
185
+
186
+
187
+
188
+
189
+ void serialize_imputer(Imputer &imputer, std::ostream &output)
190
+ {
191
+ serialize_obj(imputer, output);
192
+ }
193
+ void serialize_imputer(Imputer &imputer, const char *output_file_path)
194
+ {
195
+ std::ofstream output(output_file_path);
196
+ serialize_obj(imputer, output);
197
+ }
198
+ std::string serialize_imputer(Imputer &imputer)
199
+ {
200
+ return serialize_obj(imputer);
201
+ }
202
+ void deserialize_imputer(Imputer &output_obj, std::istream &serialized)
203
+ {
204
+ deserialize_obj(output_obj, serialized);
205
+ }
206
+ void deserialize_imputer(Imputer &output_obj, const char *input_file_path)
207
+ {
208
+ std::ifstream serialized(input_file_path);
209
+ deserialize_obj(output_obj, serialized);
210
+ }
211
+ void deserialize_imputer(Imputer &output_obj, std::string &serialized, bool move_str)
212
+ {
213
+ deserialize_obj(output_obj, serialized, move_str);
214
+ }
215
+
216
+
217
+ #ifdef _MSC_VER
218
+ void serialize_isoforest(IsoForest &model, const wchar_t *output_file_path)
219
+ {
220
+ std::ofstream output(output_file_path);
221
+ serialize_obj(model, output);
222
+ }
223
+ void deserialize_isoforest(IsoForest &output_obj, const wchar_t *input_file_path)
224
+ {
225
+ std::ifstream serialized(input_file_path);
226
+ deserialize_obj(output_obj, serialized);
227
+ }
228
+ void serialize_ext_isoforest(ExtIsoForest &model, const wchar_t *output_file_path)
229
+ {
230
+ std::ofstream output(output_file_path);
231
+ serialize_obj(model, output);
232
+ }
233
+ void deserialize_ext_isoforest(ExtIsoForest &output_obj, const wchar_t *input_file_path)
234
+ {
235
+ std::ifstream serialized(input_file_path);
236
+ deserialize_obj(output_obj, serialized);
237
+ }
238
+ void serialize_imputer(Imputer &imputer, const wchar_t *output_file_path)
239
+ {
240
+ std::ofstream output(output_file_path);
241
+ serialize_obj(imputer, output);
242
+ }
243
+ void deserialize_imputer(Imputer &output_obj, const wchar_t *input_file_path)
244
+ {
245
+ std::ifstream serialized(input_file_path);
246
+ deserialize_obj(output_obj, serialized);
247
+ }
248
+ bool has_msvc()
249
+ {
250
+ return true;
251
+ }
252
+
253
+ #else
254
+ bool has_msvc()
255
+ {
256
+ return false;
257
+ }
258
+
259
+ #endif /* ifdef _MSC_VER */
260
+
261
+
262
+ #endif /* _ENABLE_CEREAL */
@@ -0,0 +1,1574 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
22
+ * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
23
+ *
24
+ * BSD 2-Clause License
25
+ * Copyright (c) 2019, David Cortes
26
+ * All rights reserved.
27
+ * Redistribution and use in source and binary forms, with or without
28
+ * modification, are permitted provided that the following conditions are met:
29
+ * * Redistributions of source code must retain the above copyright notice, this
30
+ * list of conditions and the following disclaimer.
31
+ * * Redistributions in binary form must reproduce the above copyright notice,
32
+ * this list of conditions and the following disclaimer in the documentation
33
+ * and/or other materials provided with the distribution.
34
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44
+ */
45
+ #include "isotree.hpp"
46
+
47
+ /* ceil(log2(x)) done with bit-wise operations ensures perfect precision (and it's faster too)
48
+ https://stackoverflow.com/questions/2589096/find-most-significant-bit-left-most-that-is-set-in-a-bit-array
49
+ https://stackoverflow.com/questions/11376288/fast-computing-of-log2-for-64-bit-integers */
50
+ #if SIZE_MAX == UINT32_MAX /* 32-bit systems */
51
+ #ifdef __builtin_clz
52
+ size_t log2ceil(size_t x) {return (unsigned) (1 + (8*sizeof (uint32_t) - __builtin_clz(x-1) - 1));}
53
+ #else
54
+ static const int MultiplyDeBruijnBitPosition[32] =
55
+ {
56
+ 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
57
+ 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31
58
+ };
59
+ size_t log2ceil( size_t v )
60
+ {
61
+
62
+ v--;
63
+ v |= v >> 1; // first round down to one less than a power of 2
64
+ v |= v >> 2;
65
+ v |= v >> 4;
66
+ v |= v >> 8;
67
+ v |= v >> 16;
68
+
69
+ return MultiplyDeBruijnBitPosition[( uint32_t )( v * 0x07C4ACDDU ) >> 27] + 1;
70
+ }
71
+ #endif
72
+ #elif SIZE_MAX == UINT64_MAX /* 64-bit systems */
73
+ #ifdef __builtin_clzl
74
+ size_t log2ceil(size_t x) {return (unsigned) (1 + (8*sizeof (uint64_t) - __builtin_clzl(x-1) - 1));}
75
+ #else
76
+ static const uint64_t tab64[64] = {
77
+ 63, 0, 58, 1, 59, 47, 53, 2,
78
+ 60, 39, 48, 27, 54, 33, 42, 3,
79
+ 61, 51, 37, 40, 49, 18, 28, 20,
80
+ 55, 30, 34, 11, 43, 14, 22, 4,
81
+ 62, 57, 46, 52, 38, 26, 32, 41,
82
+ 50, 36, 17, 19, 29, 10, 13, 21,
83
+ 56, 45, 25, 31, 35, 16, 9, 12,
84
+ 44, 24, 15, 8, 23, 7, 6, 5};
85
+
86
+ size_t log2ceil(size_t value)
87
+ {
88
+ value--;
89
+ value |= value >> 1;
90
+ value |= value >> 2;
91
+ value |= value >> 4;
92
+ value |= value >> 8;
93
+ value |= value >> 16;
94
+ value |= value >> 32;
95
+ return tab64[((uint64_t)((value - (value >> 1))*0x07EDD5E59A4E28C2)) >> 58] + 1;
96
+ }
97
+ #endif
98
+ #else /* other architectures - might not be entirely precise, and will be slower */
99
+ size_t log2ceil(size_t x) {return (size_t)(ceill(log2l((long double) x)));}
100
+ #endif
101
+
102
+ /* http://fredrik-j.blogspot.com/2009/02/how-not-to-compute-harmonic-numbers.html
103
+ https://en.wikipedia.org/wiki/Harmonic_number */
104
+ #define THRESHOLD_EXACT_H 256 /* above this will get approximated */
105
+ double harmonic(size_t n)
106
+ {
107
+ if (n > THRESHOLD_EXACT_H)
108
+ return logl((long double)n) + (long double)0.5772156649;
109
+ else
110
+ return harmonic_recursive((double)1, (double)(n + 1));
111
+ }
112
+
113
+ double harmonic_recursive(double a, double b)
114
+ {
115
+ if (b == a + 1) return 1 / a;
116
+ double m = floor((a + b) / 2);
117
+ return harmonic_recursive(a, m) + harmonic_recursive(m, b);
118
+ }
119
+
120
+ /* https://stats.stackexchange.com/questions/423542/isolation-forest-and-average-expected-depth-formula
121
+ https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom */
122
+ double expected_avg_depth(size_t sample_size)
123
+ {
124
+ switch(sample_size)
125
+ {
126
+ case 1: return 0.;
127
+ case 2: return 1.;
128
+ case 3: return 5.0/3.0;
129
+ case 4: return 13.0/6.0;
130
+ case 5: return 77.0/30.0;
131
+ case 6: return 29.0/10.0;
132
+ case 7: return 223.0/70.0;
133
+ case 8: return 481.0/140.0;
134
+ case 9: return 4609.0/1260.0;
135
+ default:
136
+ {
137
+ return 2 * (harmonic(sample_size) - 1);
138
+ }
139
+ }
140
+ }
141
+
142
+ double expected_avg_depth(long double approx_sample_size)
143
+ {
144
+ if (approx_sample_size < 1.5)
145
+ return 0;
146
+ else if (approx_sample_size < 2.5)
147
+ return 1;
148
+ else if (approx_sample_size <= THRESHOLD_EXACT_H)
149
+ return expected_avg_depth((size_t) roundl(approx_sample_size));
150
+ else
151
+ return 2 * logl(approx_sample_size) - (long double)1.4227843351;
152
+ }
153
+
154
+ /* https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree */
155
+ #define THRESHOLD_EXACT_S 87670 /* difference is <5e-4 */
156
+ double expected_separation_depth(size_t n)
157
+ {
158
+ switch(n)
159
+ {
160
+ case 0: return 0.;
161
+ case 1: return 0.;
162
+ case 2: return 1.;
163
+ case 3: return 1. + (1./3.);
164
+ case 4: return 1. + (1./3.) + (2./9.);
165
+ case 5: return 1.71666666667;
166
+ case 6: return 1.84;
167
+ case 7: return 1.93809524;
168
+ case 8: return 2.01836735;
169
+ case 9: return 2.08551587;
170
+ case 10: return 2.14268078;
171
+ default:
172
+ {
173
+ if (n >= THRESHOLD_EXACT_S)
174
+ return 3;
175
+ else
176
+ return expected_separation_depth_hotstart((double)2.14268078, (size_t)10, n);
177
+ }
178
+ }
179
+ }
180
+
181
+ double expected_separation_depth_hotstart(double curr, size_t n_curr, size_t n_final)
182
+ {
183
+ if (n_final >= 1360)
184
+ {
185
+ if (n_final >= THRESHOLD_EXACT_S)
186
+ return 3;
187
+ else if (n_final >= 40774)
188
+ return 2.999;
189
+ else if (n_final >= 18844)
190
+ return 2.998;
191
+ else if (n_final >= 11956)
192
+ return 2.997;
193
+ else if (n_final >= 8643)
194
+ return 2.996;
195
+ else if (n_final >= 6713)
196
+ return 2.995;
197
+ else if (n_final >= 4229)
198
+ return 2.9925;
199
+ else if (n_final >= 3040)
200
+ return 2.99;
201
+ else if (n_final >= 2724)
202
+ return 2.989;
203
+ else if (n_final >= 1902)
204
+ return 2.985;
205
+ else if (n_final >= 1360)
206
+ return 2.98;
207
+
208
+ /* Note on the chosen precision: when calling it on smaller sample sizes,
209
+ the standard error of the separation depth will be larger, thus it's less
210
+ critical to get it right down to the smallest possible precision, while for
211
+ larger samples the standard error of the separation depth will be smaller */
212
+ }
213
+
214
+ for (size_t i = n_curr + 1; i <= n_final; i++)
215
+ curr += (-curr * (double)i + 3. * (double)i - 4.) / ((double)i * ((double)(i-1)));
216
+ return curr;
217
+ }
218
+
219
+ /* linear interpolation */
220
+ double expected_separation_depth(long double n)
221
+ {
222
+ if (n >= THRESHOLD_EXACT_S)
223
+ return 3;
224
+ double s_l = expected_separation_depth((size_t) floorl(n));
225
+ long double u = ceill(n);
226
+ double s_u = s_l + (-s_l * u + 3. * u - 4.) / (u * (u - 1.));
227
+ double diff = n - floorl(n);
228
+ return s_l + diff * s_u;
229
+ }
230
+
231
+ #define ix_comb(i, j, n, ncomb) ( ((ncomb) + ((j) - (i))) - 1 - (((n) - (i)) * ((n) - (i) - 1)) / 2 )
232
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n, double counter[], double exp_remainder)
233
+ {
234
+ size_t i, j;
235
+ size_t ncomb = (n * (n - 1)) / 2;
236
+ if (exp_remainder <= 1)
237
+ for (size_t el1 = st; el1 < end; el1++)
238
+ {
239
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
240
+ {
241
+ i = std::min(ix_arr[el1], ix_arr[el2]);
242
+ j = std::max(ix_arr[el1], ix_arr[el2]);
243
+ // counter[i * (n - (i+1)/2) + j - i - 1]++; /* beaware integer division */
244
+ counter[ix_comb(i, j, n, ncomb)]++;
245
+ }
246
+ }
247
+ else
248
+ for (size_t el1 = st; el1 < end; el1++)
249
+ {
250
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
251
+ {
252
+ i = std::min(ix_arr[el1], ix_arr[el2]);
253
+ j = std::max(ix_arr[el1], ix_arr[el2]);
254
+ counter[ix_comb(i, j, n, ncomb)] += exp_remainder;
255
+ }
256
+ }
257
+ }
258
+
259
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n,
260
+ double *restrict counter, double *restrict weights, double exp_remainder)
261
+ {
262
+ size_t i, j;
263
+ size_t ncomb = (n * (n - 1)) / 2;
264
+ if (exp_remainder <= 1)
265
+ for (size_t el1 = st; el1 < end; el1++)
266
+ {
267
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
268
+ {
269
+ i = std::min(ix_arr[el1], ix_arr[el2]);
270
+ j = std::max(ix_arr[el1], ix_arr[el2]);
271
+ // counter[i * (n - (i+1)/2) + j - i - 1] += weights[i] * weights[j]; /* beaware integer division */
272
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j];
273
+ }
274
+ }
275
+ else
276
+ for (size_t el1 = st; el1 < end; el1++)
277
+ {
278
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
279
+ {
280
+ i = std::min(ix_arr[el1], ix_arr[el2]);
281
+ j = std::max(ix_arr[el1], ix_arr[el2]);
282
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j] * exp_remainder;
283
+ }
284
+ }
285
+ }
286
+
287
+ /* Note to self: don't try merge this into a template with the one above, as the other one has 'restrict' qualifier */
288
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n,
289
+ double counter[], std::unordered_map<size_t, double> &weights, double exp_remainder)
290
+ {
291
+ size_t i, j;
292
+ size_t ncomb = (n * (n - 1)) / 2;
293
+ if (exp_remainder <= 1)
294
+ for (size_t el1 = st; el1 < end; el1++)
295
+ {
296
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
297
+ {
298
+ i = std::min(ix_arr[el1], ix_arr[el2]);
299
+ j = std::max(ix_arr[el1], ix_arr[el2]);
300
+ // counter[i * (n - (i+1)/2) + j - i - 1] += weights[i] * weights[j]; /* beaware integer division */
301
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j];
302
+ }
303
+ }
304
+ else
305
+ for (size_t el1 = st; el1 < end; el1++)
306
+ {
307
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
308
+ {
309
+ i = std::min(ix_arr[el1], ix_arr[el2]);
310
+ j = std::max(ix_arr[el1], ix_arr[el2]);
311
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j] * exp_remainder;
312
+ }
313
+ }
314
+ }
315
+
316
+ void increase_comb_counter_in_groups(size_t ix_arr[], size_t st, size_t end, size_t split_ix, size_t n,
317
+ double counter[], double exp_remainder)
318
+ {
319
+ size_t n_group = 0;
320
+ for (size_t ix = st; ix <= end; ix++)
321
+ if (ix_arr[ix] < split_ix)
322
+ n_group++;
323
+ else
324
+ break;
325
+
326
+ n = n - split_ix;
327
+
328
+ if (exp_remainder <= 1)
329
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
330
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
331
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]++;
332
+ else
333
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
334
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
335
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix] += exp_remainder;
336
+ }
337
+
338
+ void increase_comb_counter_in_groups(size_t ix_arr[], size_t st, size_t end, size_t split_ix, size_t n,
339
+ double *restrict counter, double *restrict weights, double exp_remainder)
340
+ {
341
+ size_t n_group = 0;
342
+ for (size_t ix = st; ix <= end; ix++)
343
+ if (ix_arr[ix] < split_ix)
344
+ n_group++;
345
+ else
346
+ break;
347
+
348
+ n = n - split_ix;
349
+
350
+ if (exp_remainder <= 1)
351
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
352
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
353
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]
354
+ +=
355
+ weights[ix_arr[ix1]] * weights[ix_arr[ix2]];
356
+ else
357
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
358
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
359
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]
360
+ +=
361
+ weights[ix_arr[ix1]] * weights[ix_arr[ix2]] * exp_remainder;
362
+ }
363
+
364
+ void tmat_to_dense(double *restrict tmat, double *restrict dmat, size_t n, bool diag_to_one)
365
+ {
366
+ size_t ncomb = (n * (n - 1)) / 2;
367
+ for (size_t i = 0; i < (n-1); i++)
368
+ {
369
+ for (size_t j = i + 1; j < n; j++)
370
+ {
371
+ // dmat[i + j * n] = dmat[j + i * n] = tmat[i * (n - (i+1)/2) + j - i - 1];
372
+ dmat[i + j * n] = dmat[j + i * n] = tmat[ix_comb(i, j, n, ncomb)];
373
+ }
374
+ }
375
+ if (diag_to_one)
376
+ for (size_t i = 0; i < n; i++)
377
+ dmat[i + i * n] = 1;
378
+ else
379
+ for (size_t i = 0; i < n; i++)
380
+ dmat[i + i * n] = 0;
381
+ }
382
+
383
+ /* Note: do NOT divide by (n-1) as in some situations it will still need to calculate
384
+ the standard deviation with 1-2 observations only (e.g. when using the extended model
385
+ and some column has many rows but only 2 non-missing values, or when using the non-pooled
386
+ std criterion) */
387
+ #define SD_MIN 1e-12
388
+ double calc_sd_raw(size_t cnt, long double sum, long double sum_sq)
389
+ {
390
+ if (cnt <= 1)
391
+ return 0.;
392
+ else
393
+ return sqrtl(fmax(SD_MIN, (sum_sq - (square(sum) / (long double)cnt)) / (long double)cnt ));
394
+ }
395
+
396
+ long double calc_sd_raw_l(size_t cnt, long double sum, long double sum_sq)
397
+ {
398
+ if (cnt <= 1)
399
+ return 0.;
400
+ else
401
+ return sqrtl(fmaxl(SD_MIN, (sum_sq - (square(sum) / (long double)cnt)) / (long double)cnt ));
402
+ }
403
+
404
+ void build_btree_sampler(std::vector<double> &btree_weights, double *restrict sample_weights,
405
+ size_t nrows, size_t &log2_n, size_t &btree_offset)
406
+ {
407
+ /* build a perfectly-balanced binary search tree in which each node will
408
+ hold the sum of the weights of its children */
409
+ log2_n = log2ceil(nrows);
410
+ if (!btree_weights.size())
411
+ btree_weights.resize(pow2(log2_n + 1), 0);
412
+ else
413
+ btree_weights.assign(btree_weights.size(), 0);
414
+ btree_offset = pow2(log2_n) - 1;
415
+
416
+ std::copy(sample_weights, sample_weights + nrows, btree_weights.begin() + btree_offset);
417
+ for (size_t ix = btree_weights.size() - 1; ix > 0; ix--)
418
+ btree_weights[ix_parent(ix)] += btree_weights[ix];
419
+
420
+ if (is_na_or_inf(btree_weights[0]))
421
+ {
422
+ fprintf(stderr, "Numeric precision error with sample weights, will not use them.\n");
423
+ log2_n = 0;
424
+ btree_weights.clear();
425
+ btree_weights.shrink_to_fit();
426
+ }
427
+ }
428
+
429
+ void sample_random_rows(std::vector<size_t> &ix_arr, size_t nrows, bool with_replacement,
430
+ RNG_engine &rnd_generator, std::vector<size_t> &ix_all,
431
+ double sample_weights[], std::vector<double> &btree_weights,
432
+ size_t log2_n, size_t btree_offset, std::vector<bool> &is_repeated)
433
+ {
434
+ size_t ntake = ix_arr.size();
435
+
436
+ /* if with replacement, just generate random uniform numbers */
437
+ if (with_replacement)
438
+ {
439
+ if (sample_weights == NULL)
440
+ {
441
+ std::uniform_int_distribution<size_t> runif(0, nrows - 1);
442
+ for (size_t &ix : ix_arr)
443
+ ix = runif(rnd_generator);
444
+ }
445
+
446
+ else
447
+ {
448
+ std::discrete_distribution<size_t> runif(sample_weights, sample_weights + nrows);
449
+ for (size_t &ix : ix_arr)
450
+ ix = runif(rnd_generator);
451
+ }
452
+ }
453
+
454
+ /* if all the elements are needed, don't bother with any sampling */
455
+ else if (ntake == nrows)
456
+ {
457
+ std::iota(ix_arr.begin(), ix_arr.end(), (size_t)0);
458
+ }
459
+
460
+
461
+ /* if there are sample weights, use binary trees to keep track and update weight
462
+ https://stackoverflow.com/questions/57599509/c-random-non-repeated-integers-with-weights */
463
+ else if (sample_weights != NULL)
464
+ {
465
+ double rnd_subrange, w_left;
466
+ double curr_subrange;
467
+ size_t curr_ix;
468
+ for (size_t &ix : ix_arr)
469
+ {
470
+ /* go down the tree by drawing a random number and
471
+ checking if it falls in the left or right ranges */
472
+ curr_ix = 0;
473
+ curr_subrange = btree_weights[0];
474
+ for (size_t lev = 0; lev < log2_n; lev++)
475
+ {
476
+ rnd_subrange = std::uniform_real_distribution<double>(0, curr_subrange)(rnd_generator);
477
+ w_left = btree_weights[ix_child(curr_ix)];
478
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
479
+ curr_subrange = btree_weights[curr_ix];
480
+ }
481
+
482
+ /* finally, determine element to choose in this iteration */
483
+ ix = curr_ix - btree_offset;
484
+
485
+ /* now remove the weight of the chosen element */
486
+ btree_weights[curr_ix] = 0;
487
+ for (size_t lev = 0; lev < log2_n; lev++)
488
+ {
489
+ curr_ix = ix_parent(curr_ix);
490
+ btree_weights[curr_ix] = btree_weights[ix_child(curr_ix)]
491
+ + btree_weights[ix_child(curr_ix) + 1];
492
+ }
493
+ }
494
+ }
495
+
496
+ /* if no sample weights and not with replacement (most common case expected),
497
+ then use different algorithms depending on the sampled fraction */
498
+ else
499
+ {
500
+
501
+ /* if sampling a larger fraction, fill an array enumerating the rows, shuffle, and take first N */
502
+ if (ntake >= (nrows / 2))
503
+ {
504
+
505
+ if (!ix_all.size())
506
+ ix_all.resize(nrows);
507
+
508
+ /* in order for random seeds to always be reproducible, don't re-use previous shuffles */
509
+ std::iota(ix_all.begin(), ix_all.end(), (size_t)0);
510
+
511
+ /* If the number of sampled elements is large, do a full shuffle, enjoy simd-instructs when copying over */
512
+ if (ntake >= ((nrows * 3)/4))
513
+ {
514
+ std::shuffle(ix_all.begin(), ix_all.end(), rnd_generator);
515
+ ix_arr.assign(ix_all.begin(), ix_all.begin() + ntake);
516
+ }
517
+
518
+ /* otherwise, do only a partial shuffle (use Yates algorithm) and copy elements along the way */
519
+ else
520
+ {
521
+ size_t chosen;
522
+ for (size_t i = nrows - 1; i >= nrows - ntake; i--)
523
+ {
524
+ chosen = std::uniform_int_distribution<size_t>(0, i)(rnd_generator);
525
+ ix_arr[nrows - i - 1] = ix_all[chosen];
526
+ ix_all[chosen] = ix_all[i];
527
+ }
528
+ }
529
+
530
+ }
531
+
532
+ /* If the sample size is small, use Floyd's random sampling algorithm
533
+ https://stackoverflow.com/questions/2394246/algorithm-to-select-a-single-random-combination-of-values */
534
+ else
535
+ {
536
+
537
+ size_t candidate;
538
+
539
+ /* if the sample size is relatively large, use a temporary boolean vector */
540
+ if (((long double)ntake / (long double)nrows) > (1. / 20.))
541
+ {
542
+
543
+ if (!is_repeated.size())
544
+ is_repeated.resize(nrows, false);
545
+ else
546
+ is_repeated.assign(is_repeated.size(), false);
547
+
548
+ for (size_t rnd_ix = nrows - ntake; rnd_ix < nrows; rnd_ix++)
549
+ {
550
+ candidate = std::uniform_int_distribution<size_t>(0, rnd_ix)(rnd_generator);
551
+ if (is_repeated[candidate])
552
+ {
553
+ ix_arr[ntake - (nrows - rnd_ix)] = rnd_ix;
554
+ is_repeated[rnd_ix] = true;
555
+ }
556
+
557
+ else
558
+ {
559
+ ix_arr[ntake - (nrows - rnd_ix)] = candidate;
560
+ is_repeated[candidate] = true;
561
+ }
562
+ }
563
+
564
+ }
565
+
566
+ /* if the sample size is very small, use an unordered set */
567
+ else
568
+ {
569
+
570
+ std::unordered_set<size_t> repeated_set;
571
+ for (size_t rnd_ix = nrows - ntake; rnd_ix < nrows; rnd_ix++)
572
+ {
573
+ candidate = std::uniform_int_distribution<size_t>(0, rnd_ix)(rnd_generator);
574
+ if (repeated_set.find(candidate) == repeated_set.end()) /* TODO: switch to C++20 'contains' */
575
+ {
576
+ ix_arr[ntake - (nrows - rnd_ix)] = candidate;
577
+ repeated_set.insert(candidate);
578
+ }
579
+
580
+ else
581
+ {
582
+ ix_arr[ntake - (nrows - rnd_ix)] = rnd_ix;
583
+ repeated_set.insert(rnd_ix);
584
+ }
585
+ }
586
+
587
+ }
588
+
589
+ }
590
+
591
+ }
592
+ }
593
+
594
+ /* https://stackoverflow.com/questions/57599509/c-random-non-repeated-integers-with-weights */
595
+ void weighted_shuffle(size_t *restrict outp, size_t n, double *restrict weights, double *restrict buffer_arr, RNG_engine &rnd_generator)
596
+ {
597
+ /* determine smallest power of two that is larger than N */
598
+ size_t tree_levels = log2ceil(n);
599
+
600
+ /* initialize vector with place-holders for perfectly-balanced tree */
601
+ std::fill(buffer_arr, buffer_arr + pow2(tree_levels + 1), (double)0);
602
+
603
+ /* compute sums for the tree leaves at each node */
604
+ size_t offset = pow2(tree_levels) - 1;
605
+ for (size_t ix = 0; ix < n; ix++) {
606
+ buffer_arr[ix + offset] = weights[ix];
607
+ }
608
+ for (size_t ix = pow2(tree_levels+1) - 1; ix > 0; ix--) {
609
+ buffer_arr[ix_parent(ix)] += buffer_arr[ix];
610
+ }
611
+
612
+ /* sample according to uniform distribution */
613
+ double rnd_subrange, w_left;
614
+ double curr_subrange;
615
+ int curr_ix;
616
+
617
+ for (size_t el = 0; el < n; el++)
618
+ {
619
+ /* go down the tree by drawing a random number and
620
+ checking if it falls in the left or right sub-ranges */
621
+ curr_ix = 0;
622
+ curr_subrange = buffer_arr[0];
623
+ for (size_t lev = 0; lev < tree_levels; lev++)
624
+ {
625
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
626
+ w_left = buffer_arr[ix_child(curr_ix)];
627
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
628
+ curr_subrange = buffer_arr[curr_ix];
629
+ }
630
+
631
+ /* finally, add element from this iteration */
632
+ outp[el] = curr_ix - offset;
633
+
634
+ /* now remove the weight of the chosen element */
635
+ buffer_arr[curr_ix] = 0;
636
+ for (size_t lev = 0; lev < tree_levels; lev++)
637
+ {
638
+ curr_ix = ix_parent(curr_ix);
639
+ buffer_arr[curr_ix] = buffer_arr[ix_child(curr_ix)]
640
+ + buffer_arr[ix_child(curr_ix) + 1];
641
+ }
642
+ }
643
+
644
+ }
645
+
646
+ /* For hyperplane intersections */
647
+ size_t divide_subset_split(size_t ix_arr[], double x[], size_t st, size_t end, double split_point)
648
+ {
649
+ size_t temp;
650
+ size_t st_orig = st;
651
+ for (size_t row = st_orig; row <= end; row++)
652
+ {
653
+ if (x[row - st_orig] <= split_point)
654
+ {
655
+ temp = ix_arr[st];
656
+ ix_arr[st] = ix_arr[row];
657
+ ix_arr[row] = temp;
658
+ st++;
659
+ }
660
+ }
661
+ return st;
662
+ }
663
+
664
+ /* For numerical columns */
665
+ void divide_subset_split(size_t ix_arr[], double x[], size_t st, size_t end, double split_point,
666
+ MissingAction missing_action, size_t &st_NA, size_t &end_NA, size_t &split_ix)
667
+ {
668
+ size_t temp;
669
+
670
+ /* if NAs are not to be bothered with, just need to do a single pass */
671
+ if (missing_action == Fail)
672
+ {
673
+ /* move to the left if it's l.e. split point */
674
+ for (size_t row = st; row <= end; row++)
675
+ {
676
+ if (x[ix_arr[row]] <= split_point)
677
+ {
678
+ temp = ix_arr[st];
679
+ ix_arr[st] = ix_arr[row];
680
+ ix_arr[row] = temp;
681
+ st++;
682
+ }
683
+ }
684
+ split_ix = st;
685
+ }
686
+
687
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
688
+ else
689
+ {
690
+ for (size_t row = st; row <= end; row++)
691
+ {
692
+ if (!isnan(x[ix_arr[row]]) && x[ix_arr[row]] <= split_point)
693
+ {
694
+ temp = ix_arr[st];
695
+ ix_arr[st] = ix_arr[row];
696
+ ix_arr[row] = temp;
697
+ st++;
698
+ }
699
+ }
700
+ st_NA = st;
701
+
702
+ for (size_t row = st; row <= end; row++)
703
+ {
704
+ if (isnan(x[ix_arr[row]]))
705
+ {
706
+ temp = ix_arr[st];
707
+ ix_arr[st] = ix_arr[row];
708
+ ix_arr[row] = temp;
709
+ st++;
710
+ }
711
+ }
712
+ end_NA = st;
713
+ }
714
+ }
715
+
716
+ /* For sparse numeric columns */
717
+ void divide_subset_split(size_t ix_arr[], size_t st, size_t end, size_t col_num,
718
+ double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[], double split_point,
719
+ MissingAction missing_action, size_t &st_NA, size_t &end_NA, size_t &split_ix)
720
+ {
721
+ /* TODO: this is a mess, needs refactoring */
722
+ /* TODO: when moving zeros, would be better to instead move by '>' (opposite as in here) */
723
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
724
+ {
725
+ if (missing_action == Fail)
726
+ {
727
+ split_ix = (0 <= split_point)? (end+1) : st;
728
+ }
729
+
730
+ else
731
+ {
732
+ st_NA = (0 <= split_point)? (end+1) : st;
733
+ end_NA = (0 <= split_point)? (end+1) : st;
734
+ }
735
+
736
+ }
737
+
738
+ size_t st_col = Xc_indptr[col_num];
739
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
740
+ size_t curr_pos = st_col;
741
+ size_t ind_end_col = Xc_ind[end_col];
742
+ size_t temp;
743
+ bool move_zeros = 0 <= split_point;
744
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
745
+
746
+ if (move_zeros && ptr_st > ix_arr + st)
747
+ st = ptr_st - ix_arr;
748
+
749
+ if (missing_action == Fail)
750
+ {
751
+ if (move_zeros)
752
+ {
753
+ for (size_t *row = ptr_st;
754
+ row != ix_arr + end + 1;
755
+ )
756
+ {
757
+ if (curr_pos >= end_col + 1)
758
+ {
759
+ for (size_t *r = row; r <= ix_arr + end; r++)
760
+ {
761
+ temp = ix_arr[st];
762
+ ix_arr[st] = *r;
763
+ *r = temp;
764
+ st++;
765
+ }
766
+ break;
767
+ }
768
+
769
+ if (Xc_ind[curr_pos] == *row)
770
+ {
771
+ if (Xc[curr_pos] <= split_point)
772
+ {
773
+ temp = ix_arr[st];
774
+ ix_arr[st] = *row;
775
+ *row = temp;
776
+ st++;
777
+ }
778
+ if (curr_pos == end_col && row < ix_arr + end)
779
+ for (size_t *r = row + 1; r <= ix_arr + end; r++)
780
+ {
781
+ temp = ix_arr[st];
782
+ ix_arr[st] = *r;
783
+ *r = temp;
784
+ st++;
785
+ }
786
+ if (row == ix_arr + end || curr_pos == end_col) break;
787
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
788
+ }
789
+
790
+ else
791
+ {
792
+ if (Xc_ind[curr_pos] > *row)
793
+ {
794
+ while (row <= ix_arr + end && Xc_ind[curr_pos] > *row)
795
+ {
796
+ temp = ix_arr[st];
797
+ ix_arr[st] = *row;
798
+ *row = temp;
799
+ st++; row++;
800
+ }
801
+ }
802
+
803
+ else
804
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
805
+ }
806
+ }
807
+ }
808
+
809
+ else /* don't move zeros */
810
+ {
811
+ for (size_t *row = ptr_st;
812
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
813
+ )
814
+ {
815
+ if (Xc_ind[curr_pos] == *row)
816
+ {
817
+ if (Xc[curr_pos] <= split_point)
818
+ {
819
+ temp = ix_arr[st];
820
+ ix_arr[st] = *row;
821
+ *row = temp;
822
+ st++;
823
+ }
824
+ if (row == ix_arr + end || curr_pos == end_col) break;
825
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
826
+ }
827
+
828
+ else
829
+ {
830
+ if (Xc_ind[curr_pos] > *row)
831
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
832
+ else
833
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
834
+ }
835
+ }
836
+ }
837
+
838
+ split_ix = st;
839
+ }
840
+
841
+ else /* can have NAs */
842
+ {
843
+
844
+ bool has_NAs = false;
845
+ if (move_zeros)
846
+ {
847
+ for (size_t *row = ptr_st;
848
+ row != ix_arr + end + 1;
849
+ )
850
+ {
851
+ if (curr_pos >= end_col + 1)
852
+ {
853
+ for (size_t *r = row; r <= ix_arr + end; r++)
854
+ {
855
+ temp = ix_arr[st];
856
+ ix_arr[st] = *r;
857
+ *r = temp;
858
+ st++;
859
+ }
860
+ break;
861
+ }
862
+
863
+ if (Xc_ind[curr_pos] == *row)
864
+ {
865
+ if (isnan(Xc[curr_pos]))
866
+ has_NAs = true;
867
+ else if (Xc[curr_pos] <= split_point)
868
+ {
869
+ temp = ix_arr[st];
870
+ ix_arr[st] = *row;
871
+ *row = temp;
872
+ st++;
873
+ }
874
+ if (curr_pos == end_col && row < ix_arr + end)
875
+ for (size_t *r = row + 1; r <= ix_arr + end; r++)
876
+ {
877
+ temp = ix_arr[st];
878
+ ix_arr[st] = *r;
879
+ *r = temp;
880
+ st++;
881
+ }
882
+ if (row == ix_arr + end || curr_pos == end_col) break;
883
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
884
+ }
885
+
886
+ else
887
+ {
888
+ if (Xc_ind[curr_pos] > *row)
889
+ {
890
+ while (row <= ix_arr + end && Xc_ind[curr_pos] > *row)
891
+ {
892
+ temp = ix_arr[st];
893
+ ix_arr[st] = *row;
894
+ *row = temp;
895
+ st++; row++;
896
+ }
897
+ }
898
+
899
+ else
900
+ {
901
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
902
+ }
903
+ }
904
+ }
905
+ }
906
+
907
+ else /* don't move zeros */
908
+ {
909
+ for (size_t *row = ptr_st;
910
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
911
+ )
912
+ {
913
+ if (Xc_ind[curr_pos] == *row)
914
+ {
915
+ if (isnan(Xc[curr_pos])) has_NAs = true;
916
+ if (Xc[curr_pos] <= split_point && !isnan(Xc[curr_pos]))
917
+ {
918
+ temp = ix_arr[st];
919
+ ix_arr[st] = *row;
920
+ *row = temp;
921
+ st++;
922
+ }
923
+ if (row == ix_arr + end || curr_pos == end_col) break;
924
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
925
+ }
926
+
927
+ else
928
+ {
929
+ if (Xc_ind[curr_pos] > *row)
930
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
931
+ else
932
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
933
+ }
934
+ }
935
+ }
936
+
937
+
938
+ st_NA = st;
939
+ if (has_NAs)
940
+ {
941
+ curr_pos = st_col;
942
+ std::sort(ix_arr + st, ix_arr + end + 1);
943
+ for (size_t *row = ix_arr + st;
944
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
945
+ )
946
+ {
947
+ if (Xc_ind[curr_pos] == *row)
948
+ {
949
+ if (isnan(Xc[curr_pos]))
950
+ {
951
+ temp = ix_arr[st];
952
+ ix_arr[st] = *row;
953
+ *row = temp;
954
+ st++;
955
+ }
956
+ if (row == ix_arr + end || curr_pos == end_col) break;
957
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
958
+ }
959
+
960
+ else
961
+ {
962
+ if (Xc_ind[curr_pos] > *row)
963
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
964
+ else
965
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
966
+ }
967
+ }
968
+ }
969
+ end_NA = st;
970
+
971
+ }
972
+
973
+ }
974
+
975
+ /* For categorical columns split by subset */
976
+ void divide_subset_split(size_t ix_arr[], int x[], size_t st, size_t end, char split_categ[],
977
+ MissingAction missing_action, size_t &st_NA, size_t &end_NA, size_t &split_ix)
978
+ {
979
+ size_t temp;
980
+
981
+ /* if NAs are not to be bothered with, just need to do a single pass */
982
+ if (missing_action == Fail)
983
+ {
984
+ /* move to the left if it's l.e. than the split point */
985
+ for (size_t row = st; row <= end; row++)
986
+ {
987
+ if (split_categ[ x[ix_arr[row]] ] == 1)
988
+ {
989
+ temp = ix_arr[st];
990
+ ix_arr[st] = ix_arr[row];
991
+ ix_arr[row] = temp;
992
+ st++;
993
+ }
994
+ }
995
+ split_ix = st;
996
+ }
997
+
998
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
999
+ else
1000
+ {
1001
+ for (size_t row = st; row <= end; row++)
1002
+ {
1003
+ if (x[ix_arr[row]] >= 0 && split_categ[ x[ix_arr[row]] ] == 1)
1004
+ {
1005
+ temp = ix_arr[st];
1006
+ ix_arr[st] = ix_arr[row];
1007
+ ix_arr[row] = temp;
1008
+ st++;
1009
+ }
1010
+ }
1011
+ st_NA = st;
1012
+
1013
+ for (size_t row = st; row <= end; row++)
1014
+ {
1015
+ if (x[ix_arr[row]] < 0)
1016
+ {
1017
+ temp = ix_arr[st];
1018
+ ix_arr[st] = ix_arr[row];
1019
+ ix_arr[row] = temp;
1020
+ st++;
1021
+ }
1022
+ }
1023
+ end_NA = st;
1024
+ }
1025
+ }
1026
+
1027
+ /* For categorical columns split by subset, used at prediction time (with similarity) */
1028
+ void divide_subset_split(size_t ix_arr[], int x[], size_t st, size_t end, char split_categ[],
1029
+ int ncat, MissingAction missing_action, NewCategAction new_cat_action,
1030
+ bool move_new_to_left, size_t &st_NA, size_t &end_NA, size_t &split_ix)
1031
+ {
1032
+ size_t temp;
1033
+
1034
+ /* if NAs are not to be bothered with, just need to do a single pass */
1035
+ if (missing_action == Fail && new_cat_action != Weighted)
1036
+ {
1037
+ if (new_cat_action == Smallest && move_new_to_left)
1038
+ {
1039
+ for (size_t row = st; row <= end; row++)
1040
+ {
1041
+ if (split_categ[ x[ix_arr[row]] ] == 1 || x[ix_arr[row]] >= ncat)
1042
+ {
1043
+ temp = ix_arr[st];
1044
+ ix_arr[st] = ix_arr[row];
1045
+ ix_arr[row] = temp;
1046
+ st++;
1047
+ }
1048
+ }
1049
+ }
1050
+
1051
+ else
1052
+ {
1053
+ for (size_t row = st; row <= end; row++)
1054
+ {
1055
+ if (split_categ[ x[ix_arr[row]] ] == 1)
1056
+ {
1057
+ temp = ix_arr[st];
1058
+ ix_arr[st] = ix_arr[row];
1059
+ ix_arr[row] = temp;
1060
+ st++;
1061
+ }
1062
+ }
1063
+ }
1064
+
1065
+ split_ix = st;
1066
+ }
1067
+
1068
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
1069
+ else
1070
+ {
1071
+ for (size_t row = st; row <= end; row++)
1072
+ {
1073
+ if (x[ix_arr[row]] >= 0 && split_categ[ x[ix_arr[row]] ] == 1)
1074
+ {
1075
+ temp = ix_arr[st];
1076
+ ix_arr[st] = ix_arr[row];
1077
+ ix_arr[row] = temp;
1078
+ st++;
1079
+ }
1080
+ }
1081
+ st_NA = st;
1082
+
1083
+ if (new_cat_action == Weighted)
1084
+ {
1085
+ for (size_t row = st; row <= end; row++)
1086
+ {
1087
+ if (x[ix_arr[row]] < 0 || split_categ[ x[ix_arr[row]] ] == (-1))
1088
+ {
1089
+ temp = ix_arr[st];
1090
+ ix_arr[st] = ix_arr[row];
1091
+ ix_arr[row] = temp;
1092
+ st++;
1093
+ }
1094
+ }
1095
+ }
1096
+
1097
+ else
1098
+ {
1099
+ for (size_t row = st; row <= end; row++)
1100
+ {
1101
+ if (x[ix_arr[row]] < 0)
1102
+ {
1103
+ temp = ix_arr[st];
1104
+ ix_arr[st] = ix_arr[row];
1105
+ ix_arr[row] = temp;
1106
+ st++;
1107
+ }
1108
+ }
1109
+ }
1110
+
1111
+ end_NA = st;
1112
+ }
1113
+ }
1114
+
1115
+ /* For categoricals split on a single category */
1116
+ void divide_subset_split(size_t ix_arr[], int x[], size_t st, size_t end, int split_categ,
1117
+ MissingAction missing_action, size_t &st_NA, size_t &end_NA, size_t &split_ix)
1118
+ {
1119
+ size_t temp;
1120
+
1121
+ /* if NAs are not to be bothered with, just need to do a single pass */
1122
+ if (missing_action == Fail)
1123
+ {
1124
+ /* move to the left if it's l.e. than the split point */
1125
+ for (size_t row = st; row <= end; row++)
1126
+ {
1127
+ if (x[ix_arr[row]] == split_categ)
1128
+ {
1129
+ temp = ix_arr[st];
1130
+ ix_arr[st] = ix_arr[row];
1131
+ ix_arr[row] = temp;
1132
+ st++;
1133
+ }
1134
+ }
1135
+ split_ix = st;
1136
+ }
1137
+
1138
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
1139
+ else
1140
+ {
1141
+ for (size_t row = st; row <= end; row++)
1142
+ {
1143
+ if (x[ix_arr[row]] == split_categ)
1144
+ {
1145
+ temp = ix_arr[st];
1146
+ ix_arr[st] = ix_arr[row];
1147
+ ix_arr[row] = temp;
1148
+ st++;
1149
+ }
1150
+ }
1151
+ st_NA = st;
1152
+
1153
+ for (size_t row = st; row <= end; row++)
1154
+ {
1155
+ if (x[ix_arr[row]] < 0)
1156
+ {
1157
+ temp = ix_arr[st];
1158
+ ix_arr[st] = ix_arr[row];
1159
+ ix_arr[row] = temp;
1160
+ st++;
1161
+ }
1162
+ }
1163
+ end_NA = st;
1164
+ }
1165
+ }
1166
+
1167
+ /* For categoricals split on sub-set that turned out to have 2 categories only (prediction-time) */
1168
+ void divide_subset_split(size_t ix_arr[], int x[], size_t st, size_t end,
1169
+ MissingAction missing_action, NewCategAction new_cat_action,
1170
+ bool move_new_to_left, size_t &st_NA, size_t &end_NA, size_t &split_ix)
1171
+ {
1172
+ size_t temp;
1173
+
1174
+ /* if NAs are not to be bothered with, just need to do a single pass */
1175
+ if (missing_action == Fail)
1176
+ {
1177
+ /* move to the left if it's l.e. than the split point */
1178
+ if (new_cat_action == Smallest && move_new_to_left)
1179
+ {
1180
+ for (size_t row = st; row <= end; row++)
1181
+ {
1182
+ if (x[ix_arr[row]] == 0 || x[ix_arr[row]] > 1)
1183
+ {
1184
+ temp = ix_arr[st];
1185
+ ix_arr[st] = ix_arr[row];
1186
+ ix_arr[row] = temp;
1187
+ st++;
1188
+ }
1189
+ }
1190
+ }
1191
+
1192
+ else
1193
+ {
1194
+ for (size_t row = st; row <= end; row++)
1195
+ {
1196
+ if (x[ix_arr[row]] == 0)
1197
+ {
1198
+ temp = ix_arr[st];
1199
+ ix_arr[st] = ix_arr[row];
1200
+ ix_arr[row] = temp;
1201
+ st++;
1202
+ }
1203
+ }
1204
+ }
1205
+ split_ix = st;
1206
+ }
1207
+
1208
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
1209
+ else
1210
+ {
1211
+ if (new_cat_action == Smallest && move_new_to_left)
1212
+ {
1213
+ for (size_t row = st; row <= end; row++)
1214
+ {
1215
+ if (x[ix_arr[row]] == 0 || x[ix_arr[row]] > 1)
1216
+ {
1217
+ temp = ix_arr[st];
1218
+ ix_arr[st] = ix_arr[row];
1219
+ ix_arr[row] = temp;
1220
+ st++;
1221
+ }
1222
+ }
1223
+ st_NA = st;
1224
+
1225
+ for (size_t row = st; row <= end; row++)
1226
+ {
1227
+ if (x[ix_arr[row]] < 0)
1228
+ {
1229
+ temp = ix_arr[st];
1230
+ ix_arr[st] = ix_arr[row];
1231
+ ix_arr[row] = temp;
1232
+ st++;
1233
+ }
1234
+ }
1235
+ end_NA = st;
1236
+ }
1237
+
1238
+ else
1239
+ {
1240
+ for (size_t row = st; row <= end; row++)
1241
+ {
1242
+ if (x[ix_arr[row]] == 0)
1243
+ {
1244
+ temp = ix_arr[st];
1245
+ ix_arr[st] = ix_arr[row];
1246
+ ix_arr[row] = temp;
1247
+ st++;
1248
+ }
1249
+ }
1250
+ st_NA = st;
1251
+
1252
+ for (size_t row = st; row <= end; row++)
1253
+ {
1254
+ if (x[ix_arr[row]] < 0)
1255
+ {
1256
+ temp = ix_arr[st];
1257
+ ix_arr[st] = ix_arr[row];
1258
+ ix_arr[row] = temp;
1259
+ st++;
1260
+ }
1261
+ }
1262
+ end_NA = st;
1263
+ }
1264
+ }
1265
+ }
1266
+
1267
+ /* for regular numeric columns */
1268
+ void get_range(size_t ix_arr[], double x[], size_t st, size_t end,
1269
+ MissingAction missing_action, double &xmin, double &xmax, bool &unsplittable)
1270
+ {
1271
+ xmin = HUGE_VAL;
1272
+ xmax = -HUGE_VAL;
1273
+
1274
+ if (missing_action == Fail)
1275
+ {
1276
+ for (size_t row = st; row <= end; row++)
1277
+ {
1278
+ xmin = (x[ix_arr[row]] < xmin)? x[ix_arr[row]] : xmin;
1279
+ xmax = (x[ix_arr[row]] > xmax)? x[ix_arr[row]] : xmax;
1280
+ }
1281
+ }
1282
+
1283
+
1284
+ else
1285
+ {
1286
+ for (size_t row = st; row <= end; row++)
1287
+ {
1288
+ xmin = fmin(xmin, x[ix_arr[row]]);
1289
+ xmax = fmax(xmax, x[ix_arr[row]]);
1290
+ }
1291
+ }
1292
+
1293
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL);
1294
+ }
1295
+
1296
+ /* for sparse inputs */
1297
+ void get_range(size_t ix_arr[], size_t st, size_t end, size_t col_num,
1298
+ double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
1299
+ MissingAction missing_action, double &xmin, double &xmax, bool &unsplittable)
1300
+ {
1301
+ /* ix_arr must already be sorted beforehand */
1302
+ xmin = HUGE_VAL;
1303
+ xmax = -HUGE_VAL;
1304
+
1305
+ size_t st_col = Xc_indptr[col_num];
1306
+ size_t end_col = Xc_indptr[col_num + 1];
1307
+ size_t nnz_col = end_col - st_col;
1308
+ end_col--;
1309
+ size_t curr_pos = st_col;
1310
+
1311
+ if (!nnz_col ||
1312
+ Xc_ind[st_col] > ix_arr[end] ||
1313
+ ix_arr[st] > Xc_ind[end_col]
1314
+ )
1315
+ {
1316
+ unsplittable = true;
1317
+ return;
1318
+ }
1319
+
1320
+ if (nnz_col < end - st + 1 ||
1321
+ Xc_ind[st_col] > ix_arr[st] ||
1322
+ Xc_ind[end_col] < ix_arr[end]
1323
+ )
1324
+ {
1325
+ xmin = 0;
1326
+ xmax = 0;
1327
+ }
1328
+
1329
+ size_t ind_end_col = Xc_ind[end_col];
1330
+ size_t nmatches = 0;
1331
+
1332
+ if (missing_action == Fail)
1333
+ {
1334
+ for (size_t *row = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
1335
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
1336
+ )
1337
+ {
1338
+ if (Xc_ind[curr_pos] == *row)
1339
+ {
1340
+ nmatches++;
1341
+ xmin = (Xc[curr_pos] < xmin)? Xc[curr_pos] : xmin;
1342
+ xmax = (Xc[curr_pos] > xmax)? Xc[curr_pos] : xmax;
1343
+ if (row == ix_arr + end || curr_pos == end_col) break;
1344
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
1345
+ }
1346
+
1347
+ else
1348
+ {
1349
+ if (Xc_ind[curr_pos] > *row)
1350
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
1351
+ else
1352
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
1353
+ }
1354
+ }
1355
+ }
1356
+
1357
+ else /* can have NAs */
1358
+ {
1359
+ for (size_t *row = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
1360
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
1361
+ )
1362
+ {
1363
+ if (Xc_ind[curr_pos] == *row)
1364
+ {
1365
+ nmatches++;
1366
+ xmin = fmin(xmin, Xc[curr_pos]);
1367
+ xmax = fmax(xmax, Xc[curr_pos]);
1368
+ if (row == ix_arr + end || curr_pos == end_col) break;
1369
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
1370
+ }
1371
+
1372
+ else
1373
+ {
1374
+ if (Xc_ind[curr_pos] > *row)
1375
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
1376
+ else
1377
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
1378
+ }
1379
+ }
1380
+
1381
+ }
1382
+
1383
+ if (nmatches < (end - st + 1))
1384
+ {
1385
+ xmin = fmin(xmin, 0);
1386
+ xmax = fmax(xmax, 0);
1387
+ }
1388
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL);
1389
+
1390
+ }
1391
+
1392
+
1393
+ void get_categs(size_t ix_arr[], int x[], size_t st, size_t end, int ncat,
1394
+ MissingAction missing_action, char categs[], size_t &npresent, bool &unsplittable)
1395
+ {
1396
+ std::fill(categs, categs + ncat, -1);
1397
+ npresent = 0;
1398
+ for (size_t row = st; row <= end; row++)
1399
+ if (x[ix_arr[row]] >= 0)
1400
+ categs[x[ix_arr[row]]] = 1;
1401
+
1402
+ npresent = std::accumulate(categs,
1403
+ categs + ncat,
1404
+ (size_t)0,
1405
+ [](const size_t a, const char b){return a + (b > 0);}
1406
+ );
1407
+
1408
+ unsplittable = npresent < 2;
1409
+ }
1410
+
1411
+ long double calculate_sum_weights(std::vector<size_t> &ix_arr, size_t st, size_t end, size_t curr_depth,
1412
+ std::vector<double> &weights_arr, std::unordered_map<size_t, double> &weights_map)
1413
+ {
1414
+ if (curr_depth > 0 && weights_arr.size())
1415
+ return std::accumulate(ix_arr.begin() + st,
1416
+ ix_arr.begin() + end + 1,
1417
+ (long double)0,
1418
+ [&weights_arr](const long double a, const size_t ix){return a + weights_arr[ix];});
1419
+ else if (curr_depth > 0 && weights_map.size())
1420
+ return std::accumulate(ix_arr.begin() + st,
1421
+ ix_arr.begin() + end + 1,
1422
+ (long double)0,
1423
+ [&weights_map](const long double a, const size_t ix){return a + weights_map[ix];});
1424
+ else
1425
+ return -HUGE_VAL;
1426
+ }
1427
+
1428
+ size_t move_NAs_to_front(size_t ix_arr[], size_t st, size_t end, double x[])
1429
+ {
1430
+ size_t st_non_na = st;
1431
+ size_t temp;
1432
+
1433
+ for (size_t row = st; row <= end; row++)
1434
+ {
1435
+ if (is_na_or_inf(x[ix_arr[row]]))
1436
+ {
1437
+ temp = ix_arr[st_non_na];
1438
+ ix_arr[st_non_na] = ix_arr[row];
1439
+ ix_arr[row] = temp;
1440
+ st_non_na++;
1441
+ }
1442
+ }
1443
+
1444
+ return st_non_na;
1445
+ }
1446
+
1447
+ size_t move_NAs_to_front(size_t ix_arr[], size_t st, size_t end, size_t col_num, double Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[])
1448
+ {
1449
+ size_t st_non_na = st;
1450
+ size_t temp;
1451
+
1452
+ size_t st_col = Xc_indptr[col_num];
1453
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
1454
+ size_t curr_pos = st_col;
1455
+ size_t ind_end_col = Xc_ind[end_col];
1456
+ std::sort(ix_arr + st, ix_arr + end + 1);
1457
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
1458
+
1459
+ for (size_t *row = ptr_st;
1460
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
1461
+ )
1462
+ {
1463
+ if (Xc_ind[curr_pos] == *row)
1464
+ {
1465
+ if (is_na_or_inf(Xc[curr_pos]))
1466
+ {
1467
+ temp = ix_arr[st_non_na];
1468
+ ix_arr[st_non_na] = *row;
1469
+ *row = temp;
1470
+ st_non_na++;
1471
+ }
1472
+
1473
+ if (row == ix_arr + end || curr_pos == end_col) break;
1474
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
1475
+ }
1476
+
1477
+ else
1478
+ {
1479
+ if (Xc_ind[curr_pos] > *row)
1480
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
1481
+ else
1482
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
1483
+ }
1484
+ }
1485
+
1486
+ return st_non_na;
1487
+ }
1488
+
1489
+ size_t move_NAs_to_front(size_t ix_arr[], size_t st, size_t end, int x[])
1490
+ {
1491
+ size_t st_non_na = st;
1492
+ size_t temp;
1493
+
1494
+ for (size_t row = st; row <= end; row++)
1495
+ {
1496
+ if (x[ix_arr[row]] < 0)
1497
+ {
1498
+ temp = ix_arr[st_non_na];
1499
+ ix_arr[st_non_na] = ix_arr[row];
1500
+ ix_arr[row] = temp;
1501
+ st_non_na++;
1502
+ }
1503
+ }
1504
+
1505
+ return st_non_na;
1506
+ }
1507
+
1508
+ size_t center_NAs(size_t *restrict ix_arr, size_t st_left, size_t st, size_t curr_pos)
1509
+ {
1510
+ size_t temp;
1511
+ for (size_t row = st_left; row < st; row++)
1512
+ {
1513
+ temp = ix_arr[--curr_pos];
1514
+ ix_arr[curr_pos] = ix_arr[row];
1515
+ ix_arr[row] = temp;
1516
+ }
1517
+
1518
+ return curr_pos;
1519
+ }
1520
+
1521
+ void todense(size_t ix_arr[], size_t st, size_t end,
1522
+ size_t col_num, double *restrict Xc, sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
1523
+ double *restrict buffer_arr)
1524
+ {
1525
+ std::fill(buffer_arr, buffer_arr + (end - st + 1), (double)0);
1526
+
1527
+ size_t st_col = Xc_indptr[col_num];
1528
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
1529
+ size_t curr_pos = st_col;
1530
+ size_t ind_end_col = Xc_ind[end_col];
1531
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
1532
+
1533
+ for (size_t *row = ptr_st;
1534
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
1535
+ )
1536
+ {
1537
+ if (Xc_ind[curr_pos] == *row)
1538
+ {
1539
+ buffer_arr[row - (ix_arr + st)] = Xc[curr_pos];
1540
+ if (row == ix_arr + end || curr_pos == end_col) break;
1541
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
1542
+ }
1543
+
1544
+ else
1545
+ {
1546
+ if (Xc_ind[curr_pos] > *row)
1547
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
1548
+ else
1549
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
1550
+ }
1551
+ }
1552
+ }
1553
+
1554
+ /* Function to handle interrupt signals */
1555
+ void set_interrup_global_variable(int s)
1556
+ {
1557
+ fprintf(stderr, "Error: procedure was interrupted\n");
1558
+ #pragma omp critical
1559
+ {
1560
+ interrupt_switch = true;
1561
+ }
1562
+ }
1563
+
1564
+ /* Return the #def'd constants from standard header. This is in order to determine if the return
1565
+ value from the 'fit_model' function is a success or failure within Cython, which does not
1566
+ allow importing #def'd macro values. */
1567
+ int return_EXIT_SUCCESS()
1568
+ {
1569
+ return EXIT_SUCCESS;
1570
+ }
1571
+ int return_EXIT_FAILURE()
1572
+ {
1573
+ return EXIT_FAILURE;
1574
+ }