isotree 0.2.2 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (152) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2116 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +132 -0
  14. data/vendor/isotree/src/RcppExports.cpp +594 -57
  15. data/vendor/isotree/src/Rwrapper.cpp +2452 -304
  16. data/vendor/isotree/src/c_interface.cpp +958 -0
  17. data/vendor/isotree/src/crit.hpp +4236 -0
  18. data/vendor/isotree/src/digamma.hpp +184 -0
  19. data/vendor/isotree/src/dist.hpp +1886 -0
  20. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  21. data/vendor/isotree/src/extended.hpp +1444 -0
  22. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  23. data/vendor/isotree/src/fit_model.hpp +2401 -0
  24. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  25. data/vendor/isotree/src/helpers_iforest.hpp +814 -0
  26. data/vendor/isotree/src/{impute.cpp → impute.hpp} +382 -123
  27. data/vendor/isotree/src/indexer.cpp +515 -0
  28. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  29. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  30. data/vendor/isotree/src/isoforest.hpp +1659 -0
  31. data/vendor/isotree/src/isotree.hpp +1815 -394
  32. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  33. data/vendor/isotree/src/merge_models.cpp +159 -16
  34. data/vendor/isotree/src/mult.hpp +1321 -0
  35. data/vendor/isotree/src/oop_interface.cpp +844 -0
  36. data/vendor/isotree/src/oop_interface.hpp +278 -0
  37. data/vendor/isotree/src/other_helpers.hpp +219 -0
  38. data/vendor/isotree/src/predict.hpp +1932 -0
  39. data/vendor/isotree/src/python_helpers.hpp +114 -0
  40. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  41. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  42. data/vendor/isotree/src/robinmap/README.md +483 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1639 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  46. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  47. data/vendor/isotree/src/serialize.cpp +4316 -139
  48. data/vendor/isotree/src/sql.cpp +143 -61
  49. data/vendor/isotree/src/subset_models.cpp +174 -0
  50. data/vendor/isotree/src/utils.hpp +3786 -0
  51. data/vendor/isotree/src/xoshiro.hpp +463 -0
  52. data/vendor/isotree/src/ziggurat.hpp +405 -0
  53. metadata +40 -105
  54. data/vendor/cereal/LICENSE +0 -24
  55. data/vendor/cereal/README.md +0 -85
  56. data/vendor/cereal/include/cereal/access.hpp +0 -351
  57. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  58. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  59. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  60. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  61. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  62. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  63. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  65. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  66. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  67. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  68. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  69. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  70. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  71. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  72. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  74. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  76. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  77. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  78. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  79. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  91. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  92. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  94. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  96. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  97. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  98. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  99. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  100. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  101. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  102. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  103. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  104. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  105. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  106. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  107. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  111. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  112. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  113. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  114. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  115. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  116. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  117. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  118. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  119. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  120. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  121. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  122. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  123. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  124. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  125. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  126. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  127. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  128. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  129. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  130. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  131. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  132. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  133. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  134. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  135. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  136. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  137. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  138. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  139. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  140. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  141. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  142. data/vendor/cereal/include/cereal/version.hpp +0 -52
  143. data/vendor/isotree/src/Makevars +0 -4
  144. data/vendor/isotree/src/crit.cpp +0 -912
  145. data/vendor/isotree/src/dist.cpp +0 -749
  146. data/vendor/isotree/src/extended.cpp +0 -790
  147. data/vendor/isotree/src/fit_model.cpp +0 -1090
  148. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  149. data/vendor/isotree/src/isoforest.cpp +0 -771
  150. data/vendor/isotree/src/mult.cpp +0 -607
  151. data/vendor/isotree/src/predict.cpp +0 -853
  152. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,3786 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+ /* ceil(log2(x)) done with bit-wise operations ensures perfect precision (and it's faster too)
66
+ https://stackoverflow.com/questions/2589096/find-most-significant-bit-left-most-that-is-set-in-a-bit-array
67
+ https://stackoverflow.com/questions/11376288/fast-computing-of-log2-for-64-bit-integers */
68
+ #if SIZE_MAX == UINT32_MAX /* 32-bit systems */
69
+ constexpr static const uint32_t MultiplyDeBruijnBitPosition[32] =
70
+ {
71
+ 0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30,
72
+ 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31
73
+ };
74
+ size_t log2ceil( size_t v )
75
+ {
76
+ v--;
77
+ v |= v >> 1; // first round down to one less than a power of 2
78
+ v |= v >> 2;
79
+ v |= v >> 4;
80
+ v |= v >> 8;
81
+ v |= v >> 16;
82
+
83
+ return MultiplyDeBruijnBitPosition[( uint32_t )( v * 0x07C4ACDDU ) >> 27] + 1;
84
+ }
85
+ #elif SIZE_MAX == UINT64_MAX /* 64-bit systems */
86
+ constexpr static const uint64_t tab64[64] = {
87
+ 63, 0, 58, 1, 59, 47, 53, 2,
88
+ 60, 39, 48, 27, 54, 33, 42, 3,
89
+ 61, 51, 37, 40, 49, 18, 28, 20,
90
+ 55, 30, 34, 11, 43, 14, 22, 4,
91
+ 62, 57, 46, 52, 38, 26, 32, 41,
92
+ 50, 36, 17, 19, 29, 10, 13, 21,
93
+ 56, 45, 25, 31, 35, 16, 9, 12,
94
+ 44, 24, 15, 8, 23, 7, 6, 5};
95
+
96
+ size_t log2ceil(size_t value)
97
+ {
98
+ value--;
99
+ value |= value >> 1;
100
+ value |= value >> 2;
101
+ value |= value >> 4;
102
+ value |= value >> 8;
103
+ value |= value >> 16;
104
+ value |= value >> 32;
105
+ return tab64[((uint64_t)((value - (value >> 1))*0x07EDD5E59A4E28C2)) >> 58] + 1;
106
+ }
107
+ #else /* other architectures - might be much slower */
108
+ #if (__cplusplus >= 202002L)
109
+ #include <bit>
110
+ size_t log2ceil(size_t value)
111
+ {
112
+ size_t out = std::numeric_limits<size_t>::digits - std::countl_zero(value);
113
+ out -= (value == ((size_t)1 << (out-1)));
114
+ return out;
115
+ }
116
+ #else
117
+ size_t log2ceil(size_t value)
118
+ {
119
+ size_t value_ = value;
120
+ size_t out = 0;
121
+ while (value >= 1) {
122
+ value = value >> 1;
123
+ out++;
124
+ }
125
+ out -= (value_ == ((size_t)1 << (out-1)));
126
+ return out;
127
+ }
128
+ #endif
129
+ #endif
130
+
131
+ /* adapted from cephes */
132
+ #define EULERS_GAMMA 0.577215664901532860606512
133
+ #include "digamma.hpp"
134
+
135
+ /* http://fredrik-j.blogspot.com/2009/02/how-not-to-compute-harmonic-numbers.html
136
+ https://en.wikipedia.org/wiki/Harmonic_number
137
+ https://github.com/scikit-learn/scikit-learn/pull/19087 */
138
+ template <class ldouble_safe>
139
+ double harmonic(size_t n)
140
+ {
141
+ ldouble_safe temp = (ldouble_safe)1 / square((ldouble_safe)n);
142
+ return - (ldouble_safe)0.5 * temp * ( (ldouble_safe)1/(ldouble_safe)6 - temp * ((ldouble_safe)1/(ldouble_safe)60 - ((ldouble_safe)1/(ldouble_safe)126)*temp) )
143
+ + (ldouble_safe)0.5 * ((ldouble_safe)1/(ldouble_safe)n)
144
+ + std::log((ldouble_safe)n) + (ldouble_safe)EULERS_GAMMA;
145
+ }
146
+
147
+ /* usage for getting harmonic(n) is like this: harmonic_recursive((double)1, (double)(n + 1)); */
148
+ double harmonic_recursive(double a, double b)
149
+ {
150
+ if (b == a + 1) return 1. / a;
151
+ double m = std::floor((a + b) / 2.);
152
+ return harmonic_recursive(a, m) + harmonic_recursive(m, b);
153
+ }
154
+
155
+ /* https://stats.stackexchange.com/questions/423542/isolation-forest-and-average-expected-depth-formula
156
+ https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom */
157
+ #include "exp_depth_table.hpp"
158
+ template <class ldouble_safe>
159
+ double expected_avg_depth(size_t sample_size)
160
+ {
161
+ if (likely(sample_size <= N_PRECALC_EXP_DEPTH)) {
162
+ return exp_depth_table[sample_size - 1];
163
+ }
164
+ return 2. * (harmonic<ldouble_safe>(sample_size) - 1.);
165
+ }
166
+
167
+ /* Note: H(x) = psi(x+1) + gamma */
168
+ template <class ldouble_safe>
169
+ double expected_avg_depth(ldouble_safe approx_sample_size)
170
+ {
171
+ if (approx_sample_size <= 1)
172
+ return 0;
173
+ else if (approx_sample_size < (ldouble_safe)INT32_MAX)
174
+ return 2. * (digamma(approx_sample_size + 1.) + EULERS_GAMMA - 1.);
175
+ else {
176
+ ldouble_safe temp = (ldouble_safe)1 / square(approx_sample_size);
177
+ return (ldouble_safe)2 * std::log(approx_sample_size) + (ldouble_safe)2*((ldouble_safe)EULERS_GAMMA - (ldouble_safe)1)
178
+ + ((ldouble_safe)1/approx_sample_size)
179
+ - temp * ( (ldouble_safe)1/(ldouble_safe)6 - temp * ((ldouble_safe)1/(ldouble_safe)60 - ((ldouble_safe)1/(ldouble_safe)126)*temp) );
180
+ }
181
+ }
182
+
183
+ /* https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree */
184
+ #define THRESHOLD_EXACT_S 87670 /* difference is <5e-4 */
185
+ double expected_separation_depth(size_t n)
186
+ {
187
+ switch(n)
188
+ {
189
+ case 0: return 0.;
190
+ case 1: return 0.;
191
+ case 2: return 1.;
192
+ case 3: return 1. + (1./3.);
193
+ case 4: return 1. + (1./3.) + (2./9.);
194
+ case 5: return 1.71666666667;
195
+ case 6: return 1.84;
196
+ case 7: return 1.93809524;
197
+ case 8: return 2.01836735;
198
+ case 9: return 2.08551587;
199
+ case 10: return 2.14268078;
200
+ default:
201
+ {
202
+ if (n >= THRESHOLD_EXACT_S)
203
+ return 3;
204
+ else
205
+ return expected_separation_depth_hotstart((double)2.14268078, (size_t)10, n);
206
+ }
207
+ }
208
+ }
209
+
210
+ double expected_separation_depth_hotstart(double curr, size_t n_curr, size_t n_final)
211
+ {
212
+ if (n_final >= 1360)
213
+ {
214
+ if (n_final >= THRESHOLD_EXACT_S)
215
+ return 3;
216
+ else if (n_final >= 40774)
217
+ return 2.999;
218
+ else if (n_final >= 18844)
219
+ return 2.998;
220
+ else if (n_final >= 11956)
221
+ return 2.997;
222
+ else if (n_final >= 8643)
223
+ return 2.996;
224
+ else if (n_final >= 6713)
225
+ return 2.995;
226
+ else if (n_final >= 4229)
227
+ return 2.9925;
228
+ else if (n_final >= 3040)
229
+ return 2.99;
230
+ else if (n_final >= 2724)
231
+ return 2.989;
232
+ else if (n_final >= 1902)
233
+ return 2.985;
234
+ else if (n_final >= 1360)
235
+ return 2.98;
236
+
237
+ /* Note on the chosen precision: when calling it on smaller sample sizes,
238
+ the standard error of the separation depth will be larger, thus it's less
239
+ critical to get it right down to the smallest possible precision, while for
240
+ larger samples the standard error of the separation depth will be smaller */
241
+ }
242
+
243
+ for (size_t i = n_curr + 1; i <= n_final; i++)
244
+ curr += (-curr * (double)i + 3. * (double)i - 4.) / ((double)i * ((double)(i-1)));
245
+ return curr;
246
+ }
247
+
248
+ /* linear interpolation */
249
+ template <class ldouble_safe>
250
+ double expected_separation_depth(ldouble_safe n)
251
+ {
252
+ if (n >= THRESHOLD_EXACT_S)
253
+ return 3;
254
+ double s_l = expected_separation_depth((size_t) std::floor(n));
255
+ ldouble_safe u = std::ceil(n);
256
+ double s_u = s_l + (-s_l * u + 3. * u - 4.) / (u * (u - 1.));
257
+ double diff = n - std::floor(n);
258
+ return s_l + diff * s_u;
259
+ }
260
+
261
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n, double counter[], double exp_remainder)
262
+ {
263
+ size_t i, j;
264
+ size_t ncomb = calc_ncomb(n);
265
+ if (exp_remainder <= 1)
266
+ for (size_t el1 = st; el1 < end; el1++)
267
+ {
268
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
269
+ {
270
+ // counter[i * (n - (i+1)/2) + j - i - 1]++; /* beaware integer division */
271
+ i = ix_arr[el1]; j = ix_arr[el2];
272
+ counter[ix_comb(i, j, n, ncomb)]++;
273
+ }
274
+ }
275
+ else
276
+ for (size_t el1 = st; el1 < end; el1++)
277
+ {
278
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
279
+ {
280
+ i = ix_arr[el1]; j = ix_arr[el2];
281
+ counter[ix_comb(i, j, n, ncomb)] += exp_remainder;
282
+ }
283
+ }
284
+ }
285
+
286
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n,
287
+ double *restrict counter, double *restrict weights, double exp_remainder)
288
+ {
289
+ size_t i, j;
290
+ size_t ncomb = calc_ncomb(n);
291
+ if (exp_remainder <= 1)
292
+ for (size_t el1 = st; el1 < end; el1++)
293
+ {
294
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
295
+ {
296
+ i = ix_arr[el1]; j = ix_arr[el2];
297
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j];
298
+ }
299
+ }
300
+ else
301
+ for (size_t el1 = st; el1 < end; el1++)
302
+ {
303
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
304
+ {
305
+ i = ix_arr[el1]; j = ix_arr[el2];
306
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j] * exp_remainder;
307
+ }
308
+ }
309
+ }
310
+
311
+ /* Note to self: don't try merge this into a template with the one above, as the other one has 'restrict' qualifier */
312
+ void increase_comb_counter(size_t ix_arr[], size_t st, size_t end, size_t n,
313
+ double counter[], hashed_map<size_t, double> &weights, double exp_remainder)
314
+ {
315
+ size_t i, j;
316
+ size_t ncomb = calc_ncomb(n);
317
+ if (exp_remainder <= 1)
318
+ for (size_t el1 = st; el1 < end; el1++)
319
+ {
320
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
321
+ {
322
+ i = ix_arr[el1]; j = ix_arr[el2];
323
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j];
324
+ }
325
+ }
326
+ else
327
+ for (size_t el1 = st; el1 < end; el1++)
328
+ {
329
+ for (size_t el2 = el1 + 1; el2 <= end; el2++)
330
+ {
331
+ i = ix_arr[el1]; j = ix_arr[el2];
332
+ counter[ix_comb(i, j, n, ncomb)] += weights[i] * weights[j] * exp_remainder;
333
+ }
334
+ }
335
+ }
336
+
337
+ void increase_comb_counter_in_groups(size_t ix_arr[], size_t st, size_t end, size_t split_ix, size_t n,
338
+ double counter[], double exp_remainder)
339
+ {
340
+ size_t *ptr_split_ix = std::lower_bound(ix_arr + st, ix_arr + end + 1, split_ix);
341
+ size_t n_group = std::distance(ix_arr + st, ptr_split_ix);
342
+ n = n - split_ix;
343
+
344
+ if (exp_remainder <= 1)
345
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
346
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
347
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]++;
348
+ else
349
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
350
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
351
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix] += exp_remainder;
352
+ }
353
+
354
+ void increase_comb_counter_in_groups(size_t ix_arr[], size_t st, size_t end, size_t split_ix, size_t n,
355
+ double *restrict counter, double *restrict weights, double exp_remainder)
356
+ {
357
+ size_t *ptr_split_ix = std::lower_bound(ix_arr + st, ix_arr + end + 1, split_ix);
358
+ size_t n_group = std::distance(ix_arr + st, ptr_split_ix);
359
+ n = n - split_ix;
360
+
361
+ if (exp_remainder <= 1)
362
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
363
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
364
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]
365
+ +=
366
+ weights[ix_arr[ix1]] * weights[ix_arr[ix2]];
367
+ else
368
+ for (size_t ix1 = st; ix1 < st + n_group; ix1++)
369
+ for (size_t ix2 = st + n_group; ix2 <= end; ix2++)
370
+ counter[ix_arr[ix1] * n + ix_arr[ix2] - split_ix]
371
+ +=
372
+ weights[ix_arr[ix1]] * weights[ix_arr[ix2]] * exp_remainder;
373
+ }
374
+
375
+ void tmat_to_dense(double *restrict tmat, double *restrict dmat, size_t n, double fill_diag)
376
+ {
377
+ size_t ncomb = calc_ncomb(n);
378
+ for (size_t i = 0; i < (n-1); i++)
379
+ {
380
+ for (size_t j = i + 1; j < n; j++)
381
+ {
382
+ // dmat[i + j * n] = dmat[j + i * n] = tmat[i * (n - (i+1)/2) + j - i - 1];
383
+ dmat[i + j * n] = dmat[j + i * n] = tmat[ix_comb(i, j, n, ncomb)];
384
+ }
385
+ }
386
+ for (size_t i = 0; i < n; i++)
387
+ dmat[i + i * n] = fill_diag;
388
+ }
389
+
390
+ template <class real_t>
391
+ void build_btree_sampler(std::vector<double> &btree_weights, real_t *restrict sample_weights,
392
+ size_t nrows, size_t &restrict log2_n, size_t &restrict btree_offset)
393
+ {
394
+ /* build a perfectly-balanced binary search tree in which each node will
395
+ hold the sum of the weights of its children */
396
+ log2_n = log2ceil(nrows);
397
+ if (btree_weights.empty())
398
+ btree_weights.resize(pow2(log2_n + 1), 0);
399
+ else
400
+ btree_weights.assign(btree_weights.size(), 0);
401
+ btree_offset = pow2(log2_n) - 1;
402
+
403
+ for (size_t ix = 0; ix < nrows; ix++)
404
+ btree_weights[ix + btree_offset] = std::fmax(0., sample_weights[ix]);
405
+ for (size_t ix = btree_weights.size() - 1; ix > 0; ix--)
406
+ btree_weights[ix_parent(ix)] += btree_weights[ix];
407
+
408
+ if (std::isnan(btree_weights[0]) || btree_weights[0] <= 0)
409
+ {
410
+ print_errmsg("Numeric precision error with sample weights, will not use them.\n");
411
+ log2_n = 0;
412
+ btree_weights.clear();
413
+ btree_weights.shrink_to_fit();
414
+ }
415
+ }
416
+
417
+ template <class real_t, class ldouble_safe>
418
+ void sample_random_rows(std::vector<size_t> &restrict ix_arr, size_t nrows, bool with_replacement,
419
+ RNG_engine &rnd_generator, std::vector<size_t> &restrict ix_all,
420
+ real_t *restrict sample_weights, std::vector<double> &restrict btree_weights,
421
+ size_t log2_n, size_t btree_offset, std::vector<bool> &is_repeated)
422
+ {
423
+ size_t ntake = ix_arr.size();
424
+
425
+ /* if with replacement, just generate random uniform numbers */
426
+ if (with_replacement)
427
+ {
428
+ if (sample_weights == NULL)
429
+ {
430
+ std::uniform_int_distribution<size_t> runif(0, nrows - 1);
431
+ for (size_t &ix : ix_arr)
432
+ ix = runif(rnd_generator);
433
+ }
434
+
435
+ else
436
+ {
437
+ std::discrete_distribution<size_t> runif(sample_weights, sample_weights + nrows);
438
+ for (size_t &ix : ix_arr)
439
+ ix = runif(rnd_generator);
440
+ }
441
+ }
442
+
443
+ /* if all the elements are needed, don't bother with any sampling */
444
+ else if (ntake == nrows)
445
+ {
446
+ std::iota(ix_arr.begin(), ix_arr.end(), (size_t)0);
447
+ }
448
+
449
+
450
+ /* if there are sample weights, use binary trees to keep track and update weight
451
+ https://stackoverflow.com/questions/57599509/c-random-non-repeated-integers-with-weights */
452
+ else if (sample_weights != NULL)
453
+ {
454
+ /* TODO: here could instead generate only 1 random number from zero to the full weight,
455
+ and then subtract from it as it goes down every level. Would have less precision
456
+ but should still work fine. */
457
+
458
+ double rnd_subrange, w_left;
459
+ double curr_subrange;
460
+ size_t curr_ix;
461
+ for (size_t &ix : ix_arr)
462
+ {
463
+ /* go down the tree by drawing a random number and
464
+ checking if it falls in the left or right ranges */
465
+ curr_ix = 0;
466
+ curr_subrange = btree_weights[0];
467
+ for (size_t lev = 0; lev < log2_n; lev++)
468
+ {
469
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
470
+ w_left = btree_weights[ix_child(curr_ix)];
471
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
472
+ curr_subrange = btree_weights[curr_ix];
473
+ }
474
+
475
+ /* finally, determine element to choose in this iteration */
476
+ ix = curr_ix - btree_offset;
477
+
478
+ /* now remove the weight of the chosen element */
479
+ btree_weights[curr_ix] = 0;
480
+ for (size_t lev = 0; lev < log2_n; lev++)
481
+ {
482
+ curr_ix = ix_parent(curr_ix);
483
+ btree_weights[curr_ix] = btree_weights[ix_child(curr_ix)]
484
+ + btree_weights[ix_child(curr_ix) + 1];
485
+ }
486
+ }
487
+ }
488
+
489
+ /* if no sample weights and not with replacement (most common case expected),
490
+ then use different algorithms depending on the sampled fraction */
491
+ else
492
+ {
493
+
494
+ /* if sampling a larger fraction, fill an array enumerating the rows, shuffle, and take first N */
495
+ if (ntake >= (nrows / 2))
496
+ {
497
+
498
+ if (ix_all.empty())
499
+ ix_all.resize(nrows);
500
+
501
+ /* in order for random seeds to always be reproducible, don't re-use previous shuffles */
502
+ std::iota(ix_all.begin(), ix_all.end(), (size_t)0);
503
+
504
+ /* If the number of sampled elements is large, do a full shuffle, enjoy simd-instructs when copying over */
505
+ if (ntake >= ((nrows * 3)/4))
506
+ {
507
+ std::shuffle(ix_all.begin(), ix_all.end(), rnd_generator);
508
+ ix_arr.assign(ix_all.begin(), ix_all.begin() + ntake);
509
+ }
510
+
511
+ /* otherwise, do only a partial shuffle (use Yates algorithm) and copy elements along the way */
512
+ else
513
+ {
514
+ size_t chosen;
515
+ for (size_t i = nrows - 1; i >= nrows - ntake; i--)
516
+ {
517
+ chosen = std::uniform_int_distribution<size_t>(0, i)(rnd_generator);
518
+ ix_arr[nrows - i - 1] = ix_all[chosen];
519
+ ix_all[chosen] = ix_all[i];
520
+ }
521
+ }
522
+
523
+ }
524
+
525
+ /* If the sample size is small, use Floyd's random sampling algorithm
526
+ https://stackoverflow.com/questions/2394246/algorithm-to-select-a-single-random-combination-of-values */
527
+ else
528
+ {
529
+
530
+ size_t candidate;
531
+
532
+ /* if the sample size is relatively large, use a temporary boolean vector */
533
+ if (((ldouble_safe)ntake / (ldouble_safe)nrows) > (1. / 50.))
534
+ {
535
+
536
+ if (is_repeated.empty())
537
+ is_repeated.resize(nrows, false);
538
+ else
539
+ is_repeated.assign(is_repeated.size(), false);
540
+
541
+ for (size_t rnd_ix = nrows - ntake; rnd_ix < nrows; rnd_ix++)
542
+ {
543
+ candidate = std::uniform_int_distribution<size_t>(0, rnd_ix)(rnd_generator);
544
+ if (is_repeated[candidate])
545
+ {
546
+ ix_arr[ntake - (nrows - rnd_ix)] = rnd_ix;
547
+ is_repeated[rnd_ix] = true;
548
+ }
549
+
550
+ else
551
+ {
552
+ ix_arr[ntake - (nrows - rnd_ix)] = candidate;
553
+ is_repeated[candidate] = true;
554
+ }
555
+ }
556
+
557
+ }
558
+
559
+ /* if the sample size is very small, use an unordered set */
560
+ else
561
+ {
562
+
563
+ hashed_set<size_t> repeated_set;
564
+ repeated_set.reserve(ntake);
565
+ for (size_t rnd_ix = nrows - ntake; rnd_ix < nrows; rnd_ix++)
566
+ {
567
+ candidate = std::uniform_int_distribution<size_t>(0, rnd_ix)(rnd_generator);
568
+ if (repeated_set.find(candidate) == repeated_set.end()) /* TODO: switch to C++20 'contains' */
569
+ {
570
+ ix_arr[ntake - (nrows - rnd_ix)] = candidate;
571
+ repeated_set.insert(candidate);
572
+ }
573
+
574
+ else
575
+ {
576
+ ix_arr[ntake - (nrows - rnd_ix)] = rnd_ix;
577
+ repeated_set.insert(rnd_ix);
578
+ }
579
+ }
580
+
581
+ }
582
+
583
+ }
584
+
585
+ }
586
+ }
587
+
588
+ /* https://stackoverflow.com/questions/57599509/c-random-non-repeated-integers-with-weights */
589
+ template <class real_t>
590
+ void weighted_shuffle(size_t *restrict outp, size_t n, real_t *restrict weights, double *restrict buffer_arr, RNG_engine &rnd_generator)
591
+ {
592
+ /* determine smallest power of two that is larger than N */
593
+ size_t tree_levels = log2ceil(n);
594
+
595
+ /* initialize vector with place-holders for perfectly-balanced tree */
596
+ std::fill(buffer_arr, buffer_arr + pow2(tree_levels + 1), (double)0);
597
+
598
+ /* compute sums for the tree leaves at each node */
599
+ size_t offset = pow2(tree_levels) - 1;
600
+ for (size_t ix = 0; ix < n; ix++) {
601
+ buffer_arr[ix + offset] = std::fmax(0., weights[ix]);
602
+ }
603
+ for (size_t ix = pow2(tree_levels+1) - 1; ix > 0; ix--) {
604
+ buffer_arr[ix_parent(ix)] += buffer_arr[ix];
605
+ }
606
+
607
+ /* if the weights are invalid, produce an unweighted shuffle */
608
+ if (std::isnan(buffer_arr[0]) || buffer_arr[0] <= 0)
609
+ {
610
+ std::iota(outp, outp + n, (size_t)0);
611
+ std::shuffle(outp, outp + n, rnd_generator);
612
+ return;
613
+ }
614
+
615
+ /* sample according to uniform distribution */
616
+ double rnd_subrange, w_left;
617
+ double curr_subrange;
618
+ size_t curr_ix;
619
+
620
+ for (size_t el = 0; el < n; el++)
621
+ {
622
+ /* go down the tree by drawing a random number and
623
+ checking if it falls in the left or right sub-ranges */
624
+ curr_ix = 0;
625
+ curr_subrange = buffer_arr[0];
626
+ for (size_t lev = 0; lev < tree_levels; lev++)
627
+ {
628
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
629
+ w_left = buffer_arr[ix_child(curr_ix)];
630
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
631
+ curr_subrange = buffer_arr[curr_ix];
632
+ }
633
+
634
+ /* finally, add element from this iteration */
635
+ outp[el] = curr_ix - offset;
636
+
637
+ /* now remove the weight of the chosen element */
638
+ buffer_arr[curr_ix] = 0;
639
+ for (size_t lev = 0; lev < tree_levels; lev++)
640
+ {
641
+ curr_ix = ix_parent(curr_ix);
642
+ buffer_arr[curr_ix] = buffer_arr[ix_child(curr_ix)]
643
+ + buffer_arr[ix_child(curr_ix) + 1];
644
+ }
645
+ }
646
+ }
647
+
648
+ /* Goualard, Frédéric. "Drawing random floating-point numbers from an interval."
649
+ ACM Transactions on Modeling and Computer Simulation (TOMACS) 32.3 (2022): 1-24. */
650
+ [[gnu::flatten]]
651
+ double sample_random_uniform(double xmin, double xmax, RNG_engine &rng) noexcept
652
+ {
653
+ const double random_unit = UniformUnitInterval(0, 1)(rng);
654
+ const double half_min = 0.5 * xmin;
655
+ const double half_max = 0.5 * xmax;
656
+ double out = 2. * (half_min + random_unit * (half_max - half_min));
657
+ if (unlikely(out >= xmax)) {
658
+ if (unlikely(xmax == xmin)) return xmin;
659
+ out = std::nextafter(xmax, xmin);
660
+ }
661
+ out = std::fmax(out, xmin);
662
+ return out;
663
+ }
664
+
665
+ template <class ldouble_safe>
666
+ template <class other_t>
667
+ ColumnSampler<ldouble_safe>& ColumnSampler<ldouble_safe>::operator=(const ColumnSampler<other_t> &other)
668
+ {
669
+ this->col_indices = other.col_indices;
670
+ this->tree_weights = other.tree_weights;
671
+ this->curr_pos = other.curr_pos;
672
+ this->curr_col = other.curr_col;
673
+ this->last_given = other.last_given;
674
+ this->n_cols = other.n_cols;
675
+ this->tree_levels = other.tree_levels;
676
+ this->offset = other.offset;
677
+ this->n_dropped = other.n_dropped;
678
+ return *this;
679
+ }
680
+
681
+ /* This one samples with replacement. When using weights, the algorithm is the
682
+ same as for the row sampler, but keeping the weights after taking each iteration. */
683
+ /* TODO: this column sampler could use coroutines from C++20 once compilers implement them. */
684
+ template <class ldouble_safe>
685
+ template <class real_t>
686
+ void ColumnSampler<ldouble_safe>::initialize(real_t weights[], size_t n_cols)
687
+ {
688
+ this->n_cols = n_cols;
689
+ this->tree_levels = log2ceil(n_cols);
690
+ if (this->tree_weights.empty())
691
+ this->tree_weights.resize(pow2(this->tree_levels + 1), 0);
692
+ else {
693
+ if (this->tree_weights.size() != pow2(this->tree_levels + 1))
694
+ this->tree_weights.resize(this->tree_levels);
695
+ std::fill(this->tree_weights.begin(), this->tree_weights.end(), 0.);
696
+ }
697
+
698
+ /* compute sums for the tree leaves at each node */
699
+ this->offset = pow2(this->tree_levels) - 1;
700
+ for (size_t ix = 0; ix < this->n_cols; ix++)
701
+ this->tree_weights[ix + this->offset] = std::fmax(0., weights[ix]);
702
+ for (size_t ix = this->tree_weights.size() - 1; ix > 0; ix--)
703
+ this->tree_weights[ix_parent(ix)] += this->tree_weights[ix];
704
+
705
+ /* if the weights are invalid, make it an unweighted sampler */
706
+ if (unlikely(std::isnan(this->tree_weights[0]) || this->tree_weights[0] <= 0))
707
+ {
708
+ this->drop_weights();
709
+ }
710
+
711
+ this->n_dropped = 0;
712
+ }
713
+
714
+ template <class ldouble_safe>
715
+ void ColumnSampler<ldouble_safe>::drop_weights()
716
+ {
717
+ this->tree_weights.clear();
718
+ this->tree_weights.shrink_to_fit();
719
+ this->initialize(n_cols);
720
+ this->n_dropped = 0;
721
+ }
722
+
723
+ template <class ldouble_safe>
724
+ bool ColumnSampler<ldouble_safe>::has_weights()
725
+ {
726
+ return !this->tree_weights.empty();
727
+ }
728
+
729
+ template <class ldouble_safe>
730
+ void ColumnSampler<ldouble_safe>::initialize(size_t n_cols)
731
+ {
732
+ if (!this->has_weights())
733
+ {
734
+ this->n_cols = n_cols;
735
+ this->curr_pos = n_cols;
736
+ this->col_indices.resize(n_cols);
737
+ std::iota(this->col_indices.begin(), this->col_indices.end(), (size_t)0);
738
+ }
739
+ }
740
+
741
+ /* TODO: this one should instead call the same function for sampling rows,
742
+ and should be done at the time of initialization so as to avoid allocating
743
+ and filling the whole array. That way it'd be faster and use less memory. */
744
+ template <class ldouble_safe>
745
+ void ColumnSampler<ldouble_safe>::leave_m_cols(size_t m, RNG_engine &rnd_generator)
746
+ {
747
+ if (m == 0 || m >= this->n_cols)
748
+ return;
749
+
750
+ if (!this->has_weights())
751
+ {
752
+ size_t chosen;
753
+ if (m <= this->n_cols / 4)
754
+ {
755
+ for (this->curr_pos = 0; this->curr_pos < m; this->curr_pos++)
756
+ {
757
+ chosen = std::uniform_int_distribution<size_t>(0, this->n_cols - this->curr_pos - 1)(rnd_generator);
758
+ std::swap(this->col_indices[this->curr_pos + chosen], this->col_indices[this->curr_pos]);
759
+ }
760
+ }
761
+
762
+ else if ((ldouble_safe)m >= (ldouble_safe)(3./4.) * (ldouble_safe)this->n_cols)
763
+ {
764
+ for (this->curr_pos = this->n_cols-1; this->curr_pos > this->n_cols - m; this->curr_pos--)
765
+ {
766
+ chosen = std::uniform_int_distribution<size_t>(0, this->curr_pos)(rnd_generator);
767
+ std::swap(this->col_indices[chosen], this->col_indices[this->curr_pos]);
768
+ }
769
+ this->curr_pos = m;
770
+ }
771
+
772
+ else
773
+ {
774
+ std::shuffle(this->col_indices.begin(), this->col_indices.end(), rnd_generator);
775
+ this->curr_pos = m;
776
+ }
777
+ }
778
+
779
+ else
780
+ {
781
+ std::vector<double> curr_weights = this->tree_weights;
782
+ std::fill(this->tree_weights.begin(), this->tree_weights.end(), 0.);
783
+ double rnd_subrange, w_left;
784
+ double curr_subrange;
785
+ size_t curr_ix;
786
+
787
+ for (size_t col = 0; col < m; col++)
788
+ {
789
+ curr_ix = 0;
790
+ curr_subrange = curr_weights[0];
791
+ if (curr_subrange <= 0)
792
+ {
793
+ if (col == 0)
794
+ {
795
+ this->drop_weights();
796
+ return;
797
+ }
798
+
799
+ else
800
+ {
801
+ m = col;
802
+ goto rebuild_tree;
803
+ }
804
+ }
805
+
806
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
807
+ {
808
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
809
+ w_left = curr_weights[ix_child(curr_ix)];
810
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
811
+ curr_subrange = curr_weights[curr_ix];
812
+ }
813
+
814
+ this->tree_weights[curr_ix] = curr_weights[curr_ix];
815
+
816
+ /* now remove the weight of the chosen element */
817
+ curr_weights[curr_ix] = 0;
818
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
819
+ {
820
+ curr_ix = ix_parent(curr_ix);
821
+ curr_weights[curr_ix] = curr_weights[ix_child(curr_ix)]
822
+ + curr_weights[ix_child(curr_ix) + 1];
823
+ }
824
+ }
825
+
826
+ /* rebuild the tree after getting new weights */
827
+ rebuild_tree:
828
+ for (size_t ix = this->tree_weights.size() - 1; ix > 0; ix--)
829
+ this->tree_weights[ix_parent(ix)] += this->tree_weights[ix];
830
+
831
+ this->n_dropped = this->n_cols - m;
832
+ }
833
+ }
834
+
835
+ template <class ldouble_safe>
836
+ void ColumnSampler<ldouble_safe>::drop_col(size_t col, size_t nobs_left)
837
+ {
838
+ if (!this->has_weights())
839
+ {
840
+ if (this->col_indices[this->last_given] == col)
841
+ {
842
+ std::swap(this->col_indices[this->last_given], this->col_indices[--this->curr_pos]);
843
+ }
844
+
845
+ else if (this->curr_pos > 4*nobs_left)
846
+ {
847
+ return;
848
+ }
849
+
850
+ else
851
+ {
852
+ for (size_t ix = 0; ix < this->curr_pos; ix++)
853
+ {
854
+ if (this->col_indices[ix] == col)
855
+ {
856
+ std::swap(this->col_indices[ix], this->col_indices[--this->curr_pos]);
857
+ break;
858
+ }
859
+ }
860
+ }
861
+
862
+ if (this->curr_col) this->curr_col--;
863
+ }
864
+
865
+ else
866
+ {
867
+ this->n_dropped++;
868
+ size_t curr_ix = col + this->offset;
869
+ this->tree_weights[curr_ix] = 0.;
870
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
871
+ {
872
+ curr_ix = ix_parent(curr_ix);
873
+ this->tree_weights[curr_ix] = this->tree_weights[ix_child(curr_ix)]
874
+ + this->tree_weights[ix_child(curr_ix) + 1];
875
+ }
876
+ }
877
+ }
878
+
879
+ template <class ldouble_safe>
880
+ void ColumnSampler<ldouble_safe>::drop_col(size_t col)
881
+ {
882
+ this->drop_col(col, SIZE_MAX);
883
+ }
884
+
885
+ /* to be used exclusively when initializing the density calculator,
886
+ and only when 'col_indices' is a straight range with no dropped columns */
887
+ template <class ldouble_safe>
888
+ void ColumnSampler<ldouble_safe>::drop_from_tail(size_t col)
889
+ {
890
+ std::swap(this->col_indices[col], this->col_indices[--this->curr_pos]);
891
+ }
892
+
893
+ template <class ldouble_safe>
894
+ void ColumnSampler<ldouble_safe>::prepare_full_pass()
895
+ {
896
+ this->curr_col = 0;
897
+
898
+ if (this->has_weights())
899
+ {
900
+ if (this->col_indices.size() < this->n_cols)
901
+ this->col_indices.resize(this->n_cols);
902
+ this->curr_pos = 0;
903
+ for (size_t col = 0; col < this->n_cols; col++)
904
+ {
905
+ if (this->tree_weights[col + this->offset] > 0)
906
+ this->col_indices[this->curr_pos++] = col;
907
+ }
908
+ }
909
+ }
910
+
911
+ template <class ldouble_safe>
912
+ bool ColumnSampler<ldouble_safe>::sample_col(size_t &col, RNG_engine &rnd_generator)
913
+ {
914
+ if (!this->has_weights())
915
+ {
916
+ switch(this->curr_pos)
917
+ {
918
+ case 0: return false;
919
+ case 1:
920
+ {
921
+ this->last_given = 0;
922
+ col = this->col_indices[0];
923
+ return true;
924
+ }
925
+ default:
926
+ {
927
+ this->last_given = std::uniform_int_distribution<size_t>(0, this->curr_pos-1)(rnd_generator);
928
+ col = this->col_indices[this->last_given];
929
+ return true;
930
+ }
931
+ }
932
+ }
933
+
934
+ else
935
+ {
936
+ /* TODO: here could instead generate only 1 random number from zero to the full weight,
937
+ and then subtract from it as it goes down every level. Would have less precision
938
+ but should still work fine. */
939
+ size_t curr_ix = 0;
940
+ double rnd_subrange, w_left;
941
+ double curr_subrange = this->tree_weights[0];
942
+ if (curr_subrange <= 0)
943
+ return false;
944
+
945
+ for (size_t lev = 0; lev < tree_levels; lev++)
946
+ {
947
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
948
+ w_left = this->tree_weights[ix_child(curr_ix)];
949
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
950
+ curr_subrange = this->tree_weights[curr_ix];
951
+ }
952
+
953
+ col = curr_ix - this->offset;
954
+ return true;
955
+ }
956
+ }
957
+
958
+ template <class ldouble_safe>
959
+ bool ColumnSampler<ldouble_safe>::sample_col(size_t &col)
960
+ {
961
+ if (this->curr_pos == this->curr_col || this->curr_pos == 0)
962
+ return false;
963
+ this->last_given = this->curr_col;
964
+ col = this->col_indices[this->curr_col++];
965
+ return true;
966
+ }
967
+
968
+ template <class ldouble_safe>
969
+ void ColumnSampler<ldouble_safe>::shuffle_remainder(RNG_engine &rnd_generator)
970
+ {
971
+ if (!this->has_weights())
972
+ {
973
+ this->prepare_full_pass();
974
+ std::shuffle(this->col_indices.begin(),
975
+ this->col_indices.begin() + this->curr_pos,
976
+ rnd_generator);
977
+ }
978
+
979
+ else
980
+ {
981
+ if (this->tree_weights[0] <= 0)
982
+ return;
983
+ std::vector<double> curr_weights = this->tree_weights;
984
+ this->curr_pos = 0;
985
+ this->curr_col = 0;
986
+
987
+ if (this->col_indices.size() < this->n_cols)
988
+ this->col_indices.resize(this->n_cols);
989
+
990
+ double rnd_subrange, w_left;
991
+ double curr_subrange;
992
+ size_t curr_ix;
993
+
994
+ for (this->curr_pos = 0; this->curr_pos < this->n_cols; this->curr_pos++)
995
+ {
996
+ curr_ix = 0;
997
+ curr_subrange = curr_weights[0];
998
+ if (curr_subrange <= 0)
999
+ return;
1000
+
1001
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
1002
+ {
1003
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
1004
+ w_left = curr_weights[ix_child(curr_ix)];
1005
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
1006
+ curr_subrange = curr_weights[curr_ix];
1007
+ }
1008
+
1009
+ /* finally, add element from this iteration */
1010
+ this->col_indices[this->curr_pos] = curr_ix - this->offset;
1011
+
1012
+ /* now remove the weight of the chosen element */
1013
+ curr_weights[curr_ix] = 0;
1014
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
1015
+ {
1016
+ curr_ix = ix_parent(curr_ix);
1017
+ curr_weights[curr_ix] = curr_weights[ix_child(curr_ix)]
1018
+ + curr_weights[ix_child(curr_ix) + 1];
1019
+ }
1020
+ }
1021
+ }
1022
+ }
1023
+
1024
+ template <class ldouble_safe>
1025
+ size_t ColumnSampler<ldouble_safe>::get_remaining_cols()
1026
+ {
1027
+ if (!this->has_weights())
1028
+ return this->curr_pos;
1029
+ else
1030
+ return this->n_cols - this->n_dropped;
1031
+ }
1032
+
1033
+ template <class ldouble_safe>
1034
+ void ColumnSampler<ldouble_safe>::get_array_remaining_cols(std::vector<size_t> &restrict cols)
1035
+ {
1036
+ if (!this->has_weights())
1037
+ {
1038
+ cols.assign(this->col_indices.begin(), this->col_indices.begin() + this->curr_pos);
1039
+ std::sort(cols.begin(), cols.begin() + this->curr_pos);
1040
+ }
1041
+
1042
+ else
1043
+ {
1044
+ size_t n_rem = 0;
1045
+ for (size_t col = 0; col < this->n_cols; col++)
1046
+ {
1047
+ if (this->tree_weights[col + this->offset] > 0)
1048
+ {
1049
+ cols[n_rem++] = col;
1050
+ }
1051
+ }
1052
+ }
1053
+ }
1054
+
1055
+ template<class ldouble_safe, class real_t>
1056
+ bool SingleNodeColumnSampler<ldouble_safe, real_t>::initialize
1057
+ (
1058
+ double *restrict weights,
1059
+ std::vector<size_t> *col_indices,
1060
+ size_t curr_pos,
1061
+ size_t n_sample,
1062
+ bool backup_weights
1063
+ )
1064
+ {
1065
+ if (!curr_pos) return false;
1066
+
1067
+ this->col_indices = col_indices->data();
1068
+ this->curr_pos = curr_pos;
1069
+ this->n_left = this->curr_pos;
1070
+ this->weights_orig = weights;
1071
+
1072
+ if (n_sample > std::max(log2ceil(this->curr_pos), (size_t)3))
1073
+ {
1074
+ this->using_tree = true;
1075
+ this->backup_weights = false;
1076
+
1077
+ if (this->used_weights.empty()) {
1078
+ this->used_weights.reserve(col_indices->size());
1079
+ this->mapped_indices.reserve(col_indices->size());
1080
+ this->tree_weights.reserve(2 * col_indices->size());
1081
+ }
1082
+
1083
+ this->used_weights.resize(this->curr_pos);
1084
+ this->mapped_indices.resize(this->curr_pos);
1085
+
1086
+ for (size_t col = 0; col < this->curr_pos; col++) {
1087
+ this->mapped_indices[col] = this->col_indices[col];
1088
+ this->used_weights[col] = weights[this->col_indices[col]];
1089
+ if (!weights[this->col_indices[col]]) this->n_left--;
1090
+ }
1091
+
1092
+ this->tree_weights.resize(0);
1093
+ build_btree_sampler(this->tree_weights, this->used_weights.data(),
1094
+ this->curr_pos, this->tree_levels, this->offset);
1095
+
1096
+ this->n_inf = 0;
1097
+ if (std::isinf(this->tree_weights[0]))
1098
+ {
1099
+ if (this->mapped_inf_indices.empty())
1100
+ this->mapped_inf_indices.resize(this->curr_pos);
1101
+
1102
+ for (size_t col = 0; col < this->curr_pos; col++)
1103
+ {
1104
+ if (std::isinf(weights[this->col_indices[col]]))
1105
+ {
1106
+ this->mapped_inf_indices[this->n_inf++] = this->col_indices[col];
1107
+ weights[this->col_indices[col]] = 0;
1108
+ }
1109
+
1110
+ else
1111
+ {
1112
+ this->mapped_indices[col - this->n_inf] = this->col_indices[col];
1113
+ this->used_weights[col - this->n_inf] = weights[this->col_indices[col]];
1114
+ }
1115
+ }
1116
+
1117
+ this->tree_weights.resize(0);
1118
+ build_btree_sampler(this->tree_weights, this->used_weights.data(),
1119
+ this->curr_pos - this->n_inf, this->tree_levels, this->offset);
1120
+ }
1121
+
1122
+ this->used_weights.resize(0);
1123
+
1124
+ if (this->tree_weights[0] <= 0 && !this->n_inf)
1125
+ return false;
1126
+ }
1127
+
1128
+ else
1129
+ {
1130
+ this->using_tree = false;
1131
+ this->backup_weights = backup_weights;
1132
+
1133
+ if (this->backup_weights)
1134
+ {
1135
+ if (this->weights_own.empty())
1136
+ this->weights_own.resize(col_indices->size());
1137
+ this->weights_own.assign(weights, weights + this->curr_pos);
1138
+ }
1139
+
1140
+ this->cumw = 0;
1141
+ for (size_t col = 0; col < this->curr_pos; col++) {
1142
+ this->cumw += weights[this->col_indices[col]];
1143
+ if (!weights[this->col_indices[col]]) this->n_left--;
1144
+ }
1145
+
1146
+ if (std::isnan(this->cumw))
1147
+ throw std::runtime_error("NAs encountered. Try using a different value for 'missing_action'.\n");
1148
+
1149
+ /* if it's infinite, will choose among columns with infinite weight first */
1150
+ this->n_inf = 0;
1151
+ if (std::isinf(this->cumw))
1152
+ {
1153
+ if (this->inifinite_weights.empty())
1154
+ this->inifinite_weights.resize(col_indices->size());
1155
+ else
1156
+ this->inifinite_weights.assign(col_indices->size(), false);
1157
+
1158
+ this->cumw = 0;
1159
+ for (size_t col = 0; col < this->curr_pos; col++)
1160
+ {
1161
+ if (std::isinf(weights[this->col_indices[col]])) {
1162
+ this->n_inf++;
1163
+ this->inifinite_weights[this->col_indices[col]] = true;
1164
+ weights[this->col_indices[col]] = 0;
1165
+ }
1166
+
1167
+ else {
1168
+ this->cumw += weights[this->col_indices[col]];
1169
+ }
1170
+ }
1171
+ }
1172
+
1173
+ if (!this->cumw && !this->n_inf) return false;
1174
+ }
1175
+
1176
+ return true;
1177
+ }
1178
+
1179
+ template <class ldouble_safe, class real_t>
1180
+ bool SingleNodeColumnSampler<ldouble_safe, real_t>::sample_col(size_t &col_chosen, RNG_engine &rnd_generator)
1181
+ {
1182
+ if (!this->using_tree)
1183
+ {
1184
+ if (this->backup_weights)
1185
+ this->weights_orig = this->weights_own.data();
1186
+
1187
+ /* if there's infinites, choose uniformly at random from them */
1188
+ if (this->n_inf)
1189
+ {
1190
+ size_t chosen = std::uniform_int_distribution<size_t>(0, this->n_inf-1)(rnd_generator);
1191
+ size_t curr = 0;
1192
+ for (size_t col = 0; col < this->curr_pos; col++)
1193
+ {
1194
+ curr += inifinite_weights[this->col_indices[col]];
1195
+ if (curr == chosen)
1196
+ {
1197
+ col_chosen = this->col_indices[col];
1198
+ this->n_inf--;
1199
+ this->inifinite_weights[col_chosen] = false;
1200
+ this->n_left--;
1201
+ return true;
1202
+ }
1203
+ }
1204
+ assert(0);
1205
+ }
1206
+
1207
+ if (!this->n_left) return false;
1208
+
1209
+ /* due to the way this is calculated, there can be large roundoff errors and even negatives */
1210
+ if (this->cumw <= 0)
1211
+ {
1212
+ this->cumw = 0;
1213
+ for (size_t col = 0; col < this->curr_pos; col++)
1214
+ this->cumw += this->weights_orig[this->col_indices[col]];
1215
+ if (unlikely(this->cumw <= 0))
1216
+ unexpected_error();
1217
+ }
1218
+
1219
+ /* if there are no infinites, choose a column according to weight */
1220
+ ldouble_safe chosen = std::uniform_real_distribution<ldouble_safe>((ldouble_safe)0, this->cumw)(rnd_generator);
1221
+ ldouble_safe cumw_ = 0;
1222
+ for (size_t col = 0; col < this->curr_pos; col++)
1223
+ {
1224
+ cumw_ += this->weights_orig[this->col_indices[col]];
1225
+ if (cumw_ >= chosen)
1226
+ {
1227
+ col_chosen = this->col_indices[col];
1228
+ this->cumw -= this->weights_orig[col_chosen];
1229
+ this->weights_orig[col_chosen] = 0;
1230
+ this->n_left--;
1231
+ return true;
1232
+ }
1233
+ }
1234
+ col_chosen = this->col_indices[this->curr_pos-1];
1235
+ this->cumw -= this->weights_orig[col_chosen];
1236
+ this->weights_orig[col_chosen] = 0;
1237
+ this->n_left--;
1238
+ return true;
1239
+ }
1240
+
1241
+ else
1242
+ {
1243
+ /* if there's infinites, choose uniformly at random from them */
1244
+ if (this->n_inf)
1245
+ {
1246
+ size_t chosen = std::uniform_int_distribution<size_t>(0, this->n_inf-1)(rnd_generator);
1247
+ col_chosen = this->mapped_inf_indices[chosen];
1248
+ std::swap(this->mapped_inf_indices[chosen], this->mapped_inf_indices[--this->n_inf]);
1249
+ this->n_left--;
1250
+ return true;
1251
+ }
1252
+
1253
+ else
1254
+ {
1255
+ /* TODO: should standardize all these tree traversals into one.
1256
+ This one in particular could do with sampling only a single
1257
+ random number as it will not typically require exhausting all
1258
+ options like the usual column sampler. */
1259
+ if (!this->n_left) return false;
1260
+ size_t curr_ix = 0;
1261
+ double rnd_subrange, w_left;
1262
+ double curr_subrange = this->tree_weights[0];
1263
+ if (curr_subrange <= 0)
1264
+ return false;
1265
+
1266
+ for (size_t lev = 0; lev < tree_levels; lev++)
1267
+ {
1268
+ rnd_subrange = std::uniform_real_distribution<double>(0., curr_subrange)(rnd_generator);
1269
+ w_left = this->tree_weights[ix_child(curr_ix)];
1270
+ curr_ix = ix_child(curr_ix) + (rnd_subrange >= w_left);
1271
+ curr_subrange = this->tree_weights[curr_ix];
1272
+ }
1273
+
1274
+ col_chosen = this->mapped_indices[curr_ix - this->offset];
1275
+
1276
+ this->tree_weights[curr_ix] = 0.;
1277
+ for (size_t lev = 0; lev < this->tree_levels; lev++)
1278
+ {
1279
+ curr_ix = ix_parent(curr_ix);
1280
+ this->tree_weights[curr_ix] = this->tree_weights[ix_child(curr_ix)]
1281
+ + this->tree_weights[ix_child(curr_ix) + 1];
1282
+ }
1283
+
1284
+ this->n_left--;
1285
+ return true;
1286
+ }
1287
+ }
1288
+ }
1289
+
1290
+ template <class ldouble_safe, class real_t>
1291
+ void SingleNodeColumnSampler<ldouble_safe, real_t>::backup(SingleNodeColumnSampler &other, size_t ncols_tot)
1292
+ {
1293
+ other.n_inf = this->n_inf;
1294
+ other.n_left = this->n_left;
1295
+ other.using_tree = this->using_tree;
1296
+
1297
+ if (this->using_tree)
1298
+ {
1299
+ if (other.tree_weights.empty())
1300
+ {
1301
+ other.tree_weights.reserve(ncols_tot);
1302
+ other.mapped_inf_indices.reserve(ncols_tot);
1303
+ }
1304
+ other.tree_weights.assign(this->tree_weights.begin(), this->tree_weights.end());
1305
+ other.mapped_inf_indices.assign(this->mapped_inf_indices.begin(), this->mapped_inf_indices.end());
1306
+ }
1307
+
1308
+ else
1309
+ {
1310
+ other.cumw = this->cumw;
1311
+ if (this->backup_weights)
1312
+ {
1313
+ if (other.weights_own.empty())
1314
+ other.weights_own.reserve(ncols_tot);
1315
+
1316
+ other.weights_own.resize(this->n_left);
1317
+ for (size_t col = 0; col < this->n_left; col++)
1318
+ other.weights_own[col] = this->weights_own[this->col_indices[col]];
1319
+ }
1320
+
1321
+ if (this->inifinite_weights.size())
1322
+ {
1323
+ if (other.inifinite_weights.empty())
1324
+ other.inifinite_weights.reserve(ncols_tot);
1325
+
1326
+ other.inifinite_weights.resize(this->n_left);
1327
+ for (size_t col = 0; col < this->n_left; col++)
1328
+ other.inifinite_weights[col] = this->inifinite_weights[this->col_indices[col]];
1329
+ }
1330
+ }
1331
+ }
1332
+
1333
+ template <class ldouble_safe, class real_t>
1334
+ void SingleNodeColumnSampler<ldouble_safe, real_t>::restore(const SingleNodeColumnSampler &other)
1335
+ {
1336
+ this->n_inf = other.n_inf;
1337
+ this->n_left = other.n_left;
1338
+ this->using_tree = other.using_tree;
1339
+
1340
+ if (this->using_tree)
1341
+ {
1342
+ this->tree_weights.assign(other.tree_weights.begin(), other.tree_weights.end());
1343
+ this->mapped_inf_indices.assign(other.mapped_inf_indices.begin(), other.mapped_inf_indices.end());
1344
+ }
1345
+
1346
+ else
1347
+ {
1348
+ this->cumw = other.cumw;
1349
+ if (this->backup_weights)
1350
+ {
1351
+ for (size_t col = 0; col < this->n_left; col++)
1352
+ this->weights_own[this->col_indices[col]] = other.weights_own[col];
1353
+ }
1354
+
1355
+ if (this->inifinite_weights.size())
1356
+ {
1357
+ for (size_t col = 0; col < this->n_left; col++)
1358
+ this->inifinite_weights[this->col_indices[col]] = other.inifinite_weights[col];
1359
+ }
1360
+ }
1361
+ }
1362
+
1363
+ template <class ldouble_safe, class real_t>
1364
+ void DensityCalculator<ldouble_safe, real_t>::initialize(size_t max_depth, int max_categ, bool reserve_counts, ScoringMetric scoring_metric)
1365
+ {
1366
+ this->multipliers.reserve(max_depth+3);
1367
+ this->multipliers.clear();
1368
+ if (scoring_metric != AdjDensity)
1369
+ this->multipliers.push_back(0);
1370
+ else
1371
+ this->multipliers.push_back(1);
1372
+
1373
+ if (reserve_counts)
1374
+ {
1375
+ this->counts.resize(max_categ);
1376
+ }
1377
+ }
1378
+
1379
+ template <class ldouble_safe, class real_t>
1380
+ template <class InputData>
1381
+ void DensityCalculator<ldouble_safe, real_t>::initialize_bdens(const InputData &input_data,
1382
+ const ModelParams &model_params,
1383
+ std::vector<size_t> &ix_arr,
1384
+ ColumnSampler<ldouble_safe> &col_sampler)
1385
+ {
1386
+ this->fast_bratio = model_params.fast_bratio;
1387
+ if (this->fast_bratio)
1388
+ {
1389
+ this->multipliers.reserve(model_params.max_depth + 3);
1390
+ this->multipliers.push_back(0);
1391
+ }
1392
+
1393
+ if (input_data.range_low != NULL || input_data.ncat_ != NULL)
1394
+ {
1395
+ if (input_data.ncols_numeric)
1396
+ {
1397
+ this->queue_box.reserve(model_params.max_depth+3);
1398
+ this->box_low.assign(input_data.range_low, input_data.range_low + input_data.ncols_numeric);
1399
+ this->box_high.assign(input_data.range_high, input_data.range_high + input_data.ncols_numeric);
1400
+ }
1401
+
1402
+ if (input_data.ncols_categ)
1403
+ {
1404
+ this->queue_ncat.reserve(model_params.max_depth+2);
1405
+ this->ncat.assign(input_data.ncat_, input_data.ncat_ + input_data.ncols_categ);
1406
+ }
1407
+
1408
+ if (!this->fast_bratio)
1409
+ {
1410
+ if (input_data.ncols_numeric)
1411
+ {
1412
+ this->ranges.resize(input_data.ncols_numeric);
1413
+ for (size_t col = 0; col < input_data.ncols_numeric; col++)
1414
+ this->ranges[col] = this->box_high[col] - this->box_low[col];
1415
+ }
1416
+
1417
+ if (input_data.ncols_categ)
1418
+ {
1419
+ this->ncat_orig = this->ncat;
1420
+ }
1421
+ }
1422
+
1423
+ return;
1424
+ }
1425
+
1426
+ if (input_data.ncols_numeric)
1427
+ {
1428
+ this->queue_box.reserve(model_params.max_depth+3);
1429
+ this->box_low.resize(input_data.ncols_numeric);
1430
+ this->box_high.resize(input_data.ncols_numeric);
1431
+ if (!this->fast_bratio)
1432
+ this->ranges.resize(input_data.ncols_numeric);
1433
+ }
1434
+ if (input_data.ncols_categ)
1435
+ {
1436
+ this->queue_ncat.reserve(model_params.max_depth+2);
1437
+ }
1438
+ bool unsplittable = false;
1439
+
1440
+ size_t npresent = 0;
1441
+ std::vector<signed char> categ_present;
1442
+ if (input_data.ncols_categ)
1443
+ {
1444
+ categ_present.resize(input_data.max_categ);
1445
+ }
1446
+
1447
+
1448
+ col_sampler.prepare_full_pass();
1449
+ size_t col;
1450
+ while (col_sampler.sample_col(col))
1451
+ {
1452
+ if (col < input_data.ncols_numeric)
1453
+ {
1454
+ if (input_data.Xc_indptr != NULL)
1455
+ {
1456
+ get_range((size_t*)ix_arr.data(), (size_t)0, ix_arr.size()-(size_t)1, col,
1457
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1458
+ model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1459
+ }
1460
+
1461
+ else
1462
+ {
1463
+ get_range((size_t*)ix_arr.data(), input_data.numeric_data + input_data.nrows * col, (size_t)0, ix_arr.size()-(size_t)1,
1464
+ model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1465
+ }
1466
+
1467
+
1468
+ if (unsplittable)
1469
+ {
1470
+ this->box_low[col] = 0;
1471
+ this->box_high[col] = 0;
1472
+ if (!this->fast_bratio)
1473
+ this->ranges[col] = 0;
1474
+ col_sampler.drop_col(col);
1475
+ }
1476
+
1477
+ if (!this->fast_bratio)
1478
+ {
1479
+ this->ranges[col] = (ldouble_safe)this->box_high[col] - (ldouble_safe)this->box_low[col];
1480
+ this->ranges[col] = std::fmax(this->ranges[col], (ldouble_safe)0);
1481
+ }
1482
+ }
1483
+
1484
+ else
1485
+ {
1486
+ get_categs((size_t*)ix_arr.data(),
1487
+ input_data.categ_data + input_data.nrows * (col - input_data.ncols_numeric),
1488
+ (size_t)0, ix_arr.size()-(size_t)1, input_data.ncat[col],
1489
+ model_params.missing_action, categ_present.data(), npresent, unsplittable);
1490
+
1491
+ if (unsplittable)
1492
+ {
1493
+ this->ncat[col - input_data.ncols_numeric] = 1;
1494
+ col_sampler.drop_col(col);
1495
+ }
1496
+
1497
+ else
1498
+ {
1499
+ this->ncat[col - input_data.ncols_numeric] = npresent;
1500
+ }
1501
+ }
1502
+ }
1503
+
1504
+ if (!this->fast_bratio)
1505
+ this->ncat_orig = this->ncat;
1506
+ }
1507
+
1508
+ template<class ldouble_safe, class real_t>
1509
+ template <class InputData>
1510
+ void DensityCalculator<ldouble_safe, real_t>::initialize_bdens_ext(const InputData &input_data,
1511
+ const ModelParams &model_params,
1512
+ std::vector<size_t> &ix_arr,
1513
+ ColumnSampler<ldouble_safe> &col_sampler,
1514
+ bool col_sampler_is_fresh)
1515
+ {
1516
+ this->vals_ext_box.reserve(model_params.max_depth + 3);
1517
+ this->queue_ext_box.reserve(model_params.max_depth + 3);
1518
+ this->vals_ext_box.push_back(0);
1519
+
1520
+ if (input_data.range_low != NULL)
1521
+ {
1522
+ this->box_low.assign(input_data.range_low, input_data.range_low + input_data.ncols_numeric);
1523
+ this->box_high.assign(input_data.range_high, input_data.range_high + input_data.ncols_numeric);
1524
+ return;
1525
+ }
1526
+
1527
+ this->box_low.resize(input_data.ncols_numeric);
1528
+ this->box_high.resize(input_data.ncols_numeric);
1529
+ bool unsplittable = false;
1530
+
1531
+ /* TODO: find out if there's an optimal point for choosing one or the other loop
1532
+ when using 'leave_m_cols' and when using 'prob_pick_col_by_range', then fill in the
1533
+ lines that are commented out. */
1534
+ // if (!input_data.ncols_categ || model_params.ncols_per_tree < input_data.ncols_numeric)
1535
+ if (input_data.ncols_numeric)
1536
+ {
1537
+ col_sampler.prepare_full_pass();
1538
+ size_t col;
1539
+ while (col_sampler.sample_col(col))
1540
+ {
1541
+ if (col >= input_data.ncols_numeric)
1542
+ continue;
1543
+ if (input_data.Xc_indptr != NULL)
1544
+ {
1545
+ get_range((size_t*)ix_arr.data(), (size_t)0, ix_arr.size()-(size_t)1, col,
1546
+ input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1547
+ model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1548
+ }
1549
+
1550
+ else
1551
+ {
1552
+ get_range((size_t*)ix_arr.data(), input_data.numeric_data + input_data.nrows * col, (size_t)0, ix_arr.size()-(size_t)1,
1553
+ model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1554
+ }
1555
+
1556
+ if (unsplittable)
1557
+ {
1558
+ this->box_low[col] = 0;
1559
+ this->box_high[col] = 0;
1560
+ col_sampler.drop_col(col);
1561
+ }
1562
+ }
1563
+ }
1564
+
1565
+ // else if (input_data.ncols_numeric)
1566
+ // {
1567
+ // size_t n_unsplittable = 0;
1568
+ // std::vector<size_t> unsplittable_cols;
1569
+ // if (col_sampler_is_fresh && !col_sampler.has_weights())
1570
+ // unsplittable_cols.reserve(input_data.ncols_numeric);
1571
+
1572
+ // /* TODO: this will do unnecessary calculations when using 'leave_m_cols' */
1573
+ // for (size_t col = 0; col < input_data.ncols_numeric; col++)
1574
+ // {
1575
+ // if (input_data.Xc_indptr != NULL)
1576
+ // {
1577
+ // get_range((size_t*)ix_arr.data(), (size_t)0, ix_arr.size()-(size_t)1, col,
1578
+ // input_data.Xc, input_data.Xc_ind, input_data.Xc_indptr,
1579
+ // model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1580
+ // }
1581
+
1582
+ // else
1583
+ // {
1584
+ // get_range((size_t*)ix_arr.data(), input_data.numeric_data + input_data.nrows * col, (size_t)0, ix_arr.size()-(size_t)1,
1585
+ // model_params.missing_action, this->box_low[col], this->box_high[col], unsplittable);
1586
+ // }
1587
+
1588
+ // if (unsplittable)
1589
+ // {
1590
+ // this->box_low[col] = 0;
1591
+ // this->box_high[col] = 0;
1592
+ // n_unsplittable++;
1593
+ // if (col_sampler.has_weights())
1594
+ // col_sampler.drop_col(col);
1595
+ // else if (col_sampler_is_fresh)
1596
+ // unsplittable_cols.push_back(col);
1597
+ // }
1598
+ // }
1599
+
1600
+ // if (n_unsplittable && col_sampler_is_fresh && !col_sampler.has_weights())
1601
+ // {
1602
+ // #if (__cplusplus >= 202002L)
1603
+ // for (auto col : unsplittable_cols | std::views::reverse)
1604
+ // col_sampler.drop_from_tail(col);
1605
+ // #else
1606
+ // for (size_t inv_col = 0; inv_col < unsplittable_cols.size(); inv_col++)
1607
+ // {
1608
+ // size_t col = unsplittable_cols.size() - inv_col - 1;
1609
+ // col_sampler.drop_from_tail(unsplittable_cols[col]);
1610
+ // }
1611
+ // #endif
1612
+ // }
1613
+
1614
+ // else if (n_unsplittable > model_params.sample_size / 16 && !col_sampler_is_fresh && !col_sampler.has_weights())
1615
+ // {
1616
+ // /* TODO */
1617
+ // }
1618
+ // }
1619
+ }
1620
+
1621
+ template<class ldouble_safe, class real_t>
1622
+ void DensityCalculator<ldouble_safe, real_t>::push_density(double xmin, double xmax, double split_point)
1623
+ {
1624
+ if (std::isinf(xmax) || std::isinf(xmin) || std::isnan(xmin) || std::isnan(xmax) || std::isnan(split_point))
1625
+ {
1626
+ this->multipliers.push_back(0);
1627
+ return;
1628
+ }
1629
+
1630
+ double range = std::fmax(xmax - xmin, std::numeric_limits<double>::min());
1631
+ double dleft = std::fmax(split_point - xmin, std::numeric_limits<double>::min());
1632
+ double dright = std::fmax(xmax - split_point, std::numeric_limits<double>::min());
1633
+ double mult_left = std::log(dleft / range);
1634
+ double mult_right = std::log(dright / range);
1635
+ while (std::isinf(mult_left))
1636
+ {
1637
+ dleft = std::nextafter(dleft, (mult_left < 0)? HUGE_VAL : (-HUGE_VAL));
1638
+ mult_left = std::log(dleft / range);
1639
+ }
1640
+ while (std::isinf(mult_right))
1641
+ {
1642
+ dright = std::nextafter(dright, (mult_right < 0)? HUGE_VAL : (-HUGE_VAL));
1643
+ mult_right = std::log(dright / range);
1644
+ }
1645
+
1646
+ mult_left = std::isnan(mult_left)? 0 : mult_left;
1647
+ mult_right = std::isnan(mult_right)? 0 : mult_right;
1648
+
1649
+ ldouble_safe curr = this->multipliers.back();
1650
+ this->multipliers.push_back(curr + mult_right);
1651
+ this->multipliers.push_back(curr + mult_left);
1652
+ }
1653
+
1654
+ template<class ldouble_safe, class real_t>
1655
+ void DensityCalculator<ldouble_safe, real_t>::push_density(int n_left, int n_present)
1656
+ {
1657
+ this->push_density(0., (double)n_present, (double)n_left);
1658
+ }
1659
+
1660
+ /* For single category splits */
1661
+ template<class ldouble_safe, class real_t>
1662
+ void DensityCalculator<ldouble_safe, real_t>::push_density(size_t counts[], int ncat)
1663
+ {
1664
+ /* this one assumes 'categ_present' has entries 0/1 for missing/present */
1665
+ int n_present = 0;
1666
+ for (int cat = 0; cat < ncat; cat++)
1667
+ n_present += counts[cat] > 0;
1668
+ this->push_density(0., (double)n_present, 1.);
1669
+ }
1670
+
1671
+ /* For single category splits */
1672
+ template<class ldouble_safe, class real_t>
1673
+ void DensityCalculator<ldouble_safe, real_t>::push_density(int n_present)
1674
+ {
1675
+ this->push_density(0., (double)n_present, 1.);
1676
+ }
1677
+
1678
+ /* For binary categorical splits */
1679
+ template<class ldouble_safe, class real_t>
1680
+ void DensityCalculator<ldouble_safe, real_t>::push_density()
1681
+ {
1682
+ this->multipliers.push_back(0);
1683
+ this->multipliers.push_back(0);
1684
+ }
1685
+
1686
+ template<class ldouble_safe, class real_t>
1687
+ void DensityCalculator<ldouble_safe, real_t>::push_adj(double xmin, double xmax, double split_point, double pct_tree_left, ScoringMetric scoring_metric)
1688
+ {
1689
+ double range = std::fmax(xmax - xmin, std::numeric_limits<double>::min());
1690
+ double dleft = std::fmax(split_point - xmin, std::numeric_limits<double>::min());
1691
+ double dright = std::fmax(xmax - split_point, std::numeric_limits<double>::min());
1692
+ double chunk_left = dleft / range;
1693
+ double chunk_right = dright / range;
1694
+ if (std::isinf(xmax) || std::isinf(xmin) || std::isnan(xmin) || std::isnan(xmax) || std::isnan(split_point))
1695
+ {
1696
+ chunk_left = pct_tree_left;
1697
+ chunk_right = 1. - pct_tree_left;
1698
+ goto add_chunks;
1699
+ }
1700
+
1701
+ if (std::isnan(chunk_left) || std::isnan(chunk_right))
1702
+ {
1703
+ chunk_left = 0.5;
1704
+ chunk_right = 0.5;
1705
+ }
1706
+
1707
+ chunk_left = pct_tree_left / chunk_left;
1708
+ chunk_right = (1. - pct_tree_left) / chunk_right;
1709
+
1710
+ add_chunks:
1711
+ chunk_left = 2. / (1. + .5/chunk_left);
1712
+ chunk_right = 2. / (1. + .5/chunk_right);
1713
+ // chunk_left = 2. / (1. + 1./chunk_left);
1714
+ // chunk_right = 2. / (1. + 1./chunk_right);
1715
+ // chunk_left = 2. - std::exp2(1. - chunk_left);
1716
+ // chunk_right = 2. - std::exp2(1. - chunk_right);
1717
+
1718
+ ldouble_safe curr = this->multipliers.back();
1719
+ if (scoring_metric == AdjDepth)
1720
+ {
1721
+ this->multipliers.push_back(curr + chunk_right);
1722
+ this->multipliers.push_back(curr + chunk_left);
1723
+ }
1724
+
1725
+ else
1726
+ {
1727
+ this->multipliers.push_back(std::fmax(curr * chunk_right, (ldouble_safe)std::numeric_limits<double>::epsilon()));
1728
+ this->multipliers.push_back(std::fmax(curr * chunk_left, (ldouble_safe)std::numeric_limits<double>::epsilon()));
1729
+ }
1730
+ }
1731
+
1732
+ template<class ldouble_safe, class real_t>
1733
+ void DensityCalculator<ldouble_safe, real_t>::push_adj(signed char *restrict categ_present, size_t *restrict counts, int ncat, ScoringMetric scoring_metric)
1734
+ {
1735
+ /* this one assumes 'categ_present' has entries -1/0/1 for missing/right/left */
1736
+ int cnt_cat_left = 0;
1737
+ int cnt_cat = 0;
1738
+ size_t cnt = 0;
1739
+ size_t cnt_left = 0;
1740
+ for (int cat = 0; cat < ncat; cat++)
1741
+ {
1742
+ if (counts[cat] > 0)
1743
+ {
1744
+ cnt += counts[cat];
1745
+ cnt_cat_left += categ_present[cat];
1746
+ cnt_left += categ_present[cat]? counts[cat] : 0;
1747
+ cnt_cat++;
1748
+ }
1749
+ }
1750
+
1751
+ double pct_tree_left = (ldouble_safe)cnt_left / (ldouble_safe)cnt;
1752
+ this->push_adj(0., (double)cnt_cat, (double)cnt_cat_left, pct_tree_left, scoring_metric);
1753
+ }
1754
+
1755
+ /* For single category splits */
1756
+ template<class ldouble_safe, class real_t>
1757
+ void DensityCalculator<ldouble_safe, real_t>::push_adj(size_t *restrict counts, int ncat, int chosen_cat, ScoringMetric scoring_metric)
1758
+ {
1759
+ /* this one assumes 'categ_present' has entries 0/1 for missing/present */
1760
+ int cnt_cat = 0;
1761
+ size_t cnt = 0;
1762
+ for (int cat = 0; cat < ncat; cat++)
1763
+ {
1764
+ cnt += counts[cat];
1765
+ cnt_cat += counts[cat] > 0;
1766
+ }
1767
+
1768
+ double pct_tree_left = (ldouble_safe)counts[chosen_cat] / (ldouble_safe)cnt;
1769
+ this->push_adj(0., (double)cnt_cat, 1., pct_tree_left, scoring_metric);
1770
+ }
1771
+
1772
+ /* For binary categorical splits */
1773
+ template<class ldouble_safe, class real_t>
1774
+ void DensityCalculator<ldouble_safe, real_t>::push_adj(double pct_tree_left, ScoringMetric scoring_metric)
1775
+ {
1776
+ this->push_adj(0., 1., 0.5, pct_tree_left, scoring_metric);
1777
+ }
1778
+
1779
+ template<class ldouble_safe, class real_t>
1780
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens(double split_point, size_t col)
1781
+ {
1782
+ if (this->fast_bratio)
1783
+ this->push_bdens_fast_route(split_point, col);
1784
+ else
1785
+ this->push_bdens_internal(split_point, col);
1786
+ }
1787
+
1788
+ template<class ldouble_safe, class real_t>
1789
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_internal(double split_point, size_t col)
1790
+ {
1791
+ this->queue_box.push_back(this->box_high[col]);
1792
+ this->box_high[col] = split_point;
1793
+ }
1794
+
1795
+ template<class ldouble_safe, class real_t>
1796
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_fast_route(double split_point, size_t col)
1797
+ {
1798
+ ldouble_safe curr_range = (ldouble_safe)this->box_high[col] - (ldouble_safe)this->box_low[col];
1799
+ ldouble_safe fraction_left = ((ldouble_safe)split_point - (ldouble_safe)this->box_low[col]) / curr_range;
1800
+ ldouble_safe fraction_right = ((ldouble_safe)this->box_high[col] - (ldouble_safe)split_point) / curr_range;
1801
+ fraction_left = std::fmax(fraction_left, (ldouble_safe)std::numeric_limits<double>::min());
1802
+ fraction_left = std::fmin(fraction_left, (ldouble_safe)(1. - std::numeric_limits<double>::epsilon()));
1803
+ fraction_left = std::log(fraction_left);
1804
+ fraction_left += this->multipliers.back();
1805
+ fraction_right = std::fmax(fraction_right, (ldouble_safe)std::numeric_limits<double>::min());
1806
+ fraction_right = std::fmin(fraction_right, (ldouble_safe)(1. - std::numeric_limits<double>::epsilon()));
1807
+ fraction_right = std::log(fraction_right);
1808
+ fraction_right += this->multipliers.back();
1809
+ this->multipliers.push_back(fraction_right);
1810
+ this->multipliers.push_back(fraction_left);
1811
+
1812
+ this->push_bdens_internal(split_point, col);
1813
+ }
1814
+
1815
+ template<class ldouble_safe, class real_t>
1816
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens(int ncat_branch_left, size_t col)
1817
+ {
1818
+ if (this->fast_bratio)
1819
+ this->push_bdens_fast_route(ncat_branch_left, col);
1820
+ else
1821
+ this->push_bdens_internal(ncat_branch_left, col);
1822
+ }
1823
+
1824
+ template<class ldouble_safe, class real_t>
1825
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_internal(int ncat_branch_left, size_t col)
1826
+ {
1827
+ this->queue_ncat.push_back(this->ncat[col]);
1828
+ this->ncat[col] = ncat_branch_left;
1829
+ }
1830
+
1831
+ template<class ldouble_safe, class real_t>
1832
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_fast_route(int ncat_branch_left, size_t col)
1833
+ {
1834
+ double fraction_left = std::log((double)ncat_branch_left / this->ncat[col]);
1835
+ double fraction_right = std::log((double)(this->ncat[col] - ncat_branch_left) / this->ncat[col]);
1836
+ ldouble_safe curr = this->multipliers.back();
1837
+ this->multipliers.push_back(curr + fraction_right);
1838
+ this->multipliers.push_back(curr + fraction_left);
1839
+
1840
+ this->push_bdens_internal(ncat_branch_left, col);
1841
+ }
1842
+
1843
+ template<class ldouble_safe, class real_t>
1844
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens(const std::vector<signed char> &cat_split, size_t col)
1845
+ {
1846
+ if (this->fast_bratio)
1847
+ this->push_bdens_fast_route(cat_split, col);
1848
+ else
1849
+ this->push_bdens_internal(cat_split, col);
1850
+ }
1851
+
1852
+ template<class ldouble_safe, class real_t>
1853
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_internal(const std::vector<signed char> &cat_split, size_t col)
1854
+ {
1855
+ int ncat_branch_left = 0;
1856
+ for (auto el : cat_split)
1857
+ ncat_branch_left += el == 1;
1858
+ this->push_bdens_internal(ncat_branch_left, col);
1859
+ }
1860
+
1861
+ template<class ldouble_safe, class real_t>
1862
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_fast_route(const std::vector<signed char> &cat_split, size_t col)
1863
+ {
1864
+ int ncat_branch_left = 0;
1865
+ for (auto el : cat_split)
1866
+ ncat_branch_left += el == 1;
1867
+ this->push_bdens_fast_route(ncat_branch_left, col);
1868
+ }
1869
+
1870
+ template<class ldouble_safe, class real_t>
1871
+ void DensityCalculator<ldouble_safe, real_t>::push_bdens_ext(const IsoHPlane &hplane, const ModelParams &model_params)
1872
+ {
1873
+ double x1, x2;
1874
+ double xlow = 0, xhigh = 0;
1875
+ size_t col;
1876
+ size_t col_num = 0;
1877
+ size_t col_cat = 0;
1878
+
1879
+ for (size_t col_outer = 0; col_outer < hplane.col_num.size(); col_outer++)
1880
+ {
1881
+ switch (hplane.col_type[col_outer])
1882
+ {
1883
+ case Numeric:
1884
+ {
1885
+ col = hplane.col_num[col_outer];
1886
+ x1 = hplane.coef[col_num] * (this->box_low[col] - hplane.mean[col_num]);
1887
+ x2 = hplane.coef[col_num] * (this->box_high[col] - hplane.mean[col_num]);
1888
+ xlow += std::fmin(x1, x2);
1889
+ xhigh += std::fmax(x1, x2);
1890
+ break;
1891
+ }
1892
+
1893
+ case Categorical:
1894
+ {
1895
+ switch (model_params.cat_split_type)
1896
+ {
1897
+ case SingleCateg:
1898
+ {
1899
+ xlow += std::fmin(hplane.fill_new[col_cat], 0.);
1900
+ xhigh += std::fmax(hplane.fill_new[col_cat], 0.);
1901
+ break;
1902
+ }
1903
+
1904
+ case SubSet:
1905
+ {
1906
+ xlow += *std::min_element(hplane.cat_coef[col_cat].begin(), hplane.cat_coef[col_cat].end());
1907
+ xhigh += *std::max_element(hplane.cat_coef[col_cat].begin(), hplane.cat_coef[col_cat].end());
1908
+ break;
1909
+ }
1910
+ }
1911
+ break;
1912
+ }
1913
+
1914
+ default:
1915
+ {
1916
+ assert(0);
1917
+ }
1918
+ }
1919
+ }
1920
+
1921
+ double chunk_left;
1922
+ double chunk_right;
1923
+ double xdiff = xhigh - xlow;
1924
+
1925
+ if (model_params.scoring_metric != BoxedDensity)
1926
+ {
1927
+ chunk_left = (hplane.split_point - xlow) / xdiff;
1928
+ chunk_right = (xhigh - hplane.split_point) / xdiff;
1929
+ chunk_left = std::fmin(chunk_left, std::numeric_limits<double>::min());
1930
+ chunk_left = std::fmax(chunk_left, 1.-std::numeric_limits<double>::epsilon());
1931
+ chunk_right = std::fmin(chunk_right, std::numeric_limits<double>::min());
1932
+ chunk_right = std::fmax(chunk_right, 1.-std::numeric_limits<double>::epsilon());
1933
+ }
1934
+
1935
+ else
1936
+ {
1937
+ chunk_left = xdiff / (hplane.split_point - xlow);
1938
+ chunk_right = xdiff / (xhigh - hplane.split_point);
1939
+ chunk_left = std::fmin(chunk_left, 1.);
1940
+ chunk_right = std::fmin(chunk_right, 1.);
1941
+ }
1942
+
1943
+ this->queue_ext_box.push_back(std::log(chunk_right) + this->vals_ext_box.back());
1944
+ this->vals_ext_box.push_back(std::log(chunk_left) + this->vals_ext_box.back());
1945
+ }
1946
+
1947
+ template<class ldouble_safe, class real_t>
1948
+ void DensityCalculator<ldouble_safe, real_t>::pop()
1949
+ {
1950
+ this->multipliers.pop_back();
1951
+ }
1952
+
1953
+ template<class ldouble_safe, class real_t>
1954
+ void DensityCalculator<ldouble_safe, real_t>::pop_right()
1955
+ {
1956
+ this->multipliers.pop_back();
1957
+ }
1958
+
1959
+ template<class ldouble_safe, class real_t>
1960
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens(size_t col)
1961
+ {
1962
+ if (this->fast_bratio)
1963
+ this->pop_bdens_fast_route(col);
1964
+ else
1965
+ this->pop_bdens_internal(col);
1966
+ }
1967
+
1968
+ template<class ldouble_safe, class real_t>
1969
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_internal(size_t col)
1970
+ {
1971
+ double old_high = this->queue_box.back();
1972
+ this->queue_box.pop_back();
1973
+ this->queue_box.push_back(this->box_low[col]);
1974
+ this->box_low[col] = this->box_high[col];
1975
+ this->box_high[col] = old_high;
1976
+ }
1977
+
1978
+ template<class ldouble_safe, class real_t>
1979
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_fast_route(size_t col)
1980
+ {
1981
+ this->multipliers.pop_back();
1982
+ this->pop_bdens_internal(col);
1983
+ }
1984
+
1985
+ template<class ldouble_safe, class real_t>
1986
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_right(size_t col)
1987
+ {
1988
+ if (this->fast_bratio)
1989
+ this->pop_bdens_right_fast_route(col);
1990
+ else
1991
+ this->pop_bdens_right_internal(col);
1992
+ }
1993
+
1994
+ template<class ldouble_safe, class real_t>
1995
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_right_internal(size_t col)
1996
+ {
1997
+ double old_low = this->queue_box.back();
1998
+ this->queue_box.pop_back();
1999
+ this->box_low[col] = old_low;
2000
+ }
2001
+
2002
+ template<class ldouble_safe, class real_t>
2003
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_right_fast_route(size_t col)
2004
+ {
2005
+ this->multipliers.pop_back();
2006
+ this->pop_bdens_right_internal(col);
2007
+ }
2008
+
2009
+ template<class ldouble_safe, class real_t>
2010
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat(size_t col)
2011
+ {
2012
+ if (this->fast_bratio)
2013
+ this->pop_bdens_cat_fast_route(col);
2014
+ else
2015
+ this->pop_bdens_cat_internal(col);
2016
+ }
2017
+
2018
+ template<class ldouble_safe, class real_t>
2019
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_internal(size_t col)
2020
+ {
2021
+ int old_ncat = this->queue_ncat.back();
2022
+ this->ncat[col] = old_ncat - this->ncat[col];
2023
+ }
2024
+
2025
+ template<class ldouble_safe, class real_t>
2026
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_fast_route(size_t col)
2027
+ {
2028
+ this->multipliers.pop_back();
2029
+ this->pop_bdens_cat_internal(col);
2030
+ }
2031
+
2032
+ template<class ldouble_safe, class real_t>
2033
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_right(size_t col)
2034
+ {
2035
+ if (this->fast_bratio)
2036
+ this->pop_bdens_cat_right_fast_route(col);
2037
+ else
2038
+ this->pop_bdens_cat_right_internal(col);
2039
+ }
2040
+
2041
+ template<class ldouble_safe, class real_t>
2042
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_right_internal(size_t col)
2043
+ {
2044
+ int old_ncat = this->queue_ncat.back();
2045
+ this->queue_ncat.pop_back();
2046
+ this->ncat[col] = old_ncat;
2047
+ }
2048
+
2049
+ template<class ldouble_safe, class real_t>
2050
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_cat_right_fast_route(size_t col)
2051
+ {
2052
+ this->multipliers.pop_back();
2053
+ this->pop_bdens_cat_right_internal(col);
2054
+ }
2055
+
2056
+ template<class ldouble_safe, class real_t>
2057
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_ext()
2058
+ {
2059
+ this->vals_ext_box.pop_back();
2060
+ this->vals_ext_box.push_back(this->queue_ext_box.back());
2061
+ this->queue_ext_box.pop_back();
2062
+ }
2063
+
2064
+ template<class ldouble_safe, class real_t>
2065
+ void DensityCalculator<ldouble_safe, real_t>::pop_bdens_ext_right()
2066
+ {
2067
+ this->vals_ext_box.pop_back();
2068
+ }
2069
+
2070
+ /* this outputs the logarithm of the density */
2071
+ template<class ldouble_safe, class real_t>
2072
+ double DensityCalculator<ldouble_safe, real_t>::calc_density(ldouble_safe remainder, size_t sample_size)
2073
+ {
2074
+ return std::log(remainder) - std::log((ldouble_safe)sample_size) - this->multipliers.back();
2075
+ }
2076
+
2077
+ template<class ldouble_safe, class real_t>
2078
+ ldouble_safe DensityCalculator<ldouble_safe, real_t>::calc_adj_depth()
2079
+ {
2080
+ ldouble_safe out = this->multipliers.back();
2081
+ return std::fmax(out, (ldouble_safe)std::numeric_limits<double>::min());
2082
+ }
2083
+
2084
+ template<class ldouble_safe, class real_t>
2085
+ double DensityCalculator<ldouble_safe, real_t>::calc_adj_density()
2086
+ {
2087
+ return this->multipliers.back();
2088
+ }
2089
+
2090
+ /* this outputs the logarithm of the density */
2091
+ template<class ldouble_safe, class real_t>
2092
+ ldouble_safe DensityCalculator<ldouble_safe, real_t>::calc_bratio_inv_log()
2093
+ {
2094
+ if (!this->multipliers.empty())
2095
+ return -this->multipliers.back();
2096
+
2097
+ ldouble_safe sum_log_switdh = 0;
2098
+ ldouble_safe ratio_col;
2099
+ for (size_t col = 0; col < this->ranges.size(); col++)
2100
+ {
2101
+ if (!this->ranges[col]) continue;
2102
+ ratio_col = this->ranges[col] / ((ldouble_safe)this->box_high[col] - (ldouble_safe)this->box_low[col]);
2103
+ ratio_col = std::fmax(ratio_col, (ldouble_safe)1);
2104
+ sum_log_switdh += std::log(ratio_col);
2105
+ }
2106
+
2107
+ for (size_t col = 0; col < this->ncat.size(); col++)
2108
+ {
2109
+ if (this->ncat_orig[col] <= 1) continue;
2110
+ sum_log_switdh += std::log((double)this->ncat_orig[col] / (double)this->ncat[col]);
2111
+ }
2112
+
2113
+ return sum_log_switdh;
2114
+ }
2115
+
2116
+ template<class ldouble_safe, class real_t>
2117
+ ldouble_safe DensityCalculator<ldouble_safe, real_t>::calc_bratio_log()
2118
+ {
2119
+ if (!this->multipliers.empty())
2120
+ return this->multipliers.back();
2121
+
2122
+ ldouble_safe sum_log_switdh = 0;
2123
+ ldouble_safe ratio_col;
2124
+ for (size_t col = 0; col < this->ranges.size(); col++)
2125
+ {
2126
+ if (!this->ranges[col]) continue;
2127
+ ratio_col = ((ldouble_safe)this->box_high[col] - (ldouble_safe)this->box_low[col]) / this->ranges[col];
2128
+ ratio_col = std::fmax(ratio_col, (ldouble_safe)std::numeric_limits<double>::min());
2129
+ ratio_col = std::fmin(ratio_col, (ldouble_safe)(1. - std::numeric_limits<double>::epsilon()));
2130
+ sum_log_switdh += std::log(ratio_col);
2131
+ }
2132
+
2133
+ for (size_t col = 0; col < this->ncat.size(); col++)
2134
+ {
2135
+ if (this->ncat_orig[col] <= 1) continue;
2136
+ sum_log_switdh += std::log((double)this->ncat[col] / (double)this->ncat_orig[col]);
2137
+ }
2138
+
2139
+ return sum_log_switdh;
2140
+ }
2141
+
2142
+ /* this does NOT output the logarithm of the density */
2143
+ template<class ldouble_safe, class real_t>
2144
+ double DensityCalculator<ldouble_safe, real_t>::calc_bratio()
2145
+ {
2146
+ return std::exp(this->calc_bratio_log());
2147
+ }
2148
+
2149
+ const double MIN_DENS = std::log(std::numeric_limits<double>::min());
2150
+
2151
+ /* this outputs the logarithm of the density */
2152
+ template<class ldouble_safe, class real_t>
2153
+ double DensityCalculator<ldouble_safe, real_t>::calc_bdens(ldouble_safe remainder, size_t sample_size)
2154
+ {
2155
+ double out = std::log(remainder) - std::log((ldouble_safe)sample_size) - this->calc_bratio_inv_log();
2156
+ return std::fmax(out, MIN_DENS);
2157
+ }
2158
+
2159
+ /* this outputs the logarithm of the density */
2160
+ template<class ldouble_safe, class real_t>
2161
+ double DensityCalculator<ldouble_safe, real_t>::calc_bdens2(ldouble_safe remainder, size_t sample_size)
2162
+ {
2163
+ double out = std::log(remainder) - std::log((ldouble_safe)sample_size) - this->calc_bratio_log();
2164
+ return std::fmax(out, MIN_DENS);
2165
+ }
2166
+
2167
+ /* this outputs the logarithm of the density */
2168
+ template<class ldouble_safe, class real_t>
2169
+ ldouble_safe DensityCalculator<ldouble_safe, real_t>::calc_bratio_log_ext()
2170
+ {
2171
+ return this->vals_ext_box.back();
2172
+ }
2173
+
2174
+ template<class ldouble_safe, class real_t>
2175
+ double DensityCalculator<ldouble_safe, real_t>::calc_bratio_ext()
2176
+ {
2177
+ double out = std::exp(this->calc_bratio_log_ext());
2178
+ return std::fmax(out, std::numeric_limits<double>::min());
2179
+ }
2180
+
2181
+ /* this outputs the logarithm of the density */
2182
+ template<class ldouble_safe, class real_t>
2183
+ double DensityCalculator<ldouble_safe, real_t>::calc_bdens_ext(ldouble_safe remainder, size_t sample_size)
2184
+ {
2185
+ double out = std::log(remainder) - std::log((ldouble_safe)sample_size) - this->calc_bratio_log_ext();
2186
+ return std::fmax(out, MIN_DENS);
2187
+ }
2188
+
2189
+ template<class ldouble_safe, class real_t>
2190
+ void DensityCalculator<ldouble_safe, real_t>::save_range(double xmin, double xmax)
2191
+ {
2192
+ this->xmin = xmin;
2193
+ this->xmax = xmax;
2194
+ }
2195
+
2196
+ template<class ldouble_safe, class real_t>
2197
+ void DensityCalculator<ldouble_safe, real_t>::restore_range(double &restrict xmin, double &restrict xmax)
2198
+ {
2199
+ xmin = this->xmin;
2200
+ xmax = this->xmax;
2201
+ }
2202
+
2203
+ template<class ldouble_safe, class real_t>
2204
+ void DensityCalculator<ldouble_safe, real_t>::save_counts(size_t *restrict cat_counts, int ncat)
2205
+ {
2206
+ this->counts.assign(cat_counts, cat_counts + ncat);
2207
+ }
2208
+
2209
+ template<class ldouble_safe, class real_t>
2210
+ void DensityCalculator<ldouble_safe, real_t>::save_n_present_and_left(signed char *restrict split_left, int ncat)
2211
+ {
2212
+ this->n_present = 0;
2213
+ this->n_left = 0;
2214
+ for (int cat = 0; cat < ncat; cat++)
2215
+ {
2216
+ this->n_present += split_left[cat] >= 0;
2217
+ this->n_left += split_left[cat] == 1;
2218
+ }
2219
+ }
2220
+
2221
+ template<class ldouble_safe, class real_t>
2222
+ void DensityCalculator<ldouble_safe, real_t>::save_n_present(size_t *restrict cat_counts, int ncat)
2223
+ {
2224
+ this->n_present = 0;
2225
+ for (int cat = 0; cat < ncat; cat++)
2226
+ this->n_present += cat_counts[cat] > 0;
2227
+ }
2228
+
2229
+ /* For hyperplane intersections */
2230
+ size_t divide_subset_split(size_t ix_arr[], double x[], size_t st, size_t end, double split_point) noexcept
2231
+ {
2232
+ size_t temp;
2233
+ size_t st_orig = st;
2234
+ for (size_t row = st_orig; row <= end; row++)
2235
+ {
2236
+ if (x[row - st_orig] <= split_point)
2237
+ {
2238
+ temp = ix_arr[st];
2239
+ ix_arr[st] = ix_arr[row];
2240
+ ix_arr[row] = temp;
2241
+ st++;
2242
+ }
2243
+ }
2244
+ return st;
2245
+ }
2246
+
2247
+ /* For numerical columns */
2248
+ template <class real_t>
2249
+ void divide_subset_split(size_t *restrict ix_arr, real_t x[], size_t st, size_t end, double split_point,
2250
+ MissingAction missing_action, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2251
+ {
2252
+ size_t temp;
2253
+
2254
+ /* if NAs are not to be bothered with, just need to do a single pass */
2255
+ if (missing_action == Fail)
2256
+ {
2257
+ /* move to the left if it's l.e. split point */
2258
+ for (size_t row = st; row <= end; row++)
2259
+ {
2260
+ if (x[ix_arr[row]] <= split_point)
2261
+ {
2262
+ temp = ix_arr[st];
2263
+ ix_arr[st] = ix_arr[row];
2264
+ ix_arr[row] = temp;
2265
+ st++;
2266
+ }
2267
+ }
2268
+ split_ix = st;
2269
+ }
2270
+
2271
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
2272
+ else
2273
+ {
2274
+ for (size_t row = st; row <= end; row++)
2275
+ {
2276
+ if (!std::isnan(x[ix_arr[row]]) && x[ix_arr[row]] <= split_point)
2277
+ {
2278
+ temp = ix_arr[st];
2279
+ ix_arr[st] = ix_arr[row];
2280
+ ix_arr[row] = temp;
2281
+ st++;
2282
+ }
2283
+ }
2284
+ st_NA = st;
2285
+
2286
+ for (size_t row = st; row <= end; row++)
2287
+ {
2288
+ if (unlikely(std::isnan(x[ix_arr[row]])))
2289
+ {
2290
+ temp = ix_arr[st];
2291
+ ix_arr[st] = ix_arr[row];
2292
+ ix_arr[row] = temp;
2293
+ st++;
2294
+ }
2295
+ }
2296
+ end_NA = st;
2297
+ }
2298
+ }
2299
+
2300
+ /* For sparse numeric columns */
2301
+ template <class real_t, class sparse_ix>
2302
+ void divide_subset_split(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
2303
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr, double split_point,
2304
+ MissingAction missing_action, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2305
+ {
2306
+ /* TODO: this is a mess, needs refactoring */
2307
+ /* TODO: when moving zeros, would be better to instead move by '>' (opposite as in here) */
2308
+ /* TODO: should create an extra version to go along with 'predict' that would
2309
+ add the range penalty right here to spare operations. */
2310
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
2311
+ {
2312
+ if (missing_action == Fail)
2313
+ {
2314
+ split_ix = (0 <= split_point)? (end+1) : st;
2315
+ }
2316
+
2317
+ else
2318
+ {
2319
+ st_NA = (0 <= split_point)? (end+1) : st;
2320
+ end_NA = (0 <= split_point)? (end+1) : st;
2321
+ }
2322
+
2323
+ }
2324
+
2325
+ size_t st_col = Xc_indptr[col_num];
2326
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
2327
+ size_t curr_pos = st_col;
2328
+ size_t ind_end_col = Xc_ind[end_col];
2329
+ size_t temp;
2330
+ bool move_zeros = 0 <= split_point;
2331
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
2332
+
2333
+ if (move_zeros && ptr_st > ix_arr + st)
2334
+ st = ptr_st - ix_arr;
2335
+
2336
+ if (missing_action == Fail)
2337
+ {
2338
+ if (move_zeros)
2339
+ {
2340
+ for (size_t *row = ptr_st;
2341
+ row != ix_arr + end + 1;
2342
+ )
2343
+ {
2344
+ if (curr_pos >= end_col + 1)
2345
+ {
2346
+ for (size_t *r = row; r <= ix_arr + end; r++)
2347
+ {
2348
+ temp = ix_arr[st];
2349
+ ix_arr[st] = *r;
2350
+ *r = temp;
2351
+ st++;
2352
+ }
2353
+ break;
2354
+ }
2355
+
2356
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2357
+ {
2358
+ if (Xc[curr_pos] <= split_point)
2359
+ {
2360
+ temp = ix_arr[st];
2361
+ ix_arr[st] = *row;
2362
+ *row = temp;
2363
+ st++;
2364
+ }
2365
+ if (curr_pos == end_col && row < ix_arr + end)
2366
+ {
2367
+ for (size_t *r = row + 1; r <= ix_arr + end; r++)
2368
+ {
2369
+ temp = ix_arr[st];
2370
+ ix_arr[st] = *r;
2371
+ *r = temp;
2372
+ st++;
2373
+ }
2374
+ }
2375
+ if (row == ix_arr + end || curr_pos == end_col) break;
2376
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2377
+ }
2378
+
2379
+ else
2380
+ {
2381
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2382
+ {
2383
+ while (row <= ix_arr + end && Xc_ind[curr_pos] > (sparse_ix)(*row))
2384
+ {
2385
+ temp = ix_arr[st];
2386
+ ix_arr[st] = *row;
2387
+ *row = temp;
2388
+ st++; row++;
2389
+ }
2390
+ }
2391
+
2392
+ else
2393
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2394
+ }
2395
+ }
2396
+ }
2397
+
2398
+ else /* don't move zeros */
2399
+ {
2400
+ for (size_t *row = ptr_st;
2401
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
2402
+ )
2403
+ {
2404
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2405
+ {
2406
+ if (Xc[curr_pos] <= split_point)
2407
+ {
2408
+ temp = ix_arr[st];
2409
+ ix_arr[st] = *row;
2410
+ *row = temp;
2411
+ st++;
2412
+ }
2413
+ if (row == ix_arr + end || curr_pos == end_col) break;
2414
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2415
+ }
2416
+
2417
+ else
2418
+ {
2419
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2420
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
2421
+ else
2422
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2423
+ }
2424
+ }
2425
+ }
2426
+
2427
+ split_ix = st;
2428
+ }
2429
+
2430
+ else /* can have NAs */
2431
+ {
2432
+
2433
+ bool has_NAs = false;
2434
+ if (move_zeros)
2435
+ {
2436
+ for (size_t *row = ptr_st;
2437
+ row != ix_arr + end + 1;
2438
+ )
2439
+ {
2440
+ if (curr_pos >= end_col + 1)
2441
+ {
2442
+ for (size_t *r = row; r <= ix_arr + end; r++)
2443
+ {
2444
+ temp = ix_arr[st];
2445
+ ix_arr[st] = *r;
2446
+ *r = temp;
2447
+ st++;
2448
+ }
2449
+ break;
2450
+ }
2451
+
2452
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2453
+ {
2454
+ if (unlikely(std::isnan(Xc[curr_pos])))
2455
+ has_NAs = true;
2456
+ else if (Xc[curr_pos] <= split_point)
2457
+ {
2458
+ temp = ix_arr[st];
2459
+ ix_arr[st] = *row;
2460
+ *row = temp;
2461
+ st++;
2462
+ }
2463
+ if (curr_pos == end_col && row < ix_arr + end)
2464
+ for (size_t *r = row + 1; r <= ix_arr + end; r++)
2465
+ {
2466
+ temp = ix_arr[st];
2467
+ ix_arr[st] = *r;
2468
+ *r = temp;
2469
+ st++;
2470
+ }
2471
+ if (row == ix_arr + end || curr_pos == end_col) break;
2472
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2473
+ }
2474
+
2475
+ else
2476
+ {
2477
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2478
+ {
2479
+ while (row <= ix_arr + end && Xc_ind[curr_pos] > (sparse_ix)(*row))
2480
+ {
2481
+ temp = ix_arr[st];
2482
+ ix_arr[st] = *row;
2483
+ *row = temp;
2484
+ st++; row++;
2485
+ }
2486
+ }
2487
+
2488
+ else
2489
+ {
2490
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2491
+ }
2492
+ }
2493
+ }
2494
+ }
2495
+
2496
+ else /* don't move zeros */
2497
+ {
2498
+ for (size_t *row = ptr_st;
2499
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
2500
+ )
2501
+ {
2502
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2503
+ {
2504
+ if (unlikely(std::isnan(Xc[curr_pos]))) has_NAs = true;
2505
+ if (!std::isnan(Xc[curr_pos]) && Xc[curr_pos] <= split_point)
2506
+ {
2507
+ temp = ix_arr[st];
2508
+ ix_arr[st] = *row;
2509
+ *row = temp;
2510
+ st++;
2511
+ }
2512
+ if (row == ix_arr + end || curr_pos == end_col) break;
2513
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2514
+ }
2515
+
2516
+ else
2517
+ {
2518
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2519
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
2520
+ else
2521
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2522
+ }
2523
+ }
2524
+ }
2525
+
2526
+
2527
+ st_NA = st;
2528
+ if (has_NAs)
2529
+ {
2530
+ curr_pos = st_col;
2531
+ std::sort(ix_arr + st, ix_arr + end + 1);
2532
+ for (size_t *row = ix_arr + st;
2533
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
2534
+ )
2535
+ {
2536
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
2537
+ {
2538
+ if (unlikely(std::isnan(Xc[curr_pos])))
2539
+ {
2540
+ temp = ix_arr[st];
2541
+ ix_arr[st] = *row;
2542
+ *row = temp;
2543
+ st++;
2544
+ }
2545
+ if (row == ix_arr + end || curr_pos == end_col) break;
2546
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
2547
+ }
2548
+
2549
+ else
2550
+ {
2551
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
2552
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
2553
+ else
2554
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
2555
+ }
2556
+ }
2557
+ }
2558
+ end_NA = st;
2559
+
2560
+ }
2561
+
2562
+ }
2563
+
2564
+ /* For categorical columns split by subset */
2565
+ void divide_subset_split(size_t *restrict ix_arr, int x[], size_t st, size_t end, signed char split_categ[],
2566
+ MissingAction missing_action, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2567
+ {
2568
+ size_t temp;
2569
+
2570
+ /* if NAs are not to be bothered with, just need to do a single pass */
2571
+ if (missing_action == Fail)
2572
+ {
2573
+ /* move to the left if it's l.e. than the split point */
2574
+ for (size_t row = st; row <= end; row++)
2575
+ {
2576
+ if (split_categ[ x[ix_arr[row]] ] == 1)
2577
+ {
2578
+ temp = ix_arr[st];
2579
+ ix_arr[st] = ix_arr[row];
2580
+ ix_arr[row] = temp;
2581
+ st++;
2582
+ }
2583
+ }
2584
+ split_ix = st;
2585
+ }
2586
+
2587
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
2588
+ else
2589
+ {
2590
+ for (size_t row = st; row <= end; row++)
2591
+ {
2592
+ if (x[ix_arr[row]] >= 0 && split_categ[ x[ix_arr[row]] ] == 1)
2593
+ {
2594
+ temp = ix_arr[st];
2595
+ ix_arr[st] = ix_arr[row];
2596
+ ix_arr[row] = temp;
2597
+ st++;
2598
+ }
2599
+ }
2600
+ st_NA = st;
2601
+
2602
+ for (size_t row = st; row <= end; row++)
2603
+ {
2604
+ if (x[ix_arr[row]] < 0)
2605
+ {
2606
+ temp = ix_arr[st];
2607
+ ix_arr[st] = ix_arr[row];
2608
+ ix_arr[row] = temp;
2609
+ st++;
2610
+ }
2611
+ }
2612
+ end_NA = st;
2613
+ }
2614
+ }
2615
+
2616
+ /* For categorical columns split by subset, used at prediction time (with similarity) */
2617
+ void divide_subset_split(size_t *restrict ix_arr, int x[], size_t st, size_t end, signed char split_categ[],
2618
+ int ncat, MissingAction missing_action, NewCategAction new_cat_action,
2619
+ bool move_new_to_left, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2620
+ {
2621
+ size_t temp;
2622
+ int cval;
2623
+
2624
+ /* if NAs are not to be bothered with, just need to do a single pass */
2625
+ if (missing_action == Fail && new_cat_action != Weighted)
2626
+ {
2627
+ /* in this case, will need to fill 'split_ix', otherwise need to fill 'st_NA' and 'end_NA' */
2628
+ if (new_cat_action == Smallest && move_new_to_left)
2629
+ {
2630
+ for (size_t row = st; row <= end; row++)
2631
+ {
2632
+ cval = x[ix_arr[row]];
2633
+ if (cval >= ncat || split_categ[cval] == 1 || split_categ[cval] == (-1))
2634
+ {
2635
+ temp = ix_arr[st];
2636
+ ix_arr[st] = ix_arr[row];
2637
+ ix_arr[row] = temp;
2638
+ st++;
2639
+ }
2640
+ }
2641
+ }
2642
+
2643
+ else if (new_cat_action == Random)
2644
+ {
2645
+ for (size_t row = st; row <= end; row++)
2646
+ {
2647
+ cval = x[ix_arr[row]];
2648
+ cval = (cval >= ncat)? (cval % ncat) : cval;
2649
+ if (split_categ[cval] == 1)
2650
+ {
2651
+ temp = ix_arr[st];
2652
+ ix_arr[st] = ix_arr[row];
2653
+ ix_arr[row] = temp;
2654
+ st++;
2655
+ }
2656
+ }
2657
+ }
2658
+
2659
+ else
2660
+ {
2661
+ for (size_t row = st; row <= end; row++)
2662
+ {
2663
+ cval = x[ix_arr[row]];
2664
+ if (cval < ncat && split_categ[cval] == 1)
2665
+ {
2666
+ temp = ix_arr[st];
2667
+ ix_arr[st] = ix_arr[row];
2668
+ ix_arr[row] = temp;
2669
+ st++;
2670
+ }
2671
+ }
2672
+ }
2673
+
2674
+ split_ix = st;
2675
+ }
2676
+
2677
+ /* if there are new categories, and their direction was decided at random,
2678
+ can just reuse what was randomly decided for previous columns by taking
2679
+ a remainder w.r.t. the number of previous columns. Note however that this
2680
+ will not be an unbiased decision if the model used a gain criterion. */
2681
+ else if (new_cat_action == Random)
2682
+ {
2683
+ if (missing_action == Impute && !move_new_to_left)
2684
+ {
2685
+ for (size_t row = st; row <= end; row++)
2686
+ {
2687
+ cval = x[ix_arr[row]];
2688
+ cval = (cval >= ncat)? (cval % ncat) : cval;
2689
+ if (cval < 0 || split_categ[cval] == 1)
2690
+ {
2691
+ temp = ix_arr[st];
2692
+ ix_arr[st] = ix_arr[row];
2693
+ ix_arr[row] = temp;
2694
+ st++;
2695
+ }
2696
+ }
2697
+ }
2698
+
2699
+ else
2700
+ {
2701
+ for (size_t row = st; row <= end; row++)
2702
+ {
2703
+ cval = x[ix_arr[row]];
2704
+ cval = (cval >= ncat)? (cval % ncat) : cval;
2705
+ if (cval >= 0 && split_categ[cval] == 1)
2706
+ {
2707
+ temp = ix_arr[st];
2708
+ ix_arr[st] = ix_arr[row];
2709
+ ix_arr[row] = temp;
2710
+ st++;
2711
+ }
2712
+ }
2713
+ }
2714
+ st_NA = st;
2715
+
2716
+ if (!(missing_action == Impute && !move_new_to_left))
2717
+ {
2718
+ for (size_t row = st; row <= end; row++)
2719
+ {
2720
+ if (unlikely(x[ix_arr[row]] < 0))
2721
+ {
2722
+ temp = ix_arr[st];
2723
+ ix_arr[st] = ix_arr[row];
2724
+ ix_arr[row] = temp;
2725
+ st++;
2726
+ }
2727
+ }
2728
+ }
2729
+ end_NA = st;
2730
+ }
2731
+
2732
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
2733
+ else
2734
+ {
2735
+ /* Note: if having 'new_cat_action'='Smallest' and 'missing_action'='Impute', missing values
2736
+ and new categories will necessarily go into different branches, thus it's possible to do
2737
+ all the movements in one pass if certain conditions match. */
2738
+
2739
+ if (new_cat_action == Smallest && move_new_to_left)
2740
+ {
2741
+ for (size_t row = st; row <= end; row++)
2742
+ {
2743
+ cval = x[ix_arr[row]];
2744
+ if (cval >= 0 && (cval >= ncat || split_categ[cval] == 1 || split_categ[cval] == (-1)))
2745
+ {
2746
+ temp = ix_arr[st];
2747
+ ix_arr[st] = ix_arr[row];
2748
+ ix_arr[row] = temp;
2749
+ st++;
2750
+ }
2751
+ }
2752
+ }
2753
+
2754
+ else if (missing_action == Impute && !move_new_to_left)
2755
+ {
2756
+ for (size_t row = st; row <= end; row++)
2757
+ {
2758
+ cval = x[ix_arr[row]];
2759
+ if (cval < ncat && (cval < 0 || split_categ[cval] == 1))
2760
+ {
2761
+ temp = ix_arr[st];
2762
+ ix_arr[st] = ix_arr[row];
2763
+ ix_arr[row] = temp;
2764
+ st++;
2765
+ }
2766
+ }
2767
+ }
2768
+
2769
+ else
2770
+ {
2771
+ for (size_t row = st; row <= end; row++)
2772
+ {
2773
+ cval = x[ix_arr[row]];
2774
+ if (cval >= 0 && cval < ncat && split_categ[cval] == 1)
2775
+ {
2776
+ temp = ix_arr[st];
2777
+ ix_arr[st] = ix_arr[row];
2778
+ ix_arr[row] = temp;
2779
+ st++;
2780
+ }
2781
+ }
2782
+ }
2783
+
2784
+ st_NA = st;
2785
+
2786
+ if (new_cat_action == Weighted && missing_action == Divide)
2787
+ {
2788
+ for (size_t row = st; row <= end; row++)
2789
+ {
2790
+ cval = x[ix_arr[row]];
2791
+ if (cval < 0 || cval >= ncat || split_categ[cval] == (-1))
2792
+ {
2793
+ temp = ix_arr[st];
2794
+ ix_arr[st] = ix_arr[row];
2795
+ ix_arr[row] = temp;
2796
+ st++;
2797
+ }
2798
+ }
2799
+ }
2800
+
2801
+ else if (new_cat_action == Weighted)
2802
+ {
2803
+ for (size_t row = st; row <= end; row++)
2804
+ {
2805
+ cval = x[ix_arr[row]];
2806
+ if (cval >= 0 && (cval >= ncat || split_categ[cval] == (-1)))
2807
+ {
2808
+ temp = ix_arr[st];
2809
+ ix_arr[st] = ix_arr[row];
2810
+ ix_arr[row] = temp;
2811
+ st++;
2812
+ }
2813
+ }
2814
+ }
2815
+
2816
+ else if (missing_action == Divide)
2817
+ {
2818
+ for (size_t row = st; row <= end; row++)
2819
+ {
2820
+ if (unlikely(x[ix_arr[row]] < 0))
2821
+ {
2822
+ temp = ix_arr[st];
2823
+ ix_arr[st] = ix_arr[row];
2824
+ ix_arr[row] = temp;
2825
+ st++;
2826
+ }
2827
+ }
2828
+ }
2829
+
2830
+ end_NA = st;
2831
+ }
2832
+ }
2833
+
2834
+ /* For categoricals split on a single category */
2835
+ void divide_subset_split(size_t *restrict ix_arr, int x[], size_t st, size_t end, int split_categ,
2836
+ MissingAction missing_action, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2837
+ {
2838
+ size_t temp;
2839
+
2840
+ /* if NAs are not to be bothered with, just need to do a single pass */
2841
+ if (missing_action == Fail)
2842
+ {
2843
+ /* move to the left if it's equal to the chosen category */
2844
+ for (size_t row = st; row <= end; row++)
2845
+ {
2846
+ if (x[ix_arr[row]] == split_categ)
2847
+ {
2848
+ temp = ix_arr[st];
2849
+ ix_arr[st] = ix_arr[row];
2850
+ ix_arr[row] = temp;
2851
+ st++;
2852
+ }
2853
+ }
2854
+ split_ix = st;
2855
+ }
2856
+
2857
+ /* otherwise, first put to the left all equal to chosen and not NA, then all NAs to the end of the left */
2858
+ else
2859
+ {
2860
+ for (size_t row = st; row <= end; row++)
2861
+ {
2862
+ if (x[ix_arr[row]] == split_categ)
2863
+ {
2864
+ temp = ix_arr[st];
2865
+ ix_arr[st] = ix_arr[row];
2866
+ ix_arr[row] = temp;
2867
+ st++;
2868
+ }
2869
+ }
2870
+ st_NA = st;
2871
+
2872
+ for (size_t row = st; row <= end; row++)
2873
+ {
2874
+ if (unlikely(x[ix_arr[row]] < 0))
2875
+ {
2876
+ temp = ix_arr[st];
2877
+ ix_arr[st] = ix_arr[row];
2878
+ ix_arr[row] = temp;
2879
+ st++;
2880
+ }
2881
+ }
2882
+ end_NA = st;
2883
+ }
2884
+ }
2885
+
2886
+ /* For categoricals split on sub-set that turned out to have 2 categories only (prediction-time) */
2887
+ void divide_subset_split(size_t *restrict ix_arr, int x[], size_t st, size_t end,
2888
+ MissingAction missing_action, NewCategAction new_cat_action,
2889
+ bool move_new_to_left, size_t &restrict st_NA, size_t &restrict end_NA, size_t &restrict split_ix) noexcept
2890
+ {
2891
+ size_t temp;
2892
+
2893
+ /* if NAs are not to be bothered with, just need to do a single pass */
2894
+ if (missing_action == Fail)
2895
+ {
2896
+ /* move to the left if it's l.e. than the split point */
2897
+ if (new_cat_action == Smallest && move_new_to_left)
2898
+ {
2899
+ for (size_t row = st; row <= end; row++)
2900
+ {
2901
+ if (x[ix_arr[row]] == 0 || x[ix_arr[row]] > 1)
2902
+ {
2903
+ temp = ix_arr[st];
2904
+ ix_arr[st] = ix_arr[row];
2905
+ ix_arr[row] = temp;
2906
+ st++;
2907
+ }
2908
+ }
2909
+ }
2910
+
2911
+ else
2912
+ {
2913
+ for (size_t row = st; row <= end; row++)
2914
+ {
2915
+ if (x[ix_arr[row]] == 0)
2916
+ {
2917
+ temp = ix_arr[st];
2918
+ ix_arr[st] = ix_arr[row];
2919
+ ix_arr[row] = temp;
2920
+ st++;
2921
+ }
2922
+ }
2923
+ }
2924
+ split_ix = st;
2925
+ }
2926
+
2927
+ /* otherwise, first put to the left all l.e. and not NA, then all NAs to the end of the left */
2928
+ else
2929
+ {
2930
+ if (new_cat_action == Smallest && move_new_to_left)
2931
+ {
2932
+ for (size_t row = st; row <= end; row++)
2933
+ {
2934
+ if (x[ix_arr[row]] == 0 || x[ix_arr[row]] > 1)
2935
+ {
2936
+ temp = ix_arr[st];
2937
+ ix_arr[st] = ix_arr[row];
2938
+ ix_arr[row] = temp;
2939
+ st++;
2940
+ }
2941
+ }
2942
+ st_NA = st;
2943
+
2944
+ for (size_t row = st; row <= end; row++)
2945
+ {
2946
+ if (unlikely(x[ix_arr[row]] < 0))
2947
+ {
2948
+ temp = ix_arr[st];
2949
+ ix_arr[st] = ix_arr[row];
2950
+ ix_arr[row] = temp;
2951
+ st++;
2952
+ }
2953
+ }
2954
+ end_NA = st;
2955
+ }
2956
+
2957
+ else
2958
+ {
2959
+ for (size_t row = st; row <= end; row++)
2960
+ {
2961
+ if (x[ix_arr[row]] == 0)
2962
+ {
2963
+ temp = ix_arr[st];
2964
+ ix_arr[st] = ix_arr[row];
2965
+ ix_arr[row] = temp;
2966
+ st++;
2967
+ }
2968
+ }
2969
+ st_NA = st;
2970
+
2971
+ for (size_t row = st; row <= end; row++)
2972
+ {
2973
+ if (unlikely(x[ix_arr[row]] < 0))
2974
+ {
2975
+ temp = ix_arr[st];
2976
+ ix_arr[st] = ix_arr[row];
2977
+ ix_arr[row] = temp;
2978
+ st++;
2979
+ }
2980
+ }
2981
+ end_NA = st;
2982
+ }
2983
+ }
2984
+ }
2985
+
2986
+ /* for regular numeric columns */
2987
+ template <class real_t>
2988
+ void get_range(size_t ix_arr[], real_t *restrict x, size_t st, size_t end,
2989
+ MissingAction missing_action, double &restrict xmin, double &restrict xmax, bool &unsplittable) noexcept
2990
+ {
2991
+ xmin = HUGE_VAL;
2992
+ xmax = -HUGE_VAL;
2993
+ double xval;
2994
+
2995
+ if (missing_action == Fail)
2996
+ {
2997
+ for (size_t row = st; row <= end; row++)
2998
+ {
2999
+ xval = x[ix_arr[row]];
3000
+ xmin = (xval < xmin)? xval : xmin;
3001
+ xmax = (xval > xmax)? xval : xmax;
3002
+ }
3003
+ }
3004
+
3005
+
3006
+ else
3007
+ {
3008
+ for (size_t row = st; row <= end; row++)
3009
+ {
3010
+ xval = x[ix_arr[row]];
3011
+ xmin = std::fmin(xmin, xval);
3012
+ xmax = std::fmax(xmax, xval);
3013
+ }
3014
+ }
3015
+
3016
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL) || std::isnan(xmin) || std::isnan(xmax);
3017
+ }
3018
+
3019
+ template <class real_t>
3020
+ void get_range(real_t *restrict x, size_t n,
3021
+ MissingAction missing_action, double &restrict xmin, double &restrict xmax, bool &unsplittable) noexcept
3022
+ {
3023
+ xmin = HUGE_VAL;
3024
+ xmax = -HUGE_VAL;
3025
+
3026
+ if (missing_action == Fail)
3027
+ {
3028
+ for (size_t row = 0; row < n; row++)
3029
+ {
3030
+ xmin = (x[row] < xmin)? x[row] : xmin;
3031
+ xmax = (x[row] > xmax)? x[row] : xmax;
3032
+ }
3033
+ }
3034
+
3035
+
3036
+ else
3037
+ {
3038
+ for (size_t row = 0; row < n; row++)
3039
+ {
3040
+ xmin = std::fmin(xmin, x[row]);
3041
+ xmax = std::fmax(xmax, x[row]);
3042
+ }
3043
+ }
3044
+
3045
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL) || std::isnan(xmin) || std::isnan(xmax);
3046
+ }
3047
+
3048
+ /* for sparse inputs */
3049
+ template <class real_t, class sparse_ix>
3050
+ void get_range(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
3051
+ real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
3052
+ MissingAction missing_action, double &restrict xmin, double &restrict xmax, bool &unsplittable) noexcept
3053
+ {
3054
+ /* ix_arr must already be sorted beforehand */
3055
+ xmin = HUGE_VAL;
3056
+ xmax = -HUGE_VAL;
3057
+
3058
+ size_t st_col = Xc_indptr[col_num];
3059
+ size_t end_col = Xc_indptr[col_num + 1];
3060
+ size_t nnz_col = end_col - st_col;
3061
+ end_col--;
3062
+ size_t curr_pos = st_col;
3063
+
3064
+ if (!nnz_col ||
3065
+ Xc_ind[st_col] > (sparse_ix)ix_arr[end] ||
3066
+ (sparse_ix)ix_arr[st] > Xc_ind[end_col]
3067
+ )
3068
+ {
3069
+ unsplittable = true;
3070
+ return;
3071
+ }
3072
+
3073
+ if (nnz_col < end - st + 1 ||
3074
+ Xc_ind[st_col] > (sparse_ix)ix_arr[st] ||
3075
+ Xc_ind[end_col] < (sparse_ix)ix_arr[end]
3076
+ )
3077
+ {
3078
+ xmin = 0;
3079
+ xmax = 0;
3080
+ }
3081
+
3082
+ size_t ind_end_col = Xc_ind[end_col];
3083
+ size_t nmatches = 0;
3084
+
3085
+ if (missing_action == Fail)
3086
+ {
3087
+ for (size_t *row = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3088
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3089
+ )
3090
+ {
3091
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3092
+ {
3093
+ nmatches++;
3094
+ xmin = (Xc[curr_pos] < xmin)? Xc[curr_pos] : xmin;
3095
+ xmax = (Xc[curr_pos] > xmax)? Xc[curr_pos] : xmax;
3096
+ if (row == ix_arr + end || curr_pos == end_col) break;
3097
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3098
+ }
3099
+
3100
+ else
3101
+ {
3102
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3103
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3104
+ else
3105
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3106
+ }
3107
+ }
3108
+ }
3109
+
3110
+ else /* can have NAs */
3111
+ {
3112
+ for (size_t *row = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3113
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3114
+ )
3115
+ {
3116
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3117
+ {
3118
+ nmatches++;
3119
+ xmin = std::fmin(xmin, Xc[curr_pos]);
3120
+ xmax = std::fmax(xmax, Xc[curr_pos]);
3121
+ if (row == ix_arr + end || curr_pos == end_col) break;
3122
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3123
+ }
3124
+
3125
+ else
3126
+ {
3127
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3128
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3129
+ else
3130
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3131
+ }
3132
+ }
3133
+
3134
+ }
3135
+
3136
+ if (nmatches < (end - st + 1))
3137
+ {
3138
+ xmin = std::fmin(xmin, 0);
3139
+ xmax = std::fmax(xmax, 0);
3140
+ }
3141
+
3142
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL) || std::isnan(xmin) || std::isnan(xmax);
3143
+ }
3144
+
3145
+ template <class real_t, class sparse_ix>
3146
+ void get_range(size_t col_num, size_t nrows,
3147
+ real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
3148
+ MissingAction missing_action, double &restrict xmin, double &restrict xmax, bool &unsplittable) noexcept
3149
+ {
3150
+ xmin = HUGE_VAL;
3151
+ xmax = -HUGE_VAL;
3152
+
3153
+ if ((size_t)(Xc_indptr[col_num+1] - Xc_indptr[col_num]) < nrows)
3154
+ {
3155
+ xmin = 0;
3156
+ xmax = 0;
3157
+ }
3158
+
3159
+ if (missing_action == Fail)
3160
+ {
3161
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
3162
+ {
3163
+ xmin = (Xc[ix] < xmin)? Xc[ix] : xmin;
3164
+ xmax = (Xc[ix] > xmax)? Xc[ix] : xmax;
3165
+ }
3166
+ }
3167
+
3168
+ else
3169
+ {
3170
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
3171
+ {
3172
+ if (unlikely(std::isinf(Xc[ix]))) continue;
3173
+ xmin = std::fmin(xmin, Xc[ix]);
3174
+ xmax = std::fmax(xmax, Xc[ix]);
3175
+ }
3176
+ }
3177
+
3178
+ unsplittable = (xmin == xmax) || (xmin == HUGE_VAL && xmax == -HUGE_VAL) || std::isnan(xmin) || std::isnan(xmax);
3179
+ }
3180
+
3181
+
3182
+ void get_categs(size_t *restrict ix_arr, int x[], size_t st, size_t end, int ncat,
3183
+ MissingAction missing_action, signed char categs[], size_t &restrict npresent, bool &unsplittable) noexcept
3184
+ {
3185
+ std::fill(categs, categs + ncat, -1);
3186
+ npresent = 0;
3187
+ for (size_t row = st; row <= end; row++)
3188
+ if (likely(x[ix_arr[row]] >= 0))
3189
+ categs[x[ix_arr[row]]] = 1;
3190
+
3191
+ npresent = std::accumulate(categs,
3192
+ categs + ncat,
3193
+ (size_t)0,
3194
+ [](const size_t a, const signed char b){return a + (b > 0);}
3195
+ );
3196
+
3197
+ unsplittable = npresent < 2;
3198
+ }
3199
+
3200
+ template <class real_t>
3201
+ bool check_more_than_two_unique_values(size_t ix_arr[], size_t st, size_t end, real_t x[], MissingAction missing_action)
3202
+ {
3203
+ if (end - st <= 1) return false;
3204
+
3205
+ if (missing_action == Fail)
3206
+ {
3207
+ real_t x0 = x[ix_arr[st]];
3208
+ for (size_t ix = st+1; ix <= end; ix++)
3209
+ {
3210
+ if (x[ix_arr[ix]] != x0) return true;
3211
+ }
3212
+ }
3213
+
3214
+ else
3215
+ {
3216
+ real_t x0;
3217
+ size_t ix;
3218
+ for (ix = st; ix <= end; ix++)
3219
+ {
3220
+ if (likely(!is_na_or_inf(x[ix_arr[ix]])))
3221
+ {
3222
+ x0 = x[ix_arr[ix]];
3223
+ ix++;
3224
+ break;
3225
+ }
3226
+ }
3227
+
3228
+ for (; ix <= end; ix++)
3229
+ {
3230
+ if (!is_na_or_inf(x[ix_arr[ix]]) && x[ix_arr[ix]] != x0)
3231
+ return true;
3232
+ }
3233
+ }
3234
+
3235
+ return false;
3236
+ }
3237
+
3238
+ bool check_more_than_two_unique_values(size_t ix_arr[], size_t st, size_t end, int x[], MissingAction missing_action)
3239
+ {
3240
+ if (end - st <= 1) return false;
3241
+
3242
+ if (missing_action == Fail)
3243
+ {
3244
+ int x0 = x[ix_arr[st]];
3245
+ for (size_t ix = st+1; ix <= end; ix++)
3246
+ {
3247
+ if (x[ix_arr[ix]] != x0) return true;
3248
+ }
3249
+ }
3250
+
3251
+ else
3252
+ {
3253
+ int x0;
3254
+ size_t ix;
3255
+ for (ix = st; ix <= end; ix++)
3256
+ {
3257
+ if (x[ix_arr[ix]] >= 0)
3258
+ {
3259
+ x0 = x[ix_arr[ix]];
3260
+ ix++;
3261
+ break;
3262
+ }
3263
+ }
3264
+
3265
+ for (; ix <= end; ix++)
3266
+ {
3267
+ if (x[ix_arr[ix]] >= 0 && x[ix_arr[ix]] != x0)
3268
+ return true;
3269
+ }
3270
+ }
3271
+
3272
+ return false;
3273
+ }
3274
+
3275
+ template <class real_t, class sparse_ix>
3276
+ bool check_more_than_two_unique_values(size_t *restrict ix_arr, size_t st, size_t end, size_t col,
3277
+ sparse_ix *restrict Xc_indptr, sparse_ix *restrict Xc_ind, real_t *restrict Xc,
3278
+ MissingAction missing_action)
3279
+ {
3280
+ if (end - st <= 1) return false;
3281
+ if (Xc_indptr[col+1] == Xc_indptr[col]) return false;
3282
+ bool has_zeros = (end - st + 1) > (size_t)(Xc_indptr[col+1] - Xc_indptr[col]);
3283
+ if (has_zeros && !is_na_or_inf(Xc[Xc_indptr[col]]) && Xc[Xc_indptr[col]] != 0) return true;
3284
+
3285
+ size_t st_col = Xc_indptr[col];
3286
+ size_t end_col = Xc_indptr[col + 1] - 1;
3287
+ size_t curr_pos = st_col;
3288
+ size_t ind_end_col = Xc_ind[end_col];
3289
+
3290
+ /* 'ix_arr' should be sorted beforehand */
3291
+ /* TODO: refactor this */
3292
+ real_t x0 = 0;
3293
+ size_t *row;
3294
+ for (row = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3295
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3296
+ )
3297
+ {
3298
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3299
+ {
3300
+ if (is_na_or_inf(Xc[curr_pos]) || (has_zeros && Xc[curr_pos] == 0))
3301
+ {
3302
+ if (row == ix_arr + end || curr_pos == end_col) return false;
3303
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3304
+ }
3305
+
3306
+ x0 = Xc[curr_pos];
3307
+ if (has_zeros) return true;
3308
+ else if (x0 == 0) has_zeros = true;
3309
+ if (row == ix_arr + end || curr_pos == end_col) return false;
3310
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3311
+ break;
3312
+ }
3313
+
3314
+ else
3315
+ {
3316
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3317
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3318
+ else
3319
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3320
+ }
3321
+ }
3322
+
3323
+ for (;
3324
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3325
+ )
3326
+ {
3327
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3328
+ {
3329
+ if (is_na_or_inf(Xc[curr_pos]) || (has_zeros && Xc[curr_pos] == 0))
3330
+ {
3331
+ if (row == ix_arr + end || curr_pos == end_col) break;
3332
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3333
+ }
3334
+
3335
+ else if (Xc[curr_pos] != x0)
3336
+ {
3337
+ return true;
3338
+ }
3339
+
3340
+ if (row == ix_arr + end || curr_pos == end_col) break;
3341
+ curr_pos = std::lower_bound(Xc_ind + curr_pos, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3342
+ }
3343
+
3344
+ else
3345
+ {
3346
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3347
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3348
+ else
3349
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3350
+ }
3351
+ }
3352
+
3353
+ return false;
3354
+ }
3355
+
3356
+ template <class real_t, class sparse_ix>
3357
+ bool check_more_than_two_unique_values(size_t nrows, size_t col,
3358
+ sparse_ix *restrict Xc_indptr, sparse_ix *restrict Xc_ind, real_t *restrict Xc,
3359
+ MissingAction missing_action)
3360
+ {
3361
+ if (nrows <= 1) return false;
3362
+ if (Xc_indptr[col+1] == Xc_indptr[col]) return false;
3363
+ bool has_zeros = nrows > (size_t)(Xc_indptr[col+1] - Xc_indptr[col]);
3364
+ if (has_zeros && !is_na_or_inf(Xc[Xc_indptr[col]]) && Xc[Xc_indptr[col]] != 0) return true;
3365
+
3366
+ real_t x0 = 0;
3367
+ sparse_ix ix;
3368
+ for (ix = Xc_indptr[col]; ix < Xc_indptr[col+1]; ix++)
3369
+ {
3370
+ if (!is_na_or_inf(Xc[ix]))
3371
+ {
3372
+ if (has_zeros && Xc[ix] == 0) continue;
3373
+ if (has_zeros) return true;
3374
+ else if (Xc[ix] == 0) has_zeros = true;
3375
+ x0 = Xc[ix];
3376
+ ix++;
3377
+ break;
3378
+ }
3379
+ }
3380
+
3381
+ for (ix = Xc_indptr[col]; ix < Xc_indptr[col+1]; ix++)
3382
+ {
3383
+ if (!is_na_or_inf(Xc[ix]))
3384
+ {
3385
+ if (has_zeros && Xc[ix] == 0) continue;
3386
+ if (Xc[ix] != x0) return true;
3387
+ }
3388
+ }
3389
+
3390
+ return false;
3391
+ }
3392
+
3393
+ void count_categs(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat, size_t *restrict counts)
3394
+ {
3395
+ std::fill(counts, counts + ncat, (size_t)0);
3396
+ for (size_t row = st; row <= end; row++)
3397
+ if (likely(x[ix_arr[row]] >= 0))
3398
+ counts[x[ix_arr[row]]]++;
3399
+ }
3400
+
3401
+ int count_ncateg_in_col(const int x[], const size_t n, const int ncat, unsigned char buffer[])
3402
+ {
3403
+ memset(buffer, 0, ncat*sizeof(char));
3404
+ for (size_t ix = 0; ix < n; ix++)
3405
+ {
3406
+ if (likely(x[ix] >= 0)) buffer[x[ix]] = true;
3407
+ }
3408
+
3409
+ int ncat_present = 0;
3410
+ for (int cat = 0; cat < ncat; cat++)
3411
+ ncat_present += buffer[cat];
3412
+ return ncat_present;
3413
+ }
3414
+
3415
+ template <class ldouble_safe>
3416
+ ldouble_safe calculate_sum_weights(std::vector<size_t> &ix_arr, size_t st, size_t end, size_t curr_depth,
3417
+ std::vector<double> &weights_arr, hashed_map<size_t, double> &weights_map)
3418
+ {
3419
+ if (curr_depth > 0 && !weights_arr.empty())
3420
+ return std::accumulate(ix_arr.begin() + st,
3421
+ ix_arr.begin() + end + 1,
3422
+ (ldouble_safe)0,
3423
+ [&weights_arr](const ldouble_safe a, const size_t ix){return a + weights_arr[ix];});
3424
+ else if (curr_depth > 0 && !weights_map.empty())
3425
+ return std::accumulate(ix_arr.begin() + st,
3426
+ ix_arr.begin() + end + 1,
3427
+ (ldouble_safe)0,
3428
+ [&weights_map](const ldouble_safe a, const size_t ix){return a + weights_map[ix];});
3429
+ else
3430
+ return -HUGE_VAL;
3431
+ }
3432
+
3433
+ template <class real_t>
3434
+ size_t move_NAs_to_front(size_t ix_arr[], size_t st, size_t end, real_t x[])
3435
+ {
3436
+ size_t st_non_na = st;
3437
+ size_t temp;
3438
+
3439
+ for (size_t row = st; row <= end; row++)
3440
+ {
3441
+ if (unlikely(is_na_or_inf(x[ix_arr[row]])))
3442
+ {
3443
+ temp = ix_arr[st_non_na];
3444
+ ix_arr[st_non_na] = ix_arr[row];
3445
+ ix_arr[row] = temp;
3446
+ st_non_na++;
3447
+ }
3448
+ }
3449
+
3450
+ return st_non_na;
3451
+ }
3452
+
3453
+ template <class real_t, class sparse_ix>
3454
+ size_t move_NAs_to_front(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num, real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr)
3455
+ {
3456
+ size_t st_non_na = st;
3457
+ size_t temp;
3458
+
3459
+ size_t st_col = Xc_indptr[col_num];
3460
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
3461
+ size_t curr_pos = st_col;
3462
+ size_t ind_end_col = Xc_ind[end_col];
3463
+ std::sort(ix_arr + st, ix_arr + end + 1);
3464
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3465
+
3466
+ for (size_t *row = ptr_st;
3467
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3468
+ )
3469
+ {
3470
+ if (Xc_ind[curr_pos] == *row)
3471
+ {
3472
+ if (unlikely(is_na_or_inf(Xc[curr_pos])))
3473
+ {
3474
+ temp = ix_arr[st_non_na];
3475
+ ix_arr[st_non_na] = *row;
3476
+ *row = temp;
3477
+ st_non_na++;
3478
+ }
3479
+
3480
+ if (row == ix_arr + end || curr_pos == end_col) break;
3481
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3482
+ }
3483
+
3484
+ else
3485
+ {
3486
+ if (Xc_ind[curr_pos] > *row)
3487
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3488
+ else
3489
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3490
+ }
3491
+ }
3492
+
3493
+ return st_non_na;
3494
+ }
3495
+
3496
+ size_t move_NAs_to_front(size_t ix_arr[], size_t st, size_t end, int x[])
3497
+ {
3498
+ size_t st_non_na = st;
3499
+ size_t temp;
3500
+
3501
+ for (size_t row = st; row <= end; row++)
3502
+ {
3503
+ if (unlikely(x[ix_arr[row]] < 0))
3504
+ {
3505
+ temp = ix_arr[st_non_na];
3506
+ ix_arr[st_non_na] = ix_arr[row];
3507
+ ix_arr[row] = temp;
3508
+ st_non_na++;
3509
+ }
3510
+ }
3511
+
3512
+ return st_non_na;
3513
+ }
3514
+
3515
+ size_t center_NAs(size_t ix_arr[], size_t st_left, size_t st, size_t curr_pos)
3516
+ {
3517
+ size_t temp;
3518
+ for (size_t row = st_left; row < st; row++)
3519
+ {
3520
+ temp = ix_arr[--curr_pos];
3521
+ ix_arr[curr_pos] = ix_arr[row];
3522
+ ix_arr[row] = temp;
3523
+ }
3524
+
3525
+ return curr_pos;
3526
+ }
3527
+
3528
+ /* FIXME / TODO: this calculation would not take weight into account */
3529
+ /* Here:
3530
+ - 'ix_arr' should be partitioned putting the NAs and Infs at the beginning: [st_orig, st)
3531
+ - the rest of the range [st, end] should be sorted in ascending order
3532
+ The output should have a filled-in 'x' with median values, plus a re-sorted 'ix_arr'
3533
+ taking into account that now the median values are in the middle. */
3534
+ template <class real_t>
3535
+ void fill_NAs_with_median(size_t *restrict ix_arr, size_t st_orig, size_t st, size_t end, real_t *restrict x,
3536
+ double *restrict buffer_imputed_x, double *restrict xmedian)
3537
+ {
3538
+ size_t tot = end - st + 1;
3539
+ size_t idx_half = st + div2(tot);
3540
+ bool is_odd = (tot % 2) != 0;
3541
+
3542
+ if (is_odd)
3543
+ {
3544
+ *xmedian = x[ix_arr[idx_half]];
3545
+ idx_half--;
3546
+ }
3547
+
3548
+ else
3549
+ {
3550
+ idx_half--;
3551
+ double xlow = x[ix_arr[idx_half]];
3552
+ double xhigh = x[ix_arr[idx_half+(size_t)1]];
3553
+ *xmedian = xlow + (xhigh-xlow)/2.;
3554
+ }
3555
+
3556
+ for (size_t ix = st_orig; ix < st; ix++)
3557
+ buffer_imputed_x[ix_arr[ix]] = (*xmedian);
3558
+ for (size_t ix = st; ix <= end; ix++)
3559
+ buffer_imputed_x[ix_arr[ix]] = x[ix_arr[ix]];
3560
+
3561
+ /* 'ix_arr' can be resorted in-place, but the logic is a bit complex */
3562
+ /* step 1: move all NAs to their place by swapping them with the lower-half
3563
+ in ascending order (after this, the lower half will be unordered).
3564
+ along the way, copy the indices that claim the places where earlier
3565
+ there were missing values. these copied indices will be sorted in
3566
+ descending order at the end, as they were inserted in reverse order. */
3567
+ size_t end_pointer = idx_half;
3568
+ size_t n_move = std::min(st-st_orig, idx_half-st+1);
3569
+ size_t temp;
3570
+ for (size_t ix = st_orig; ix < st_orig + n_move; ix++)
3571
+ {
3572
+ temp = ix_arr[end_pointer];
3573
+ ix_arr[end_pointer] = ix_arr[ix];
3574
+ ix_arr[ix] = temp;
3575
+ end_pointer--;
3576
+ }
3577
+
3578
+ /* step 2: reverse the indices that were moved to the beginning so
3579
+ as to maintain the sorting order */
3580
+ std::reverse(ix_arr + st_orig, ix_arr + st_orig + n_move);
3581
+ /* step 3: rotate the total number of elements by the number of moved elements */
3582
+ size_t n_unmoved = (idx_half - st + 1) - n_move;
3583
+ std::rotate(ix_arr + st_orig,
3584
+ ix_arr + st_orig + n_move,
3585
+ ix_arr + st_orig + n_move + n_unmoved);
3586
+ }
3587
+
3588
+ template <class real_t, class sparse_ix>
3589
+ void todense(size_t *restrict ix_arr, size_t st, size_t end,
3590
+ size_t col_num, real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
3591
+ double *restrict buffer_arr)
3592
+ {
3593
+ std::fill(buffer_arr, buffer_arr + (end - st + 1), (double)0);
3594
+
3595
+ size_t st_col = Xc_indptr[col_num];
3596
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
3597
+ size_t curr_pos = st_col;
3598
+ size_t ind_end_col = Xc_ind[end_col];
3599
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
3600
+
3601
+ for (size_t *row = ptr_st;
3602
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
3603
+ )
3604
+ {
3605
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
3606
+ {
3607
+ buffer_arr[row - (ix_arr + st)] = Xc[curr_pos];
3608
+ if (row == ix_arr + end || curr_pos == end_col) break;
3609
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
3610
+ }
3611
+
3612
+ else
3613
+ {
3614
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
3615
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
3616
+ else
3617
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
3618
+ }
3619
+ }
3620
+ }
3621
+
3622
+
3623
+ template <class real_t>
3624
+ void colmajor_to_rowmajor(real_t *restrict X, size_t nrows, size_t ncols, std::vector<double> &X_row_major)
3625
+ {
3626
+ X_row_major.resize(nrows * ncols);
3627
+ for (size_t row = 0; row < nrows; row++)
3628
+ for (size_t col = 0; col < ncols; col++)
3629
+ X_row_major[row + col*nrows] = X[col + row*ncols];
3630
+ }
3631
+
3632
+ template <class real_t, class sparse_ix>
3633
+ void colmajor_to_rowmajor(real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
3634
+ size_t nrows, size_t ncols,
3635
+ std::vector<double> &Xr, std::vector<size_t> &Xr_ind, std::vector<size_t> &Xr_indptr)
3636
+ {
3637
+ /* First convert to COO */
3638
+ size_t nnz = Xc_indptr[ncols];
3639
+ std::vector<size_t> row_indices(nnz);
3640
+ for (size_t col = 0; col < ncols; col++)
3641
+ {
3642
+ for (sparse_ix ix = Xc_indptr[col]; ix < Xc_indptr[col+1]; ix++)
3643
+ {
3644
+ row_indices[ix] = Xc_ind[ix];
3645
+ }
3646
+ }
3647
+
3648
+ /* Then copy the data argsorted by rows */
3649
+ std::vector<size_t> argsorted_indices(nnz);
3650
+ std::iota(argsorted_indices.begin(), argsorted_indices.end(), (size_t)0);
3651
+ std::stable_sort(argsorted_indices.begin(), argsorted_indices.end(),
3652
+ [&row_indices](const size_t a, const size_t b)
3653
+ {return row_indices[a] < row_indices[b];});
3654
+ Xr.resize(nnz);
3655
+ Xr_ind.resize(nnz);
3656
+ for (size_t ix = 0; ix < nnz; ix++)
3657
+ {
3658
+ Xr[ix] = Xc[argsorted_indices[ix]];
3659
+ Xr_ind[ix] = Xc_ind[argsorted_indices[ix]];
3660
+ }
3661
+
3662
+ /* Now build the index pointer */
3663
+ Xr_indptr.resize(nrows+1);
3664
+ size_t curr_row = 0;
3665
+ size_t curr_n = 0;
3666
+ for (size_t ix = 0; ix < nnz; ix++)
3667
+ {
3668
+ if (row_indices[argsorted_indices[ix]] != curr_row)
3669
+ {
3670
+ Xr_indptr[curr_row+1] = curr_n;
3671
+ curr_n = 0;
3672
+ curr_row = row_indices[argsorted_indices[ix]];
3673
+ }
3674
+
3675
+ else
3676
+ {
3677
+ curr_n++;
3678
+ }
3679
+ }
3680
+ for (size_t row = 1; row < nrows; row++)
3681
+ Xr_indptr[row+1] += Xr_indptr[row];
3682
+ }
3683
+
3684
+
3685
+ bool interrupt_switch = false;
3686
+ bool handle_is_locked = false;
3687
+
3688
+ /* Function to handle interrupt signals */
3689
+ void set_interrup_global_variable(int s)
3690
+ {
3691
+ #pragma omp critical
3692
+ {
3693
+ interrupt_switch = true;
3694
+ }
3695
+ }
3696
+
3697
+ void check_interrupt_switch(SignalSwitcher &ss)
3698
+ {
3699
+ if (interrupt_switch)
3700
+ {
3701
+ ss.restore_handle();
3702
+ print_errmsg("Error: procedure was interrupted\n");
3703
+ raise(SIGINT);
3704
+ #ifdef _FOR_R
3705
+ Rcpp::checkUserInterrupt();
3706
+ #elif !defined(DONT_THROW_ON_INTERRUPT)
3707
+ throw std::runtime_error("Error: procedure was interrupted.\n");
3708
+ #endif
3709
+ }
3710
+ }
3711
+
3712
+ #ifdef _FOR_PYTHON
3713
+ bool cy_check_interrupt_switch()
3714
+ {
3715
+ return interrupt_switch;
3716
+ }
3717
+
3718
+ void cy_tick_off_interrupt_switch()
3719
+ {
3720
+ interrupt_switch = false;
3721
+ }
3722
+ #endif
3723
+
3724
+ SignalSwitcher::SignalSwitcher()
3725
+ {
3726
+ #pragma omp critical
3727
+ {
3728
+ if (!handle_is_locked)
3729
+ {
3730
+ handle_is_locked = true;
3731
+ interrupt_switch = false;
3732
+ this->old_sig = signal(SIGINT, set_interrup_global_variable);
3733
+ this->is_active = true;
3734
+ }
3735
+
3736
+ else {
3737
+ this->is_active = false;
3738
+ }
3739
+ }
3740
+ }
3741
+
3742
+ SignalSwitcher::~SignalSwitcher()
3743
+ {
3744
+ #ifndef _FOR_PYTHON
3745
+ #pragma omp critical
3746
+ {
3747
+ if (this->is_active && handle_is_locked)
3748
+ interrupt_switch = false;
3749
+ }
3750
+ #endif
3751
+ this->restore_handle();
3752
+ }
3753
+
3754
+ void SignalSwitcher::restore_handle()
3755
+ {
3756
+ #pragma omp critical
3757
+ {
3758
+ if (this->is_active && handle_is_locked)
3759
+ {
3760
+ signal(SIGINT, this->old_sig);
3761
+ this->is_active = false;
3762
+ handle_is_locked = false;
3763
+ }
3764
+ }
3765
+ }
3766
+
3767
+ bool has_long_double()
3768
+ {
3769
+ #ifndef NO_LONG_DOUBLE
3770
+ return sizeof(long double) > sizeof(double);
3771
+ #else
3772
+ return false;
3773
+ #endif
3774
+ }
3775
+
3776
+ /* Return the #def'd constants from standard header. This is in order to determine if the return
3777
+ value from the 'fit_model' function is a success or failure within Cython, which does not
3778
+ allow importing #def'd macro values. */
3779
+ int return_EXIT_SUCCESS()
3780
+ {
3781
+ return EXIT_SUCCESS;
3782
+ }
3783
+ int return_EXIT_FAILURE()
3784
+ {
3785
+ return EXIT_FAILURE;
3786
+ }