isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,3808 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+ /* ceil(log2(x)) done with bit-wise operations ensures perfect precision (and it's faster too)
66
+ https://stackoverflow.com/questions/2589096/find-most-significant-bit-left-most-that-is-set-in-a-bit-array
67
+ https://stackoverflow.com/questions/11376288/fast-computing-of-log2-for-64-bit-integers */
68
+ #if SIZE_MAX == UINT32_MAX /* 32-bit systems */
69
+ constexpr static const uint32_t MultiplyDeBruijnBitPosition[32] =
70
+ {
71
+ 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
72
+ 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31
73
+ };
74
+ size_t log2ceil( size_t v )
75
+ {
76
+ v--;
77
+ v |= v >> 1; // first round down to one less than a power of 2
78
+ v |= v >> 2;
79
+ v |= v >> 4;
80
+ v |= v >> 8;
81
+ v |= v >> 16;
82
+
83
+ return MultiplyDeBruijnBitPosition[( uint32_t )( v * 0x07C4ACDDU ) >> 27] + 1;
84
+ }
85
+ #elif SIZE_MAX == UINT64_MAX /* 64-bit systems */
86
+ constexpr static const uint64_t tab64[64] = {
87
+ 63, 0, 58, 1, 59, 47, 53, 2,
88
+ 60, 39, 48, 27, 54, 33, 42, 3,
89
+ 61, 51, 37, 40, 49, 18, 28, 20,
90
+ 55, 30, 34, 11, 43, 14, 22, 4,
91
+ 62, 57, 46, 52, 38, 26, 32, 41,
92
+ 50, 36, 17, 19, 29, 10, 13, 21,
93
+ 56, 45, 25, 31, 35, 16, 9, 12,
94
+ 44, 24, 15, 8, 23, 7, 6, 5};
95
+
96
+ size_t log2ceil(size_t value)
97
+ {
98
+ value--;
99
+ value |= value >> 1;
100
+ value |= value >> 2;
101
+ value |= value >> 4;
102
+ value |= value >> 8;
103
+ value |= value >> 16;
104
+ value |= value >> 32;
105
+ return tab64[((uint64_t)((value - (value >> 1))*0x07EDD5E59A4E28C2)) >> 58] + 1;
106
+ }
107
+ #else /* other architectures - might be much slower */
108
+ #if (__cplusplus >= 202002L)
109
+ #include <bit>
110
+ size_t log2ceil(size_t value)
111
+ {
112
+ size_t out = std::numeric_limits<size_t>::digits - std::countl_zero(value);
113
+ out -= (value == ((size_t)1 << (out-1)));
114
+ return out;
115
+ }
116
+ #else
117
+ size_t log2ceil(size_t value)
118
+ {
119
+ size_t value_ = value;
120
+ size_t out = 0;
121
+ while (value >= 1) {
122
+ value = value >> 1;
123
+ out++;
124
+ }
125
+ out -= (value_ == ((size_t)1 << (out-1)));
126
+ return out;
127
+ }
128
+ #endif
129
+ #endif
130
+
131
+ /* adapted from cephes */
132
+ #define EULERS_GAMMA 0.577215664901532860606512
133
+ double digamma(double x)
134
+ {
135
+ double y, z, z2;
136
+
137
+ /* check for positive integer up to 128 */
138
+ if (unlikely((x <= 64) && (x == std::floor(x)))) {
139
+ return harmonic_recursive(1.0, (double)x) - EULERS_GAMMA;
140
+ }
141
+
142
+ if (likely(x < 1.0e17 ))
143
+ {
144
+ z = 1.0/(x * x);
145
+ z2 = square(z);
146
+ y = z * ( 8.33333333333333333333E-2
147
+ -8.33333333333333333333E-3*z
148
+ +3.96825396825396825397E-3*z2
149
+ -4.16666666666666666667E-3*z2*z
150
+ +7.57575757575757575758E-3*square(z2)
151
+ -2.10927960927960927961E-2*square(z2)*z
152
+ +8.33333333333333333333E-2*square(z2)*z2);
153
+ }
154
+ else {
155
+ y = 0.0;
156
+ }
157
+
158
+ y = ((-0.5/x) - y) + std::log(x);
159
+ return y;
160
+ }
161
+
162
+ /* http://fredrik-j.blogspot.com/2009/02/how-not-to-compute-harmonic-numbers.html
163
+ https://en.wikipedia.org/wiki/Harmonic_number
164
+ https://github.com/scikit-learn/scikit-learn/pull/19087 */
165
+ template <class ldouble_safe>
166
+ double harmonic(size_t n)
167
+ {
168
+ ldouble_safe temp = (ldouble_safe)1 / square((ldouble_safe)n);
169
+ return - (ldouble_safe)0.5 * temp * ( (ldouble_safe)1/(ldouble_safe)6 - temp * ((ldouble_safe)1/(ldouble_safe)60 - ((ldouble_safe)1/(ldouble_safe)126)*temp) )
170
+ + (ldouble_safe)0.5 * ((ldouble_safe)1/(ldouble_safe)n)
171
+ + std::log((ldouble_safe)n) + (ldouble_safe)EULERS_GAMMA;
172
+ }
173
+
174
+ /* usage for getting harmonic(n) is like this: harmonic_recursive((double)1, (double)(n + 1)); */
175
+ double harmonic_recursive(double a, double b)
176
+ {
177
+ if (b == a + 1) return 1. / a;
178
+ double m = std::floor((a + b) / 2.);
179
+ return harmonic_recursive(a, m) + harmonic_recursive(m, b);
180
+ }
181
+
182
+ /* https://stats.stackexchange.com/questions/423542/isolation-forest-and-average-expected-depth-formula
183
+ https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom */
184
+ #include "exp_depth_table.hpp"
185
+ template <class ldouble_safe>
186
+ double expected_avg_depth(size_t sample_size)
187
+ {
188
+ if (likely(sample_size <= N_PRECALC_EXP_DEPTH)) {
189
+ return exp_depth_table[sample_size - 1];
190
+ }
191
+ return 2. * (harmonic<ldouble_safe>(sample_size) - 1.);
192
+ }
193
+
194
+ /* Note: H(x) = psi(x+1) + gamma */
195
+ template <class ldouble_safe>
196
+ double expected_avg_depth(ldouble_safe approx_sample_size)
197
+ {
198
+ if (approx_sample_size <= 1)
199
+ return 0;
200
+ else if (approx_sample_size < (ldouble_safe)INT32_MAX)
201
+ return 2. * (digamma(approx_sample_size + 1.) + EULERS_GAMMA - 1.);
202
+ else {
203
+ ldouble_safe temp = (ldouble_safe)1 / square(approx_sample_size);
204
+ return (ldouble_safe)2 * std::log(approx_sample_size) + (ldouble_safe)2*((ldouble_safe)EULERS_GAMMA - (ldouble_safe)1)
205
+ + ((ldouble_safe)1/approx_sample_size)
206
+ - temp * ( (ldouble_safe)1/(ldouble_safe)6 - temp * ((ldouble_safe)1/(ldouble_safe)60 - ((ldouble_safe)1/(ldouble_safe)126)*temp) );
207
+ }
208
+ }
209
+
210
+ /* https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree */
211
+ #define THRESHOLD_EXACT_S 87670 /* difference is <5e-4 */
212
+ double expected_separation_depth(size_t n)
213
+ {
214
+ switch(n)
215
+ {
216
+ case 0: return 0.;
217
+ case 1: return 0.;
218
+ case 2: return 1.;
219
+ case 3: return 1. + (1./3.);
220
+ case 4: return 1. + (1./3.) + (2./9.);
221
+ case 5: return 1.71666666667;
222
+ case 6: return 1.84;
223
+ case 7: return 1.93809524;
224
+ case 8: return 2.01836735;
225
+ case 9: return 2.08551587;
226
+ case 10: return 2.14268078;
227
+ default:
228
+ {
229
+ if (n >= THRESHOLD_EXACT_S)
230
+ return 3;
231
+ else
232
+ return expected_separation_depth_hotstart((double)2.14268078, (size_t)10, n);
233
+ }
234
+ }
235
+ }
236
+
237
+ double expected_separation_depth_hotstart(double curr, size_t n_curr, size_t n_final)
238
+ {
239
+ if (n_final >= 1360)
240
+ {
241
+ if (n_final >= THRESHOLD_EXACT_S)
242
+ return 3;
243
+ else if (n_final >= 40774)
244
+ return 2.999;
245
+ else if (n_final >= 18844)
246
+ return 2.998;
247
+ else if (n_final >= 11956)
248
+ return 2.997;
249
+ else if (n_final >= 8643)
250
+ return 2.996;
251
+ else if (n_final >= 6713)
252
+ return 2.995;
253
+ else if (n_final >= 4229)
254
+ return 2.9925;
255
+ else if (n_final >= 3040)
256
+ return 2.99;
257
+ else if (n_final >= 2724)
258
+ return 2.989;
259
+ else if (n_final >= 1902)
260
+ return 2.985;
261
+ else if (n_final >= 1360)
262
+ return 2.98;
263
+
264
+ /* Note on the chosen precision: when calling it on smaller sample sizes,
265
+ the standard error of the separation depth will be larger, thus it's less
266
+ critical to get it right down to the smallest possible precision, while for
267
+ larger samples the standard error of the separation depth will be smaller */
268
+ }
269
+
270
+ for (size_t i = n_curr + 1; i <= n_final; i++)
271
+ curr += (-curr * (double)i + 3. * (double)i - 4.) / ((double)i * ((double)(i-1)));
272
+ return curr;
273
+ }
274
+
275
+ /* linear interpolation */
276
+ template <class ldouble_safe>
277
+ double expected_separation_depth(ldouble_safe n)
278
+ {
279
+ if (n >= THRESHOLD_EXACT_S)
280
+ return 3;
281
+ double s_l = expected_separation_depth((size_t) std::floor(n));
282
+ ldouble_safe u = std::ceil(n);
283
+ double s_u = s_l + (-s_l * u + 3. * u - 4.) / (u * (u - 1.));
284
+ double diff = n - std::floor(n);
285
+ return s_l + diff * s_u;
286
+ }
287
+
288
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n, double counter[], double exp_remainder)
289
+ {
290
+ size_t i, j;
291
+ size_t ncomb = calc_ncomb(n);
292
+ if (exp_remainder <= 1)
293
+ for (size_t el1 = st; el1 < end; el1++)
294
+ {
295
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
296
+ {
297
+ // counter[i * (n - (i+1)/2) + j - i - 1]++; /* beaware integer division */
298
+ i = ix_arr[el1]; j = ix_arr[el2];
299
+ counter[ix_comb(i, j, n, ncomb)]++;
300
+ }
301
+ }
302
+ else
303
+ for (size_t el1 = st; el1 < end; el1++)
304
+ {
305
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
306
+ {
307
+ i = ix_arr[el1]; j = ix_arr[el2];
308
+ counter[ix_comb(i, j, n, ncomb)] += exp_remainder;
309
+ }
310
+ }
311
+ }
312
+
313
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n,
314
+ double *restrict counter, double *restrict weights, double exp_remainder)
315
+ {
316
+ size_t i, j;
317
+ size_t ncomb = calc_ncomb(n);
318
+ if (exp_remainder <= 1)
319
+ for (size_t el1 = st; el1 < end; el1++)
320
+ {
321
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
322
+ {
323
+ i = ix_arr[el1]; j = ix_arr[el2];
324
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j];
325
+ }
326
+ }
327
+ else
328
+ for (size_t el1 = st; el1 < end; el1++)
329
+ {
330
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
331
+ {
332
+ i = ix_arr[el1]; j = ix_arr[el2];
333
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j] * exp_remainder;
334
+ }
335
+ }
336
+ }
337
+
338
+ /* Note to self: don't try merge this into a template with the one above, as the other one has 'restrict' qualifier */
339
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n,
340
+ double counter[], hashed_map<size_t, double> &weights, double exp_remainder)
341
+ {
342
+ size_t i, j;
343
+ size_t ncomb = calc_ncomb(n);
344
+ if (exp_remainder <= 1)
345
+ for (size_t el1 = st; el1 < end; el1++)
346
+ {
347
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
348
+ {
349
+ i = ix_arr[el1]; j = ix_arr[el2];
350
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j];
351
+ }
352
+ }
353
+ else
354
+ for (size_t el1 = st; el1 < end; el1++)
355
+ {
356
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
357
+ {
358
+ i = ix_arr[el1]; j = ix_arr[el2];
359
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j] * exp_remainder;
360
+ }
361
+ }
362
+ }
363
+
364
+ void increase_comb_counter_in_groups(size_t ix_arr[], size_t st, size_t end, size_t split_ix, size_t n,
365
+ double counter[], double exp_remainder)
366
+ {
367
+ size_t *ptr_split_ix = std::lower_bound(ix_arr + st, ix_arr + end + 1, split_ix);
368
+ size_t n_group = std::distance(ix_arr + st, ptr_split_ix);
369
+ n = n - split_ix;
370
+
371
+ if (exp_remainder <= 1)
372
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
373
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
374
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]++;
375
+ else
376
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
377
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
378
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix] += exp_remainder;
379
+ }
380
+
381
+ void increase_comb_counter_in_groups(size_t ix_arr[], size_t st, size_t end, size_t split_ix, size_t n,
382
+ double *restrict counter, double *restrict weights, double exp_remainder)
383
+ {
384
+ size_t *ptr_split_ix = std::lower_bound(ix_arr + st, ix_arr + end + 1, split_ix);
385
+ size_t n_group = std::distance(ix_arr + st, ptr_split_ix);
386
+ n = n - split_ix;
387
+
388
+ if (exp_remainder <= 1)
389
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
390
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
391
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]
392
+ +=
393
+ weights[ix_arr[ix1]] * weights[ix_arr[ix2]];
394
+ else
395
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
396
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
397
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]
398
+ +=
399
+ weights[ix_arr[ix1]] * weights[ix_arr[ix2]] * exp_remainder;
400
+ }
401
+
402
+ void tmat_to_dense(double *restrict tmat, double *restrict dmat, size_t n, double fill_diag)
403
+ {
404
+ size_t ncomb = calc_ncomb(n);
405
+ for (size_t i = 0; i < (n-1); i++)
406
+ {
407
+ for (size_t j = i + 1; j < n; j++)
408
+ {
409
+ // dmat[i + j * n] = dmat[j + i * n] = tmat[i * (n - (i+1)/2) + j - i - 1];
410
+ dmat[i + j * n] = dmat[j + i * n] = tmat[ix_comb(i, j, n, ncomb)];
411
+ }
412
+ }
413
+ for (size_t i = 0; i < n; i++)
414
+ dmat[i + i * n] = fill_diag;
415
+ }
416
+
417
+ template <class real_t>
418
+ void build_btree_sampler(std::vector<double> &btree_weights, real_t *restrict sample_weights,
419
+ size_t nrows, size_t &restrict log2_n, size_t &restrict btree_offset)
420
+ {
421
+ /* build a perfectly-balanced binary search tree in which each node will
422
+ hold the sum of the weights of its children */
423
+ log2_n = log2ceil(nrows);
424
+ if (btree_weights.empty())
425
+ btree_weights.resize(pow2(log2_n + 1), 0);
426
+ else
427
+ btree_weights.assign(btree_weights.size(), 0);
428
+ btree_offset = pow2(log2_n) - 1;
429
+
430
+ for (size_t ix = 0; ix < nrows; ix++)
431
+ btree_weights[ix + btree_offset] = std::fmax(0., sample_weights[ix]);
432
+ for (size_t ix = btree_weights.size() - 1; ix > 0; ix--)
433
+ btree_weights[ix_parent(ix)] += btree_weights[ix];
434
+
435
+ if (std::isnan(btree_weights[0]) || btree_weights[0] <= 0)
436
+ {
437
+ fprintf(stderr, "Numeric precision error with sample weights, will not use them.\n");
438
+ log2_n = 0;
439
+ btree_weights.clear();
440
+ btree_weights.shrink_to_fit();
441
+ }
442
+ }
443
+
444
+ template <class real_t, class ldouble_safe>
445
+ void sample_random_rows(std::vector<size_t> &restrict ix_arr, size_t nrows, bool with_replacement,
446
+ RNG_engine &rnd_generator, std::vector<size_t> &restrict ix_all,
447
+ real_t *restrict sample_weights, std::vector<double> &restrict btree_weights,
448
+ size_t log2_n, size_t btree_offset, std::vector<bool> &is_repeated)
449
+ {
450
+ size_t ntake = ix_arr.size();
451
+
452
+ /* if with replacement, just generate random uniform numbers */
453
+ if (with_replacement)
454
+ {
455
+ if (sample_weights == NULL)
456
+ {
457
+ std::uniform_int_distribution<size_t> runif(0, nrows - 1);
458
+ for (size_t &ix : ix_arr)
459
+ ix = runif(rnd_generator);
460
+ }
461
+
462
+ else
463
+ {
464
+ std::discrete_distribution<size_t> runif(sample_weights, sample_weights + nrows);
465
+ for (size_t &ix : ix_arr)
466
+ ix = runif(rnd_generator);
467
+ }
468
+ }
469
+
470
+ /* if all the elements are needed, don't bother with any sampling */
471
+ else if (ntake == nrows)
472
+ {
473
+ std::iota(ix_arr.begin(), ix_arr.end(), (size_t)0);
474
+ }
475
+
476
+
477
+ /* if there are sample weights, use binary trees to keep track and update weight
478
+ https://stackoverflow.com/questions/57599509/c-random-non-repeated-integers-with-weights */
479
+ else if (sample_weights != NULL)
480
+ {
481
+ /* TODO: here could instead generate only 1 random number from zero to the full weight,
482
+ and then subtract from it as it goes down every level. Would have less precision
483
+ but should still work fine. */
484
+
485
+ double rnd_subrange, w_left;
486
+ double curr_subrange;
487
+ size_t curr_ix;
488
+ for (size_t &ix : ix_arr)
489
+ {
490
+ /* go down the tree by drawing a random number and
491
+ checking if it falls in the left or right ranges */
492
+ curr_ix = 0;
493
+ curr_subrange = btree_weights[0];
494
+ for (size_t lev = 0; lev < log2_n; lev++)
495
+ {
496
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
497
+ w_left = btree_weights[ix_child(curr_ix)];
498
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
499
+ curr_subrange = btree_weights[curr_ix];
500
+ }
501
+
502
+ /* finally, determine element to choose in this iteration */
503
+ ix = curr_ix - btree_offset;
504
+
505
+ /* now remove the weight of the chosen element */
506
+ btree_weights[curr_ix] = 0;
507
+ for (size_t lev = 0; lev < log2_n; lev++)
508
+ {
509
+ curr_ix = ix_parent(curr_ix);
510
+ btree_weights[curr_ix] = btree_weights[ix_child(curr_ix)]
511
+ + btree_weights[ix_child(curr_ix) + 1];
512
+ }
513
+ }
514
+ }
515
+
516
+ /* if no sample weights and not with replacement (most common case expected),
517
+ then use different algorithms depending on the sampled fraction */
518
+ else
519
+ {
520
+
521
+ /* if sampling a larger fraction, fill an array enumerating the rows, shuffle, and take first N */
522
+ if (ntake >= (nrows / 2))
523
+ {
524
+
525
+ if (ix_all.empty())
526
+ ix_all.resize(nrows);
527
+
528
+ /* in order for random seeds to always be reproducible, don't re-use previous shuffles */
529
+ std::iota(ix_all.begin(), ix_all.end(), (size_t)0);
530
+
531
+ /* If the number of sampled elements is large, do a full shuffle, enjoy simd-instructs when copying over */
532
+ if (ntake >= ((nrows * 3)/4))
533
+ {
534
+ std::shuffle(ix_all.begin(), ix_all.end(), rnd_generator);
535
+ ix_arr.assign(ix_all.begin(), ix_all.begin() + ntake);
536
+ }
537
+
538
+ /* otherwise, do only a partial shuffle (use Yates algorithm) and copy elements along the way */
539
+ else
540
+ {
541
+ size_t chosen;
542
+ for (size_t i = nrows - 1; i >= nrows - ntake; i--)
543
+ {
544
+ chosen = std::uniform_int_distribution<size_t>(0, i)(rnd_generator);
545
+ ix_arr[nrows - i - 1] = ix_all[chosen];
546
+ ix_all[chosen] = ix_all[i];
547
+ }
548
+ }
549
+
550
+ }
551
+
552
+ /* If the sample size is small, use Floyd's random sampling algorithm
553
+ https://stackoverflow.com/questions/2394246/algorithm-to-select-a-single-random-combination-of-values */
554
+ else
555
+ {
556
+
557
+ size_t candidate;
558
+
559
+ /* if the sample size is relatively large, use a temporary boolean vector */
560
+ if (((ldouble_safe)ntake / (ldouble_safe)nrows) > (1. / 50.))
561
+ {
562
+
563
+ if (is_repeated.empty())
564
+ is_repeated.resize(nrows, false);
565
+ else
566
+ is_repeated.assign(is_repeated.size(), false);
567
+
568
+ for (size_t rnd_ix = nrows - ntake; rnd_ix < nrows; rnd_ix++)
569
+ {
570
+ candidate = std::uniform_int_distribution<size_t>(0, rnd_ix)(rnd_generator);
571
+ if (is_repeated[candidate])
572
+ {
573
+ ix_arr[ntake - (nrows - rnd_ix)] = rnd_ix;
574
+ is_repeated[rnd_ix] = true;
575
+ }
576
+
577
+ else
578
+ {
579
+ ix_arr[ntake - (nrows - rnd_ix)] = candidate;
580
+ is_repeated[candidate] = true;
581
+ }
582
+ }
583
+
584
+ }
585
+
586
+ /* if the sample size is very small, use an unordered set */
587
+ else
588
+ {
589
+
590
+ hashed_set<size_t> repeated_set;
591
+ repeated_set.reserve(ntake);
592
+ for (size_t rnd_ix = nrows - ntake; rnd_ix < nrows; rnd_ix++)
593
+ {
594
+ candidate = std::uniform_int_distribution<size_t>(0, rnd_ix)(rnd_generator);
595
+ if (repeated_set.find(candidate) == repeated_set.end()) /* TODO: switch to C++20 'contains' */
596
+ {
597
+ ix_arr[ntake - (nrows - rnd_ix)] = candidate;
598
+ repeated_set.insert(candidate);
599
+ }
600
+
601
+ else
602
+ {
603
+ ix_arr[ntake - (nrows - rnd_ix)] = rnd_ix;
604
+ repeated_set.insert(rnd_ix);
605
+ }
606
+ }
607
+
608
+ }
609
+
610
+ }
611
+
612
+ }
613
+ }
614
+
615
+ /* https://stackoverflow.com/questions/57599509/c-random-non-repeated-integers-with-weights */
616
+ template <class real_t>
617
+ void weighted_shuffle(size_t *restrict outp, size_t n, real_t *restrict weights, double *restrict buffer_arr, RNG_engine &rnd_generator)
618
+ {
619
+ /* determine smallest power of two that is larger than N */
620
+ size_t tree_levels = log2ceil(n);
621
+
622
+ /* initialize vector with place-holders for perfectly-balanced tree */
623
+ std::fill(buffer_arr, buffer_arr + pow2(tree_levels + 1), (double)0);
624
+
625
+ /* compute sums for the tree leaves at each node */
626
+ size_t offset = pow2(tree_levels) - 1;
627
+ for (size_t ix = 0; ix < n; ix++) {
628
+ buffer_arr[ix + offset] = std::fmax(0., weights[ix]);
629
+ }
630
+ for (size_t ix = pow2(tree_levels+1) - 1; ix > 0; ix--) {
631
+ buffer_arr[ix_parent(ix)] += buffer_arr[ix];
632
+ }
633
+
634
+ /* if the weights are invalid, produce an unweighted shuffle */
635
+ if (std::isnan(buffer_arr[0]) || buffer_arr[0] <= 0)
636
+ {
637
+ std::iota(outp, outp + n, (size_t)0);
638
+ std::shuffle(outp, outp + n, rnd_generator);
639
+ return;
640
+ }
641
+
642
+ /* sample according to uniform distribution */
643
+ double rnd_subrange, w_left;
644
+ double curr_subrange;
645
+ size_t curr_ix;
646
+
647
+ for (size_t el = 0; el < n; el++)
648
+ {
649
+ /* go down the tree by drawing a random number and
650
+ checking if it falls in the left or right sub-ranges */
651
+ curr_ix = 0;
652
+ curr_subrange = buffer_arr[0];
653
+ for (size_t lev = 0; lev < tree_levels; lev++)
654
+ {
655
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
656
+ w_left = buffer_arr[ix_child(curr_ix)];
657
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
658
+ curr_subrange = buffer_arr[curr_ix];
659
+ }
660
+
661
+ /* finally, add element from this iteration */
662
+ outp[el] = curr_ix - offset;
663
+
664
+ /* now remove the weight of the chosen element */
665
+ buffer_arr[curr_ix] = 0;
666
+ for (size_t lev = 0; lev < tree_levels; lev++)
667
+ {
668
+ curr_ix = ix_parent(curr_ix);
669
+ buffer_arr[curr_ix] = buffer_arr[ix_child(curr_ix)]
670
+ + buffer_arr[ix_child(curr_ix) + 1];
671
+ }
672
+ }
673
+ }
674
+
675
+ double sample_random_uniform(double xmin, double xmax, RNG_engine &rng) noexcept
676
+ {
677
+ double out;
678
+ std::uniform_real_distribution<double> runif(xmin, xmax);
679
+ for (int attempt = 0; attempt < 100; attempt++)
680
+ {
681
+ out = runif(rng);
682
+ if (likely(out < xmax)) return out;
683
+ }
684
+ return xmin;
685
+ }
686
+
687
+ template <class ldouble_safe>
688
+ template <class other_t>
689
+ ColumnSampler<ldouble_safe>& ColumnSampler<ldouble_safe>::operator=(const ColumnSampler<other_t> &other)
690
+ {
691
+ this->col_indices = other.col_indices;
692
+ this->tree_weights = other.tree_weights;
693
+ this->curr_pos = other.curr_pos;
694
+ this->curr_col = other.curr_col;
695
+ this->last_given = other.last_given;
696
+ this->n_cols = other.n_cols;
697
+ this->tree_levels = other.tree_levels;
698
+ this->offset = other.offset;
699
+ this->n_dropped = other.n_dropped;
700
+ return *this;
701
+ }
702
+
703
+ /* This one samples with replacement. When using weights, the algorithm is the
704
+ same as for the row sampler, but keeping the weights after taking each iteration. */
705
+ /* TODO: this column sampler could use coroutines from C++20 once compilers implement them. */
706
+ template <class ldouble_safe>
707
+ template <class real_t>
708
+ void ColumnSampler<ldouble_safe>::initialize(real_t weights[], size_t n_cols)
709
+ {
710
+ this->n_cols = n_cols;
711
+ this->tree_levels = log2ceil(n_cols);
712
+ if (this->tree_weights.empty())
713
+ this->tree_weights.resize(pow2(this->tree_levels + 1), 0);
714
+ else {
715
+ if (this->tree_weights.size() != pow2(this->tree_levels + 1))
716
+ this->tree_weights.resize(this->tree_levels);
717
+ std::fill(this->tree_weights.begin(), this->tree_weights.end(), 0.);
718
+ }
719
+
720
+ /* compute sums for the tree leaves at each node */
721
+ this->offset = pow2(this->tree_levels) - 1;
722
+ for (size_t ix = 0; ix < this->n_cols; ix++)
723
+ this->tree_weights[ix + this->offset] = std::fmax(0., weights[ix]);
724
+ for (size_t ix = this->tree_weights.size() - 1; ix > 0; ix--)
725
+ this->tree_weights[ix_parent(ix)] += this->tree_weights[ix];
726
+
727
+ /* if the weights are invalid, make it an unweighted sampler */
728
+ if (unlikely(std::isnan(this->tree_weights[0]) || this->tree_weights[0] <= 0))
729
+ {
730
+ this->drop_weights();
731
+ }
732
+
733
+ this->n_dropped = 0;
734
+ }
735
+
736
+ template <class ldouble_safe>
737
+ void ColumnSampler<ldouble_safe>::drop_weights()
738
+ {
739
+ this->tree_weights.clear();
740
+ this->tree_weights.shrink_to_fit();
741
+ this->initialize(n_cols);
742
+ this->n_dropped = 0;
743
+ }
744
+
745
+ template <class ldouble_safe>
746
+ bool ColumnSampler<ldouble_safe>::has_weights()
747
+ {
748
+ return !this->tree_weights.empty();
749
+ }
750
+
751
+ template <class ldouble_safe>
752
+ void ColumnSampler<ldouble_safe>::initialize(size_t n_cols)
753
+ {
754
+ if (!this->has_weights())
755
+ {
756
+ this->n_cols = n_cols;
757
+ this->curr_pos = n_cols;
758
+ this->col_indices.resize(n_cols);
759
+ std::iota(this->col_indices.begin(), this->col_indices.end(), (size_t)0);
760
+ }
761
+ }
762
+
763
+ /* TODO: this one should instead call the same function for sampling rows,
764
+ and should be done at the time of initialization so as to avoid allocating
765
+ and filling the whole array. That way it'd be faster and use less memory. */
766
+ template <class ldouble_safe>
767
+ void ColumnSampler<ldouble_safe>::leave_m_cols(size_t m, RNG_engine &rnd_generator)
768
+ {
769
+ if (m == 0 || m >= this->n_cols)
770
+ return;
771
+
772
+ if (!this->has_weights())
773
+ {
774
+ size_t chosen;
775
+ if (m <= this->n_cols / 4)
776
+ {
777
+ for (this->curr_pos = 0; this->curr_pos < m; this->curr_pos++)
778
+ {
779
+ chosen = std::uniform_int_distribution<size_t>(0, this->n_cols - this->curr_pos - 1)(rnd_generator);
780
+ std::swap(this->col_indices[this->curr_pos + chosen], this->col_indices[this->curr_pos]);
781
+ }
782
+ }
783
+
784
+ else if ((ldouble_safe)m >= (ldouble_safe)(3./4.) * (ldouble_safe)this->n_cols)
785
+ {
786
+ for (this->curr_pos = this->n_cols-1; this->curr_pos > this->n_cols - m; this->curr_pos--)
787
+ {
788
+ chosen = std::uniform_int_distribution<size_t>(0, this->curr_pos)(rnd_generator);
789
+ std::swap(this->col_indices[chosen], this->col_indices[this->curr_pos]);
790
+ }
791
+ this->curr_pos = m;
792
+ }
793
+
794
+ else
795
+ {
796
+ std::shuffle(this->col_indices.begin(), this->col_indices.end(), rnd_generator);
797
+ this->curr_pos = m;
798
+ }
799
+ }
800
+
801
+ else
802
+ {
803
+ std::vector<double> curr_weights = this->tree_weights;
804
+ std::fill(this->tree_weights.begin(), this->tree_weights.end(), 0.);
805
+ double rnd_subrange, w_left;
806
+ double curr_subrange;
807
+ size_t curr_ix;
808
+
809
+ for (size_t col = 0; col < m; col++)
810
+ {
811
+ curr_ix = 0;
812
+ curr_subrange = curr_weights[0];
813
+ if (curr_subrange <= 0)
814
+ {
815
+ if (col == 0)
816
+ {
817
+ this->drop_weights();
818
+ return;
819
+ }
820
+
821
+ else
822
+ {
823
+ m = col;
824
+ goto rebuild_tree;
825
+ }
826
+ }
827
+
828
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
829
+ {
830
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
831
+ w_left = curr_weights[ix_child(curr_ix)];
832
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
833
+ curr_subrange = curr_weights[curr_ix];
834
+ }
835
+
836
+ this->tree_weights[curr_ix] = curr_weights[curr_ix];
837
+
838
+ /* now remove the weight of the chosen element */
839
+ curr_weights[curr_ix] = 0;
840
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
841
+ {
842
+ curr_ix = ix_parent(curr_ix);
843
+ curr_weights[curr_ix] = curr_weights[ix_child(curr_ix)]
844
+ + curr_weights[ix_child(curr_ix) + 1];
845
+ }
846
+ }
847
+
848
+ /* rebuild the tree after getting new weights */
849
+ rebuild_tree:
850
+ for (size_t ix = this->tree_weights.size() - 1; ix > 0; ix--)
851
+ this->tree_weights[ix_parent(ix)] += this->tree_weights[ix];
852
+
853
+ this->n_dropped = this->n_cols - m;
854
+ }
855
+ }
856
+
857
+ template <class ldouble_safe>
858
+ void ColumnSampler<ldouble_safe>::drop_col(size_t col, size_t nobs_left)
859
+ {
860
+ if (!this->has_weights())
861
+ {
862
+ if (this->col_indices[this->last_given] == col)
863
+ {
864
+ std::swap(this->col_indices[this->last_given], this->col_indices[--this->curr_pos]);
865
+ }
866
+
867
+ else if (this->curr_pos > 4*nobs_left)
868
+ {
869
+ return;
870
+ }
871
+
872
+ else
873
+ {
874
+ for (size_t ix = 0; ix < this->curr_pos; ix++)
875
+ {
876
+ if (this->col_indices[ix] == col)
877
+ {
878
+ std::swap(this->col_indices[ix], this->col_indices[--this->curr_pos]);
879
+ break;
880
+ }
881
+ }
882
+ }
883
+
884
+ if (this->curr_col) this->curr_col--;
885
+ }
886
+
887
+ else
888
+ {
889
+ this->n_dropped++;
890
+ size_t curr_ix = col + this->offset;
891
+ this->tree_weights[curr_ix] = 0.;
892
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
893
+ {
894
+ curr_ix = ix_parent(curr_ix);
895
+ this->tree_weights[curr_ix] = this->tree_weights[ix_child(curr_ix)]
896
+ + this->tree_weights[ix_child(curr_ix) + 1];
897
+ }
898
+ }
899
+ }
900
+
901
+ template <class ldouble_safe>
902
+ void ColumnSampler<ldouble_safe>::drop_col(size_t col)
903
+ {
904
+ this->drop_col(col, SIZE_MAX);
905
+ }
906
+
907
+ /* to be used exclusively when initializing the density calculator,
908
+ and only when 'col_indices' is a straight range with no dropped columns */
909
+ template <class ldouble_safe>
910
+ void ColumnSampler<ldouble_safe>::drop_from_tail(size_t col)
911
+ {
912
+ std::swap(this->col_indices[col], this->col_indices[--this->curr_pos]);
913
+ }
914
+
915
+ template <class ldouble_safe>
916
+ void ColumnSampler<ldouble_safe>::prepare_full_pass()
917
+ {
918
+ this->curr_col = 0;
919
+
920
+ if (this->has_weights())
921
+ {
922
+ if (this->col_indices.size() < this->n_cols)
923
+ this->col_indices.resize(this->n_cols);
924
+ this->curr_pos = 0;
925
+ for (size_t col = 0; col < this->n_cols; col++)
926
+ {
927
+ if (this->tree_weights[col + this->offset] > 0)
928
+ this->col_indices[this->curr_pos++] = col;
929
+ }
930
+ }
931
+ }
932
+
933
+ template <class ldouble_safe>
934
+ bool ColumnSampler<ldouble_safe>::sample_col(size_t &col, RNG_engine &rnd_generator)
935
+ {
936
+ if (!this->has_weights())
937
+ {
938
+ switch(this->curr_pos)
939
+ {
940
+ case 0: return false;
941
+ case 1:
942
+ {
943
+ this->last_given = 0;
944
+ col = this->col_indices[0];
945
+ return true;
946
+ }
947
+ default:
948
+ {
949
+ this->last_given = std::uniform_int_distribution<size_t>(0, this->curr_pos-1)(rnd_generator);
950
+ col = this->col_indices[this->last_given];
951
+ return true;
952
+ }
953
+ }
954
+ }
955
+
956
+ else
957
+ {
958
+ /* TODO: here could instead generate only 1 random number from zero to the full weight,
959
+ and then subtract from it as it goes down every level. Would have less precision
960
+ but should still work fine. */
961
+ size_t curr_ix = 0;
962
+ double rnd_subrange, w_left;
963
+ double curr_subrange = this->tree_weights[0];
964
+ if (curr_subrange <= 0)
965
+ return false;
966
+
967
+ for (size_t lev = 0; lev < tree_levels; lev++)
968
+ {
969
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
970
+ w_left = this->tree_weights[ix_child(curr_ix)];
971
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
972
+ curr_subrange = this->tree_weights[curr_ix];
973
+ }
974
+
975
+ col = curr_ix - this->offset;
976
+ return true;
977
+ }
978
+ }
979
+
980
+ template <class ldouble_safe>
981
+ bool ColumnSampler<ldouble_safe>::sample_col(size_t &col)
982
+ {
983
+ if (this->curr_pos == this->curr_col || this->curr_pos == 0)
984
+ return false;
985
+ this->last_given = this->curr_col;
986
+ col = this->col_indices[this->curr_col++];
987
+ return true;
988
+ }
989
+
990
+ template <class ldouble_safe>
991
+ void ColumnSampler<ldouble_safe>::shuffle_remainder(RNG_engine &rnd_generator)
992
+ {
993
+ if (!this->has_weights())
994
+ {
995
+ this->prepare_full_pass();
996
+ std::shuffle(this->col_indices.begin(),
997
+ this->col_indices.begin() + this->curr_pos,
998
+ rnd_generator);
999
+ }
1000
+
1001
+ else
1002
+ {
1003
+ if (this->tree_weights[0] <= 0)
1004
+ return;
1005
+ std::vector<double> curr_weights = this->tree_weights;
1006
+ this->curr_pos = 0;
1007
+ this->curr_col = 0;
1008
+
1009
+ if (this->col_indices.size() < this->n_cols)
1010
+ this->col_indices.resize(this->n_cols);
1011
+
1012
+ double rnd_subrange, w_left;
1013
+ double curr_subrange;
1014
+ size_t curr_ix;
1015
+
1016
+ for (this->curr_pos = 0; this->curr_pos < this->n_cols; this->curr_pos++)
1017
+ {
1018
+ curr_ix = 0;
1019
+ curr_subrange = curr_weights[0];
1020
+ if (curr_subrange <= 0)
1021
+ return;
1022
+
1023
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
1024
+ {
1025
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
1026
+ w_left = curr_weights[ix_child(curr_ix)];
1027
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
1028
+ curr_subrange = curr_weights[curr_ix];
1029
+ }
1030
+
1031
+ /* finally, add element from this iteration */
1032
+ this->col_indices[this->curr_pos] = curr_ix - this->offset;
1033
+
1034
+ /* now remove the weight of the chosen element */
1035
+ curr_weights[curr_ix] = 0;
1036
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
1037
+ {
1038
+ curr_ix = ix_parent(curr_ix);
1039
+ curr_weights[curr_ix] = curr_weights[ix_child(curr_ix)]
1040
+ + curr_weights[ix_child(curr_ix) + 1];
1041
+ }
1042
+ }
1043
+ }
1044
+ }
1045
+
1046
+ template <class ldouble_safe>
1047
+ size_t ColumnSampler<ldouble_safe>::get_remaining_cols()
1048
+ {
1049
+ if (!this->has_weights())
1050
+ return this->curr_pos;
1051
+ else
1052
+ return this->n_cols - this->n_dropped;
1053
+ }
1054
+
1055
+ template <class ldouble_safe>
1056
+ void ColumnSampler<ldouble_safe>::get_array_remaining_cols(std::vector<size_t> &restrict cols)
1057
+ {
1058
+ if (!this->has_weights())
1059
+ {
1060
+ cols.assign(this->col_indices.begin(), this->col_indices.begin() + this->curr_pos);
1061
+ std::sort(cols.begin(), cols.begin() + this->curr_pos);
1062
+ }
1063
+
1064
+ else
1065
+ {
1066
+ size_t n_rem = 0;
1067
+ for (size_t col = 0; col < this->n_cols; col++)
1068
+ {
1069
+ if (this->tree_weights[col + this->offset] > 0)
1070
+ {
1071
+ cols[n_rem++] = col;
1072
+ }
1073
+ }
1074
+ }
1075
+ }
1076
+
1077
+ template<class ldouble_safe, class real_t>
1078
+ bool SingleNodeColumnSampler<ldouble_safe, real_t>::initialize
1079
+ (
1080
+ double *restrict weights,
1081
+ std::vector<size_t> *col_indices,
1082
+ size_t curr_pos,
1083
+ size_t n_sample,
1084
+ bool backup_weights
1085
+ )
1086
+ {
1087
+ if (!curr_pos) return false;
1088
+
1089
+ this->col_indices = col_indices->data();
1090
+ this->curr_pos = curr_pos;
1091
+ this->n_left = this->curr_pos;
1092
+ this->weights_orig = weights;
1093
+
1094
+ if (n_sample > std::max(log2ceil(this->curr_pos), (size_t)3))
1095
+ {
1096
+ this->using_tree = true;
1097
+ this->backup_weights = false;
1098
+
1099
+ if (this->used_weights.empty()) {
1100
+ this->used_weights.reserve(col_indices->size());
1101
+ this->mapped_indices.reserve(col_indices->size());
1102
+ this->tree_weights.reserve(2 * col_indices->size());
1103
+ }
1104
+
1105
+ this->used_weights.resize(this->curr_pos);
1106
+ this->mapped_indices.resize(this->curr_pos);
1107
+
1108
+ for (size_t col = 0; col < this->curr_pos; col++) {
1109
+ this->mapped_indices[col] = this->col_indices[col];
1110
+ this->used_weights[col] = weights[this->col_indices[col]];
1111
+ if (!weights[this->col_indices[col]]) this->n_left--;
1112
+ }
1113
+
1114
+ this->tree_weights.resize(0);
1115
+ build_btree_sampler(this->tree_weights, this->used_weights.data(),
1116
+ this->curr_pos, this->tree_levels, this->offset);
1117
+
1118
+ this->n_inf = 0;
1119
+ if (std::isinf(this->tree_weights[0]))
1120
+ {
1121
+ if (this->mapped_inf_indices.empty())
1122
+ this->mapped_inf_indices.resize(this->curr_pos);
1123
+
1124
+ for (size_t col = 0; col < this->curr_pos; col++)
1125
+ {
1126
+ if (std::isinf(weights[this->col_indices[col]]))
1127
+ {
1128
+ this->mapped_inf_indices[this->n_inf++] = this->col_indices[col];
1129
+ weights[this->col_indices[col]] = 0;
1130
+ }
1131
+
1132
+ else
1133
+ {
1134
+ this->mapped_indices[col - this->n_inf] = this->col_indices[col];
1135
+ this->used_weights[col - this->n_inf] = weights[this->col_indices[col]];
1136
+ }
1137
+ }
1138
+
1139
+ this->tree_weights.resize(0);
1140
+ build_btree_sampler(this->tree_weights, this->used_weights.data(),
1141
+ this->curr_pos - this->n_inf, this->tree_levels, this->offset);
1142
+ }
1143
+
1144
+ this->used_weights.resize(0);
1145
+
1146
+ if (this->tree_weights[0] <= 0 && !this->n_inf)
1147
+ return false;
1148
+ }
1149
+
1150
+ else
1151
+ {
1152
+ this->using_tree = false;
1153
+ this->backup_weights = backup_weights;
1154
+
1155
+ if (this->backup_weights)
1156
+ {
1157
+ if (this->weights_own.empty())
1158
+ this->weights_own.resize(col_indices->size());
1159
+ this->weights_own.assign(weights, weights + this->curr_pos);
1160
+ }
1161
+
1162
+ this->cumw = 0;
1163
+ for (size_t col = 0; col < this->curr_pos; col++) {
1164
+ this->cumw += weights[this->col_indices[col]];
1165
+ if (!weights[this->col_indices[col]]) this->n_left--;
1166
+ }
1167
+
1168
+ if (std::isnan(this->cumw))
1169
+ throw std::runtime_error("NAs encountered. Try using a different value for 'missing_action'.\n");
1170
+
1171
+ /* if it's infinite, will choose among columns with infinite weight first */
1172
+ this->n_inf = 0;
1173
+ if (std::isinf(this->cumw))
1174
+ {
1175
+ if (this->inifinite_weights.empty())
1176
+ this->inifinite_weights.resize(col_indices->size());
1177
+ else
1178
+ this->inifinite_weights.assign(col_indices->size(), false);
1179
+
1180
+ this->cumw = 0;
1181
+ for (size_t col = 0; col < this->curr_pos; col++)
1182
+ {
1183
+ if (std::isinf(weights[this->col_indices[col]])) {
1184
+ this->n_inf++;
1185
+ this->inifinite_weights[this->col_indices[col]] = true;
1186
+ weights[this->col_indices[col]] = 0;
1187
+ }
1188
+
1189
+ else {
1190
+ this->cumw += weights[this->col_indices[col]];
1191
+ }
1192
+ }
1193
+ }
1194
+
1195
+ if (!this->cumw && !this->n_inf) return false;
1196
+ }
1197
+
1198
+ return true;
1199
+ }
1200
+
1201
+ template <class ldouble_safe, class real_t>
1202
+ bool SingleNodeColumnSampler<ldouble_safe, real_t>::sample_col(size_t &col_chosen, RNG_engine &rnd_generator)
1203
+ {
1204
+ if (!this->using_tree)
1205
+ {
1206
+ if (this->backup_weights)
1207
+ this->weights_orig = this->weights_own.data();
1208
+
1209
+ /* if there's infinites, choose uniformly at random from them */
1210
+ if (this->n_inf)
1211
+ {
1212
+ size_t chosen = std::uniform_int_distribution<size_t>(0, this->n_inf-1)(rnd_generator);
1213
+ size_t curr = 0;
1214
+ for (size_t col = 0; col < this->curr_pos; col++)
1215
+ {
1216
+ curr += inifinite_weights[this->col_indices[col]];
1217
+ if (curr == chosen)
1218
+ {
1219
+ col_chosen = this->col_indices[col];
1220
+ this->n_inf--;
1221
+ this->inifinite_weights[col_chosen] = false;
1222
+ this->n_left--;
1223
+ return true;
1224
+ }
1225
+ }
1226
+ assert(0);
1227
+ }
1228
+
1229
+ if (!this->n_left) return false;
1230
+
1231
+ /* due to the way this is calculated, there can be large roundoff errors and even negatives */
1232
+ if (this->cumw <= 0)
1233
+ {
1234
+ this->cumw = 0;
1235
+ for (size_t col = 0; col < this->curr_pos; col++)
1236
+ this->cumw += this->weights_orig[this->col_indices[col]];
1237
+ if (unlikely(this->cumw <= 0))
1238
+ unexpected_error();
1239
+ }
1240
+
1241
+ /* if there are no infinites, choose a column according to weight */
1242
+ ldouble_safe chosen = std::uniform_real_distribution<ldouble_safe>((ldouble_safe)0, this->cumw)(rnd_generator);
1243
+ ldouble_safe cumw_ = 0;
1244
+ for (size_t col = 0; col < this->curr_pos; col++)
1245
+ {
1246
+ cumw_ += this->weights_orig[this->col_indices[col]];
1247
+ if (cumw_ >= chosen)
1248
+ {
1249
+ col_chosen = this->col_indices[col];
1250
+ this->cumw -= this->weights_orig[col_chosen];
1251
+ this->weights_orig[col_chosen] = 0;
1252
+ this->n_left--;
1253
+ return true;
1254
+ }
1255
+ }
1256
+ col_chosen = this->col_indices[this->curr_pos-1];
1257
+ this->cumw -= this->weights_orig[col_chosen];
1258
+ this->weights_orig[col_chosen] = 0;
1259
+ this->n_left--;
1260
+ return true;
1261
+ }
1262
+
1263
+ else
1264
+ {
1265
+ /* if there's infinites, choose uniformly at random from them */
1266
+ if (this->n_inf)
1267
+ {
1268
+ size_t chosen = std::uniform_int_distribution<size_t>(0, this->n_inf-1)(rnd_generator);
1269
+ col_chosen = this->mapped_inf_indices[chosen];
1270
+ std::swap(this->mapped_inf_indices[chosen], this->mapped_inf_indices[--this->n_inf]);
1271
+ this->n_left--;
1272
+ return true;
1273
+ }
1274
+
1275
+ else
1276
+ {
1277
+ /* TODO: should standardize all these tree traversals into one.
1278
+ This one in particular could do with sampling only a single
1279
+ random number as it will not typically require exhausting all
1280
+ options like the usual column sampler. */
1281
+ if (!this->n_left) return false;
1282
+ size_t curr_ix = 0;
1283
+ double rnd_subrange, w_left;
1284
+ double curr_subrange = this->tree_weights[0];
1285
+ if (curr_subrange <= 0)
1286
+ return false;
1287
+
1288
+ for (size_t lev = 0; lev < tree_levels; lev++)
1289
+ {
1290
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
1291
+ w_left = this->tree_weights[ix_child(curr_ix)];
1292
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
1293
+ curr_subrange = this->tree_weights[curr_ix];
1294
+ }
1295
+
1296
+ col_chosen = this->mapped_indices[curr_ix - this->offset];
1297
+
1298
+ this->tree_weights[curr_ix] = 0.;
1299
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
1300
+ {
1301
+ curr_ix = ix_parent(curr_ix);
1302
+ this->tree_weights[curr_ix] = this->tree_weights[ix_child(curr_ix)]
1303
+ + this->tree_weights[ix_child(curr_ix) + 1];
1304
+ }
1305
+
1306
+ this->n_left--;
1307
+ return true;
1308
+ }
1309
+ }
1310
+ }
1311
+
1312
+ template <class ldouble_safe, class real_t>
1313
+ void SingleNodeColumnSampler<ldouble_safe, real_t>::backup(SingleNodeColumnSampler &other, size_t ncols_tot)
1314
+ {
1315
+ other.n_inf = this->n_inf;
1316
+ other.n_left = this->n_left;
1317
+ other.using_tree = this->using_tree;
1318
+
1319
+ if (this->using_tree)
1320
+ {
1321
+ if (other.tree_weights.empty())
1322
+ {
1323
+ other.tree_weights.reserve(ncols_tot);
1324
+ other.mapped_inf_indices.reserve(ncols_tot);
1325
+ }
1326
+ other.tree_weights.assign(this->tree_weights.begin(), this->tree_weights.end());
1327
+ other.mapped_inf_indices.assign(this->mapped_inf_indices.begin(), this->mapped_inf_indices.end());
1328
+ }
1329
+
1330
+ else
1331
+ {
1332
+ other.cumw = this->cumw;
1333
+ if (this->backup_weights)
1334
+ {
1335
+ if (other.weights_own.empty())
1336
+ other.weights_own.reserve(ncols_tot);
1337
+
1338
+ other.weights_own.resize(this->n_left);
1339
+ for (size_t col = 0; col < this->n_left; col++)
1340
+ other.weights_own[col] = this->weights_own[this->col_indices[col]];
1341
+ }
1342
+
1343
+ if (this->inifinite_weights.size())
1344
+ {
1345
+ if (other.inifinite_weights.empty())
1346
+ other.inifinite_weights.reserve(ncols_tot);
1347
+
1348
+ other.inifinite_weights.resize(this->n_left);
1349
+ for (size_t col = 0; col < this->n_left; col++)
1350
+ other.inifinite_weights[col] = this->inifinite_weights[this->col_indices[col]];
1351
+ }
1352
+ }
1353
+ }
1354
+
1355
+ template <class ldouble_safe, class real_t>
1356
+ void SingleNodeColumnSampler<ldouble_safe, real_t>::restore(const SingleNodeColumnSampler &other)
1357
+ {
1358
+ this->n_inf = other.n_inf;
1359
+ this->n_left = other.n_left;
1360
+ this->using_tree = other.using_tree;
1361
+
1362
+ if (this->using_tree)
1363
+ {
1364
+ this->tree_weights.assign(other.tree_weights.begin(), other.tree_weights.end());
1365
+ this->mapped_inf_indices.assign(other.mapped_inf_indices.begin(), other.mapped_inf_indices.end());
1366
+ }
1367
+
1368
+ else
1369
+ {
1370
+ this->cumw = other.cumw;
1371
+ if (this->backup_weights)
1372
+ {
1373
+ for (size_t col = 0; col < this->n_left; col++)
1374
+ this->weights_own[this->col_indices[col]] = other.weights_own[col];
1375
+ }
1376
+
1377
+ if (this->inifinite_weights.size())
1378
+ {
1379
+ for (size_t col = 0; col < this->n_left; col++)
1380
+ this->inifinite_weights[this->col_indices[col]] = other.inifinite_weights[col];
1381
+ }
1382
+ }
1383
+ }
1384
+
1385
+ template <class ldouble_safe, class real_t>
1386
+ void DensityCalculator<ldouble_safe, real_t>::initialize(size_t max_depth, int max_categ, bool reserve_counts, ScoringMetric scoring_metric)
1387
+ {
1388
+ this->multipliers.reserve(max_depth+3);
1389
+ this->multipliers.clear();
1390
+ if (scoring_metric != AdjDensity)
1391
+ this->multipliers.push_back(0);
1392
+ else
1393
+ this->multipliers.push_back(1);
1394
+
1395
+ if (reserve_counts)
1396
+ {
1397
+ this->counts.resize(max_categ);
1398
+ }
1399
+ }
1400
+
1401
+ template <class ldouble_safe, class real_t>
1402
+ template <class InputData>
1403
+ void DensityCalculator<ldouble_safe, real_t>::initialize_bdens(const InputData &input_data,
1404
+ const ModelParams &model_params,
1405
+ std::vector<size_t> &ix_arr,
1406
+ ColumnSampler<ldouble_safe> &col_sampler)
1407
+ {
1408
+ this->fast_bratio = model_params.fast_bratio;
1409
+ if (this->fast_bratio)
1410
+ {
1411
+ this->multipliers.reserve(model_params.max_depth + 3);
1412
+ this->multipliers.push_back(0);
1413
+ }
1414
+
1415
+ if (input_data.range_low != NULL || input_data.ncat_ != NULL)
1416
+ {
1417
+ if (input_data.ncols_numeric)
1418
+ {
1419
+ this->queue_box.reserve(model_params.max_depth+3);
1420
+ this->box_low.assign(input_data.range_low, input_data.range_low + input_data.ncols_numeric);
1421
+ this->box_high.assign(input_data.range_high, input_data.range_high + input_data.ncols_numeric);
1422
+ }
1423
+
1424
+ if (input_data.ncols_categ)
1425
+ {
1426
+ this->queue_ncat.reserve(model_params.max_depth+2);
1427
+ this->ncat.assign(input_data.ncat_, input_data.ncat_ + input_data.ncols_categ);
1428
+ }
1429
+
1430
+ if (!this->fast_bratio)
1431
+ {
1432
+ if (input_data.ncols_numeric)
1433
+ {
1434
+ this->ranges.resize(input_data.ncols_numeric);
1435
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
1436
+ this->ranges[col] = this->box_high[col] - this->box_low[col];
1437
+ }
1438
+
1439
+ if (input_data.ncols_categ)
1440
+ {
1441
+ this->ncat_orig = this->ncat;
1442
+ }
1443
+ }
1444
+
1445
+ return;
1446
+ }
1447
+
1448
+ if (input_data.ncols_numeric)
1449
+ {
1450
+ this->queue_box.reserve(model_params.max_depth+3);
1451
+ this->box_low.resize(input_data.ncols_numeric);
1452
+ this->box_high.resize(input_data.ncols_numeric);
1453
+ if (!this->fast_bratio)
1454
+ this->ranges.resize(input_data.ncols_numeric);
1455
+ }
1456
+ if (input_data.ncols_categ)
1457
+ {
1458
+ this->queue_ncat.reserve(model_params.max_depth+2);
1459
+ }
1460
+ bool unsplittable = false;
1461
+
1462
+ size_t npresent = 0;
1463
+ std::vector<signed char> categ_present;
1464
+ if (input_data.ncols_categ)
1465
+ {
1466
+ categ_present.resize(input_data.max_categ);
1467
+ }
1468
+
1469
+
1470
+ col_sampler.prepare_full_pass();
1471
+ size_t col;
1472
+ while (col_sampler.sample_col(col))
1473
+ {
1474
+ if (col < input_data.ncols_numeric)
1475
+ {
1476
+ if (input_data.Xc_indptr != NULL)
1477
+ {
1478
+ get_range((size_t*)ix_arr.data(), (size_t)0, ix_arr.size()-(size_t)1, col,
1479
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1480
+ model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1481
+ }
1482
+
1483
+ else
1484
+ {
1485
+ get_range((size_t*)ix_arr.data(), input_data.numeric_data + input_data.nrows * col, (size_t)0, ix_arr.size()-(size_t)1,
1486
+ model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1487
+ }
1488
+
1489
+
1490
+ if (unsplittable)
1491
+ {
1492
+ this->box_low[col] = 0;
1493
+ this->box_high[col] = 0;
1494
+ if (!this->fast_bratio)
1495
+ this->ranges[col] = 0;
1496
+ col_sampler.drop_col(col);
1497
+ }
1498
+
1499
+ if (!this->fast_bratio)
1500
+ {
1501
+ this->ranges[col] = (ldouble_safe)this->box_high[col] - (ldouble_safe)this->box_low[col];
1502
+ this->ranges[col] = std::fmax(this->ranges[col], (ldouble_safe)0);
1503
+ }
1504
+ }
1505
+
1506
+ else
1507
+ {
1508
+ get_categs((size_t*)ix_arr.data(),
1509
+ input_data.categ_data + input_data.nrows * (col - input_data.ncols_numeric),
1510
+ (size_t)0, ix_arr.size()-(size_t)1, input_data.ncat[col],
1511
+ model_params.missing_action, categ_present.data(), npresent, unsplittable);
1512
+
1513
+ if (unsplittable)
1514
+ {
1515
+ this->ncat[col - input_data.ncols_numeric] = 1;
1516
+ col_sampler.drop_col(col);
1517
+ }
1518
+
1519
+ else
1520
+ {
1521
+ this->ncat[col - input_data.ncols_numeric] = npresent;
1522
+ }
1523
+ }
1524
+ }
1525
+
1526
+ if (!this->fast_bratio)
1527
+ this->ncat_orig = this->ncat;
1528
+ }
1529
+
1530
+ template<class ldouble_safe, class real_t>
1531
+ template <class InputData>
1532
+ void DensityCalculator<ldouble_safe, real_t>::initialize_bdens_ext(const InputData &input_data,
1533
+ const ModelParams &model_params,
1534
+ std::vector<size_t> &ix_arr,
1535
+ ColumnSampler<ldouble_safe> &col_sampler,
1536
+ bool col_sampler_is_fresh)
1537
+ {
1538
+ this->vals_ext_box.reserve(model_params.max_depth + 3);
1539
+ this->queue_ext_box.reserve(model_params.max_depth + 3);
1540
+ this->vals_ext_box.push_back(0);
1541
+
1542
+ if (input_data.range_low != NULL)
1543
+ {
1544
+ this->box_low.assign(input_data.range_low, input_data.range_low + input_data.ncols_numeric);
1545
+ this->box_high.assign(input_data.range_high, input_data.range_high + input_data.ncols_numeric);
1546
+ return;
1547
+ }
1548
+
1549
+ this->box_low.resize(input_data.ncols_numeric);
1550
+ this->box_high.resize(input_data.ncols_numeric);
1551
+ bool unsplittable = false;
1552
+
1553
+ /* TODO: find out if there's an optimal point for choosing one or the other loop
1554
+ when using 'leave_m_cols' and when using 'prob_pick_col_by_range', then fill in the
1555
+ lines that are commented out. */
1556
+ // if (!input_data.ncols_categ || model_params.ncols_per_tree < input_data.ncols_numeric)
1557
+ if (input_data.ncols_numeric)
1558
+ {
1559
+ col_sampler.prepare_full_pass();
1560
+ size_t col;
1561
+ while (col_sampler.sample_col(col))
1562
+ {
1563
+ if (col >= input_data.ncols_numeric)
1564
+ continue;
1565
+ if (input_data.Xc_indptr != NULL)
1566
+ {
1567
+ get_range((size_t*)ix_arr.data(), (size_t)0, ix_arr.size()-(size_t)1, col,
1568
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1569
+ model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1570
+ }
1571
+
1572
+ else
1573
+ {
1574
+ get_range((size_t*)ix_arr.data(), input_data.numeric_data + input_data.nrows * col, (size_t)0, ix_arr.size()-(size_t)1,
1575
+ model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1576
+ }
1577
+
1578
+ if (unsplittable)
1579
+ {
1580
+ this->box_low[col] = 0;
1581
+ this->box_high[col] = 0;
1582
+ col_sampler.drop_col(col);
1583
+ }
1584
+ }
1585
+ }
1586
+
1587
+ // else if (input_data.ncols_numeric)
1588
+ // {
1589
+ // size_t n_unsplittable = 0;
1590
+ // std::vector<size_t> unsplittable_cols;
1591
+ // if (col_sampler_is_fresh && !col_sampler.has_weights())
1592
+ // unsplittable_cols.reserve(input_data.ncols_numeric);
1593
+
1594
+ // /* TODO: this will do unnecessary calculations when using 'leave_m_cols' */
1595
+ // for (size_t col = 0; col < input_data.ncols_numeric; col++)
1596
+ // {
1597
+ // if (input_data.Xc_indptr != NULL)
1598
+ // {
1599
+ // get_range((size_t*)ix_arr.data(), (size_t)0, ix_arr.size()-(size_t)1, col,
1600
+ // input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1601
+ // model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1602
+ // }
1603
+
1604
+ // else
1605
+ // {
1606
+ // get_range((size_t*)ix_arr.data(), input_data.numeric_data + input_data.nrows * col, (size_t)0, ix_arr.size()-(size_t)1,
1607
+ // model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1608
+ // }
1609
+
1610
+ // if (unsplittable)
1611
+ // {
1612
+ // this->box_low[col] = 0;
1613
+ // this->box_high[col] = 0;
1614
+ // n_unsplittable++;
1615
+ // if (col_sampler.has_weights())
1616
+ // col_sampler.drop_col(col);
1617
+ // else if (col_sampler_is_fresh)
1618
+ // unsplittable_cols.push_back(col);
1619
+ // }
1620
+ // }
1621
+
1622
+ // if (n_unsplittable && col_sampler_is_fresh && !col_sampler.has_weights())
1623
+ // {
1624
+ // #if (__cplusplus >= 202002L)
1625
+ // for (auto col : unsplittable_cols | std::views::reverse)
1626
+ // col_sampler.drop_from_tail(col);
1627
+ // #else
1628
+ // for (size_t inv_col = 0; inv_col < unsplittable_cols.size(); inv_col++)
1629
+ // {
1630
+ // size_t col = unsplittable_cols.size() - inv_col - 1;
1631
+ // col_sampler.drop_from_tail(unsplittable_cols[col]);
1632
+ // }
1633
+ // #endif
1634
+ // }
1635
+
1636
+ // else if (n_unsplittable > model_params.sample_size / 16 && !col_sampler_is_fresh && !col_sampler.has_weights())
1637
+ // {
1638
+ // /* TODO */
1639
+ // }
1640
+ // }
1641
+ }
1642
+
1643
+ template<class ldouble_safe, class real_t>
1644
+ void DensityCalculator<ldouble_safe, real_t>::push_density(double xmin, double xmax, double split_point)
1645
+ {
1646
+ if (std::isinf(xmax) || std::isinf(xmin) || std::isnan(xmin) || std::isnan(xmax) || std::isnan(split_point))
1647
+ {
1648
+ this->multipliers.push_back(0);
1649
+ return;
1650
+ }
1651
+
1652
+ double range = std::fmax(xmax - xmin, std::numeric_limits<double>::min());
1653
+ double dleft = std::fmax(split_point - xmin, std::numeric_limits<double>::min());
1654
+ double dright = std::fmax(xmax - split_point, std::numeric_limits<double>::min());
1655
+ double mult_left = std::log(dleft / range);
1656
+ double mult_right = std::log(dright / range);
1657
+ while (std::isinf(mult_left))
1658
+ {
1659
+ dleft = std::nextafter(dleft, (mult_left < 0)? HUGE_VAL : (-HUGE_VAL));
1660
+ mult_left = std::log(dleft / range);
1661
+ }
1662
+ while (std::isinf(mult_right))
1663
+ {
1664
+ dright = std::nextafter(dright, (mult_right < 0)? HUGE_VAL : (-HUGE_VAL));
1665
+ mult_right = std::log(dright / range);
1666
+ }
1667
+
1668
+ mult_left = std::isnan(mult_left)? 0 : mult_left;
1669
+ mult_right = std::isnan(mult_right)? 0 : mult_right;
1670
+
1671
+ ldouble_safe curr = this->multipliers.back();
1672
+ this->multipliers.push_back(curr + mult_right);
1673
+ this->multipliers.push_back(curr + mult_left);
1674
+ }
1675
+
1676
+ template<class ldouble_safe, class real_t>
1677
+ void DensityCalculator<ldouble_safe, real_t>::push_density(int n_left, int n_present)
1678
+ {
1679
+ this->push_density(0., (double)n_present, (double)n_left);
1680
+ }
1681
+
1682
+ /* For single category splits */
1683
+ template<class ldouble_safe, class real_t>
1684
+ void DensityCalculator<ldouble_safe, real_t>::push_density(size_t counts[], int ncat)
1685
+ {
1686
+ /* this one assumes 'categ_present' has entries 0/1 for missing/present */
1687
+ int n_present = 0;
1688
+ for (int cat = 0; cat < ncat; cat++)
1689
+ n_present += counts[cat] > 0;
1690
+ this->push_density(0., (double)n_present, 1.);
1691
+ }
1692
+
1693
+ /* For single category splits */
1694
+ template<class ldouble_safe, class real_t>
1695
+ void DensityCalculator<ldouble_safe, real_t>::push_density(int n_present)
1696
+ {
1697
+ this->push_density(0., (double)n_present, 1.);
1698
+ }
1699
+
1700
+ /* For binary categorical splits */
1701
+ template<class ldouble_safe, class real_t>
1702
+ void DensityCalculator<ldouble_safe, real_t>::push_density()
1703
+ {
1704
+ this->multipliers.push_back(0);
1705
+ this->multipliers.push_back(0);
1706
+ }
1707
+
1708
+ template<class ldouble_safe, class real_t>
1709
+ void DensityCalculator<ldouble_safe, real_t>::push_adj(double xmin, double xmax, double split_point, double pct_tree_left, ScoringMetric scoring_metric)
1710
+ {
1711
+ double range = std::fmax(xmax - xmin, std::numeric_limits<double>::min());
1712
+ double dleft = std::fmax(split_point - xmin, std::numeric_limits<double>::min());
1713
+ double dright = std::fmax(xmax - split_point, std::numeric_limits<double>::min());
1714
+ double chunk_left = dleft / range;
1715
+ double chunk_right = dright / range;
1716
+ if (std::isinf(xmax) || std::isinf(xmin) || std::isnan(xmin) || std::isnan(xmax) || std::isnan(split_point))
1717
+ {
1718
+ chunk_left = pct_tree_left;
1719
+ chunk_right = 1. - pct_tree_left;
1720
+ goto add_chunks;
1721
+ }
1722
+
1723
+ if (std::isnan(chunk_left) || std::isnan(chunk_right))
1724
+ {
1725
+ chunk_left = 0.5;
1726
+ chunk_right = 0.5;
1727
+ }
1728
+
1729
+ chunk_left = pct_tree_left / chunk_left;
1730
+ chunk_right = (1. - pct_tree_left) / chunk_right;
1731
+
1732
+ add_chunks:
1733
+ chunk_left = 2. / (1. + .5/chunk_left);
1734
+ chunk_right = 2. / (1. + .5/chunk_right);
1735
+ // chunk_left = 2. / (1. + 1./chunk_left);
1736
+ // chunk_right = 2. / (1. + 1./chunk_right);
1737
+ // chunk_left = 2. - std::exp2(1. - chunk_left);
1738
+ // chunk_right = 2. - std::exp2(1. - chunk_right);
1739
+
1740
+ ldouble_safe curr = this->multipliers.back();
1741
+ if (scoring_metric == AdjDepth)
1742
+ {
1743
+ this->multipliers.push_back(curr + chunk_right);
1744
+ this->multipliers.push_back(curr + chunk_left);
1745
+ }
1746
+
1747
+ else
1748
+ {
1749
+ this->multipliers.push_back(std::fmax(curr * chunk_right, (ldouble_safe)std::numeric_limits<double>::epsilon()));
1750
+ this->multipliers.push_back(std::fmax(curr * chunk_left, (ldouble_safe)std::numeric_limits<double>::epsilon()));
1751
+ }
1752
+ }
1753
+
1754
+ template<class ldouble_safe, class real_t>
1755
+ void DensityCalculator<ldouble_safe, real_t>::push_adj(signed char *restrict categ_present, size_t *restrict counts, int ncat, ScoringMetric scoring_metric)
1756
+ {
1757
+ /* this one assumes 'categ_present' has entries -1/0/1 for missing/right/left */
1758
+ int cnt_cat_left = 0;
1759
+ int cnt_cat = 0;
1760
+ size_t cnt = 0;
1761
+ size_t cnt_left = 0;
1762
+ for (int cat = 0; cat < ncat; cat++)
1763
+ {
1764
+ if (counts[cat] > 0)
1765
+ {
1766
+ cnt += counts[cat];
1767
+ cnt_cat_left += categ_present[cat];
1768
+ cnt_left += categ_present[cat]? counts[cat] : 0;
1769
+ cnt_cat++;
1770
+ }
1771
+ }
1772
+
1773
+ double pct_tree_left = (ldouble_safe)cnt_left / (ldouble_safe)cnt;
1774
+ this->push_adj(0., (double)cnt_cat, (double)cnt_cat_left, pct_tree_left, scoring_metric);
1775
+ }
1776
+
1777
+ /* For single category splits */
1778
+ template<class ldouble_safe, class real_t>
1779
+ void DensityCalculator<ldouble_safe, real_t>::push_adj(size_t *restrict counts, int ncat, int chosen_cat, ScoringMetric scoring_metric)
1780
+ {
1781
+ /* this one assumes 'categ_present' has entries 0/1 for missing/present */
1782
+ int cnt_cat = 0;
1783
+ size_t cnt = 0;
1784
+ for (int cat = 0; cat < ncat; cat++)
1785
+ {
1786
+ cnt += counts[cat];
1787
+ cnt_cat += counts[cat] > 0;
1788
+ }
1789
+
1790
+ double pct_tree_left = (ldouble_safe)counts[chosen_cat] / (ldouble_safe)cnt;
1791
+ this->push_adj(0., (double)cnt_cat, 1., pct_tree_left, scoring_metric);
1792
+ }
1793
+
1794
+ /* For binary categorical splits */
1795
+ template<class ldouble_safe, class real_t>
1796
+ void DensityCalculator<ldouble_safe, real_t>::push_adj(double pct_tree_left, ScoringMetric scoring_metric)
1797
+ {
1798
+ this->push_adj(0., 1., 0.5, pct_tree_left, scoring_metric);
1799
+ }
1800
+
1801
+ template<class ldouble_safe, class real_t>
1802
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens(double split_point, size_t col)
1803
+ {
1804
+ if (this->fast_bratio)
1805
+ this->push_bdens_fast_route(split_point, col);
1806
+ else
1807
+ this->push_bdens_internal(split_point, col);
1808
+ }
1809
+
1810
+ template<class ldouble_safe, class real_t>
1811
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_internal(double split_point, size_t col)
1812
+ {
1813
+ this->queue_box.push_back(this->box_high[col]);
1814
+ this->box_high[col] = split_point;
1815
+ }
1816
+
1817
+ template<class ldouble_safe, class real_t>
1818
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_fast_route(double split_point, size_t col)
1819
+ {
1820
+ ldouble_safe curr_range = (ldouble_safe)this->box_high[col] - (ldouble_safe)this->box_low[col];
1821
+ ldouble_safe fraction_left = ((ldouble_safe)split_point - (ldouble_safe)this->box_low[col]) / curr_range;
1822
+ ldouble_safe fraction_right = ((ldouble_safe)this->box_high[col] - (ldouble_safe)split_point) / curr_range;
1823
+ fraction_left = std::fmax(fraction_left, (ldouble_safe)std::numeric_limits<double>::min());
1824
+ fraction_left = std::fmin(fraction_left, (ldouble_safe)(1. - std::numeric_limits<double>::epsilon()));
1825
+ fraction_left = std::log(fraction_left);
1826
+ fraction_left += this->multipliers.back();
1827
+ fraction_right = std::fmax(fraction_right, (ldouble_safe)std::numeric_limits<double>::min());
1828
+ fraction_right = std::fmin(fraction_right, (ldouble_safe)(1. - std::numeric_limits<double>::epsilon()));
1829
+ fraction_right = std::log(fraction_right);
1830
+ fraction_right += this->multipliers.back();
1831
+ this->multipliers.push_back(fraction_right);
1832
+ this->multipliers.push_back(fraction_left);
1833
+
1834
+ this->push_bdens_internal(split_point, col);
1835
+ }
1836
+
1837
+ template<class ldouble_safe, class real_t>
1838
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens(int ncat_branch_left, size_t col)
1839
+ {
1840
+ if (this->fast_bratio)
1841
+ this->push_bdens_fast_route(ncat_branch_left, col);
1842
+ else
1843
+ this->push_bdens_internal(ncat_branch_left, col);
1844
+ }
1845
+
1846
+ template<class ldouble_safe, class real_t>
1847
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_internal(int ncat_branch_left, size_t col)
1848
+ {
1849
+ this->queue_ncat.push_back(this->ncat[col]);
1850
+ this->ncat[col] = ncat_branch_left;
1851
+ }
1852
+
1853
+ template<class ldouble_safe, class real_t>
1854
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_fast_route(int ncat_branch_left, size_t col)
1855
+ {
1856
+ double fraction_left = std::log((double)ncat_branch_left / this->ncat[col]);
1857
+ double fraction_right = std::log((double)(this->ncat[col] - ncat_branch_left) / this->ncat[col]);
1858
+ ldouble_safe curr = this->multipliers.back();
1859
+ this->multipliers.push_back(curr + fraction_right);
1860
+ this->multipliers.push_back(curr + fraction_left);
1861
+
1862
+ this->push_bdens_internal(ncat_branch_left, col);
1863
+ }
1864
+
1865
+ template<class ldouble_safe, class real_t>
1866
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens(const std::vector<signed char> &cat_split, size_t col)
1867
+ {
1868
+ if (this->fast_bratio)
1869
+ this->push_bdens_fast_route(cat_split, col);
1870
+ else
1871
+ this->push_bdens_internal(cat_split, col);
1872
+ }
1873
+
1874
+ template<class ldouble_safe, class real_t>
1875
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_internal(const std::vector<signed char> &cat_split, size_t col)
1876
+ {
1877
+ int ncat_branch_left = 0;
1878
+ for (auto el : cat_split)
1879
+ ncat_branch_left += el == 1;
1880
+ this->push_bdens_internal(ncat_branch_left, col);
1881
+ }
1882
+
1883
+ template<class ldouble_safe, class real_t>
1884
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_fast_route(const std::vector<signed char> &cat_split, size_t col)
1885
+ {
1886
+ int ncat_branch_left = 0;
1887
+ for (auto el : cat_split)
1888
+ ncat_branch_left += el == 1;
1889
+ this->push_bdens_fast_route(ncat_branch_left, col);
1890
+ }
1891
+
1892
+ template<class ldouble_safe, class real_t>
1893
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_ext(const IsoHPlane &hplane, const ModelParams &model_params)
1894
+ {
1895
+ double x1, x2;
1896
+ double xlow = 0, xhigh = 0;
1897
+ size_t col;
1898
+ size_t col_num = 0;
1899
+ size_t col_cat = 0;
1900
+
1901
+ for (size_t col_outer = 0; col_outer < hplane.col_num.size(); col_outer++)
1902
+ {
1903
+ switch (hplane.col_type[col_outer])
1904
+ {
1905
+ case Numeric:
1906
+ {
1907
+ col = hplane.col_num[col_outer];
1908
+ x1 = hplane.coef[col_num] * (this->box_low[col] - hplane.mean[col_num]);
1909
+ x2 = hplane.coef[col_num] * (this->box_high[col] - hplane.mean[col_num]);
1910
+ xlow += std::fmin(x1, x2);
1911
+ xhigh += std::fmax(x1, x2);
1912
+ break;
1913
+ }
1914
+
1915
+ case Categorical:
1916
+ {
1917
+ switch (model_params.cat_split_type)
1918
+ {
1919
+ case SingleCateg:
1920
+ {
1921
+ xlow += std::fmin(hplane.fill_new[col_cat], 0.);
1922
+ xhigh += std::fmax(hplane.fill_new[col_cat], 0.);
1923
+ break;
1924
+ }
1925
+
1926
+ case SubSet:
1927
+ {
1928
+ xlow += *std::min_element(hplane.cat_coef[col_cat].begin(), hplane.cat_coef[col_cat].end());
1929
+ xhigh += *std::max_element(hplane.cat_coef[col_cat].begin(), hplane.cat_coef[col_cat].end());
1930
+ break;
1931
+ }
1932
+ }
1933
+ break;
1934
+ }
1935
+
1936
+ default:
1937
+ {
1938
+ assert(0);
1939
+ }
1940
+ }
1941
+ }
1942
+
1943
+ double chunk_left;
1944
+ double chunk_right;
1945
+ double xdiff = xhigh - xlow;
1946
+
1947
+ if (model_params.scoring_metric != BoxedDensity)
1948
+ {
1949
+ chunk_left = (hplane.split_point - xlow) / xdiff;
1950
+ chunk_right = (xhigh - hplane.split_point) / xdiff;
1951
+ chunk_left = std::fmin(chunk_left, std::numeric_limits<double>::min());
1952
+ chunk_left = std::fmax(chunk_left, 1.-std::numeric_limits<double>::epsilon());
1953
+ chunk_right = std::fmin(chunk_right, std::numeric_limits<double>::min());
1954
+ chunk_right = std::fmax(chunk_right, 1.-std::numeric_limits<double>::epsilon());
1955
+ }
1956
+
1957
+ else
1958
+ {
1959
+ chunk_left = xdiff / (hplane.split_point - xlow);
1960
+ chunk_right = xdiff / (xhigh - hplane.split_point);
1961
+ chunk_left = std::fmin(chunk_left, 1.);
1962
+ chunk_right = std::fmin(chunk_right, 1.);
1963
+ }
1964
+
1965
+ this->queue_ext_box.push_back(std::log(chunk_right) + this->vals_ext_box.back());
1966
+ this->vals_ext_box.push_back(std::log(chunk_left) + this->vals_ext_box.back());
1967
+ }
1968
+
1969
+ template<class ldouble_safe, class real_t>
1970
+ void DensityCalculator<ldouble_safe, real_t>::pop()
1971
+ {
1972
+ this->multipliers.pop_back();
1973
+ }
1974
+
1975
+ template<class ldouble_safe, class real_t>
1976
+ void DensityCalculator<ldouble_safe, real_t>::pop_right()
1977
+ {
1978
+ this->multipliers.pop_back();
1979
+ }
1980
+
1981
+ template<class ldouble_safe, class real_t>
1982
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens(size_t col)
1983
+ {
1984
+ if (this->fast_bratio)
1985
+ this->pop_bdens_fast_route(col);
1986
+ else
1987
+ this->pop_bdens_internal(col);
1988
+ }
1989
+
1990
+ template<class ldouble_safe, class real_t>
1991
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_internal(size_t col)
1992
+ {
1993
+ double old_high = this->queue_box.back();
1994
+ this->queue_box.pop_back();
1995
+ this->queue_box.push_back(this->box_low[col]);
1996
+ this->box_low[col] = this->box_high[col];
1997
+ this->box_high[col] = old_high;
1998
+ }
1999
+
2000
+ template<class ldouble_safe, class real_t>
2001
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_fast_route(size_t col)
2002
+ {
2003
+ this->multipliers.pop_back();
2004
+ this->pop_bdens_internal(col);
2005
+ }
2006
+
2007
+ template<class ldouble_safe, class real_t>
2008
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_right(size_t col)
2009
+ {
2010
+ if (this->fast_bratio)
2011
+ this->pop_bdens_right_fast_route(col);
2012
+ else
2013
+ this->pop_bdens_right_internal(col);
2014
+ }
2015
+
2016
+ template<class ldouble_safe, class real_t>
2017
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_right_internal(size_t col)
2018
+ {
2019
+ double old_low = this->queue_box.back();
2020
+ this->queue_box.pop_back();
2021
+ this->box_low[col] = old_low;
2022
+ }
2023
+
2024
+ template<class ldouble_safe, class real_t>
2025
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_right_fast_route(size_t col)
2026
+ {
2027
+ this->multipliers.pop_back();
2028
+ this->pop_bdens_right_internal(col);
2029
+ }
2030
+
2031
+ template<class ldouble_safe, class real_t>
2032
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat(size_t col)
2033
+ {
2034
+ if (this->fast_bratio)
2035
+ this->pop_bdens_cat_fast_route(col);
2036
+ else
2037
+ this->pop_bdens_cat_internal(col);
2038
+ }
2039
+
2040
+ template<class ldouble_safe, class real_t>
2041
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_internal(size_t col)
2042
+ {
2043
+ int old_ncat = this->queue_ncat.back();
2044
+ this->ncat[col] = old_ncat - this->ncat[col];
2045
+ }
2046
+
2047
+ template<class ldouble_safe, class real_t>
2048
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_fast_route(size_t col)
2049
+ {
2050
+ this->multipliers.pop_back();
2051
+ this->pop_bdens_cat_internal(col);
2052
+ }
2053
+
2054
+ template<class ldouble_safe, class real_t>
2055
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_right(size_t col)
2056
+ {
2057
+ if (this->fast_bratio)
2058
+ this->pop_bdens_cat_right_fast_route(col);
2059
+ else
2060
+ this->pop_bdens_cat_right_internal(col);
2061
+ }
2062
+
2063
+ template<class ldouble_safe, class real_t>
2064
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_right_internal(size_t col)
2065
+ {
2066
+ int old_ncat = this->queue_ncat.back();
2067
+ this->queue_ncat.pop_back();
2068
+ this->ncat[col] = old_ncat;
2069
+ }
2070
+
2071
+ template<class ldouble_safe, class real_t>
2072
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_right_fast_route(size_t col)
2073
+ {
2074
+ this->multipliers.pop_back();
2075
+ this->pop_bdens_cat_right_internal(col);
2076
+ }
2077
+
2078
+ template<class ldouble_safe, class real_t>
2079
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_ext()
2080
+ {
2081
+ this->vals_ext_box.pop_back();
2082
+ this->vals_ext_box.push_back(this->queue_ext_box.back());
2083
+ this->queue_ext_box.pop_back();
2084
+ }
2085
+
2086
+ template<class ldouble_safe, class real_t>
2087
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_ext_right()
2088
+ {
2089
+ this->vals_ext_box.pop_back();
2090
+ }
2091
+
2092
+ /* this outputs the logarithm of the density */
2093
+ template<class ldouble_safe, class real_t>
2094
+ double DensityCalculator<ldouble_safe, real_t>::calc_density(ldouble_safe remainder, size_t sample_size)
2095
+ {
2096
+ return std::log(remainder) - std::log((ldouble_safe)sample_size) - this->multipliers.back();
2097
+ }
2098
+
2099
+ template<class ldouble_safe, class real_t>
2100
+ ldouble_safe DensityCalculator<ldouble_safe, real_t>::calc_adj_depth()
2101
+ {
2102
+ ldouble_safe out = this->multipliers.back();
2103
+ return std::fmax(out, (ldouble_safe)std::numeric_limits<double>::min());
2104
+ }
2105
+
2106
+ template<class ldouble_safe, class real_t>
2107
+ double DensityCalculator<ldouble_safe, real_t>::calc_adj_density()
2108
+ {
2109
+ return this->multipliers.back();
2110
+ }
2111
+
2112
+ /* this outputs the logarithm of the density */
2113
+ template<class ldouble_safe, class real_t>
2114
+ ldouble_safe DensityCalculator<ldouble_safe, real_t>::calc_bratio_inv_log()
2115
+ {
2116
+ if (!this->multipliers.empty())
2117
+ return -this->multipliers.back();
2118
+
2119
+ ldouble_safe sum_log_switdh = 0;
2120
+ ldouble_safe ratio_col;
2121
+ for (size_t col = 0; col < this->ranges.size(); col++)
2122
+ {
2123
+ if (!this->ranges[col]) continue;
2124
+ ratio_col = this->ranges[col] / ((ldouble_safe)this->box_high[col] - (ldouble_safe)this->box_low[col]);
2125
+ ratio_col = std::fmax(ratio_col, (ldouble_safe)1);
2126
+ sum_log_switdh += std::log(ratio_col);
2127
+ }
2128
+
2129
+ for (size_t col = 0; col < this->ncat.size(); col++)
2130
+ {
2131
+ if (this->ncat_orig[col] <= 1) continue;
2132
+ sum_log_switdh += std::log((double)this->ncat_orig[col] / (double)this->ncat[col]);
2133
+ }
2134
+
2135
+ return sum_log_switdh;
2136
+ }
2137
+
2138
+ template<class ldouble_safe, class real_t>
2139
+ ldouble_safe DensityCalculator<ldouble_safe, real_t>::calc_bratio_log()
2140
+ {
2141
+ if (!this->multipliers.empty())
2142
+ return this->multipliers.back();
2143
+
2144
+ ldouble_safe sum_log_switdh = 0;
2145
+ ldouble_safe ratio_col;
2146
+ for (size_t col = 0; col < this->ranges.size(); col++)
2147
+ {
2148
+ if (!this->ranges[col]) continue;
2149
+ ratio_col = ((ldouble_safe)this->box_high[col] - (ldouble_safe)this->box_low[col]) / this->ranges[col];
2150
+ ratio_col = std::fmax(ratio_col, (ldouble_safe)std::numeric_limits<double>::min());
2151
+ ratio_col = std::fmin(ratio_col, (ldouble_safe)(1. - std::numeric_limits<double>::epsilon()));
2152
+ sum_log_switdh += std::log(ratio_col);
2153
+ }
2154
+
2155
+ for (size_t col = 0; col < this->ncat.size(); col++)
2156
+ {
2157
+ if (this->ncat_orig[col] <= 1) continue;
2158
+ sum_log_switdh += std::log((double)this->ncat[col] / (double)this->ncat_orig[col]);
2159
+ }
2160
+
2161
+ return sum_log_switdh;
2162
+ }
2163
+
2164
+ /* this does NOT output the logarithm of the density */
2165
+ template<class ldouble_safe, class real_t>
2166
+ double DensityCalculator<ldouble_safe, real_t>::calc_bratio()
2167
+ {
2168
+ return std::exp(this->calc_bratio_log());
2169
+ }
2170
+
2171
+ const double MIN_DENS = std::log(std::numeric_limits<double>::min());
2172
+
2173
+ /* this outputs the logarithm of the density */
2174
+ template<class ldouble_safe, class real_t>
2175
+ double DensityCalculator<ldouble_safe, real_t>::calc_bdens(ldouble_safe remainder, size_t sample_size)
2176
+ {
2177
+ double out = std::log(remainder) - std::log((ldouble_safe)sample_size) - this->calc_bratio_inv_log();
2178
+ return std::fmax(out, MIN_DENS);
2179
+ }
2180
+
2181
+ /* this outputs the logarithm of the density */
2182
+ template<class ldouble_safe, class real_t>
2183
+ double DensityCalculator<ldouble_safe, real_t>::calc_bdens2(ldouble_safe remainder, size_t sample_size)
2184
+ {
2185
+ double out = std::log(remainder) - std::log((ldouble_safe)sample_size) - this->calc_bratio_log();
2186
+ return std::fmax(out, MIN_DENS);
2187
+ }
2188
+
2189
+ /* this outputs the logarithm of the density */
2190
+ template<class ldouble_safe, class real_t>
2191
+ ldouble_safe DensityCalculator<ldouble_safe, real_t>::calc_bratio_log_ext()
2192
+ {
2193
+ return this->vals_ext_box.back();
2194
+ }
2195
+
2196
+ template<class ldouble_safe, class real_t>
2197
+ double DensityCalculator<ldouble_safe, real_t>::calc_bratio_ext()
2198
+ {
2199
+ double out = std::exp(this->calc_bratio_log_ext());
2200
+ return std::fmax(out, std::numeric_limits<double>::min());
2201
+ }
2202
+
2203
+ /* this outputs the logarithm of the density */
2204
+ template<class ldouble_safe, class real_t>
2205
+ double DensityCalculator<ldouble_safe, real_t>::calc_bdens_ext(ldouble_safe remainder, size_t sample_size)
2206
+ {
2207
+ double out = std::log(remainder) - std::log((ldouble_safe)sample_size) - this->calc_bratio_log_ext();
2208
+ return std::fmax(out, MIN_DENS);
2209
+ }
2210
+
2211
+ template<class ldouble_safe, class real_t>
2212
+ void DensityCalculator<ldouble_safe, real_t>::save_range(double xmin, double xmax)
2213
+ {
2214
+ this->xmin = xmin;
2215
+ this->xmax = xmax;
2216
+ }
2217
+
2218
+ template<class ldouble_safe, class real_t>
2219
+ void DensityCalculator<ldouble_safe, real_t>::restore_range(double &restrict xmin, double &restrict xmax)
2220
+ {
2221
+ xmin = this->xmin;
2222
+ xmax = this->xmax;
2223
+ }
2224
+
2225
+ template<class ldouble_safe, class real_t>
2226
+ void DensityCalculator<ldouble_safe, real_t>::save_counts(size_t *restrict cat_counts, int ncat)
2227
+ {
2228
+ this->counts.assign(cat_counts, cat_counts + ncat);
2229
+ }
2230
+
2231
+ template<class ldouble_safe, class real_t>
2232
+ void DensityCalculator<ldouble_safe, real_t>::save_n_present_and_left(signed char *restrict split_left, int ncat)
2233
+ {
2234
+ this->n_present = 0;
2235
+ this->n_left = 0;
2236
+ for (int cat = 0; cat < ncat; cat++)
2237
+ {
2238
+ this->n_present += split_left[cat] >= 0;
2239
+ this->n_left += split_left[cat] == 1;
2240
+ }
2241
+ }
2242
+
2243
+ template<class ldouble_safe, class real_t>
2244
+ void DensityCalculator<ldouble_safe, real_t>::save_n_present(size_t *restrict cat_counts, int ncat)
2245
+ {
2246
+ this->n_present = 0;
2247
+ for (int cat = 0; cat < ncat; cat++)
2248
+ this->n_present += cat_counts[cat] > 0;
2249
+ }
2250
+
2251
+ /* For hyperplane intersections */
2252
+ size_t divide_subset_split(size_t ix_arr[], double x[], size_t st, size_t end, double split_point) noexcept
2253
+ {
2254
+ size_t temp;
2255
+ size_t st_orig = st;
2256
+ for (size_t row = st_orig; row <= end; row++)
2257
+ {
2258
+ if (x[row - st_orig] <= split_point)
2259
+ {
2260
+ temp = ix_arr[st];
2261
+ ix_arr[st] = ix_arr[row];
2262
+ ix_arr[row] = temp;
2263
+ st++;
2264
+ }
2265
+ }
2266
+ return st;
2267
+ }
2268
+
2269
+ /* For numerical columns */
2270
+ template <class real_t>
2271
+ void divide_subset_split(size_t *restrict ix_arr, real_t x[], size_t st, size_t end, double split_point,
2272
+ MissingAction missing_action, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2273
+ {
2274
+ size_t temp;
2275
+
2276
+ /* if NAs are not to be bothered with, just need to do a single pass */
2277
+ if (missing_action == Fail)
2278
+ {
2279
+ /* move to the left if it's l.e. split point */
2280
+ for (size_t row = st; row <= end; row++)
2281
+ {
2282
+ if (x[ix_arr[row]] <= split_point)
2283
+ {
2284
+ temp = ix_arr[st];
2285
+ ix_arr[st] = ix_arr[row];
2286
+ ix_arr[row] = temp;
2287
+ st++;
2288
+ }
2289
+ }
2290
+ split_ix = st;
2291
+ }
2292
+
2293
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
2294
+ else
2295
+ {
2296
+ for (size_t row = st; row <= end; row++)
2297
+ {
2298
+ if (!std::isnan(x[ix_arr[row]]) && x[ix_arr[row]] <= split_point)
2299
+ {
2300
+ temp = ix_arr[st];
2301
+ ix_arr[st] = ix_arr[row];
2302
+ ix_arr[row] = temp;
2303
+ st++;
2304
+ }
2305
+ }
2306
+ st_NA = st;
2307
+
2308
+ for (size_t row = st; row <= end; row++)
2309
+ {
2310
+ if (unlikely(std::isnan(x[ix_arr[row]])))
2311
+ {
2312
+ temp = ix_arr[st];
2313
+ ix_arr[st] = ix_arr[row];
2314
+ ix_arr[row] = temp;
2315
+ st++;
2316
+ }
2317
+ }
2318
+ end_NA = st;
2319
+ }
2320
+ }
2321
+
2322
+ /* For sparse numeric columns */
2323
+ template <class real_t, class sparse_ix>
2324
+ void divide_subset_split(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
2325
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr, double split_point,
2326
+ MissingAction missing_action, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2327
+ {
2328
+ /* TODO: this is a mess, needs refactoring */
2329
+ /* TODO: when moving zeros, would be better to instead move by '>' (opposite as in here) */
2330
+ /* TODO: should create an extra version to go along with 'predict' that would
2331
+ add the range penalty right here to spare operations. */
2332
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
2333
+ {
2334
+ if (missing_action == Fail)
2335
+ {
2336
+ split_ix = (0 <= split_point)? (end+1) : st;
2337
+ }
2338
+
2339
+ else
2340
+ {
2341
+ st_NA = (0 <= split_point)? (end+1) : st;
2342
+ end_NA = (0 <= split_point)? (end+1) : st;
2343
+ }
2344
+
2345
+ }
2346
+
2347
+ size_t st_col = Xc_indptr[col_num];
2348
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
2349
+ size_t curr_pos = st_col;
2350
+ size_t ind_end_col = Xc_ind[end_col];
2351
+ size_t temp;
2352
+ bool move_zeros = 0 <= split_point;
2353
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
2354
+
2355
+ if (move_zeros && ptr_st > ix_arr + st)
2356
+ st = ptr_st - ix_arr;
2357
+
2358
+ if (missing_action == Fail)
2359
+ {
2360
+ if (move_zeros)
2361
+ {
2362
+ for (size_t *row = ptr_st;
2363
+ row != ix_arr + end + 1;
2364
+ )
2365
+ {
2366
+ if (curr_pos >= end_col + 1)
2367
+ {
2368
+ for (size_t *r = row; r <= ix_arr + end; r++)
2369
+ {
2370
+ temp = ix_arr[st];
2371
+ ix_arr[st] = *r;
2372
+ *r = temp;
2373
+ st++;
2374
+ }
2375
+ break;
2376
+ }
2377
+
2378
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2379
+ {
2380
+ if (Xc[curr_pos] <= split_point)
2381
+ {
2382
+ temp = ix_arr[st];
2383
+ ix_arr[st] = *row;
2384
+ *row = temp;
2385
+ st++;
2386
+ }
2387
+ if (curr_pos == end_col && row < ix_arr + end)
2388
+ {
2389
+ for (size_t *r = row + 1; r <= ix_arr + end; r++)
2390
+ {
2391
+ temp = ix_arr[st];
2392
+ ix_arr[st] = *r;
2393
+ *r = temp;
2394
+ st++;
2395
+ }
2396
+ }
2397
+ if (row == ix_arr + end || curr_pos == end_col) break;
2398
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2399
+ }
2400
+
2401
+ else
2402
+ {
2403
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2404
+ {
2405
+ while (row <= ix_arr + end && Xc_ind[curr_pos] > (sparse_ix)(*row))
2406
+ {
2407
+ temp = ix_arr[st];
2408
+ ix_arr[st] = *row;
2409
+ *row = temp;
2410
+ st++; row++;
2411
+ }
2412
+ }
2413
+
2414
+ else
2415
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2416
+ }
2417
+ }
2418
+ }
2419
+
2420
+ else /* don't move zeros */
2421
+ {
2422
+ for (size_t *row = ptr_st;
2423
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
2424
+ )
2425
+ {
2426
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2427
+ {
2428
+ if (Xc[curr_pos] <= split_point)
2429
+ {
2430
+ temp = ix_arr[st];
2431
+ ix_arr[st] = *row;
2432
+ *row = temp;
2433
+ st++;
2434
+ }
2435
+ if (row == ix_arr + end || curr_pos == end_col) break;
2436
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2437
+ }
2438
+
2439
+ else
2440
+ {
2441
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2442
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
2443
+ else
2444
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2445
+ }
2446
+ }
2447
+ }
2448
+
2449
+ split_ix = st;
2450
+ }
2451
+
2452
+ else /* can have NAs */
2453
+ {
2454
+
2455
+ bool has_NAs = false;
2456
+ if (move_zeros)
2457
+ {
2458
+ for (size_t *row = ptr_st;
2459
+ row != ix_arr + end + 1;
2460
+ )
2461
+ {
2462
+ if (curr_pos >= end_col + 1)
2463
+ {
2464
+ for (size_t *r = row; r <= ix_arr + end; r++)
2465
+ {
2466
+ temp = ix_arr[st];
2467
+ ix_arr[st] = *r;
2468
+ *r = temp;
2469
+ st++;
2470
+ }
2471
+ break;
2472
+ }
2473
+
2474
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2475
+ {
2476
+ if (unlikely(std::isnan(Xc[curr_pos])))
2477
+ has_NAs = true;
2478
+ else if (Xc[curr_pos] <= split_point)
2479
+ {
2480
+ temp = ix_arr[st];
2481
+ ix_arr[st] = *row;
2482
+ *row = temp;
2483
+ st++;
2484
+ }
2485
+ if (curr_pos == end_col && row < ix_arr + end)
2486
+ for (size_t *r = row + 1; r <= ix_arr + end; r++)
2487
+ {
2488
+ temp = ix_arr[st];
2489
+ ix_arr[st] = *r;
2490
+ *r = temp;
2491
+ st++;
2492
+ }
2493
+ if (row == ix_arr + end || curr_pos == end_col) break;
2494
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2495
+ }
2496
+
2497
+ else
2498
+ {
2499
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2500
+ {
2501
+ while (row <= ix_arr + end && Xc_ind[curr_pos] > (sparse_ix)(*row))
2502
+ {
2503
+ temp = ix_arr[st];
2504
+ ix_arr[st] = *row;
2505
+ *row = temp;
2506
+ st++; row++;
2507
+ }
2508
+ }
2509
+
2510
+ else
2511
+ {
2512
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2513
+ }
2514
+ }
2515
+ }
2516
+ }
2517
+
2518
+ else /* don't move zeros */
2519
+ {
2520
+ for (size_t *row = ptr_st;
2521
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
2522
+ )
2523
+ {
2524
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2525
+ {
2526
+ if (unlikely(std::isnan(Xc[curr_pos]))) has_NAs = true;
2527
+ if (!std::isnan(Xc[curr_pos]) && Xc[curr_pos] <= split_point)
2528
+ {
2529
+ temp = ix_arr[st];
2530
+ ix_arr[st] = *row;
2531
+ *row = temp;
2532
+ st++;
2533
+ }
2534
+ if (row == ix_arr + end || curr_pos == end_col) break;
2535
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2536
+ }
2537
+
2538
+ else
2539
+ {
2540
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2541
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
2542
+ else
2543
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2544
+ }
2545
+ }
2546
+ }
2547
+
2548
+
2549
+ st_NA = st;
2550
+ if (has_NAs)
2551
+ {
2552
+ curr_pos = st_col;
2553
+ std::sort(ix_arr + st, ix_arr + end + 1);
2554
+ for (size_t *row = ix_arr + st;
2555
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
2556
+ )
2557
+ {
2558
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2559
+ {
2560
+ if (unlikely(std::isnan(Xc[curr_pos])))
2561
+ {
2562
+ temp = ix_arr[st];
2563
+ ix_arr[st] = *row;
2564
+ *row = temp;
2565
+ st++;
2566
+ }
2567
+ if (row == ix_arr + end || curr_pos == end_col) break;
2568
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2569
+ }
2570
+
2571
+ else
2572
+ {
2573
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2574
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
2575
+ else
2576
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2577
+ }
2578
+ }
2579
+ }
2580
+ end_NA = st;
2581
+
2582
+ }
2583
+
2584
+ }
2585
+
2586
+ /* For categorical columns split by subset */
2587
+ void divide_subset_split(size_t *restrict ix_arr, int x[], size_t st, size_t end, signed char split_categ[],
2588
+ MissingAction missing_action, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2589
+ {
2590
+ size_t temp;
2591
+
2592
+ /* if NAs are not to be bothered with, just need to do a single pass */
2593
+ if (missing_action == Fail)
2594
+ {
2595
+ /* move to the left if it's l.e. than the split point */
2596
+ for (size_t row = st; row <= end; row++)
2597
+ {
2598
+ if (split_categ[ x[ix_arr[row]] ] == 1)
2599
+ {
2600
+ temp = ix_arr[st];
2601
+ ix_arr[st] = ix_arr[row];
2602
+ ix_arr[row] = temp;
2603
+ st++;
2604
+ }
2605
+ }
2606
+ split_ix = st;
2607
+ }
2608
+
2609
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
2610
+ else
2611
+ {
2612
+ for (size_t row = st; row <= end; row++)
2613
+ {
2614
+ if (x[ix_arr[row]] >= 0 && split_categ[ x[ix_arr[row]] ] == 1)
2615
+ {
2616
+ temp = ix_arr[st];
2617
+ ix_arr[st] = ix_arr[row];
2618
+ ix_arr[row] = temp;
2619
+ st++;
2620
+ }
2621
+ }
2622
+ st_NA = st;
2623
+
2624
+ for (size_t row = st; row <= end; row++)
2625
+ {
2626
+ if (x[ix_arr[row]] < 0)
2627
+ {
2628
+ temp = ix_arr[st];
2629
+ ix_arr[st] = ix_arr[row];
2630
+ ix_arr[row] = temp;
2631
+ st++;
2632
+ }
2633
+ }
2634
+ end_NA = st;
2635
+ }
2636
+ }
2637
+
2638
+ /* For categorical columns split by subset, used at prediction time (with similarity) */
2639
+ void divide_subset_split(size_t *restrict ix_arr, int x[], size_t st, size_t end, signed char split_categ[],
2640
+ int ncat, MissingAction missing_action, NewCategAction new_cat_action,
2641
+ bool move_new_to_left, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2642
+ {
2643
+ size_t temp;
2644
+ int cval;
2645
+
2646
+ /* if NAs are not to be bothered with, just need to do a single pass */
2647
+ if (missing_action == Fail && new_cat_action != Weighted)
2648
+ {
2649
+ /* in this case, will need to fill 'split_ix', otherwise need to fill 'st_NA' and 'end_NA' */
2650
+ if (new_cat_action == Smallest && move_new_to_left)
2651
+ {
2652
+ for (size_t row = st; row <= end; row++)
2653
+ {
2654
+ cval = x[ix_arr[row]];
2655
+ if (cval >= ncat || split_categ[cval] == 1 || split_categ[cval] == (-1))
2656
+ {
2657
+ temp = ix_arr[st];
2658
+ ix_arr[st] = ix_arr[row];
2659
+ ix_arr[row] = temp;
2660
+ st++;
2661
+ }
2662
+ }
2663
+ }
2664
+
2665
+ else if (new_cat_action == Random)
2666
+ {
2667
+ for (size_t row = st; row <= end; row++)
2668
+ {
2669
+ cval = x[ix_arr[row]];
2670
+ cval = (cval >= ncat)? (cval % ncat) : cval;
2671
+ if (split_categ[cval] == 1)
2672
+ {
2673
+ temp = ix_arr[st];
2674
+ ix_arr[st] = ix_arr[row];
2675
+ ix_arr[row] = temp;
2676
+ st++;
2677
+ }
2678
+ }
2679
+ }
2680
+
2681
+ else
2682
+ {
2683
+ for (size_t row = st; row <= end; row++)
2684
+ {
2685
+ cval = x[ix_arr[row]];
2686
+ if (cval < ncat && split_categ[cval] == 1)
2687
+ {
2688
+ temp = ix_arr[st];
2689
+ ix_arr[st] = ix_arr[row];
2690
+ ix_arr[row] = temp;
2691
+ st++;
2692
+ }
2693
+ }
2694
+ }
2695
+
2696
+ split_ix = st;
2697
+ }
2698
+
2699
+ /* if there are new categories, and their direction was decided at random,
2700
+ can just reuse what was randomly decided for previous columns by taking
2701
+ a remainder w.r.t. the number of previous columns. Note however that this
2702
+ will not be an unbiased decision if the model used a gain criterion. */
2703
+ else if (new_cat_action == Random)
2704
+ {
2705
+ if (missing_action == Impute && !move_new_to_left)
2706
+ {
2707
+ for (size_t row = st; row <= end; row++)
2708
+ {
2709
+ cval = x[ix_arr[row]];
2710
+ cval = (cval >= ncat)? (cval % ncat) : cval;
2711
+ if (cval < 0 || split_categ[cval] == 1)
2712
+ {
2713
+ temp = ix_arr[st];
2714
+ ix_arr[st] = ix_arr[row];
2715
+ ix_arr[row] = temp;
2716
+ st++;
2717
+ }
2718
+ }
2719
+ }
2720
+
2721
+ else
2722
+ {
2723
+ for (size_t row = st; row <= end; row++)
2724
+ {
2725
+ cval = x[ix_arr[row]];
2726
+ cval = (cval >= ncat)? (cval % ncat) : cval;
2727
+ if (cval >= 0 && split_categ[cval] == 1)
2728
+ {
2729
+ temp = ix_arr[st];
2730
+ ix_arr[st] = ix_arr[row];
2731
+ ix_arr[row] = temp;
2732
+ st++;
2733
+ }
2734
+ }
2735
+ }
2736
+ st_NA = st;
2737
+
2738
+ if (!(missing_action == Impute && !move_new_to_left))
2739
+ {
2740
+ for (size_t row = st; row <= end; row++)
2741
+ {
2742
+ if (unlikely(x[ix_arr[row]] < 0))
2743
+ {
2744
+ temp = ix_arr[st];
2745
+ ix_arr[st] = ix_arr[row];
2746
+ ix_arr[row] = temp;
2747
+ st++;
2748
+ }
2749
+ }
2750
+ }
2751
+ end_NA = st;
2752
+ }
2753
+
2754
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
2755
+ else
2756
+ {
2757
+ /* Note: if having 'new_cat_action'='Smallest' and 'missing_action'='Impute', missing values
2758
+ and new categories will necessarily go into different branches, thus it's possible to do
2759
+ all the movements in one pass if certain conditions match. */
2760
+
2761
+ if (new_cat_action == Smallest && move_new_to_left)
2762
+ {
2763
+ for (size_t row = st; row <= end; row++)
2764
+ {
2765
+ cval = x[ix_arr[row]];
2766
+ if (cval >= 0 && (cval >= ncat || split_categ[cval] == 1 || split_categ[cval] == (-1)))
2767
+ {
2768
+ temp = ix_arr[st];
2769
+ ix_arr[st] = ix_arr[row];
2770
+ ix_arr[row] = temp;
2771
+ st++;
2772
+ }
2773
+ }
2774
+ }
2775
+
2776
+ else if (missing_action == Impute && !move_new_to_left)
2777
+ {
2778
+ for (size_t row = st; row <= end; row++)
2779
+ {
2780
+ cval = x[ix_arr[row]];
2781
+ if (cval < ncat && (cval < 0 || split_categ[cval] == 1))
2782
+ {
2783
+ temp = ix_arr[st];
2784
+ ix_arr[st] = ix_arr[row];
2785
+ ix_arr[row] = temp;
2786
+ st++;
2787
+ }
2788
+ }
2789
+ }
2790
+
2791
+ else
2792
+ {
2793
+ for (size_t row = st; row <= end; row++)
2794
+ {
2795
+ cval = x[ix_arr[row]];
2796
+ if (cval >= 0 && cval < ncat && split_categ[cval] == 1)
2797
+ {
2798
+ temp = ix_arr[st];
2799
+ ix_arr[st] = ix_arr[row];
2800
+ ix_arr[row] = temp;
2801
+ st++;
2802
+ }
2803
+ }
2804
+ }
2805
+
2806
+ st_NA = st;
2807
+
2808
+ if (new_cat_action == Weighted && missing_action == Divide)
2809
+ {
2810
+ for (size_t row = st; row <= end; row++)
2811
+ {
2812
+ cval = x[ix_arr[row]];
2813
+ if (cval < 0 || cval >= ncat || split_categ[cval] == (-1))
2814
+ {
2815
+ temp = ix_arr[st];
2816
+ ix_arr[st] = ix_arr[row];
2817
+ ix_arr[row] = temp;
2818
+ st++;
2819
+ }
2820
+ }
2821
+ }
2822
+
2823
+ else if (new_cat_action == Weighted)
2824
+ {
2825
+ for (size_t row = st; row <= end; row++)
2826
+ {
2827
+ cval = x[ix_arr[row]];
2828
+ if (cval >= 0 && (cval >= ncat || split_categ[cval] == (-1)))
2829
+ {
2830
+ temp = ix_arr[st];
2831
+ ix_arr[st] = ix_arr[row];
2832
+ ix_arr[row] = temp;
2833
+ st++;
2834
+ }
2835
+ }
2836
+ }
2837
+
2838
+ else if (missing_action == Divide)
2839
+ {
2840
+ for (size_t row = st; row <= end; row++)
2841
+ {
2842
+ if (unlikely(x[ix_arr[row]] < 0))
2843
+ {
2844
+ temp = ix_arr[st];
2845
+ ix_arr[st] = ix_arr[row];
2846
+ ix_arr[row] = temp;
2847
+ st++;
2848
+ }
2849
+ }
2850
+ }
2851
+
2852
+ end_NA = st;
2853
+ }
2854
+ }
2855
+
2856
+ /* For categoricals split on a single category */
2857
+ void divide_subset_split(size_t *restrict ix_arr, int x[], size_t st, size_t end, int split_categ,
2858
+ MissingAction missing_action, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2859
+ {
2860
+ size_t temp;
2861
+
2862
+ /* if NAs are not to be bothered with, just need to do a single pass */
2863
+ if (missing_action == Fail)
2864
+ {
2865
+ /* move to the left if it's equal to the chosen category */
2866
+ for (size_t row = st; row <= end; row++)
2867
+ {
2868
+ if (x[ix_arr[row]] == split_categ)
2869
+ {
2870
+ temp = ix_arr[st];
2871
+ ix_arr[st] = ix_arr[row];
2872
+ ix_arr[row] = temp;
2873
+ st++;
2874
+ }
2875
+ }
2876
+ split_ix = st;
2877
+ }
2878
+
2879
+ /* otherwise, first put to the left all equal to chosen and not NA, then all NAs to the end of the left */
2880
+ else
2881
+ {
2882
+ for (size_t row = st; row <= end; row++)
2883
+ {
2884
+ if (x[ix_arr[row]] == split_categ)
2885
+ {
2886
+ temp = ix_arr[st];
2887
+ ix_arr[st] = ix_arr[row];
2888
+ ix_arr[row] = temp;
2889
+ st++;
2890
+ }
2891
+ }
2892
+ st_NA = st;
2893
+
2894
+ for (size_t row = st; row <= end; row++)
2895
+ {
2896
+ if (unlikely(x[ix_arr[row]] < 0))
2897
+ {
2898
+ temp = ix_arr[st];
2899
+ ix_arr[st] = ix_arr[row];
2900
+ ix_arr[row] = temp;
2901
+ st++;
2902
+ }
2903
+ }
2904
+ end_NA = st;
2905
+ }
2906
+ }
2907
+
2908
+ /* For categoricals split on sub-set that turned out to have 2 categories only (prediction-time) */
2909
+ void divide_subset_split(size_t *restrict ix_arr, int x[], size_t st, size_t end,
2910
+ MissingAction missing_action, NewCategAction new_cat_action,
2911
+ bool move_new_to_left, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2912
+ {
2913
+ size_t temp;
2914
+
2915
+ /* if NAs are not to be bothered with, just need to do a single pass */
2916
+ if (missing_action == Fail)
2917
+ {
2918
+ /* move to the left if it's l.e. than the split point */
2919
+ if (new_cat_action == Smallest && move_new_to_left)
2920
+ {
2921
+ for (size_t row = st; row <= end; row++)
2922
+ {
2923
+ if (x[ix_arr[row]] == 0 || x[ix_arr[row]] > 1)
2924
+ {
2925
+ temp = ix_arr[st];
2926
+ ix_arr[st] = ix_arr[row];
2927
+ ix_arr[row] = temp;
2928
+ st++;
2929
+ }
2930
+ }
2931
+ }
2932
+
2933
+ else
2934
+ {
2935
+ for (size_t row = st; row <= end; row++)
2936
+ {
2937
+ if (x[ix_arr[row]] == 0)
2938
+ {
2939
+ temp = ix_arr[st];
2940
+ ix_arr[st] = ix_arr[row];
2941
+ ix_arr[row] = temp;
2942
+ st++;
2943
+ }
2944
+ }
2945
+ }
2946
+ split_ix = st;
2947
+ }
2948
+
2949
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
2950
+ else
2951
+ {
2952
+ if (new_cat_action == Smallest && move_new_to_left)
2953
+ {
2954
+ for (size_t row = st; row <= end; row++)
2955
+ {
2956
+ if (x[ix_arr[row]] == 0 || x[ix_arr[row]] > 1)
2957
+ {
2958
+ temp = ix_arr[st];
2959
+ ix_arr[st] = ix_arr[row];
2960
+ ix_arr[row] = temp;
2961
+ st++;
2962
+ }
2963
+ }
2964
+ st_NA = st;
2965
+
2966
+ for (size_t row = st; row <= end; row++)
2967
+ {
2968
+ if (unlikely(x[ix_arr[row]] < 0))
2969
+ {
2970
+ temp = ix_arr[st];
2971
+ ix_arr[st] = ix_arr[row];
2972
+ ix_arr[row] = temp;
2973
+ st++;
2974
+ }
2975
+ }
2976
+ end_NA = st;
2977
+ }
2978
+
2979
+ else
2980
+ {
2981
+ for (size_t row = st; row <= end; row++)
2982
+ {
2983
+ if (x[ix_arr[row]] == 0)
2984
+ {
2985
+ temp = ix_arr[st];
2986
+ ix_arr[st] = ix_arr[row];
2987
+ ix_arr[row] = temp;
2988
+ st++;
2989
+ }
2990
+ }
2991
+ st_NA = st;
2992
+
2993
+ for (size_t row = st; row <= end; row++)
2994
+ {
2995
+ if (unlikely(x[ix_arr[row]] < 0))
2996
+ {
2997
+ temp = ix_arr[st];
2998
+ ix_arr[st] = ix_arr[row];
2999
+ ix_arr[row] = temp;
3000
+ st++;
3001
+ }
3002
+ }
3003
+ end_NA = st;
3004
+ }
3005
+ }
3006
+ }
3007
+
3008
+ /* for regular numeric columns */
3009
+ template <class real_t>
3010
+ void get_range(size_t ix_arr[], real_t *restrict x, size_t st, size_t end,
3011
+ MissingAction missing_action, double &restrict xmin, double &restrict xmax, bool &unsplittable) noexcept
3012
+ {
3013
+ xmin = HUGE_VAL;
3014
+ xmax = -HUGE_VAL;
3015
+ double xval;
3016
+
3017
+ if (missing_action == Fail)
3018
+ {
3019
+ for (size_t row = st; row <= end; row++)
3020
+ {
3021
+ xval = x[ix_arr[row]];
3022
+ xmin = (xval < xmin)? xval : xmin;
3023
+ xmax = (xval > xmax)? xval : xmax;
3024
+ }
3025
+ }
3026
+
3027
+
3028
+ else
3029
+ {
3030
+ for (size_t row = st; row <= end; row++)
3031
+ {
3032
+ xval = x[ix_arr[row]];
3033
+ xmin = std::fmin(xmin, xval);
3034
+ xmax = std::fmax(xmax, xval);
3035
+ }
3036
+ }
3037
+
3038
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL) || std::isnan(xmin) || std::isnan(xmax);
3039
+ }
3040
+
3041
+ template <class real_t>
3042
+ void get_range(real_t *restrict x, size_t n,
3043
+ MissingAction missing_action, double &restrict xmin, double &restrict xmax, bool &unsplittable) noexcept
3044
+ {
3045
+ xmin = HUGE_VAL;
3046
+ xmax = -HUGE_VAL;
3047
+
3048
+ if (missing_action == Fail)
3049
+ {
3050
+ for (size_t row = 0; row < n; row++)
3051
+ {
3052
+ xmin = (x[row] < xmin)? x[row] : xmin;
3053
+ xmax = (x[row] > xmax)? x[row] : xmax;
3054
+ }
3055
+ }
3056
+
3057
+
3058
+ else
3059
+ {
3060
+ for (size_t row = 0; row < n; row++)
3061
+ {
3062
+ xmin = std::fmin(xmin, x[row]);
3063
+ xmax = std::fmax(xmax, x[row]);
3064
+ }
3065
+ }
3066
+
3067
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL) || std::isnan(xmin) || std::isnan(xmax);
3068
+ }
3069
+
3070
+ /* for sparse inputs */
3071
+ template <class real_t, class sparse_ix>
3072
+ void get_range(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
3073
+ real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
3074
+ MissingAction missing_action, double &restrict xmin, double &restrict xmax, bool &unsplittable) noexcept
3075
+ {
3076
+ /* ix_arr must already be sorted beforehand */
3077
+ xmin = HUGE_VAL;
3078
+ xmax = -HUGE_VAL;
3079
+
3080
+ size_t st_col = Xc_indptr[col_num];
3081
+ size_t end_col = Xc_indptr[col_num + 1];
3082
+ size_t nnz_col = end_col - st_col;
3083
+ end_col--;
3084
+ size_t curr_pos = st_col;
3085
+
3086
+ if (!nnz_col ||
3087
+ Xc_ind[st_col] > (sparse_ix)ix_arr[end] ||
3088
+ (sparse_ix)ix_arr[st] > Xc_ind[end_col]
3089
+ )
3090
+ {
3091
+ unsplittable = true;
3092
+ return;
3093
+ }
3094
+
3095
+ if (nnz_col < end - st + 1 ||
3096
+ Xc_ind[st_col] > (sparse_ix)ix_arr[st] ||
3097
+ Xc_ind[end_col] < (sparse_ix)ix_arr[end]
3098
+ )
3099
+ {
3100
+ xmin = 0;
3101
+ xmax = 0;
3102
+ }
3103
+
3104
+ size_t ind_end_col = Xc_ind[end_col];
3105
+ size_t nmatches = 0;
3106
+
3107
+ if (missing_action == Fail)
3108
+ {
3109
+ for (size_t *row = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3110
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3111
+ )
3112
+ {
3113
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3114
+ {
3115
+ nmatches++;
3116
+ xmin = (Xc[curr_pos] < xmin)? Xc[curr_pos] : xmin;
3117
+ xmax = (Xc[curr_pos] > xmax)? Xc[curr_pos] : xmax;
3118
+ if (row == ix_arr + end || curr_pos == end_col) break;
3119
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3120
+ }
3121
+
3122
+ else
3123
+ {
3124
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3125
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3126
+ else
3127
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3128
+ }
3129
+ }
3130
+ }
3131
+
3132
+ else /* can have NAs */
3133
+ {
3134
+ for (size_t *row = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3135
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3136
+ )
3137
+ {
3138
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3139
+ {
3140
+ nmatches++;
3141
+ xmin = std::fmin(xmin, Xc[curr_pos]);
3142
+ xmax = std::fmax(xmax, Xc[curr_pos]);
3143
+ if (row == ix_arr + end || curr_pos == end_col) break;
3144
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3145
+ }
3146
+
3147
+ else
3148
+ {
3149
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3150
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3151
+ else
3152
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3153
+ }
3154
+ }
3155
+
3156
+ }
3157
+
3158
+ if (nmatches < (end - st + 1))
3159
+ {
3160
+ xmin = std::fmin(xmin, 0);
3161
+ xmax = std::fmax(xmax, 0);
3162
+ }
3163
+
3164
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL) || std::isnan(xmin) || std::isnan(xmax);
3165
+ }
3166
+
3167
+ template <class real_t, class sparse_ix>
3168
+ void get_range(size_t col_num, size_t nrows,
3169
+ real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
3170
+ MissingAction missing_action, double &restrict xmin, double &restrict xmax, bool &unsplittable) noexcept
3171
+ {
3172
+ xmin = HUGE_VAL;
3173
+ xmax = -HUGE_VAL;
3174
+
3175
+ if ((size_t)(Xc_indptr[col_num+1] - Xc_indptr[col_num]) < nrows)
3176
+ {
3177
+ xmin = 0;
3178
+ xmax = 0;
3179
+ }
3180
+
3181
+ if (missing_action == Fail)
3182
+ {
3183
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
3184
+ {
3185
+ xmin = (Xc[ix] < xmin)? Xc[ix] : xmin;
3186
+ xmax = (Xc[ix] > xmax)? Xc[ix] : xmax;
3187
+ }
3188
+ }
3189
+
3190
+ else
3191
+ {
3192
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
3193
+ {
3194
+ if (unlikely(std::isinf(Xc[ix]))) continue;
3195
+ xmin = std::fmin(xmin, Xc[ix]);
3196
+ xmax = std::fmax(xmax, Xc[ix]);
3197
+ }
3198
+ }
3199
+
3200
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL) || std::isnan(xmin) || std::isnan(xmax);
3201
+ }
3202
+
3203
+
3204
+ void get_categs(size_t *restrict ix_arr, int x[], size_t st, size_t end, int ncat,
3205
+ MissingAction missing_action, signed char categs[], size_t &restrict npresent, bool &unsplittable) noexcept
3206
+ {
3207
+ std::fill(categs, categs + ncat, -1);
3208
+ npresent = 0;
3209
+ for (size_t row = st; row <= end; row++)
3210
+ if (likely(x[ix_arr[row]] >= 0))
3211
+ categs[x[ix_arr[row]]] = 1;
3212
+
3213
+ npresent = std::accumulate(categs,
3214
+ categs + ncat,
3215
+ (size_t)0,
3216
+ [](const size_t a, const signed char b){return a + (b > 0);}
3217
+ );
3218
+
3219
+ unsplittable = npresent < 2;
3220
+ }
3221
+
3222
+ template <class real_t>
3223
+ bool check_more_than_two_unique_values(size_t ix_arr[], size_t st, size_t end, real_t x[], MissingAction missing_action)
3224
+ {
3225
+ if (end - st <= 1) return false;
3226
+
3227
+ if (missing_action == Fail)
3228
+ {
3229
+ real_t x0 = x[ix_arr[st]];
3230
+ for (size_t ix = st+1; ix <= end; ix++)
3231
+ {
3232
+ if (x[ix_arr[ix]] != x0) return true;
3233
+ }
3234
+ }
3235
+
3236
+ else
3237
+ {
3238
+ real_t x0;
3239
+ size_t ix;
3240
+ for (ix = st; ix <= end; ix++)
3241
+ {
3242
+ if (likely(!is_na_or_inf(x[ix_arr[ix]])))
3243
+ {
3244
+ x0 = x[ix_arr[ix]];
3245
+ ix++;
3246
+ break;
3247
+ }
3248
+ }
3249
+
3250
+ for (; ix <= end; ix++)
3251
+ {
3252
+ if (!is_na_or_inf(x[ix_arr[ix]]) && x[ix_arr[ix]] != x0)
3253
+ return true;
3254
+ }
3255
+ }
3256
+
3257
+ return false;
3258
+ }
3259
+
3260
+ bool check_more_than_two_unique_values(size_t ix_arr[], size_t st, size_t end, int x[], MissingAction missing_action)
3261
+ {
3262
+ if (end - st <= 1) return false;
3263
+
3264
+ if (missing_action == Fail)
3265
+ {
3266
+ int x0 = x[ix_arr[st]];
3267
+ for (size_t ix = st+1; ix <= end; ix++)
3268
+ {
3269
+ if (x[ix_arr[ix]] != x0) return true;
3270
+ }
3271
+ }
3272
+
3273
+ else
3274
+ {
3275
+ int x0;
3276
+ size_t ix;
3277
+ for (ix = st; ix <= end; ix++)
3278
+ {
3279
+ if (x[ix_arr[ix]] >= 0)
3280
+ {
3281
+ x0 = x[ix_arr[ix]];
3282
+ ix++;
3283
+ break;
3284
+ }
3285
+ }
3286
+
3287
+ for (; ix <= end; ix++)
3288
+ {
3289
+ if (x[ix_arr[ix]] >= 0 && x[ix_arr[ix]] != x0)
3290
+ return true;
3291
+ }
3292
+ }
3293
+
3294
+ return false;
3295
+ }
3296
+
3297
+ template <class real_t, class sparse_ix>
3298
+ bool check_more_than_two_unique_values(size_t *restrict ix_arr, size_t st, size_t end, size_t col,
3299
+ sparse_ix *restrict Xc_indptr, sparse_ix *restrict Xc_ind, real_t *restrict Xc,
3300
+ MissingAction missing_action)
3301
+ {
3302
+ if (end - st <= 1) return false;
3303
+ if (Xc_indptr[col+1] == Xc_indptr[col]) return false;
3304
+ bool has_zeros = (end - st + 1) > (size_t)(Xc_indptr[col+1] - Xc_indptr[col]);
3305
+ if (has_zeros && !is_na_or_inf(Xc[Xc_indptr[col]]) && Xc[Xc_indptr[col]] != 0) return true;
3306
+
3307
+ size_t st_col = Xc_indptr[col];
3308
+ size_t end_col = Xc_indptr[col + 1] - 1;
3309
+ size_t curr_pos = st_col;
3310
+ size_t ind_end_col = Xc_ind[end_col];
3311
+
3312
+ /* 'ix_arr' should be sorted beforehand */
3313
+ /* TODO: refactor this */
3314
+ real_t x0 = 0;
3315
+ size_t *row;
3316
+ for (row = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3317
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3318
+ )
3319
+ {
3320
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3321
+ {
3322
+ if (is_na_or_inf(Xc[curr_pos]) || (has_zeros && Xc[curr_pos] == 0))
3323
+ {
3324
+ if (row == ix_arr + end || curr_pos == end_col) return false;
3325
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3326
+ }
3327
+
3328
+ x0 = Xc[curr_pos];
3329
+ if (has_zeros) return true;
3330
+ else if (x0 == 0) has_zeros = true;
3331
+ if (row == ix_arr + end || curr_pos == end_col) return false;
3332
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3333
+ break;
3334
+ }
3335
+
3336
+ else
3337
+ {
3338
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3339
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3340
+ else
3341
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3342
+ }
3343
+ }
3344
+
3345
+ for (;
3346
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3347
+ )
3348
+ {
3349
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3350
+ {
3351
+ if (is_na_or_inf(Xc[curr_pos]) || (has_zeros && Xc[curr_pos] == 0))
3352
+ {
3353
+ if (row == ix_arr + end || curr_pos == end_col) break;
3354
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3355
+ }
3356
+
3357
+ else if (Xc[curr_pos] != x0)
3358
+ {
3359
+ return true;
3360
+ }
3361
+
3362
+ if (row == ix_arr + end || curr_pos == end_col) break;
3363
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3364
+ }
3365
+
3366
+ else
3367
+ {
3368
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3369
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3370
+ else
3371
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3372
+ }
3373
+ }
3374
+
3375
+ return false;
3376
+ }
3377
+
3378
+ template <class real_t, class sparse_ix>
3379
+ bool check_more_than_two_unique_values(size_t nrows, size_t col,
3380
+ sparse_ix *restrict Xc_indptr, sparse_ix *restrict Xc_ind, real_t *restrict Xc,
3381
+ MissingAction missing_action)
3382
+ {
3383
+ if (nrows <= 1) return false;
3384
+ if (Xc_indptr[col+1] == Xc_indptr[col]) return false;
3385
+ bool has_zeros = nrows > (size_t)(Xc_indptr[col+1] - Xc_indptr[col]);
3386
+ if (has_zeros && !is_na_or_inf(Xc[Xc_indptr[col]]) && Xc[Xc_indptr[col]] != 0) return true;
3387
+
3388
+ real_t x0 = 0;
3389
+ sparse_ix ix;
3390
+ for (ix = Xc_indptr[col]; ix < Xc_indptr[col+1]; ix++)
3391
+ {
3392
+ if (!is_na_or_inf(Xc[ix]))
3393
+ {
3394
+ if (has_zeros && Xc[ix] == 0) continue;
3395
+ if (has_zeros) return true;
3396
+ else if (Xc[ix] == 0) has_zeros = true;
3397
+ x0 = Xc[ix];
3398
+ ix++;
3399
+ break;
3400
+ }
3401
+ }
3402
+
3403
+ for (ix = Xc_indptr[col]; ix < Xc_indptr[col+1]; ix++)
3404
+ {
3405
+ if (!is_na_or_inf(Xc[ix]))
3406
+ {
3407
+ if (has_zeros && Xc[ix] == 0) continue;
3408
+ if (Xc[ix] != x0) return true;
3409
+ }
3410
+ }
3411
+
3412
+ return false;
3413
+ }
3414
+
3415
+ void count_categs(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat, size_t *restrict counts)
3416
+ {
3417
+ std::fill(counts, counts + ncat, (size_t)0);
3418
+ for (size_t row = st; row <= end; row++)
3419
+ if (likely(x[ix_arr[row]] >= 0))
3420
+ counts[x[ix_arr[row]]]++;
3421
+ }
3422
+
3423
+ int count_ncateg_in_col(const int x[], const size_t n, const int ncat, unsigned char buffer[])
3424
+ {
3425
+ memset(buffer, 0, ncat*sizeof(char));
3426
+ for (size_t ix = 0; ix < n; ix++)
3427
+ {
3428
+ if (likely(x[ix] >= 0)) buffer[x[ix]] = true;
3429
+ }
3430
+
3431
+ int ncat_present = 0;
3432
+ for (int cat = 0; cat < ncat; cat++)
3433
+ ncat_present += buffer[cat];
3434
+ return ncat_present;
3435
+ }
3436
+
3437
+ template <class ldouble_safe>
3438
+ ldouble_safe calculate_sum_weights(std::vector<size_t> &ix_arr, size_t st, size_t end, size_t curr_depth,
3439
+ std::vector<double> &weights_arr, hashed_map<size_t, double> &weights_map)
3440
+ {
3441
+ if (curr_depth > 0 && !weights_arr.empty())
3442
+ return std::accumulate(ix_arr.begin() + st,
3443
+ ix_arr.begin() + end + 1,
3444
+ (ldouble_safe)0,
3445
+ [&weights_arr](const ldouble_safe a, const size_t ix){return a + weights_arr[ix];});
3446
+ else if (curr_depth > 0 && !weights_map.empty())
3447
+ return std::accumulate(ix_arr.begin() + st,
3448
+ ix_arr.begin() + end + 1,
3449
+ (ldouble_safe)0,
3450
+ [&weights_map](const ldouble_safe a, const size_t ix){return a + weights_map[ix];});
3451
+ else
3452
+ return -HUGE_VAL;
3453
+ }
3454
+
3455
+ template <class real_t>
3456
+ size_t move_NAs_to_front(size_t ix_arr[], size_t st, size_t end, real_t x[])
3457
+ {
3458
+ size_t st_non_na = st;
3459
+ size_t temp;
3460
+
3461
+ for (size_t row = st; row <= end; row++)
3462
+ {
3463
+ if (unlikely(is_na_or_inf(x[ix_arr[row]])))
3464
+ {
3465
+ temp = ix_arr[st_non_na];
3466
+ ix_arr[st_non_na] = ix_arr[row];
3467
+ ix_arr[row] = temp;
3468
+ st_non_na++;
3469
+ }
3470
+ }
3471
+
3472
+ return st_non_na;
3473
+ }
3474
+
3475
+ template <class real_t, class sparse_ix>
3476
+ size_t move_NAs_to_front(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num, real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr)
3477
+ {
3478
+ size_t st_non_na = st;
3479
+ size_t temp;
3480
+
3481
+ size_t st_col = Xc_indptr[col_num];
3482
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
3483
+ size_t curr_pos = st_col;
3484
+ size_t ind_end_col = Xc_ind[end_col];
3485
+ std::sort(ix_arr + st, ix_arr + end + 1);
3486
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3487
+
3488
+ for (size_t *row = ptr_st;
3489
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3490
+ )
3491
+ {
3492
+ if (Xc_ind[curr_pos] == *row)
3493
+ {
3494
+ if (unlikely(is_na_or_inf(Xc[curr_pos])))
3495
+ {
3496
+ temp = ix_arr[st_non_na];
3497
+ ix_arr[st_non_na] = *row;
3498
+ *row = temp;
3499
+ st_non_na++;
3500
+ }
3501
+
3502
+ if (row == ix_arr + end || curr_pos == end_col) break;
3503
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3504
+ }
3505
+
3506
+ else
3507
+ {
3508
+ if (Xc_ind[curr_pos] > *row)
3509
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3510
+ else
3511
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3512
+ }
3513
+ }
3514
+
3515
+ return st_non_na;
3516
+ }
3517
+
3518
+ size_t move_NAs_to_front(size_t ix_arr[], size_t st, size_t end, int x[])
3519
+ {
3520
+ size_t st_non_na = st;
3521
+ size_t temp;
3522
+
3523
+ for (size_t row = st; row <= end; row++)
3524
+ {
3525
+ if (unlikely(x[ix_arr[row]] < 0))
3526
+ {
3527
+ temp = ix_arr[st_non_na];
3528
+ ix_arr[st_non_na] = ix_arr[row];
3529
+ ix_arr[row] = temp;
3530
+ st_non_na++;
3531
+ }
3532
+ }
3533
+
3534
+ return st_non_na;
3535
+ }
3536
+
3537
+ size_t center_NAs(size_t ix_arr[], size_t st_left, size_t st, size_t curr_pos)
3538
+ {
3539
+ size_t temp;
3540
+ for (size_t row = st_left; row < st; row++)
3541
+ {
3542
+ temp = ix_arr[--curr_pos];
3543
+ ix_arr[curr_pos] = ix_arr[row];
3544
+ ix_arr[row] = temp;
3545
+ }
3546
+
3547
+ return curr_pos;
3548
+ }
3549
+
3550
+ /* FIXME / TODO: this calculation would not take weight into account */
3551
+ /* Here:
3552
+ - 'ix_arr' should be partitioned putting the NAs and Infs at the beginning: [st_orig, st)
3553
+ - the rest of the range [st, end] should be sorted in ascending order
3554
+ The output should have a filled-in 'x' with median values, plus a re-sorted 'ix_arr'
3555
+ taking into account that now the median values are in the middle. */
3556
+ template <class real_t>
3557
+ void fill_NAs_with_median(size_t *restrict ix_arr, size_t st_orig, size_t st, size_t end, real_t *restrict x,
3558
+ double *restrict buffer_imputed_x, double *restrict xmedian)
3559
+ {
3560
+ size_t tot = end - st + 1;
3561
+ size_t idx_half = st + div2(tot);
3562
+ bool is_odd = (tot % 2) != 0;
3563
+
3564
+ if (is_odd)
3565
+ {
3566
+ *xmedian = x[ix_arr[idx_half]];
3567
+ idx_half--;
3568
+ }
3569
+
3570
+ else
3571
+ {
3572
+ idx_half--;
3573
+ double xlow = x[ix_arr[idx_half]];
3574
+ double xhigh = x[ix_arr[idx_half+(size_t)1]];
3575
+ *xmedian = xlow + (xhigh-xlow)/2.;
3576
+ }
3577
+
3578
+ for (size_t ix = st_orig; ix < st; ix++)
3579
+ buffer_imputed_x[ix_arr[ix]] = (*xmedian);
3580
+ for (size_t ix = st; ix <= end; ix++)
3581
+ buffer_imputed_x[ix_arr[ix]] = x[ix_arr[ix]];
3582
+
3583
+ /* 'ix_arr' can be resorted in-place, but the logic is a bit complex */
3584
+ /* step 1: move all NAs to their place by swapping them with the lower-half
3585
+ in ascending order (after this, the lower half will be unordered).
3586
+ along the way, copy the indices that claim the places where earlier
3587
+ there were missing values. these copied indices will be sorted in
3588
+ descending order at the end, as they were inserted in reverse order. */
3589
+ size_t end_pointer = idx_half;
3590
+ size_t n_move = std::min(st-st_orig, idx_half-st+1);
3591
+ size_t temp;
3592
+ for (size_t ix = st_orig; ix < st_orig + n_move; ix++)
3593
+ {
3594
+ temp = ix_arr[end_pointer];
3595
+ ix_arr[end_pointer] = ix_arr[ix];
3596
+ ix_arr[ix] = temp;
3597
+ end_pointer--;
3598
+ }
3599
+
3600
+ /* step 2: reverse the indices that were moved to the beginning so
3601
+ as to maintain the sorting order */
3602
+ std::reverse(ix_arr + st_orig, ix_arr + st_orig + n_move);
3603
+ /* step 3: rotate the total number of elements by the number of moved elements */
3604
+ size_t n_unmoved = (idx_half - st + 1) - n_move;
3605
+ std::rotate(ix_arr + st_orig,
3606
+ ix_arr + st_orig + n_move,
3607
+ ix_arr + st_orig + n_move + n_unmoved);
3608
+ }
3609
+
3610
+ template <class real_t, class sparse_ix>
3611
+ void todense(size_t *restrict ix_arr, size_t st, size_t end,
3612
+ size_t col_num, real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
3613
+ double *restrict buffer_arr)
3614
+ {
3615
+ std::fill(buffer_arr, buffer_arr + (end - st + 1), (double)0);
3616
+
3617
+ size_t st_col = Xc_indptr[col_num];
3618
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
3619
+ size_t curr_pos = st_col;
3620
+ size_t ind_end_col = Xc_ind[end_col];
3621
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3622
+
3623
+ for (size_t *row = ptr_st;
3624
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3625
+ )
3626
+ {
3627
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3628
+ {
3629
+ buffer_arr[row - (ix_arr + st)] = Xc[curr_pos];
3630
+ if (row == ix_arr + end || curr_pos == end_col) break;
3631
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3632
+ }
3633
+
3634
+ else
3635
+ {
3636
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3637
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3638
+ else
3639
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3640
+ }
3641
+ }
3642
+ }
3643
+
3644
+
3645
+ template <class real_t>
3646
+ void colmajor_to_rowmajor(real_t *restrict X, size_t nrows, size_t ncols, std::vector<double> &X_row_major)
3647
+ {
3648
+ X_row_major.resize(nrows * ncols);
3649
+ for (size_t row = 0; row < nrows; row++)
3650
+ for (size_t col = 0; col < ncols; col++)
3651
+ X_row_major[row + col*nrows] = X[col + row*ncols];
3652
+ }
3653
+
3654
+ template <class real_t, class sparse_ix>
3655
+ void colmajor_to_rowmajor(real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
3656
+ size_t nrows, size_t ncols,
3657
+ std::vector<double> &Xr, std::vector<size_t> &Xr_ind, std::vector<size_t> &Xr_indptr)
3658
+ {
3659
+ /* First convert to COO */
3660
+ size_t nnz = Xc_indptr[ncols];
3661
+ std::vector<size_t> row_indices(nnz);
3662
+ for (size_t col = 0; col < ncols; col++)
3663
+ {
3664
+ for (sparse_ix ix = Xc_indptr[col]; ix < Xc_indptr[col+1]; ix++)
3665
+ {
3666
+ row_indices[ix] = Xc_ind[ix];
3667
+ }
3668
+ }
3669
+
3670
+ /* Then copy the data argsorted by rows */
3671
+ std::vector<size_t> argsorted_indices(nnz);
3672
+ std::iota(argsorted_indices.begin(), argsorted_indices.end(), (size_t)0);
3673
+ std::stable_sort(argsorted_indices.begin(), argsorted_indices.end(),
3674
+ [&row_indices](const size_t a, const size_t b)
3675
+ {return row_indices[a] < row_indices[b];});
3676
+ Xr.resize(nnz);
3677
+ Xr_ind.resize(nnz);
3678
+ for (size_t ix = 0; ix < nnz; ix++)
3679
+ {
3680
+ Xr[ix] = Xc[argsorted_indices[ix]];
3681
+ Xr_ind[ix] = Xc_ind[argsorted_indices[ix]];
3682
+ }
3683
+
3684
+ /* Now build the index pointer */
3685
+ Xr_indptr.resize(nrows+1);
3686
+ size_t curr_row = 0;
3687
+ size_t curr_n = 0;
3688
+ for (size_t ix = 0; ix < nnz; ix++)
3689
+ {
3690
+ if (row_indices[argsorted_indices[ix]] != curr_row)
3691
+ {
3692
+ Xr_indptr[curr_row+1] = curr_n;
3693
+ curr_n = 0;
3694
+ curr_row = row_indices[argsorted_indices[ix]];
3695
+ }
3696
+
3697
+ else
3698
+ {
3699
+ curr_n++;
3700
+ }
3701
+ }
3702
+ for (size_t row = 1; row < nrows; row++)
3703
+ Xr_indptr[row+1] += Xr_indptr[row];
3704
+ }
3705
+
3706
+
3707
+ bool interrupt_switch = false;
3708
+ bool handle_is_locked = false;
3709
+
3710
+ /* Function to handle interrupt signals */
3711
+ void set_interrup_global_variable(int s)
3712
+ {
3713
+ #pragma omp critical
3714
+ {
3715
+ interrupt_switch = true;
3716
+ }
3717
+ }
3718
+
3719
+ void check_interrupt_switch(SignalSwitcher &ss)
3720
+ {
3721
+ if (interrupt_switch)
3722
+ {
3723
+ ss.restore_handle();
3724
+ fprintf(stderr, "Error: procedure was interrupted\n");
3725
+ raise(SIGINT);
3726
+ #ifdef _FOR_R
3727
+ Rcpp::checkUserInterrupt();
3728
+ #elif !defined(DONT_THROW_ON_INTERRUPT)
3729
+ throw std::runtime_error("Error: procedure was interrupted.\n");
3730
+ #endif
3731
+ }
3732
+ }
3733
+
3734
+ #ifdef _FOR_PYTHON
3735
+ bool cy_check_interrupt_switch()
3736
+ {
3737
+ return interrupt_switch;
3738
+ }
3739
+
3740
+ void cy_tick_off_interrupt_switch()
3741
+ {
3742
+ interrupt_switch = false;
3743
+ }
3744
+ #endif
3745
+
3746
+ SignalSwitcher::SignalSwitcher()
3747
+ {
3748
+ #pragma omp critical
3749
+ {
3750
+ if (!handle_is_locked)
3751
+ {
3752
+ handle_is_locked = true;
3753
+ interrupt_switch = false;
3754
+ this->old_sig = signal(SIGINT, set_interrup_global_variable);
3755
+ this->is_active = true;
3756
+ }
3757
+
3758
+ else {
3759
+ this->is_active = false;
3760
+ }
3761
+ }
3762
+ }
3763
+
3764
+ SignalSwitcher::~SignalSwitcher()
3765
+ {
3766
+ #ifndef _FOR_PYTHON
3767
+ #pragma omp critical
3768
+ {
3769
+ if (this->is_active && handle_is_locked)
3770
+ interrupt_switch = false;
3771
+ }
3772
+ #endif
3773
+ this->restore_handle();
3774
+ }
3775
+
3776
+ void SignalSwitcher::restore_handle()
3777
+ {
3778
+ #pragma omp critical
3779
+ {
3780
+ if (this->is_active && handle_is_locked)
3781
+ {
3782
+ signal(SIGINT, this->old_sig);
3783
+ this->is_active = false;
3784
+ handle_is_locked = false;
3785
+ }
3786
+ }
3787
+ }
3788
+
3789
+ bool has_long_double()
3790
+ {
3791
+ #ifndef NO_LONG_DOUBLE
3792
+ return sizeof(long double) > sizeof(double);
3793
+ #else
3794
+ return false;
3795
+ #endif
3796
+ }
3797
+
3798
+ /* Return the #def'd constants from standard header. This is in order to determine if the return
3799
+ value from the 'fit_model' function is a success or failure within Cython, which does not
3800
+ allow importing #def'd macro values. */
3801
+ int return_EXIT_SUCCESS()
3802
+ {
3803
+ return EXIT_SUCCESS;
3804
+ }
3805
+ int return_EXIT_FAILURE()
3806
+ {
3807
+ return EXIT_FAILURE;
3808
+ }