isotree 0.2.2 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (152) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2116 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +132 -0
  14. data/vendor/isotree/src/RcppExports.cpp +594 -57
  15. data/vendor/isotree/src/Rwrapper.cpp +2452 -304
  16. data/vendor/isotree/src/c_interface.cpp +958 -0
  17. data/vendor/isotree/src/crit.hpp +4236 -0
  18. data/vendor/isotree/src/digamma.hpp +184 -0
  19. data/vendor/isotree/src/dist.hpp +1886 -0
  20. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  21. data/vendor/isotree/src/extended.hpp +1444 -0
  22. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  23. data/vendor/isotree/src/fit_model.hpp +2401 -0
  24. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  25. data/vendor/isotree/src/helpers_iforest.hpp +814 -0
  26. data/vendor/isotree/src/{impute.cpp → impute.hpp} +382 -123
  27. data/vendor/isotree/src/indexer.cpp +515 -0
  28. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  29. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  30. data/vendor/isotree/src/isoforest.hpp +1659 -0
  31. data/vendor/isotree/src/isotree.hpp +1815 -394
  32. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  33. data/vendor/isotree/src/merge_models.cpp +159 -16
  34. data/vendor/isotree/src/mult.hpp +1321 -0
  35. data/vendor/isotree/src/oop_interface.cpp +844 -0
  36. data/vendor/isotree/src/oop_interface.hpp +278 -0
  37. data/vendor/isotree/src/other_helpers.hpp +219 -0
  38. data/vendor/isotree/src/predict.hpp +1932 -0
  39. data/vendor/isotree/src/python_helpers.hpp +114 -0
  40. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  41. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  42. data/vendor/isotree/src/robinmap/README.md +483 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1639 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  46. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  47. data/vendor/isotree/src/serialize.cpp +4316 -139
  48. data/vendor/isotree/src/sql.cpp +143 -61
  49. data/vendor/isotree/src/subset_models.cpp +174 -0
  50. data/vendor/isotree/src/utils.hpp +3786 -0
  51. data/vendor/isotree/src/xoshiro.hpp +463 -0
  52. data/vendor/isotree/src/ziggurat.hpp +405 -0
  53. metadata +40 -105
  54. data/vendor/cereal/LICENSE +0 -24
  55. data/vendor/cereal/README.md +0 -85
  56. data/vendor/cereal/include/cereal/access.hpp +0 -351
  57. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  58. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  59. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  60. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  61. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  62. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  63. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  65. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  66. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  67. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  68. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  69. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  70. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  71. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  72. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  74. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  76. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  77. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  78. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  79. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  91. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  92. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  94. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  96. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  97. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  98. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  99. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  100. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  101. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  102. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  103. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  104. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  105. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  106. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  107. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  111. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  112. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  113. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  114. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  115. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  116. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  117. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  118. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  119. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  120. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  121. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  122. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  123. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  124. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  125. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  126. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  127. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  128. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  129. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  130. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  131. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  132. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  133. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  134. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  135. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  136. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  137. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  138. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  139. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  140. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  141. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  142. data/vendor/cereal/include/cereal/version.hpp +0 -52
  143. data/vendor/isotree/src/Makevars +0 -4
  144. data/vendor/isotree/src/crit.cpp +0 -912
  145. data/vendor/isotree/src/dist.cpp +0 -749
  146. data/vendor/isotree/src/extended.cpp +0 -790
  147. data/vendor/isotree/src/fit_model.cpp +0 -1090
  148. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  149. data/vendor/isotree/src/isoforest.cpp +0 -771
  150. data/vendor/isotree/src/mult.cpp +0 -607
  151. data/vendor/isotree/src/predict.cpp +0 -853
  152. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,4236 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+ /* TODO: should the kurtosis calculation impute values when using ndim=1 + missing_action=Impute?
66
+ It should be the theoretically correct approach, but will cause the kurtosis to increase
67
+ significantly if there is a large number of missing values, which would lead to prefer
68
+ splitting on columns with mostly missing values. */
69
+
70
+ /* TODO: this kurtosis caps the minimum values to zero, but the minimum achievable value is 1,
71
+ see how are imprecise results used outside of the function in the different kind of calculations
72
+ that use kurtosis and maybe change the logic. */
73
+
74
+ #define pw1(x) ((x))
75
+ #define pw2(x) ((x) * (x))
76
+ #define pw3(x) ((x) * (x) * (x))
77
+ #define pw4(x) ((x) * (x) * (x) * (x))
78
+
79
+ /* https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics */
80
+ template <class real_t, class ldouble_safe>
81
+ double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, real_t x[], MissingAction missing_action)
82
+ {
83
+ ldouble_safe m = 0;
84
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
85
+ ldouble_safe delta, delta_s, delta_div;
86
+ ldouble_safe diff, n;
87
+ ldouble_safe out;
88
+
89
+ if (missing_action == Fail)
90
+ {
91
+ for (size_t row = st; row <= end; row++)
92
+ {
93
+ n = (ldouble_safe)(row - st + 1);
94
+
95
+ delta = x[ix_arr[row]] - m;
96
+ delta_div = delta / n;
97
+ delta_s = delta_div * delta_div;
98
+ diff = delta * (delta_div * (ldouble_safe)(row - st));
99
+
100
+ m += delta_div;
101
+ M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
102
+ M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
103
+ M2 += diff;
104
+ }
105
+
106
+ if (unlikely(!is_na_or_inf(M2) && M2 <= 0))
107
+ {
108
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
109
+ return -HUGE_VAL;
110
+ }
111
+
112
+ out = ( M4 / M2 ) * ( (ldouble_safe)(end - st + 1) / M2 );
113
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
114
+ }
115
+
116
+ else
117
+ {
118
+ size_t cnt = 0;
119
+ for (size_t row = st; row <= end; row++)
120
+ {
121
+ if (likely(!is_na_or_inf(x[ix_arr[row]])))
122
+ {
123
+ cnt++;
124
+ n = (ldouble_safe) cnt;
125
+
126
+ delta = x[ix_arr[row]] - m;
127
+ delta_div = delta / n;
128
+ delta_s = delta_div * delta_div;
129
+ diff = delta * (delta_div * (ldouble_safe)(cnt - 1));
130
+
131
+ m += delta_div;
132
+ M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
133
+ M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
134
+ M2 += diff;
135
+ }
136
+ }
137
+
138
+ if (unlikely(cnt == 0)) return -HUGE_VAL;
139
+ if (unlikely(!is_na_or_inf(M2) && M2 <= 0))
140
+ {
141
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
142
+ return -HUGE_VAL;
143
+ }
144
+
145
+ out = ( M4 / M2 ) * ( (ldouble_safe)cnt / M2 );
146
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
147
+ }
148
+ }
149
+
150
+ template <class real_t, class ldouble_safe>
151
+ double calc_kurtosis(real_t x[], size_t n, MissingAction missing_action)
152
+ {
153
+ ldouble_safe m = 0;
154
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
155
+ ldouble_safe delta, delta_s, delta_div;
156
+ ldouble_safe diff, n_;
157
+ ldouble_safe out;
158
+
159
+ if (missing_action == Fail)
160
+ {
161
+ for (size_t row = 0; row < n; row++)
162
+ {
163
+ n_ = (ldouble_safe)(row + 1);
164
+
165
+ delta = x[row] - m;
166
+ delta_div = delta / n_;
167
+ delta_s = delta_div * delta_div;
168
+ diff = delta * (delta_div * (ldouble_safe)row);
169
+
170
+ m += delta_div;
171
+ M4 += diff * delta_s * (n_ * n_ - 3 * n_ + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
172
+ M3 += diff * delta_div * (n_ - 2) - 3 * delta_div * M2;
173
+ M2 += diff;
174
+ }
175
+
176
+ out = ( M4 / M2 ) * ( (ldouble_safe)n / M2 );
177
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
178
+ }
179
+
180
+ else
181
+ {
182
+ size_t cnt = 0;
183
+ for (size_t row = 0; row < n; row++)
184
+ {
185
+ if (likely(!is_na_or_inf(x[row])))
186
+ {
187
+ cnt++;
188
+ n_ = (ldouble_safe) cnt;
189
+
190
+ delta = x[row] - m;
191
+ delta_div = delta / n_;
192
+ delta_s = delta_div * delta_div;
193
+ diff = delta * (delta_div * (ldouble_safe)(cnt - 1));
194
+
195
+ m += delta_div;
196
+ M4 += diff * delta_s * (n_ * n_ - 3 * n_ + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
197
+ M3 += diff * delta_div * (n_ - 2) - 3 * delta_div * M2;
198
+ M2 += diff;
199
+ }
200
+ }
201
+
202
+ if (unlikely(cnt == 0)) return -HUGE_VAL;
203
+
204
+ out = ( M4 / M2 ) * ( (ldouble_safe)cnt / M2 );
205
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
206
+ }
207
+ }
208
+
209
+ /* TODO: is this algorithm correct? */
210
+ template <class real_t, class mapping, class ldouble_safe>
211
+ double calc_kurtosis_weighted(size_t ix_arr[], size_t st, size_t end, real_t x[],
212
+ MissingAction missing_action, mapping &restrict w)
213
+ {
214
+ ldouble_safe m = 0;
215
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
216
+ ldouble_safe delta, delta_s, delta_div;
217
+ ldouble_safe diff;
218
+ ldouble_safe n = 0;
219
+ ldouble_safe out;
220
+ ldouble_safe n_prev = 0.;
221
+ ldouble_safe w_this;
222
+
223
+ for (size_t row = st; row <= end; row++)
224
+ {
225
+ if (likely(!is_na_or_inf(x[ix_arr[row]])))
226
+ {
227
+ w_this = w[ix_arr[row]];
228
+ n += w_this;
229
+
230
+ delta = x[ix_arr[row]] - m;
231
+ delta_div = delta / n;
232
+ delta_s = delta_div * delta_div;
233
+ diff = delta * (delta_div * n_prev);
234
+ n_prev = n;
235
+
236
+ m += w_this * (delta_div);
237
+ M4 += w_this * (diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3);
238
+ M3 += w_this * (diff * delta_div * (n - 2) - 3 * delta_div * M2);
239
+ M2 += w_this * (diff);
240
+ }
241
+ }
242
+
243
+ if (unlikely(n <= 0)) return -HUGE_VAL;
244
+ if (unlikely(!is_na_or_inf(M2) && M2 <= std::numeric_limits<double>::epsilon()))
245
+ {
246
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
247
+ return -HUGE_VAL;
248
+ }
249
+
250
+ out = ( M4 / M2 ) * ( n / M2 );
251
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
252
+ }
253
+
254
+ template <class real_t, class ldouble_safe>
255
+ double calc_kurtosis_weighted(real_t *restrict x, size_t n_, MissingAction missing_action, real_t *restrict w)
256
+ {
257
+ ldouble_safe m = 0;
258
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
259
+ ldouble_safe delta, delta_s, delta_div;
260
+ ldouble_safe diff;
261
+ ldouble_safe n = 0;
262
+ ldouble_safe out;
263
+ ldouble_safe n_prev = 0.;
264
+ ldouble_safe w_this;
265
+
266
+ for (size_t row = 0; row < n_; row++)
267
+ {
268
+ if (likely(!is_na_or_inf(x[row])))
269
+ {
270
+ w_this = w[row];
271
+ n += w_this;
272
+
273
+ delta = x[row] - m;
274
+ delta_div = delta / n;
275
+ delta_s = delta_div * delta_div;
276
+ diff = delta * (delta_div * n_prev);
277
+ n_prev = n;
278
+
279
+ m += w_this * (delta_div);
280
+ M4 += w_this * (diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3);
281
+ M3 += w_this * (diff * delta_div * (n - 2) - 3 * delta_div * M2);
282
+ M2 += w_this * (diff);
283
+ }
284
+ }
285
+
286
+ if (unlikely(n <= 0)) return -HUGE_VAL;
287
+
288
+ out = ( M4 / M2 ) * ( n / M2 );
289
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
290
+ }
291
+
292
+
293
+ /* TODO: make these compensated sums */
294
+ /* TODO: can this use the same algorithm as above but with a correction at the end,
295
+ like it was done for the variance? */
296
+ template <class real_t, class sparse_ix, class ldouble_safe>
297
+ double calc_kurtosis(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
298
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
299
+ MissingAction missing_action)
300
+ {
301
+ /* ix_arr must be already sorted beforehand */
302
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
303
+ return -HUGE_VAL;
304
+
305
+ ldouble_safe s1 = 0;
306
+ ldouble_safe s2 = 0;
307
+ ldouble_safe s3 = 0;
308
+ ldouble_safe s4 = 0;
309
+ ldouble_safe x_sq;
310
+ size_t cnt = end - st + 1;
311
+
312
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
313
+
314
+ size_t st_col = Xc_indptr[col_num];
315
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
316
+ size_t curr_pos = st_col;
317
+ size_t ind_end_col = Xc_ind[end_col];
318
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
319
+
320
+ ldouble_safe xval;
321
+
322
+ if (missing_action != Fail)
323
+ {
324
+ for (size_t *row = ptr_st;
325
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
326
+ )
327
+ {
328
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
329
+ {
330
+ xval = Xc[curr_pos];
331
+ if (unlikely(is_na_or_inf(xval)))
332
+ {
333
+ cnt--;
334
+ }
335
+
336
+ else
337
+ {
338
+ /* TODO: is it safe to use FMA here? some calculations rely on assuming that
339
+ some of these 's' are larger than the others. Would this procedure be guaranteed
340
+ to preserve such differences if done with a mixture of sums and FMAs? */
341
+ x_sq = square(xval);
342
+ s1 += xval;
343
+ s2 = std::fma(xval, xval, s2);
344
+ s3 = std::fma(x_sq, xval, s3);
345
+ s4 = std::fma(x_sq, x_sq, s4);
346
+ // s1 += pw1(xval);
347
+ // s2 += pw2(xval);
348
+ // s3 += pw3(xval);
349
+ // s4 += pw4(xval);
350
+ }
351
+
352
+ if (row == ix_arr + end || curr_pos == end_col) break;
353
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
354
+ }
355
+
356
+ else
357
+ {
358
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
359
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
360
+ else
361
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
362
+ }
363
+ }
364
+
365
+ if (unlikely(cnt <= (end - st + 1) - (Xc_indptr[col_num+1] - Xc_indptr[col_num]))) return -HUGE_VAL;
366
+ }
367
+
368
+ else
369
+ {
370
+ for (size_t *row = ptr_st;
371
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
372
+ )
373
+ {
374
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
375
+ {
376
+ xval = Xc[curr_pos];
377
+ x_sq = square(xval);
378
+ s1 += xval;
379
+ s2 = std::fma(xval, xval, s2);
380
+ s3 = std::fma(x_sq, xval, s3);
381
+ s4 = std::fma(x_sq, x_sq, s4);
382
+ // s1 += pw1(xval);
383
+ // s2 += pw2(xval);
384
+ // s3 += pw3(xval);
385
+ // s4 += pw4(xval);
386
+
387
+ if (row == ix_arr + end || curr_pos == end_col) break;
388
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
389
+ }
390
+
391
+ else
392
+ {
393
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
394
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
395
+ else
396
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
397
+ }
398
+ }
399
+ }
400
+
401
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
402
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
403
+ ldouble_safe sn = s1 / cnt_l;
404
+ ldouble_safe v = s2 / cnt_l - pw2(sn);
405
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
406
+ if (
407
+ v <= std::numeric_limits<double>::epsilon() &&
408
+ !check_more_than_two_unique_values(ix_arr, st, end, col_num,
409
+ Xc_indptr, Xc_ind, Xc,
410
+ missing_action)
411
+ )
412
+ return -HUGE_VAL;
413
+ if (unlikely(v <= 0)) return 0.;
414
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
415
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
416
+ }
417
+
418
+ template <class real_t, class sparse_ix, class ldouble_safe>
419
+ double calc_kurtosis(size_t col_num, size_t nrows,
420
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
421
+ MissingAction missing_action)
422
+ {
423
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
424
+ return -HUGE_VAL;
425
+
426
+ ldouble_safe s1 = 0;
427
+ ldouble_safe s2 = 0;
428
+ ldouble_safe s3 = 0;
429
+ ldouble_safe s4 = 0;
430
+ ldouble_safe x_sq;
431
+ size_t cnt = nrows;
432
+
433
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
434
+
435
+ ldouble_safe xval;
436
+
437
+ if (missing_action != Fail)
438
+ {
439
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
440
+ {
441
+ xval = Xc[ix];
442
+ if (unlikely(is_na_or_inf(xval)))
443
+ {
444
+ cnt--;
445
+ }
446
+
447
+ else
448
+ {
449
+ x_sq = square(xval);
450
+ s1 += xval;
451
+ s2 = std::fma(xval, xval, s2);
452
+ s3 = std::fma(x_sq, xval, s3);
453
+ s4 = std::fma(x_sq, x_sq, s4);
454
+ // s1 += pw1(xval);
455
+ // s2 += pw2(xval);
456
+ // s3 += pw3(xval);
457
+ // s4 += pw4(xval);
458
+ }
459
+ }
460
+
461
+ if (cnt <= (nrows) - (Xc_indptr[col_num+1] - Xc_indptr[col_num])) return -HUGE_VAL;
462
+ }
463
+
464
+ else
465
+ {
466
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
467
+ {
468
+ xval = Xc[ix];
469
+ x_sq = square(xval);
470
+ s1 += xval;
471
+ s2 = std::fma(xval, xval, s2);
472
+ s3 = std::fma(x_sq, xval, s3);
473
+ s4 = std::fma(x_sq, x_sq, s4);
474
+ // s1 += pw1(xval);
475
+ // s2 += pw2(xval);
476
+ // s3 += pw3(xval);
477
+ // s4 += pw4(xval);
478
+ }
479
+ }
480
+
481
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
482
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
483
+ ldouble_safe sn = s1 / cnt_l;
484
+ ldouble_safe v = s2 / cnt_l - pw2(sn);
485
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
486
+ if (
487
+ v <= std::numeric_limits<double>::epsilon() &&
488
+ !check_more_than_two_unique_values(nrows, col_num,
489
+ Xc_indptr, Xc_ind, Xc,
490
+ missing_action)
491
+ )
492
+ return -HUGE_VAL;
493
+ if (unlikely(v <= 0)) return 0.;
494
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
495
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
496
+ }
497
+
498
+
499
+ template <class real_t, class sparse_ix, class mapping, class ldouble_safe>
500
+ double calc_kurtosis_weighted(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
501
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
502
+ MissingAction missing_action, mapping &restrict w)
503
+ {
504
+ /* ix_arr must be already sorted beforehand */
505
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
506
+ return -HUGE_VAL;
507
+
508
+ ldouble_safe s1 = 0;
509
+ ldouble_safe s2 = 0;
510
+ ldouble_safe s3 = 0;
511
+ ldouble_safe s4 = 0;
512
+ ldouble_safe x_sq;
513
+ ldouble_safe w_this;
514
+ ldouble_safe cnt = 0;
515
+ for (size_t row = st; row <= end; row++)
516
+ cnt += w[ix_arr[row]];
517
+
518
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
519
+
520
+ size_t st_col = Xc_indptr[col_num];
521
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
522
+ size_t curr_pos = st_col;
523
+ size_t ind_end_col = Xc_ind[end_col];
524
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
525
+
526
+ ldouble_safe xval;
527
+
528
+ if (missing_action != Fail)
529
+ {
530
+ for (size_t *row = ptr_st;
531
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
532
+ )
533
+ {
534
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
535
+ {
536
+ w_this = w[*row];
537
+ xval = Xc[curr_pos];
538
+
539
+ if (unlikely(is_na_or_inf(xval)))
540
+ {
541
+ cnt -= w_this;
542
+ }
543
+
544
+ else
545
+ {
546
+ x_sq = xval * xval;
547
+ s1 = std::fma(w_this, xval, s1);
548
+ s2 = std::fma(w_this, x_sq, s2);
549
+ s3 = std::fma(w_this, x_sq*xval, s3);
550
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
551
+ // s1 += w_this * pw1(xval);
552
+ // s2 += w_this * pw2(xval);
553
+ // s3 += w_this * pw3(xval);
554
+ // s4 += w_this * pw4(xval);
555
+ }
556
+
557
+ if (row == ix_arr + end || curr_pos == end_col) break;
558
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
559
+ }
560
+
561
+ else
562
+ {
563
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
564
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
565
+ else
566
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
567
+ }
568
+ }
569
+
570
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
571
+ }
572
+
573
+ else
574
+ {
575
+ for (size_t *row = ptr_st;
576
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
577
+ )
578
+ {
579
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
580
+ {
581
+ w_this = w[*row];
582
+ xval = Xc[curr_pos];
583
+
584
+ x_sq = xval * xval;
585
+ s1 = std::fma(w_this, xval, s1);
586
+ s2 = std::fma(w_this, x_sq, s2);
587
+ s3 = std::fma(w_this, x_sq*xval, s3);
588
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
589
+ // s1 += w_this * pw1(xval);
590
+ // s2 += w_this * pw2(xval);
591
+ // s3 += w_this * pw3(xval);
592
+ // s4 += w_this * pw4(xval);
593
+
594
+ if (row == ix_arr + end || curr_pos == end_col) break;
595
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
596
+ }
597
+
598
+ else
599
+ {
600
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
601
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
602
+ else
603
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
604
+ }
605
+ }
606
+ }
607
+
608
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
609
+ ldouble_safe sn = s1 / cnt;
610
+ ldouble_safe v = s2 / cnt - pw2(sn);
611
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
612
+ if (
613
+ v <= std::numeric_limits<double>::epsilon() &&
614
+ !check_more_than_two_unique_values(ix_arr, st, end, col_num,
615
+ Xc_indptr, Xc_ind, Xc,
616
+ missing_action)
617
+ )
618
+ return -HUGE_VAL;
619
+ if (v <= 0) return 0.;
620
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt * pw4(sn)) / (cnt * pw2(v));
621
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
622
+ }
623
+
624
+ template <class real_t, class sparse_ix, class ldouble_safe>
625
+ double calc_kurtosis_weighted(size_t col_num, size_t nrows,
626
+ real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
627
+ MissingAction missing_action, real_t *restrict w)
628
+ {
629
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
630
+ return -HUGE_VAL;
631
+
632
+ ldouble_safe s1 = 0;
633
+ ldouble_safe s2 = 0;
634
+ ldouble_safe s3 = 0;
635
+ ldouble_safe s4 = 0;
636
+ ldouble_safe x_sq;
637
+ ldouble_safe w_this;
638
+ ldouble_safe cnt = nrows - (Xc_indptr[col_num + 1] - Xc_indptr[col_num]);
639
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
640
+ cnt += w[Xc_ind[ix]];
641
+
642
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
643
+
644
+ ldouble_safe xval;
645
+
646
+ if (missing_action != Fail)
647
+ {
648
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
649
+ {
650
+ w_this = w[Xc_ind[ix]];
651
+ xval = Xc[ix];
652
+
653
+ if (unlikely(is_na_or_inf(xval)))
654
+ {
655
+ cnt -= w_this;
656
+ }
657
+
658
+ else
659
+ {
660
+ x_sq = xval * xval;
661
+ s1 = std::fma(w_this, xval, s1);
662
+ s2 = std::fma(w_this, x_sq, s2);
663
+ s3 = std::fma(w_this, x_sq*xval, s3);
664
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
665
+ // s1 += w_this * pw1(xval);
666
+ // s2 += w_this * pw2(xval);
667
+ // s3 += w_this * pw3(xval);
668
+ // s4 += w_this * pw4(xval);
669
+ }
670
+ }
671
+
672
+ if (cnt <= 0) return -HUGE_VAL;
673
+ }
674
+
675
+ else
676
+ {
677
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
678
+ {
679
+ w_this = w[Xc_ind[ix]];
680
+ xval = Xc[ix];
681
+
682
+ x_sq = xval * xval;
683
+ s1 = std::fma(w_this, xval, s1);
684
+ s2 = std::fma(w_this, x_sq, s2);
685
+ s3 = std::fma(w_this, x_sq*xval, s3);
686
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
687
+ // s1 += w_this * pw1(xval);
688
+ // s2 += w_this * pw2(xval);
689
+ // s3 += w_this * pw3(xval);
690
+ // s4 += w_this * pw4(xval);
691
+ }
692
+ }
693
+
694
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
695
+ ldouble_safe sn = s1 / cnt;
696
+ ldouble_safe v = s2 / cnt - pw2(sn);
697
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
698
+ if (
699
+ v <= std::numeric_limits<double>::epsilon() &&
700
+ !check_more_than_two_unique_values(nrows, col_num,
701
+ Xc_indptr, Xc_ind, Xc,
702
+ missing_action)
703
+ )
704
+ return -HUGE_VAL;
705
+ if (unlikely(v <= 0)) return -HUGE_VAL;
706
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt * pw4(sn)) / (cnt * pw2(v));
707
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
708
+ }
709
+
710
+
711
+ template <class ldouble_safe>
712
+ double calc_kurtosis_internal(size_t cnt, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
713
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
714
+ {
715
+ /* This calculation proceeds as follows:
716
+ - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
717
+ each category, and approximate kurtosis by sampling from such distribution
718
+ with the same probabilities as given by the current counts.
719
+ - If splitting by isolating one category, will binarize at each categorical level,
720
+ assume the values are zero or one, and output the average assuming each categorical
721
+ level has equal probability of being picked.
722
+ (Note that both are misleading heuristics, but might be better than random)
723
+ */
724
+ double sum_kurt = 0;
725
+
726
+ cnt -= buffer_cnt[ncat];
727
+ if (cnt <= 1) return -HUGE_VAL;
728
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
729
+ for (int cat = 0; cat < ncat; cat++)
730
+ buffer_prob[cat] = buffer_cnt[cat] / cnt_l;
731
+
732
+ switch (cat_split_type)
733
+ {
734
+ case SubSet:
735
+ {
736
+ ldouble_safe temp_v;
737
+ ldouble_safe s1, s2, s3, s4;
738
+ ldouble_safe coef;
739
+ ldouble_safe coef2;
740
+ ldouble_safe w_this;
741
+ UniformUnitInterval runif(0, 1);
742
+ size_t ntry = 50;
743
+ for (size_t iternum = 0; iternum < 50; iternum++)
744
+ {
745
+ s1 = 0; s2 = 0; s3 = 0; s4 = 0;
746
+ for (int cat = 0; cat < ncat; cat++)
747
+ {
748
+ coef = runif(rnd_generator);
749
+ coef2 = coef * coef;
750
+ w_this = buffer_prob[cat];
751
+ s1 = std::fma(w_this, coef, s1);
752
+ s2 = std::fma(w_this, coef2, s2);
753
+ s3 = std::fma(w_this, coef2*coef, s3);
754
+ s4 = std::fma(w_this, coef2*coef2, s4);
755
+ // s1 += buffer_prob[cat] * pw1(coef);
756
+ // s2 += buffer_prob[cat] * pw2(coef);
757
+ // s3 += buffer_prob[cat] * pw3(coef);
758
+ // s4 += buffer_prob[cat] * pw4(coef);
759
+ }
760
+ temp_v = s2 - pw2(s1);
761
+ if (temp_v <= 0)
762
+ ntry--;
763
+ else
764
+ sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
765
+ }
766
+ if (unlikely(!ntry))
767
+ return -HUGE_VAL;
768
+ else if (unlikely(is_na_or_inf(sum_kurt)))
769
+ return -HUGE_VAL;
770
+ else
771
+ return std::fmax(sum_kurt, 0.) / (double)ntry;
772
+ }
773
+
774
+ case SingleCateg:
775
+ {
776
+ double p;
777
+ int ncat_present = ncat;
778
+ for (int cat = 0; cat < ncat; cat++)
779
+ {
780
+ p = buffer_prob[cat];
781
+ if (p == 0)
782
+ ncat_present--;
783
+ else
784
+ sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
785
+ }
786
+ if (ncat_present <= 1)
787
+ return -HUGE_VAL;
788
+ else if (unlikely(is_na_or_inf(sum_kurt)))
789
+ return -HUGE_VAL;
790
+ else
791
+ return std::fmax(sum_kurt, 0.) / (double)ncat_present;
792
+ }
793
+ }
794
+
795
+ unreachable();
796
+ return -1; /* this will never be reached, but CRAN complains otherwise */
797
+ }
798
+
799
+ template <class ldouble_safe>
800
+ double calc_kurtosis(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat, size_t *restrict buffer_cnt, double buffer_prob[],
801
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
802
+ {
803
+ /* This calculation proceeds as follows:
804
+ - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
805
+ each category, and approximate kurtosis by sampling from such distribution
806
+ with the same probabilities as given by the current counts.
807
+ - If splitting by isolating one category, will binarize at each categorical level,
808
+ assume the values are zero or one, and output the average assuming each categorical
809
+ level has equal probability of being picked.
810
+ (Note that both are misleading heuristics, but might be better than random)
811
+ */
812
+ size_t cnt = end - st + 1;
813
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
814
+
815
+ if (missing_action == Fail)
816
+ {
817
+ for (size_t row = st; row <= end; row++)
818
+ buffer_cnt[x[ix_arr[row]]]++;
819
+ }
820
+
821
+ else
822
+ {
823
+ for (size_t row = st; row <= end; row++)
824
+ {
825
+ if (likely(x[ix_arr[row]] >= 0))
826
+ buffer_cnt[x[ix_arr[row]]]++;
827
+ else
828
+ buffer_cnt[ncat]++;
829
+ }
830
+ }
831
+
832
+ return calc_kurtosis_internal<ldouble_safe>(
833
+ cnt, x, ncat, buffer_cnt, buffer_prob,
834
+ missing_action, cat_split_type, rnd_generator);
835
+ }
836
+
837
+ template <class ldouble_safe>
838
+ double calc_kurtosis(size_t nrows, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
839
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
840
+ {
841
+ size_t cnt = nrows;
842
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
843
+
844
+ if (missing_action == Fail)
845
+ {
846
+ for (size_t row = 0; row < nrows; row++)
847
+ buffer_cnt[x[row]]++;
848
+ }
849
+
850
+ else
851
+ {
852
+ for (size_t row = 0; row < nrows; row++)
853
+ {
854
+ if (likely(x[row] >= 0))
855
+ buffer_cnt[x[row]]++;
856
+ else
857
+ buffer_cnt[ncat]++;
858
+ }
859
+ }
860
+
861
+ return calc_kurtosis_internal<ldouble_safe>(
862
+ cnt, x, ncat, buffer_cnt, buffer_prob,
863
+ missing_action, cat_split_type, rnd_generator);
864
+ }
865
+
866
+
867
+ /* TODO: this one should get a buffer preallocated from outside */
868
+ template <class mapping, class ldouble_safe>
869
+ double calc_kurtosis_weighted_internal(std::vector<ldouble_safe> &buffer_cnt, int x[], int ncat,
870
+ double buffer_prob[], MissingAction missing_action, CategSplit cat_split_type,
871
+ RNG_engine &rnd_generator, mapping &restrict w)
872
+ {
873
+ double sum_kurt = 0;
874
+
875
+ ldouble_safe cnt = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
876
+
877
+ cnt -= buffer_cnt[ncat];
878
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
879
+ for (int cat = 0; cat < ncat; cat++)
880
+ buffer_prob[cat] = buffer_cnt[cat] / cnt;
881
+
882
+ switch (cat_split_type)
883
+ {
884
+ case SubSet:
885
+ {
886
+ ldouble_safe temp_v;
887
+ ldouble_safe s1, s2, s3, s4;
888
+ ldouble_safe coef, coef2;
889
+ ldouble_safe w_this;
890
+ UniformUnitInterval runif(0, 1);
891
+ size_t ntry = 50;
892
+ for (size_t iternum = 0; iternum < 50; iternum++)
893
+ {
894
+ s1 = 0; s2 = 0; s3 = 0; s4 = 0;
895
+ for (int cat = 0; cat < ncat; cat++)
896
+ {
897
+ coef = runif(rnd_generator);
898
+ coef2 = coef * coef;
899
+ w_this = buffer_prob[cat];
900
+ s1 = std::fma(w_this, coef, s1);
901
+ s2 = std::fma(w_this, coef2, s2);
902
+ s3 = std::fma(w_this, coef2*coef, s3);
903
+ s4 = std::fma(w_this, coef2*coef2, s4);
904
+ // s1 += buffer_prob[cat] * pw1(coef);
905
+ // s2 += buffer_prob[cat] * pw2(coef);
906
+ // s3 += buffer_prob[cat] * pw3(coef);
907
+ // s4 += buffer_prob[cat] * pw4(coef);
908
+ }
909
+ temp_v = s2 - pw2(s1);
910
+ if (unlikely(temp_v <= 0))
911
+ ntry--;
912
+ else
913
+ sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
914
+ }
915
+ if (unlikely(!ntry))
916
+ return -HUGE_VAL;
917
+ else if (unlikely(is_na_or_inf(sum_kurt)))
918
+ return -HUGE_VAL;
919
+ else
920
+ return std::fmax(sum_kurt, 0.) / (double)ntry;
921
+ }
922
+
923
+ case SingleCateg:
924
+ {
925
+ double p;
926
+ int ncat_present = ncat;
927
+ for (int cat = 0; cat < ncat; cat++)
928
+ {
929
+ p = buffer_prob[cat];
930
+ if (p == 0)
931
+ ncat_present--;
932
+ else
933
+ sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
934
+ }
935
+ if (ncat_present <= 1)
936
+ return -HUGE_VAL;
937
+ else if (unlikely(is_na_or_inf(sum_kurt)))
938
+ return -HUGE_VAL;
939
+ else
940
+ return std::fmax(sum_kurt, 0.) / (double)ncat_present;
941
+ }
942
+ }
943
+
944
+ unreachable();
945
+ return -1; /* this will never be reached, but CRAN complains otherwise */
946
+ }
947
+
948
+ template <class mapping, class ldouble_safe>
949
+ double calc_kurtosis_weighted(size_t ix_arr[], size_t st, size_t end, int x[], int ncat, double buffer_prob[],
950
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator,
951
+ mapping &restrict w)
952
+ {
953
+ std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
954
+ ldouble_safe w_this;
955
+
956
+ for (size_t row = st; row <= end; row++)
957
+ {
958
+ w_this = w[ix_arr[row]];
959
+ if (likely(x[ix_arr[row]] >= 0))
960
+ buffer_cnt[x[ix_arr[row]]] += w_this;
961
+ else
962
+ buffer_cnt[ncat] += w_this;
963
+ }
964
+
965
+ return calc_kurtosis_weighted_internal<mapping, ldouble_safe>(
966
+ buffer_cnt, x, ncat,
967
+ buffer_prob, missing_action, cat_split_type,
968
+ rnd_generator, w);
969
+ }
970
+
971
+ template <class real_t, class ldouble_safe>
972
+ double calc_kurtosis_weighted(size_t nrows, int x[], int ncat, double *restrict buffer_prob,
973
+ MissingAction missing_action, CategSplit cat_split_type,
974
+ RNG_engine &rnd_generator, real_t *restrict w)
975
+ {
976
+ std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
977
+ ldouble_safe w_this;
978
+
979
+ for (size_t row = 0; row < nrows; row++)
980
+ {
981
+ w_this = w[row];
982
+ if (likely(x[row] >= 0))
983
+ buffer_cnt[x[row]] += w_this;
984
+ else
985
+ buffer_cnt[ncat] += w_this;
986
+ }
987
+
988
+ return calc_kurtosis_weighted_internal<real_t *restrict, ldouble_safe>(
989
+ buffer_cnt, x, ncat,
990
+ buffer_prob, missing_action, cat_split_type,
991
+ rnd_generator, w);
992
+ }
993
+
994
+ template <class int_t, class ldouble_safe>
995
+ double expected_sd_cat(double p[], size_t n, int_t pos[])
996
+ {
997
+ if (n <= 1) return 0;
998
+
999
+ ldouble_safe cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
1000
+ for (size_t cat1 = 2; cat1 < n; cat1++)
1001
+ {
1002
+ cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
1003
+ for (size_t cat2 = 0; cat2 < cat1; cat2++)
1004
+ cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
1005
+ }
1006
+ return std::sqrt(std::fmax(cum_var, (ldouble_safe)0));
1007
+ }
1008
+
1009
+ template <class number, class int_t, class ldouble_safe>
1010
+ double expected_sd_cat(number *restrict counts, double *restrict p, size_t n, int_t *restrict pos)
1011
+ {
1012
+ if (n <= 1) return 0;
1013
+
1014
+ number tot = std::accumulate(pos, pos + n, (number)0, [&counts](number tot, const size_t ix){return tot + counts[ix];});
1015
+ ldouble_safe cnt_div = (ldouble_safe) tot;
1016
+ for (size_t cat = 0; cat < n; cat++)
1017
+ p[pos[cat]] = (ldouble_safe)counts[pos[cat]] / cnt_div;
1018
+
1019
+ return expected_sd_cat<int_t, ldouble_safe>(p, n, pos);
1020
+ }
1021
+
1022
+ template <class number, class int_t, class ldouble_safe>
1023
+ double expected_sd_cat_single(number *restrict counts, double *restrict p, size_t n, int_t *restrict pos, size_t cat_exclude, number cnt)
1024
+ {
1025
+ if (cat_exclude == 0)
1026
+ return expected_sd_cat<number, int_t, ldouble_safe>(counts, p, n-1, pos + 1);
1027
+
1028
+ else if (cat_exclude == (n-1))
1029
+ return expected_sd_cat<number, int_t, ldouble_safe>(counts, p, n-1, pos);
1030
+
1031
+ size_t ix_exclude = pos[cat_exclude];
1032
+
1033
+ ldouble_safe cnt_div = (ldouble_safe) (cnt - counts[ix_exclude]);
1034
+ for (size_t cat = 0; cat < n; cat++)
1035
+ p[pos[cat]] = (ldouble_safe)counts[pos[cat]] / cnt_div;
1036
+
1037
+ ldouble_safe cum_var;
1038
+ if (cat_exclude != 1)
1039
+ cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
1040
+ else
1041
+ cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[2]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[2]]) / 3.0 + p[pos[2]] / 3.0;
1042
+ for (size_t cat1 = (cat_exclude == 1)? 3 : 2; cat1 < n; cat1++)
1043
+ {
1044
+ if (pos[cat1] == ix_exclude) continue;
1045
+ cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
1046
+ for (size_t cat2 = 0; cat2 < cat1; cat2++)
1047
+ {
1048
+ if (pos[cat2] == ix_exclude) continue;
1049
+ cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
1050
+ }
1051
+
1052
+ }
1053
+ return std::sqrt(std::fmax(cum_var, (ldouble_safe)0));
1054
+ }
1055
+
1056
+ template <class number, class int_t, class ldouble_safe>
1057
+ double expected_sd_cat_internal(int ncat, number *restrict buffer_cnt, ldouble_safe cnt_l,
1058
+ int_t *restrict buffer_pos, double *restrict buffer_prob)
1059
+ {
1060
+ /* move zero-valued to the beginning */
1061
+ std::iota(buffer_pos, buffer_pos + ncat, (int_t)0);
1062
+ int_t st_pos = 0;
1063
+ int ncat_present = 0;
1064
+ int_t temp;
1065
+ for (int cat = 0; cat < ncat; cat++)
1066
+ {
1067
+ if (buffer_cnt[cat])
1068
+ {
1069
+ ncat_present++;
1070
+ buffer_prob[cat] = (ldouble_safe) buffer_cnt[cat] / cnt_l;
1071
+ }
1072
+
1073
+ else
1074
+ {
1075
+ temp = buffer_pos[st_pos];
1076
+ buffer_pos[st_pos] = buffer_pos[cat];
1077
+ buffer_pos[cat] = temp;
1078
+ st_pos++;
1079
+ }
1080
+ }
1081
+
1082
+ if (ncat_present <= 1) return 0;
1083
+ return expected_sd_cat<int_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
1084
+ }
1085
+
1086
+
1087
+ template <class int_t, class ldouble_safe>
1088
+ double expected_sd_cat(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat,
1089
+ MissingAction missing_action,
1090
+ size_t *restrict buffer_cnt, int_t *restrict buffer_pos, double buffer_prob[])
1091
+ {
1092
+ /* generate counts */
1093
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
1094
+ size_t cnt = end - st + 1;
1095
+
1096
+ if (missing_action != Fail)
1097
+ {
1098
+ int xval;
1099
+ for (size_t row = st; row <= end; row++)
1100
+ {
1101
+ xval = x[ix_arr[row]];
1102
+ if (unlikely(xval < 0))
1103
+ buffer_cnt[ncat]++;
1104
+ else
1105
+ buffer_cnt[xval]++;
1106
+ }
1107
+ cnt -= buffer_cnt[ncat];
1108
+ if (cnt == 0) return 0;
1109
+ }
1110
+
1111
+ else
1112
+ {
1113
+ for (size_t row = st; row <= end; row++)
1114
+ {
1115
+ if (likely(x[ix_arr[row]] >= 0)) buffer_cnt[x[ix_arr[row]]]++;
1116
+ }
1117
+ }
1118
+
1119
+ return expected_sd_cat_internal<size_t, int_t, ldouble_safe>(ncat, buffer_cnt, cnt, buffer_pos, buffer_prob);
1120
+ }
1121
+
1122
+ template <class mapping, class int_t, class ldouble_safe>
1123
+ double expected_sd_cat_weighted(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat,
1124
+ MissingAction missing_action, mapping &restrict w,
1125
+ double *restrict buffer_cnt, int_t *restrict buffer_pos, double *restrict buffer_prob)
1126
+ {
1127
+ /* generate counts */
1128
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, 0.);
1129
+ ldouble_safe cnt = 0;
1130
+
1131
+ if (missing_action != Fail)
1132
+ {
1133
+ int xval;
1134
+ double w_this;
1135
+ for (size_t row = st; row <= end; row++)
1136
+ {
1137
+ xval = x[ix_arr[row]];
1138
+ w_this = w[ix_arr[row]];
1139
+
1140
+ if (unlikely(xval < 0)) {
1141
+ buffer_cnt[ncat] += w_this;
1142
+ }
1143
+ else {
1144
+ buffer_cnt[xval] += w_this;
1145
+ cnt += w_this;
1146
+ }
1147
+ }
1148
+ if (cnt == 0) return 0;
1149
+ }
1150
+
1151
+ else
1152
+ {
1153
+ for (size_t row = st; row <= end; row++)
1154
+ {
1155
+ if (likely(x[ix_arr[row]] >= 0))
1156
+ {
1157
+ buffer_cnt[x[ix_arr[row]]] += w[ix_arr[row]];
1158
+ }
1159
+ }
1160
+ for (int cat = 0; cat < ncat; cat++)
1161
+ cnt += buffer_cnt[cat];
1162
+ if (unlikely(cnt == 0)) return 0;
1163
+ }
1164
+
1165
+ return expected_sd_cat_internal<double, int_t, ldouble_safe>(ncat, buffer_cnt, cnt, buffer_pos, buffer_prob);
1166
+ }
1167
+
1168
+ /* Note: this isn't exactly comparable to the pooled gain from numeric variables,
1169
+ but among all the possible options, this is what happens to end up in the most
1170
+ similar scale when considering standardized gain. */
1171
+ template <class number, class ldouble_safe>
1172
+ double categ_gain(number cnt_left, number cnt_right,
1173
+ ldouble_safe s_left, ldouble_safe s_right,
1174
+ ldouble_safe base_info, ldouble_safe cnt)
1175
+ {
1176
+ return (
1177
+ base_info -
1178
+ (((cnt_left <= 1)? 0 : ((ldouble_safe)cnt_left * std::log((ldouble_safe)cnt_left))) - s_left) -
1179
+ (((cnt_right <= 1)? 0 : ((ldouble_safe)cnt_right * std::log((ldouble_safe)cnt_right))) - s_right)
1180
+ ) / cnt;
1181
+ }
1182
+
1183
+
1184
+ /* A couple notes about gain calculation:
1185
+
1186
+ Here one wants to find the best split point, maximizing either:
1187
+ (1/sigma) * (sigma - (1/n)*(n_left*sigma_left + n_right*sigma_right))
1188
+ or:
1189
+ (1/sigma) * (sigma - (1/2)*(sigma_left + sigma_right))
1190
+
1191
+ All the algorithms here use the sorted-indices approach, which is
1192
+ an exact method (note that there's still room for optimization by adding the
1193
+ unsorted approach for small sample sizes and for sparse matrices).
1194
+
1195
+ A naive approach would move observations one at a time from right
1196
+ to left using this formula:
1197
+ sigma = (ssq - s^2/n) / n
1198
+ ssq = sum(x^2)
1199
+ s = sum(x)
1200
+ But such approach has poor numerical precision, and this library is
1201
+ aimed precisely at cases in which there are outliers in the data.
1202
+ It's possible to improve the numerical precision by standardizing the
1203
+ data beforehand, but this library uses instead a more exact two-pass
1204
+ sigma calculation observation-by-observation (from left to right and
1205
+ from right to left, keeping the calculations of the first pass in an
1206
+ array and calculating gain in the second pass), but there's
1207
+ other methods too.
1208
+
1209
+ If one is aiming at maximizing the pooled gain, it's possible to
1210
+ simplify either the gain or the increase in gain without involving
1211
+ 'ssq'. Assuming one already has 'ssq' and 's' calculated for the left and
1212
+ right partitions, and one wants to move one ovservation from right to left,
1213
+ the following hold:
1214
+ s_right = s - s_left
1215
+ ssq_right = ssq - ssq_left
1216
+ n_right = n - n_left
1217
+ If one then moves observation x, these are updated as follows:
1218
+ s_left_new = s_left + x
1219
+ s_right_new = s - s_left - x
1220
+ ssq_left_new = ssq_left + x^2
1221
+ ssq_right_new = ssq - ssq_left - x^2
1222
+ n_left_new = n_left + 1
1223
+ n_right_new = n - n_left - 1
1224
+ Gain is then:
1225
+ (1/sigma) * (sigma - (1/n)*({ssq_left_new - s_left_new^2/n_left_new} + {ssq_right_new - s_right_new^2/n_right_new}))
1226
+ Which simplifies to:
1227
+ 1 - (1/(sigma*n))(ssq - ( (s_left + x)^2/(n_left+1) + (s - (s_left + x))^2/(n - (n_left+1)) ))
1228
+ Since 'sigma', n', and 'ssq' are constant, they can be ignored when determining the
1229
+ maximum gain - that is, one is interest in finding the point that maximizes:
1230
+ (s_left+x)^2/(n_left+1) + (s-(s_left+x))^2/(n-(n_left+1))
1231
+ And this calculation will be robust-enough when dealing with numbers that were
1232
+ already standardized beforehand, as the extended model does at each step.
1233
+ Note however that, when fitting this model, one is usually interested in evaluating
1234
+ the actual gain, standardized by the standard deviation, as it will try different
1235
+ linear combinations which will give different standard deviations, so this simpler
1236
+ formula cannot be applied unless only one linear combination is probed.
1237
+
1238
+ One can also look at:
1239
+ diff_gain = (1/sigma) * (gain_new - gain)
1240
+ Which can be simplified to something that doesn't include sums of squares:
1241
+ (1/(sigma*n))*( -s_left^2/n_left - (s-s_left)^2/(n-n_left) + (s_left+x)^2/(n_left+1) + (s-(s_left+x))^2/(n-(n_left+1)) )
1242
+ And this calculation would in theory allow getting the actual standardized gain.
1243
+ In practice however, this calculation can have poor numerical precision when the
1244
+ sample size is large, so the functions here do not even attempt at calculating it,
1245
+ and this is the reason why the two-pass approach is preferred.
1246
+
1247
+ The averaged SD formula unfortunately doesn't reduce to something that would involve
1248
+ only sums.
1249
+ */
1250
+
1251
+ /* TODO: maybe it's not a good idea to use the two-pass approach with un-standardized
1252
+ variables at large sample sizes (ndim=1), considering that they come in sorted order.
1253
+ Maybe it should instead use sums of centered squares: sigma = sqrt((x-mean(x))^2/n)
1254
+ The sums of centered squares method is also likely to be more precise. */
1255
+
1256
+ template <class real_t>
1257
+ double midpoint(real_t x, real_t y)
1258
+ {
1259
+ real_t m = x + (y-x)/(real_t)2;
1260
+ if (likely((double)m < (double)y))
1261
+ return m;
1262
+ else {
1263
+ m = std::nextafter(m, y);
1264
+ if (m > x && m < y)
1265
+ return m;
1266
+ else
1267
+ return x;
1268
+ }
1269
+ }
1270
+
1271
+ template <class real_t>
1272
+ double midpoint_with_reorder(real_t x, real_t y)
1273
+ {
1274
+ if (x < y)
1275
+ return midpoint(x, y);
1276
+ else
1277
+ return midpoint(y, x);
1278
+ }
1279
+
1280
+ #define sd_gain(sd, sd_left, sd_right) (1. - ((sd_left) + (sd_right)) / (2. * (sd)))
1281
+ #define pooled_gain(sd, cnt, sd_left, sd_right, cnt_left, cnt_right) \
1282
+ (1. - (1./(sd))*( ( ((real_t)(cnt_left))/(cnt) )*(sd_left) + ( ((real_t)(cnt_right)/(cnt)) )*(sd_right) ))
1283
+
1284
+
1285
+ /* TODO: make this a compensated sum */
1286
+ template <class real_t, class real_t_>
1287
+ double find_split_rel_gain_t(real_t_ *restrict x, size_t n, double &restrict split_point)
1288
+ {
1289
+ real_t this_gain;
1290
+ real_t best_gain = -HUGE_VAL;
1291
+ real_t x1 = 0, x2 = 0;
1292
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0;
1293
+ for (size_t row = 0; row < n; row++)
1294
+ sum_tot += x[row];
1295
+ for (size_t row = 0; row < n-1; row++)
1296
+ {
1297
+ sum_left += x[row];
1298
+ if (x[row] == x[row+1])
1299
+ continue;
1300
+
1301
+ sum_right = sum_tot - sum_left;
1302
+ this_gain = sum_left * (sum_left / (real_t)(row+1))
1303
+ + sum_right * (sum_right / (real_t)(n-row-1));
1304
+ if (this_gain > best_gain)
1305
+ {
1306
+ best_gain = this_gain;
1307
+ x1 = x[row]; x2 = x[row+1];
1308
+ }
1309
+ }
1310
+
1311
+ if (best_gain <= -HUGE_VAL)
1312
+ return best_gain;
1313
+ split_point = midpoint(x1, x2);
1314
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1315
+ }
1316
+
1317
+ template <class real_t_, class ldouble_safe>
1318
+ double find_split_rel_gain(real_t_ *restrict x, size_t n, double &restrict split_point)
1319
+ {
1320
+ if (n < THRESHOLD_LONG_DOUBLE)
1321
+ return find_split_rel_gain_t<double, real_t_>(x, n, split_point);
1322
+ else
1323
+ return find_split_rel_gain_t<ldouble_safe, real_t_>(x, n, split_point);
1324
+ }
1325
+
1326
+ /* Note: there is no 'weighted' version of 'find_split_rel_gain' with unindexed 'x', because calling it would
1327
+ imply having to argsort the 'x' values in order to sort the weights, which is less efficient. */
1328
+
1329
+ template <class real_t, class real_t_>
1330
+ double find_split_rel_gain_t(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &split_point, size_t &restrict split_ix)
1331
+ {
1332
+ real_t this_gain;
1333
+ real_t best_gain = -HUGE_VAL;
1334
+ split_ix = 0; /* <- avoid out-of-bounds at the end */
1335
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0;
1336
+ for (size_t row = st; row <= end; row++)
1337
+ sum_tot += x[ix_arr[row]] - xmean;
1338
+ for (size_t row = st; row < end; row++)
1339
+ {
1340
+ sum_left += x[ix_arr[row]] - xmean;
1341
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1342
+ continue;
1343
+
1344
+ sum_right = sum_tot - sum_left;
1345
+ this_gain = sum_left * (sum_left / (real_t)(row - st + 1))
1346
+ + sum_right * (sum_right / (real_t)(end - row));
1347
+ if (this_gain > best_gain)
1348
+ {
1349
+ best_gain = this_gain;
1350
+ split_ix = row;
1351
+ }
1352
+ }
1353
+
1354
+ if (best_gain <= -HUGE_VAL)
1355
+ return best_gain;
1356
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1357
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1358
+ }
1359
+
1360
+ template <class real_t_, class ldouble_safe>
1361
+ double find_split_rel_gain(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &restrict split_point, size_t &restrict split_ix)
1362
+ {
1363
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1364
+ return find_split_rel_gain_t<double, real_t_>(x, xmean, ix_arr, st, end, split_point, split_ix);
1365
+ else
1366
+ return find_split_rel_gain_t<ldouble_safe, real_t_>(x, xmean, ix_arr, st, end, split_point, split_ix);
1367
+ }
1368
+
1369
+ template <class real_t, class real_t_, class mapping>
1370
+ double find_split_rel_gain_weighted_t(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &split_point, size_t &restrict split_ix, mapping &restrict w)
1371
+ {
1372
+ real_t this_gain;
1373
+ real_t best_gain = -HUGE_VAL;
1374
+ split_ix = 0; /* <- avoid out-of-bounds at the end */
1375
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0, sumw = 0, sumw_tot = 0;
1376
+
1377
+ for (size_t row = st; row <= end; row++)
1378
+ sumw_tot += w[ix_arr[row]];
1379
+ for (size_t row = st; row <= end; row++)
1380
+ sum_tot += x[ix_arr[row]] - xmean;
1381
+ for (size_t row = st; row < end; row++)
1382
+ {
1383
+ sumw += w[ix_arr[row]];
1384
+ sum_left += x[ix_arr[row]] - xmean;
1385
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1386
+ continue;
1387
+
1388
+ sum_right = sum_tot - sum_left;
1389
+ this_gain = sum_left * (sum_left / sumw)
1390
+ + sum_right * (sum_right / (sumw_tot - sumw));
1391
+ if (this_gain > best_gain)
1392
+ {
1393
+ best_gain = this_gain;
1394
+ split_ix = row;
1395
+ }
1396
+ }
1397
+
1398
+ if (best_gain <= -HUGE_VAL)
1399
+ return best_gain;
1400
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1401
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1402
+ }
1403
+
1404
+ template <class real_t_, class mapping, class ldouble_safe>
1405
+ double find_split_rel_gain_weighted(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
1406
+ {
1407
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1408
+ return find_split_rel_gain_weighted_t<double, real_t_, mapping>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
1409
+ else
1410
+ return find_split_rel_gain_weighted_t<ldouble_safe, real_t_, mapping>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
1411
+ }
1412
+
1413
+
1414
+ template <class real_t, class real_t_>
1415
+ real_t calc_sd_right_to_left(real_t_ *restrict x, size_t n, double *restrict sd_arr)
1416
+ {
1417
+ real_t running_mean = 0;
1418
+ real_t running_ssq = 0;
1419
+ real_t mean_prev = x[n-1];
1420
+ for (size_t row = 0; row < n-1; row++)
1421
+ {
1422
+ running_mean += (x[n-row-1] - running_mean) / (real_t)(row+1);
1423
+ running_ssq += (x[n-row-1] - running_mean) * (x[n-row-1] - mean_prev);
1424
+ mean_prev = running_mean;
1425
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1426
+ }
1427
+ running_mean += (x[0] - running_mean) / (real_t)n;
1428
+ running_ssq += (x[0] - running_mean) * (x[0] - mean_prev);
1429
+ return std::sqrt(running_ssq / (real_t)n);
1430
+ }
1431
+
1432
+ template <class real_t_, class ldouble_safe>
1433
+ ldouble_safe calc_sd_right_to_left_weighted(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1434
+ double *restrict w, ldouble_safe &cumw, size_t *restrict sorted_ix)
1435
+ {
1436
+ ldouble_safe running_mean = 0;
1437
+ ldouble_safe running_ssq = 0;
1438
+ ldouble_safe mean_prev = x[sorted_ix[n-1]];
1439
+ ldouble_safe cnt = 0;
1440
+ double w_this;
1441
+ for (size_t row = 0; row < n-1; row++)
1442
+ {
1443
+ w_this = w[sorted_ix[n-row-1]];
1444
+ cnt += w_this;
1445
+ running_mean += w_this * (x[sorted_ix[n-row-1]] - running_mean) / cnt;
1446
+ running_ssq += w_this * ((x[sorted_ix[n-row-1]] - running_mean) * (x[sorted_ix[n-row-1]] - mean_prev));
1447
+ mean_prev = running_mean;
1448
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / cnt);
1449
+ }
1450
+ w_this = w[sorted_ix[0]];
1451
+ cnt += w_this;
1452
+ running_mean += (x[sorted_ix[0]] - running_mean) / cnt;
1453
+ running_ssq += w_this * ((x[sorted_ix[0]] - running_mean) * (x[sorted_ix[0]] - mean_prev));
1454
+ cumw = cnt;
1455
+ return std::sqrt(running_ssq / cnt);
1456
+ }
1457
+
1458
+ template <class real_t, class real_t_>
1459
+ real_t calc_sd_right_to_left(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr)
1460
+ {
1461
+ real_t running_mean = 0;
1462
+ real_t running_ssq = 0;
1463
+ real_t mean_prev = x[ix_arr[end]] - xmean;
1464
+ size_t n = end - st + 1;
1465
+ for (size_t row = 0; row < n-1; row++)
1466
+ {
1467
+ running_mean += ((x[ix_arr[end-row]] - xmean) - running_mean) / (real_t)(row+1);
1468
+ running_ssq += ((x[ix_arr[end-row]] - xmean) - running_mean) * ((x[ix_arr[end-row]] - xmean) - mean_prev);
1469
+ mean_prev = running_mean;
1470
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1471
+ }
1472
+ running_mean += ((x[ix_arr[st]] - xmean) - running_mean) / (real_t)n;
1473
+ running_ssq += ((x[ix_arr[st]] - xmean) - running_mean) * ((x[ix_arr[st]] - xmean) - mean_prev);
1474
+ return std::sqrt(running_ssq / (real_t)n);
1475
+ }
1476
+
1477
+ template <class real_t_, class mapping, class ldouble_safe>
1478
+ ldouble_safe calc_sd_right_to_left_weighted(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end,
1479
+ double *restrict sd_arr, mapping &restrict w, ldouble_safe &cumw)
1480
+ {
1481
+ ldouble_safe running_mean = 0;
1482
+ ldouble_safe running_ssq = 0;
1483
+ real_t_ mean_prev = x[ix_arr[end]] - xmean;
1484
+ size_t n = end - st + 1;
1485
+ ldouble_safe cnt = 0;
1486
+ double w_this;
1487
+ for (size_t row = 0; row < n-1; row++)
1488
+ {
1489
+ w_this = w[ix_arr[end-row]];
1490
+ cnt += w_this;
1491
+ running_mean += w_this * ((x[ix_arr[end-row]] - xmean) - running_mean) / cnt;
1492
+ running_ssq += w_this * (((x[ix_arr[end-row]] - xmean) - running_mean) * ((x[ix_arr[end-row]] - xmean) - mean_prev));
1493
+ mean_prev = running_mean;
1494
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / cnt);
1495
+ }
1496
+ w_this = w[ix_arr[st]];
1497
+ cnt += w_this;
1498
+ running_mean += ((x[ix_arr[st]] - xmean) - running_mean) / cnt;
1499
+ running_ssq += w_this * (((x[ix_arr[st]] - xmean) - running_mean) * ((x[ix_arr[st]] - xmean) - mean_prev));
1500
+ cumw = cnt;
1501
+ return std::sqrt(running_ssq / cnt);
1502
+ }
1503
+
1504
+ template <class real_t, class real_t_>
1505
+ double find_split_std_gain_t(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1506
+ GainCriterion criterion, double min_gain, double &restrict split_point)
1507
+ {
1508
+ real_t full_sd = calc_sd_right_to_left<real_t>(x, n, sd_arr);
1509
+ real_t running_mean = 0;
1510
+ real_t running_ssq = 0;
1511
+ real_t mean_prev = x[0];
1512
+ real_t best_gain = -HUGE_VAL;
1513
+ real_t this_sd, this_gain;
1514
+ real_t n_ = (real_t)n;
1515
+ size_t best_ix = 0;
1516
+ for (size_t row = 0; row < n-1; row++)
1517
+ {
1518
+ running_mean += (x[row] - running_mean) / (real_t)(row+1);
1519
+ running_ssq += (x[row] - running_mean) * (x[row] - mean_prev);
1520
+ mean_prev = running_mean;
1521
+ if (x[row] == x[row+1])
1522
+ continue;
1523
+
1524
+ this_sd = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1525
+ this_gain = (criterion == Pooled)?
1526
+ pooled_gain(full_sd, n_, this_sd, sd_arr[row+1], row+1, n-row-1)
1527
+ :
1528
+ sd_gain(full_sd, this_sd, sd_arr[row+1]);
1529
+ if (this_gain > best_gain && this_gain > min_gain)
1530
+ {
1531
+ best_gain = this_gain;
1532
+ best_ix = row;
1533
+ }
1534
+ }
1535
+
1536
+ if (best_gain > -HUGE_VAL)
1537
+ split_point = midpoint(x[best_ix], x[best_ix+1]);
1538
+
1539
+ return best_gain;
1540
+ }
1541
+
1542
+ template <class real_t_, class ldouble_safe>
1543
+ double find_split_std_gain(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1544
+ GainCriterion criterion, double min_gain, double &restrict split_point)
1545
+ {
1546
+ if (n < THRESHOLD_LONG_DOUBLE)
1547
+ return find_split_std_gain_t<double, real_t_>(x, n, sd_arr, criterion, min_gain, split_point);
1548
+ else
1549
+ return find_split_std_gain_t<ldouble_safe, real_t_>(x, n, sd_arr, criterion, min_gain, split_point);
1550
+ }
1551
+
1552
+ template <class real_t, class ldouble_safe>
1553
+ double find_split_std_gain_weighted(real_t *restrict x, size_t n, double *restrict sd_arr,
1554
+ GainCriterion criterion, double min_gain, double &restrict split_point,
1555
+ double *restrict w, size_t *restrict sorted_ix)
1556
+ {
1557
+ ldouble_safe cumw;
1558
+ double full_sd = calc_sd_right_to_left_weighted(x, n, sd_arr, w, cumw, sorted_ix);
1559
+ ldouble_safe running_mean = 0;
1560
+ ldouble_safe running_ssq = 0;
1561
+ ldouble_safe mean_prev = x[sorted_ix[0]];
1562
+ double best_gain = -HUGE_VAL;
1563
+ double this_sd, this_gain;
1564
+ double w_this;
1565
+ ldouble_safe currw = 0;
1566
+ size_t best_ix = 0;
1567
+
1568
+ for (size_t row = 0; row < n-1; row++)
1569
+ {
1570
+ w_this = w[sorted_ix[row]];
1571
+ currw += w_this;
1572
+ running_mean += w_this * (x[sorted_ix[row]] - running_mean) / currw;
1573
+ running_ssq += w_this * ((x[sorted_ix[row]] - running_mean) * (x[sorted_ix[row]] - mean_prev));
1574
+ mean_prev = running_mean;
1575
+ if (x[sorted_ix[row]] == x[sorted_ix[row+1]])
1576
+ continue;
1577
+
1578
+ this_sd = (row == 0)? 0. : std::sqrt(running_ssq / currw);
1579
+ this_gain = (criterion == Pooled)?
1580
+ pooled_gain(full_sd, cumw, this_sd, sd_arr[row+1], currw, cumw-currw)
1581
+ :
1582
+ sd_gain(full_sd, this_sd, sd_arr[row+1]);
1583
+ if (this_gain > best_gain && this_gain > min_gain)
1584
+ {
1585
+ best_gain = this_gain;
1586
+ best_ix = row;
1587
+ }
1588
+ }
1589
+
1590
+ if (best_gain > -HUGE_VAL)
1591
+ split_point = midpoint(x[sorted_ix[best_ix]], x[sorted_ix[best_ix+1]]);
1592
+
1593
+ return best_gain;
1594
+ }
1595
+
1596
+ template <class real_t, class real_t_>
1597
+ double find_split_std_gain_t(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1598
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix)
1599
+ {
1600
+ real_t full_sd = calc_sd_right_to_left<real_t>(x, xmean, ix_arr, st, end, sd_arr);
1601
+ real_t running_mean = 0;
1602
+ real_t running_ssq = 0;
1603
+ real_t mean_prev = x[ix_arr[st]] - xmean;
1604
+ real_t best_gain = -HUGE_VAL;
1605
+ real_t n = (real_t)(end - st + 1);
1606
+ real_t this_sd, this_gain;
1607
+ split_ix = st;
1608
+ for (size_t row = st; row < end; row++)
1609
+ {
1610
+ running_mean += ((x[ix_arr[row]] - xmean) - running_mean) / (real_t)(row-st+1);
1611
+ running_ssq += ((x[ix_arr[row]] - xmean) - running_mean) * ((x[ix_arr[row]] - xmean) - mean_prev);
1612
+ mean_prev = running_mean;
1613
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1614
+ continue;
1615
+
1616
+ this_sd = (row == st)? 0. : std::sqrt(running_ssq / (real_t)(row-st+1));
1617
+ this_gain = (criterion == Pooled)?
1618
+ pooled_gain(full_sd, n, this_sd, sd_arr[row-st+1], row-st+1, end-row)
1619
+ :
1620
+ sd_gain(full_sd, this_sd, sd_arr[row-st+1]);
1621
+ if (this_gain > best_gain && this_gain > min_gain)
1622
+ {
1623
+ best_gain = this_gain;
1624
+ split_ix = row;
1625
+ }
1626
+ }
1627
+
1628
+ if (best_gain > -HUGE_VAL)
1629
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1630
+
1631
+ return best_gain;
1632
+ }
1633
+
1634
+ template <class real_t_, class ldouble_safe>
1635
+ double find_split_std_gain(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1636
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix)
1637
+ {
1638
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1639
+ return find_split_std_gain_t<double, real_t_>(x, xmean, ix_arr, st, end, sd_arr, criterion, min_gain, split_point, split_ix);
1640
+ else
1641
+ return find_split_std_gain_t<ldouble_safe, real_t_>(x, xmean, ix_arr, st, end, sd_arr, criterion, min_gain, split_point, split_ix);
1642
+ }
1643
+
1644
+ template <class real_t, class mapping, class ldouble_safe>
1645
+ double find_split_std_gain_weighted(real_t *restrict x, real_t xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1646
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
1647
+ {
1648
+ ldouble_safe cumw;
1649
+ double full_sd = calc_sd_right_to_left_weighted(x, xmean, ix_arr, st, end, sd_arr, w, cumw);
1650
+ ldouble_safe running_mean = 0;
1651
+ ldouble_safe running_ssq = 0;
1652
+ ldouble_safe mean_prev = x[ix_arr[st]] - xmean;
1653
+ double best_gain = -HUGE_VAL;
1654
+ ldouble_safe currw = 0;
1655
+ double this_sd, this_gain;
1656
+ double w_this;
1657
+ split_ix = st;
1658
+
1659
+ for (size_t row = st; row < end; row++)
1660
+ {
1661
+ w_this = w[ix_arr[row]];
1662
+ currw += w_this;
1663
+ running_mean += w_this * ((x[ix_arr[row]] - xmean) - running_mean) / currw;
1664
+ running_ssq += w_this * (((x[ix_arr[row]] - xmean) - running_mean) * ((x[ix_arr[row]] - xmean) - mean_prev));
1665
+ mean_prev = running_mean;
1666
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1667
+ continue;
1668
+
1669
+ this_sd = (row == st)? 0. : std::sqrt(running_ssq / currw);
1670
+ this_gain = (criterion == Pooled)?
1671
+ pooled_gain(full_sd, cumw, this_sd, sd_arr[row-st+1], currw, cumw-currw)
1672
+ :
1673
+ sd_gain(full_sd, this_sd, sd_arr[row-st+1]);
1674
+ if (this_gain > best_gain && this_gain > min_gain)
1675
+ {
1676
+ best_gain = this_gain;
1677
+ split_ix = row;
1678
+ }
1679
+ }
1680
+
1681
+ if (best_gain > -HUGE_VAL)
1682
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1683
+
1684
+ return best_gain;
1685
+ }
1686
+
1687
+ #ifndef _FOR_R
1688
+ #if defined(__clang__)
1689
+ #pragma clang diagnostic push
1690
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
1691
+ #endif
1692
+ #endif
1693
+
1694
+ #ifndef _FOR_R
1695
+ [[gnu::optimize("Ofast")]]
1696
+ #endif
1697
+ static inline void xpy1(double *restrict x, double *restrict y, size_t n)
1698
+ {
1699
+ for (size_t ix = 0; ix < n; ix++) y[ix] += x[ix];
1700
+ }
1701
+
1702
+ #ifndef _FOR_R
1703
+ [[gnu::optimize("Ofast")]]
1704
+ #endif
1705
+ static inline void axpy1(const double a, double *restrict x, double *restrict y, size_t n)
1706
+ {
1707
+ for (size_t ix = 0; ix < n; ix++) y[ix] = std::fma(a, x[ix], y[ix]);
1708
+ }
1709
+
1710
+ #ifndef _FOR_R
1711
+ [[gnu::optimize("Ofast")]]
1712
+ #endif
1713
+ static inline void xpy1(double *restrict xval, size_t ind[], size_t nnz, double *restrict y)
1714
+ {
1715
+ for (size_t ix = 0; ix < nnz; ix++) y[ind[ix]] += xval[ix];
1716
+ }
1717
+
1718
+ #ifndef _FOR_R
1719
+ [[gnu::optimize("Ofast")]]
1720
+ #endif
1721
+ static inline void axpy1(const double a, double *restrict xval, size_t ind[], size_t nnz, double *restrict y)
1722
+ {
1723
+ for (size_t ix = 0; ix < nnz; ix++) y[ind[ix]] = std::fma(a, xval[ix], y[ind[ix]]);
1724
+ }
1725
+
1726
+ #ifndef _FOR_R
1727
+ #if defined(__clang__)
1728
+ #pragma clang diagnostic pop
1729
+ #endif
1730
+ #endif
1731
+
1732
+ template <class real_t, class ldouble_safe>
1733
+ double find_split_full_gain(real_t *restrict x, size_t st, size_t end, size_t *restrict ix_arr,
1734
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
1735
+ double *restrict X_row_major, size_t ncols,
1736
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
1737
+ double *restrict buffer_sum_left, double *restrict buffer_sum_tot,
1738
+ size_t &restrict split_ix, double &restrict split_point,
1739
+ bool x_uses_ix_arr)
1740
+ {
1741
+ if (end <= st) return -HUGE_VAL;
1742
+ if (cols_use != NULL && ncols_use && (double)ncols_use / (double)ncols < 0.1)
1743
+ force_cols_use = true;
1744
+
1745
+ memset(buffer_sum_tot, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
1746
+ if (Xr_indptr == NULL)
1747
+ {
1748
+ if (force_cols_use)
1749
+ {
1750
+ double *restrict ptr_row;
1751
+ for (size_t row = st; row <= end; row++)
1752
+ {
1753
+ ptr_row = X_row_major + ix_arr[row]*ncols;
1754
+ for (size_t col = 0; col < ncols_use; col++)
1755
+ buffer_sum_tot[col] += ptr_row[cols_use[col]];
1756
+ }
1757
+ }
1758
+
1759
+ else
1760
+ {
1761
+ for (size_t row = st; row <= end; row++)
1762
+ xpy1(X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
1763
+ }
1764
+ }
1765
+
1766
+ else
1767
+ {
1768
+ if (force_cols_use)
1769
+ {
1770
+ size_t *curr_begin;
1771
+ size_t *row_end;
1772
+ size_t *curr_col;
1773
+ double *restrict Xr_this;
1774
+ size_t *cols_end = cols_use + ncols_use;
1775
+ for (size_t row = st; row <= end; row++)
1776
+ {
1777
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
1778
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
1779
+ if (curr_begin == row_end) continue;
1780
+ curr_col = cols_use;
1781
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
1782
+
1783
+ while (curr_col < cols_end && curr_begin < row_end)
1784
+ {
1785
+ if (*curr_begin == *curr_col)
1786
+ {
1787
+ buffer_sum_tot[std::distance(cols_use, curr_col)] += Xr_this[std::distance(curr_begin, row_end)];
1788
+ curr_col++;
1789
+ curr_begin++;
1790
+ }
1791
+
1792
+ else
1793
+ {
1794
+ if (*curr_begin > *curr_col)
1795
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
1796
+ else
1797
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
1798
+ }
1799
+ }
1800
+ }
1801
+ }
1802
+
1803
+ else
1804
+ {
1805
+ size_t ptr_this;
1806
+ for (size_t row = st; row <= end; row++)
1807
+ {
1808
+ ptr_this = Xr_indptr[ix_arr[row]];
1809
+ xpy1(Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
1810
+ }
1811
+ }
1812
+ }
1813
+
1814
+ double best_gain = -HUGE_VAL;
1815
+ double this_gain;
1816
+ double sl, sr;
1817
+ double dl, dr;
1818
+ double vleft, vright;
1819
+ memset(buffer_sum_left, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
1820
+ if (Xr_indptr == NULL)
1821
+ {
1822
+ if (!force_cols_use)
1823
+ {
1824
+ for (size_t row = st; row < end; row++)
1825
+ {
1826
+ xpy1(X_row_major + ix_arr[row]*ncols, buffer_sum_left, ncols);
1827
+ if (x_uses_ix_arr) {
1828
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1829
+ }
1830
+ else {
1831
+ if (unlikely(x[row] == x[row+1])) continue;
1832
+ }
1833
+
1834
+ vleft = 0;
1835
+ vright = 0;
1836
+ dl = (double)(row-st+1);
1837
+ dr = (double)(end-row);
1838
+ for (size_t col = 0; col < ncols; col++)
1839
+ {
1840
+ sl = buffer_sum_left[col];
1841
+ vleft += sl * (sl / dl);
1842
+ sr = buffer_sum_tot[col] - sl;
1843
+ vright += sr * (sr / dr);
1844
+ }
1845
+
1846
+ this_gain = vleft + vright;
1847
+ if (this_gain > best_gain)
1848
+ {
1849
+ best_gain = this_gain;
1850
+ split_ix = row;
1851
+ }
1852
+ }
1853
+ }
1854
+
1855
+ else
1856
+ {
1857
+ double *restrict ptr_row;
1858
+ for (size_t row = st; row < end; row++)
1859
+ {
1860
+ ptr_row = X_row_major + ix_arr[row]*ncols;
1861
+ for (size_t col = 0; col < ncols_use; col++)
1862
+ buffer_sum_left[col] += ptr_row[cols_use[col]];
1863
+ if (x_uses_ix_arr) {
1864
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1865
+ }
1866
+ else {
1867
+ if (unlikely(x[row] == x[row+1])) continue;
1868
+ }
1869
+
1870
+ vleft = 0;
1871
+ vright = 0;
1872
+ dl = (double)(row-st+1);
1873
+ dr = (double)(end-row);
1874
+ for (size_t col = 0; col < ncols_use; col++)
1875
+ {
1876
+ sl = buffer_sum_left[col];
1877
+ vleft += sl * (sl / dl);
1878
+ sr = buffer_sum_tot[col] - sl;
1879
+ vright += sr * (sr / dr);
1880
+ }
1881
+
1882
+ this_gain = vleft + vright;
1883
+ if (this_gain > best_gain)
1884
+ {
1885
+ best_gain = this_gain;
1886
+ split_ix = row;
1887
+ }
1888
+ }
1889
+ }
1890
+ }
1891
+
1892
+ else
1893
+ {
1894
+ if (!force_cols_use)
1895
+ {
1896
+ size_t ptr_this;
1897
+ for (size_t row = st; row < end; row++)
1898
+ {
1899
+ ptr_this = Xr_indptr[ix_arr[row]];
1900
+ xpy1(Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_left);
1901
+ if (x_uses_ix_arr) {
1902
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1903
+ }
1904
+ else {
1905
+ if (unlikely(x[row] == x[row+1])) continue;
1906
+ }
1907
+
1908
+ vleft = 0;
1909
+ vright = 0;
1910
+ dl = (double)(row-st+1);
1911
+ dr = (double)(end-row);
1912
+ for (size_t col = 0; col < ncols; col++)
1913
+ {
1914
+ sl = buffer_sum_left[col];
1915
+ vleft += sl * (sl / dl);
1916
+ sr = buffer_sum_tot[col] - sl;
1917
+ vright += sr * (sr / dr);
1918
+ }
1919
+
1920
+ this_gain = vleft + vright;
1921
+ if (this_gain > best_gain)
1922
+ {
1923
+ best_gain = this_gain;
1924
+ split_ix = row;
1925
+ }
1926
+ }
1927
+ }
1928
+
1929
+ else
1930
+ {
1931
+ size_t *curr_begin;
1932
+ size_t *row_end;
1933
+ size_t *curr_col;
1934
+ double *restrict Xr_this;
1935
+ size_t *cols_end = cols_use + ncols_use;
1936
+ for (size_t row = st; row < end; row++)
1937
+ {
1938
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
1939
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
1940
+ if (curr_begin == row_end) goto skip_sum;
1941
+ curr_col = cols_use;
1942
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
1943
+ while (curr_col < cols_end && curr_begin < row_end)
1944
+ {
1945
+ if (*curr_begin == *curr_col)
1946
+ {
1947
+ buffer_sum_left[std::distance(cols_use, curr_col)] += Xr_this[std::distance(curr_begin, row_end)];
1948
+ curr_col++;
1949
+ curr_begin++;
1950
+ }
1951
+
1952
+ else
1953
+ {
1954
+ if (*curr_begin > *curr_col)
1955
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
1956
+ else
1957
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
1958
+ }
1959
+ }
1960
+
1961
+ skip_sum:
1962
+ if (x_uses_ix_arr) {
1963
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1964
+ }
1965
+ else {
1966
+ if (unlikely(x[row] == x[row+1])) continue;
1967
+ }
1968
+
1969
+ vleft = 0;
1970
+ vright = 0;
1971
+ dl = (double)(row-st+1);
1972
+ dr = (double)(end-row);
1973
+ for (size_t col = 0; col < ncols_use; col++)
1974
+ {
1975
+ sl = buffer_sum_left[col];
1976
+ vleft += sl * (sl / dl);
1977
+ sr = buffer_sum_tot[col] - sl;
1978
+ vright += sr * (sr / dr);
1979
+ }
1980
+
1981
+ this_gain = vleft + vright;
1982
+ if (this_gain > best_gain)
1983
+ {
1984
+ best_gain = this_gain;
1985
+ split_ix = row;
1986
+ }
1987
+ }
1988
+ }
1989
+ }
1990
+
1991
+ if (best_gain <= -HUGE_VAL) return best_gain;
1992
+
1993
+ if (x_uses_ix_arr)
1994
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1995
+ else
1996
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
1997
+ return best_gain / (ldouble_safe)(end - st + 1);
1998
+ }
1999
+
2000
+ template <class real_t, class mapping, class ldouble_safe>
2001
+ double find_split_full_gain_weighted(real_t *restrict x, size_t st, size_t end, size_t *restrict ix_arr,
2002
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
2003
+ double *restrict X_row_major, size_t ncols,
2004
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
2005
+ double *restrict buffer_sum_left, double *restrict buffer_sum_tot,
2006
+ size_t &restrict split_ix, double &restrict split_point,
2007
+ bool x_uses_ix_arr,
2008
+ mapping &restrict w)
2009
+ {
2010
+ if (end <= st) return -HUGE_VAL;
2011
+ if (cols_use != NULL && ncols_use && (double)ncols_use / (double)ncols < 0.1)
2012
+ force_cols_use = true;
2013
+
2014
+ double wtot = 0;
2015
+ if (x_uses_ix_arr)
2016
+ {
2017
+ for (size_t row = st; row <= end; row++)
2018
+ wtot += w[ix_arr[row]];
2019
+ }
2020
+
2021
+ else
2022
+ {
2023
+ for (size_t row = st; row <= end; row++)
2024
+ wtot += w[row];
2025
+ }
2026
+
2027
+ memset(buffer_sum_tot, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
2028
+ if (Xr_indptr == NULL)
2029
+ {
2030
+ if (!force_cols_use)
2031
+ {
2032
+ if (x_uses_ix_arr)
2033
+ {
2034
+ for (size_t row = st; row <= end; row++)
2035
+ axpy1(w[ix_arr[row]], X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
2036
+ }
2037
+
2038
+ else
2039
+ {
2040
+ for (size_t row = st; row <= end; row++)
2041
+ axpy1(w[row], X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
2042
+ }
2043
+ }
2044
+
2045
+ else
2046
+ {
2047
+ double *restrict ptr_row;
2048
+ double w_row;
2049
+
2050
+ if (x_uses_ix_arr)
2051
+ {
2052
+ for (size_t row = st; row <= end; row++)
2053
+ {
2054
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2055
+ w_row = w[ix_arr[row]];
2056
+ for (size_t col = 0; col < ncols_use; col++)
2057
+ buffer_sum_tot[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_tot[col]);
2058
+ }
2059
+ }
2060
+
2061
+ else
2062
+ {
2063
+ for (size_t row = st; row <= end; row++)
2064
+ {
2065
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2066
+ w_row = w[row];
2067
+ for (size_t col = 0; col < ncols_use; col++)
2068
+ buffer_sum_tot[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_tot[col]);
2069
+ }
2070
+ }
2071
+ }
2072
+ }
2073
+
2074
+ else
2075
+ {
2076
+ if (!force_cols_use)
2077
+ {
2078
+ size_t ptr_this;
2079
+ if (x_uses_ix_arr)
2080
+ {
2081
+ for (size_t row = st; row <= end; row++)
2082
+ {
2083
+ ptr_this = Xr_indptr[ix_arr[row]];
2084
+ axpy1(w[ix_arr[row]], Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
2085
+ }
2086
+ }
2087
+
2088
+ else
2089
+ {
2090
+ for (size_t row = st; row <= end; row++)
2091
+ {
2092
+ ptr_this = Xr_indptr[ix_arr[row]];
2093
+ axpy1(w[row], Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
2094
+ }
2095
+ }
2096
+ }
2097
+
2098
+ else
2099
+ {
2100
+ size_t *curr_begin;
2101
+ size_t *row_end;
2102
+ size_t *curr_col;
2103
+ double *restrict Xr_this;
2104
+ size_t *cols_end = cols_use + ncols_use;
2105
+ double w_row;
2106
+ for (size_t row = st; row <= end; row++)
2107
+ {
2108
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
2109
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
2110
+ if (curr_begin == row_end) continue;
2111
+ curr_col = cols_use;
2112
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
2113
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2114
+ size_t dtemp;
2115
+
2116
+ while (curr_col < cols_end && curr_begin < row_end)
2117
+ {
2118
+ if (*curr_begin == *curr_col)
2119
+ {
2120
+ dtemp = std::distance(cols_use, curr_col);
2121
+ buffer_sum_tot[dtemp]
2122
+ =
2123
+ std::fma(w_row, Xr_this[std::distance(curr_begin, row_end)], buffer_sum_tot[dtemp]);
2124
+ curr_col++;
2125
+ curr_begin++;
2126
+ }
2127
+
2128
+ else
2129
+ {
2130
+ if (*curr_begin > *curr_col)
2131
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
2132
+ else
2133
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
2134
+ }
2135
+ }
2136
+ }
2137
+ }
2138
+ }
2139
+
2140
+ double best_gain = -HUGE_VAL;
2141
+ double this_gain;
2142
+ double sl, sr;
2143
+ double vleft, vright;
2144
+ double wleft = 0;
2145
+ double w_row;
2146
+ double wright;
2147
+ memset(buffer_sum_left, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
2148
+ if (Xr_indptr == NULL)
2149
+ {
2150
+ if (!force_cols_use)
2151
+ {
2152
+ for (size_t row = st; row < end; row++)
2153
+ {
2154
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2155
+ wleft += w_row;
2156
+ axpy1(w_row, X_row_major + ix_arr[row]*ncols, buffer_sum_left, ncols);
2157
+ if (x_uses_ix_arr) {
2158
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2159
+ }
2160
+ else {
2161
+ if (unlikely(x[row] == x[row+1])) continue;
2162
+ }
2163
+
2164
+ vleft = 0;
2165
+ vright = 0;
2166
+ wright = wtot - wleft;
2167
+ for (size_t col = 0; col < ncols; col++)
2168
+ {
2169
+ sl = buffer_sum_left[col];
2170
+ vleft += sl * (sl / wleft);
2171
+ sr = buffer_sum_tot[col] - sl;
2172
+ vright += sr * (sr / wright);
2173
+ }
2174
+
2175
+ this_gain = vleft + vright;
2176
+ if (this_gain > best_gain)
2177
+ {
2178
+ best_gain = this_gain;
2179
+ split_ix = row;
2180
+ }
2181
+ }
2182
+ }
2183
+
2184
+ else
2185
+ {
2186
+ double *restrict ptr_row;
2187
+ double w_row;
2188
+ for (size_t row = st; row < end; row++)
2189
+ {
2190
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2191
+ wleft += w_row;
2192
+
2193
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2194
+ for (size_t col = 0; col < ncols_use; col++)
2195
+ buffer_sum_left[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_left[col]);
2196
+ if (x_uses_ix_arr) {
2197
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2198
+ }
2199
+ else {
2200
+ if (unlikely(x[row] == x[row+1])) continue;
2201
+ }
2202
+
2203
+ vleft = 0;
2204
+ vright = 0;
2205
+ wright = wtot - wleft;
2206
+ for (size_t col = 0; col < ncols_use; col++)
2207
+ {
2208
+ sl = buffer_sum_left[col];
2209
+ vleft += sl * (sl / wleft);
2210
+ sr = buffer_sum_tot[col] - sl;
2211
+ vright += sr * (sr / wright);
2212
+ }
2213
+
2214
+ this_gain = vleft + vright;
2215
+ if (this_gain > best_gain)
2216
+ {
2217
+ best_gain = this_gain;
2218
+ split_ix = row;
2219
+ }
2220
+ }
2221
+ }
2222
+ }
2223
+
2224
+ else
2225
+ {
2226
+ if (!force_cols_use)
2227
+ {
2228
+ size_t ptr_this;
2229
+ double w_row;
2230
+ for (size_t row = st; row < end; row++)
2231
+ {
2232
+ w_row= w[x_uses_ix_arr? ix_arr[row] : row];
2233
+ wleft += w_row;
2234
+ ptr_this = Xr_indptr[ix_arr[row]];
2235
+ axpy1(w_row, Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_left);
2236
+ if (x_uses_ix_arr) {
2237
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2238
+ }
2239
+ else {
2240
+ if (unlikely(x[row] == x[row+1])) continue;
2241
+ }
2242
+
2243
+ vleft = 0;
2244
+ vright = 0;
2245
+ wright = wtot - wleft;
2246
+ for (size_t col = 0; col < ncols; col++)
2247
+ {
2248
+ sl = buffer_sum_left[col];
2249
+ vleft += sl * (sl / wleft);
2250
+ sr = buffer_sum_tot[col] - sl;
2251
+ vright += sr * (sr / wright);
2252
+ }
2253
+
2254
+ this_gain = vleft + vright;
2255
+ if (this_gain > best_gain)
2256
+ {
2257
+ best_gain = this_gain;
2258
+ split_ix = row;
2259
+ }
2260
+ }
2261
+ }
2262
+
2263
+ else
2264
+ {
2265
+ size_t *curr_begin;
2266
+ size_t *row_end;
2267
+ size_t *curr_col;
2268
+ double *restrict Xr_this;
2269
+ size_t *cols_end = cols_use + ncols_use;
2270
+ double w_row;
2271
+ size_t dtemp;
2272
+ for (size_t row = st; row < end; row++)
2273
+ {
2274
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2275
+ wleft += w_row;
2276
+
2277
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
2278
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
2279
+ if (curr_begin == row_end) goto skip_sum;
2280
+ curr_col = cols_use;
2281
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
2282
+ while (curr_col < cols_end && curr_begin < row_end)
2283
+ {
2284
+ if (*curr_begin == *curr_col)
2285
+ {
2286
+ dtemp = std::distance(cols_use, curr_col);
2287
+ buffer_sum_left[dtemp]
2288
+ =
2289
+ std::fma(w_row, Xr_this[std::distance(curr_begin, row_end)], buffer_sum_left[dtemp]);
2290
+ curr_col++;
2291
+ curr_begin++;
2292
+ }
2293
+
2294
+ else
2295
+ {
2296
+ if (*curr_begin > *curr_col)
2297
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
2298
+ else
2299
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
2300
+ }
2301
+ }
2302
+
2303
+ skip_sum:
2304
+ if (x_uses_ix_arr) {
2305
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2306
+ }
2307
+ else {
2308
+ if (unlikely(x[row] == x[row+1])) continue;
2309
+ }
2310
+
2311
+ vleft = 0;
2312
+ vright = 0;
2313
+ wright = wtot - wleft;
2314
+ for (size_t col = 0; col < ncols_use; col++)
2315
+ {
2316
+ sl = buffer_sum_left[col];
2317
+ vleft += sl * (sl / wleft);
2318
+ sr = buffer_sum_tot[col] - sl;
2319
+ vright += sr * (sr / wright);
2320
+ }
2321
+
2322
+ this_gain = vleft + vright;
2323
+ if (this_gain > best_gain)
2324
+ {
2325
+ best_gain = this_gain;
2326
+ split_ix = row;
2327
+ }
2328
+ }
2329
+ }
2330
+ }
2331
+
2332
+ if (best_gain <= -HUGE_VAL) return best_gain;
2333
+
2334
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
2335
+ return best_gain / wtot;
2336
+ }
2337
+
2338
+ template <class real_t_, class real_t>
2339
+ double find_split_dens_shortform_t(real_t *restrict x, size_t n, double &restrict split_point)
2340
+ {
2341
+ double best_gain = -HUGE_VAL;
2342
+ size_t n_minus_one = n - 1;
2343
+ real_t_ xmin = x[0];
2344
+ real_t_ xmax = x[n-1];
2345
+ real_t_ xleft, xright;
2346
+ real_t_ xmid;
2347
+ double this_gain;
2348
+ size_t split_ix = 0;
2349
+
2350
+ for (size_t ix = 0; ix < n_minus_one; ix++)
2351
+ {
2352
+ if (x[ix] == x[ix+1]) continue;
2353
+ xmid = (real_t_)x[ix] + ((real_t_)x[ix+1] - (real_t_)x[ix]) / (real_t_)2;
2354
+ xleft = xmid - xmin;
2355
+ xright = xmax - xmid;
2356
+ if (unlikely(!xleft || !xright)) continue;
2357
+ this_gain = (real_t_)square(ix+1) / xleft + (real_t_)square(n_minus_one - ix) / xright;
2358
+ if (this_gain > best_gain)
2359
+ {
2360
+ best_gain = this_gain;
2361
+ split_ix = ix;
2362
+ }
2363
+ }
2364
+
2365
+ if (best_gain <= -HUGE_VAL) return best_gain;
2366
+
2367
+ real_t_ xtot = (real_t_)xmax - (real_t_)xmin;
2368
+ real_t_ nleft = (real_t_)(split_ix+1);
2369
+ real_t_ nright = (real_t_)(n_minus_one - split_ix);
2370
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
2371
+ real_t_ rpct_left = split_point / xtot;
2372
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2373
+ real_t_ rpct_right = (real_t_)1 - rpct_left;
2374
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2375
+
2376
+ real_t_ nl_sq = nleft / (real_t_)n; nl_sq = square(nl_sq);
2377
+ real_t_ nr_sq = nright / (real_t_)n; nl_sq = square(nr_sq);
2378
+
2379
+ return nl_sq / rpct_left + nr_sq / rpct_right;
2380
+ }
2381
+
2382
+ template <class real_t, class ldouble_safe>
2383
+ double find_split_dens_shortform(real_t *restrict x, size_t n, double &restrict split_point)
2384
+ {
2385
+ if (n < INT32_MAX)
2386
+ return find_split_dens_shortform_t<double, real_t>(x, n, split_point);
2387
+ else
2388
+ return find_split_dens_shortform_t<ldouble_safe, real_t>(x, n, split_point);
2389
+ }
2390
+
2391
+ template <class real_t_, class real_t, class mapping>
2392
+ double find_split_dens_shortform_weighted_t(real_t *restrict x, size_t n, double &restrict split_point, mapping &restrict w, size_t *restrict buffer_indices)
2393
+ {
2394
+ double best_gain = -HUGE_VAL;
2395
+ size_t n_minus_one = n - 1;
2396
+ real_t_ xmin = x[buffer_indices[0]];
2397
+ real_t_ xmax = x[buffer_indices[n-1]];
2398
+ real_t_ xleft, xright;
2399
+ real_t_ xmid;
2400
+ double this_gain;
2401
+
2402
+ real_t_ wtot = 0;
2403
+ for (size_t ix = 0; ix < n; ix++)
2404
+ wtot += w[buffer_indices[ix]];
2405
+ real_t_ w_left = 0;
2406
+ real_t_ w_right;
2407
+ real_t_ best_w = 0;
2408
+ size_t split_ix = 0;
2409
+
2410
+ for (size_t ix = 0; ix < n_minus_one; ix++)
2411
+ {
2412
+ w_left += w[buffer_indices[ix]];
2413
+ if (x[buffer_indices[ix]] == x[buffer_indices[ix+1]]) continue;
2414
+ xmid = (real_t_)x[buffer_indices[ix]] + ((real_t_)x[buffer_indices[ix+1]] - (real_t_)x[buffer_indices[ix]]) / (real_t_)2;
2415
+ xleft = xmid - xmin;
2416
+ xright = xmax - xmid;
2417
+ if (unlikely(!xleft || !xright)) continue;
2418
+
2419
+ w_right = wtot - w_left;
2420
+ this_gain = square(w_left) / xleft + square(w_right) / xright;
2421
+ if (this_gain > best_gain)
2422
+ {
2423
+ best_gain = this_gain;
2424
+ best_w = w_left;
2425
+ split_ix = xmid;
2426
+ }
2427
+ }
2428
+
2429
+ if (best_gain <= -HUGE_VAL) return best_gain;
2430
+
2431
+ real_t_ xtot = xmax - xmin;
2432
+ w_left = best_w;
2433
+ w_right = wtot - w_left;
2434
+ w_left = std::fmax(w_left, std::numeric_limits<double>::min());
2435
+ w_right = std::fmax(w_right, std::numeric_limits<double>::min());
2436
+ split_point = midpoint(x[buffer_indices[split_ix]], x[buffer_indices[split_ix+1]]);
2437
+ real_t_ rpct_left = split_point / xtot;
2438
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2439
+ real_t_ rpct_right = (real_t_)1 - rpct_left;
2440
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2441
+
2442
+ real_t_ wl_sq = w_left / wtot; wl_sq = square(wl_sq);
2443
+ real_t_ wr_sq = w_right / wtot; wl_sq = square(wr_sq);
2444
+
2445
+ return wl_sq / rpct_left + wr_sq / rpct_right;
2446
+ }
2447
+
2448
+ template <class real_t, class mapping, class ldouble_safe>
2449
+ double find_split_dens_shortform_weighted(real_t *restrict x, size_t n, double &restrict split_point, mapping &restrict w, size_t *restrict buffer_indices)
2450
+ {
2451
+ if (n < INT32_MAX)
2452
+ return find_split_dens_shortform_weighted_t<double, real_t, mapping>(x, n, split_point, w, buffer_indices);
2453
+ else
2454
+ return find_split_dens_shortform_weighted_t<ldouble_safe, real_t, mapping>(x, n, split_point, w, buffer_indices);
2455
+ }
2456
+
2457
+ template <class real_t>
2458
+ double find_split_dens_shortform(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2459
+ double &restrict split_point, size_t &restrict split_ix)
2460
+ {
2461
+ double best_gain = -HUGE_VAL;
2462
+ real_t xmin = x[ix_arr[st]];
2463
+ real_t xmax = x[ix_arr[end]];
2464
+ real_t xleft, xright;
2465
+ real_t xmid;
2466
+ double this_gain;
2467
+
2468
+ for (size_t row = st; row < end; row++)
2469
+ {
2470
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2471
+ xmid = x[ix_arr[row]] + (x[ix_arr[row+1]] - x[ix_arr[row]]) / (real_t)2;
2472
+ xleft = xmid - xmin;
2473
+ xright = xmax - xmid;
2474
+ if (unlikely(!xleft || !xright)) continue;
2475
+ this_gain = square(row-st+1) / xleft + square(end-row) / xright;
2476
+ if (this_gain > best_gain)
2477
+ {
2478
+ best_gain = this_gain;
2479
+ split_ix = row;
2480
+ }
2481
+ }
2482
+
2483
+ if (best_gain <= -HUGE_VAL) return best_gain;
2484
+
2485
+ double xtot = (double)xmax - (double)xmin;
2486
+ double nleft = (double)(split_ix-st+1);
2487
+ double nright = (double)(end - split_ix);
2488
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
2489
+ double rpct_left = split_point / xtot;
2490
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2491
+ double rpct_right = 1. - rpct_left;
2492
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2493
+ double ntot = (double)(end - st + 1);
2494
+
2495
+ double nl_sq = nleft / ntot; nl_sq = square(nl_sq);
2496
+ double nr_sq = nright / ntot; nl_sq = square(nr_sq);
2497
+
2498
+ return nl_sq / rpct_left + nr_sq / rpct_right;
2499
+ }
2500
+
2501
+ template <class real_t, class mapping>
2502
+ double find_split_dens_shortform_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2503
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2504
+ {
2505
+ double best_gain = -HUGE_VAL;
2506
+ real_t xmin = x[ix_arr[st]];
2507
+ real_t xmax = x[ix_arr[end]];
2508
+ real_t xleft, xright;
2509
+ real_t xmid;
2510
+ double this_gain;
2511
+
2512
+ double wtot = 0;
2513
+ for (size_t row = st; row <= end; row++)
2514
+ wtot += w[ix_arr[row]];
2515
+ double w_left = 0;
2516
+ double w_right;
2517
+ double best_w = 0;
2518
+
2519
+ for (size_t row = st; row < end; row++)
2520
+ {
2521
+ w_left += w[ix_arr[row]];
2522
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2523
+ xmid = x[ix_arr[row]] + (x[ix_arr[row+1]] - x[ix_arr[row]]) / (real_t)2;
2524
+ xleft = xmid - xmin;
2525
+ xright = xmax - xmid;
2526
+ if (unlikely(!xleft || !xright)) continue;
2527
+
2528
+ w_right = wtot - w_left;
2529
+ this_gain = square(w_left) / xleft + square(w_right) / xright;
2530
+ if (this_gain > best_gain)
2531
+ {
2532
+ best_gain = this_gain;
2533
+ best_w = w_left;
2534
+ split_ix = row;
2535
+ }
2536
+ }
2537
+
2538
+ if (best_gain <= -HUGE_VAL) return best_gain;
2539
+
2540
+ double xtot = (double)xmax - (double)xmin;
2541
+ w_left = best_w;
2542
+ w_right = wtot - w_left;
2543
+ w_left = std::fmax(w_left, std::numeric_limits<double>::min());
2544
+ w_right = std::fmax(w_right, std::numeric_limits<double>::min());
2545
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
2546
+ double rpct_left = split_point / xtot;
2547
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2548
+ double rpct_right = 1. - rpct_left;
2549
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2550
+
2551
+ double wl_sq = w_left / wtot; wl_sq = square(wl_sq);
2552
+ double wr_sq = w_right / wtot; wl_sq = square(wr_sq);
2553
+
2554
+ return wl_sq / rpct_left + wr_sq / rpct_right;
2555
+ }
2556
+
2557
+ /* This is a slower but more numerically-robust form */
2558
+ template <class real_t, class ldouble_safe>
2559
+ double find_split_dens_longform(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2560
+ double &restrict split_point, size_t &restrict split_ix)
2561
+ {
2562
+ double best_gain = -HUGE_VAL;
2563
+ real_t xmin = x[ix_arr[st]];
2564
+ real_t xmax = x[ix_arr[end]];
2565
+ real_t xleft, xright;
2566
+ real_t xmid;
2567
+ ldouble_safe pct_left, pct_right;
2568
+ ldouble_safe rpct_left, rpct_right;
2569
+ ldouble_safe n_tot = end - st + 1;
2570
+ ldouble_safe xtot = (ldouble_safe)xmax - (ldouble_safe)xmin;
2571
+ ldouble_safe cnt_left;
2572
+ double this_gain;
2573
+
2574
+ for (size_t row = st; row < end; row++)
2575
+ {
2576
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2577
+ xmid = midpoint(x[ix_arr[row]], x[ix_arr[row+1]]);
2578
+ xleft = xmid - xmin;
2579
+ xright = xmax - xmid;
2580
+ if (unlikely(!xleft || !xright)) continue;
2581
+
2582
+ cnt_left = (ldouble_safe)(row-st+1);
2583
+
2584
+ xleft = std::fmax(xleft, (real_t)std::numeric_limits<real_t>::min());
2585
+ xright = std::fmax(xright, (real_t)std::numeric_limits<real_t>::min());
2586
+ pct_left = cnt_left / n_tot;
2587
+ pct_right = (ldouble_safe)1 - pct_left;
2588
+ rpct_left = (ldouble_safe)xleft / xtot;
2589
+ rpct_right = (ldouble_safe)xright / xtot;
2590
+
2591
+ this_gain = square(pct_left) / rpct_left + square(pct_right) / rpct_right;
2592
+ if (unlikely(is_na_or_inf(this_gain))) continue;
2593
+ if (this_gain > best_gain)
2594
+ {
2595
+ best_gain = this_gain;
2596
+ split_point = xmid;
2597
+ split_ix = row;
2598
+ }
2599
+ }
2600
+
2601
+ return best_gain;
2602
+ }
2603
+
2604
+ template <class real_t, class mapping, class ldouble_safe>
2605
+ double find_split_dens_longform_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2606
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2607
+ {
2608
+ double best_gain = -HUGE_VAL;
2609
+ real_t xmin = x[ix_arr[st]];
2610
+ real_t xmax = x[ix_arr[end]];
2611
+ real_t xleft, xright;
2612
+ real_t xmid;
2613
+ ldouble_safe pct_left, pct_right;
2614
+ ldouble_safe rpct_left, rpct_right;
2615
+ ldouble_safe xtot = (ldouble_safe)xmax - (ldouble_safe)xmin;
2616
+ double this_gain;
2617
+
2618
+ ldouble_safe wtot = 0;
2619
+ for (size_t row = st; row <= end; row++)
2620
+ wtot += w[ix_arr[row]];
2621
+ ldouble_safe w_left = 0;
2622
+
2623
+ for (size_t row = st; row < end; row++)
2624
+ {
2625
+ w_left += w[ix_arr[row]];
2626
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2627
+ xmid = midpoint(x[ix_arr[row]], x[ix_arr[row+1]]);
2628
+ xleft = xmid - xmin;
2629
+ xright = xmax - xmid;
2630
+ if (unlikely(!xleft || !xright)) continue;
2631
+
2632
+ xleft = std::fmax(xleft, (real_t)std::numeric_limits<real_t>::min());
2633
+ xright = std::fmax(xright, (real_t)std::numeric_limits<real_t>::min());
2634
+ pct_left = w_left / wtot;
2635
+ pct_right = (ldouble_safe)1 - pct_left;
2636
+ rpct_left = (ldouble_safe)xleft / xtot;
2637
+ rpct_right = (ldouble_safe)xright / xtot;
2638
+
2639
+ this_gain = square(pct_left) / rpct_left + square(pct_right) / rpct_right;
2640
+ if (unlikely(is_na_or_inf(this_gain))) continue;
2641
+ if (this_gain > best_gain)
2642
+ {
2643
+ best_gain = this_gain;
2644
+ split_point = xmid;
2645
+ split_ix = row;
2646
+ }
2647
+ }
2648
+
2649
+ return best_gain;
2650
+ }
2651
+
2652
+ template <class real_t, class ldouble_safe>
2653
+ double find_split_dens(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2654
+ double &restrict split_point, size_t &restrict split_ix)
2655
+ {
2656
+ if (end - st + 1 < INT32_MAX && x[ix_arr[end]] - x[ix_arr[st]] >= 1)
2657
+ return find_split_dens_shortform<real_t>(x, ix_arr, st, end, split_point, split_ix);
2658
+ else
2659
+ return find_split_dens_longform<real_t, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix);
2660
+ }
2661
+
2662
+ template <class real_t, class mapping, class ldouble_safe>
2663
+ double find_split_dens_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2664
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2665
+ {
2666
+ if (end - st + 1 < INT32_MAX && x[ix_arr[end]] - x[ix_arr[st]] >= 1)
2667
+ return find_split_dens_shortform_weighted<real_t, mapping>(x, ix_arr, st, end, split_point, split_ix, w);
2668
+ else
2669
+ return find_split_dens_longform_weighted<real_t, mapping, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix, w);
2670
+ }
2671
+
2672
+ template <class int_t, class ldouble_safe>
2673
+ double find_split_dens_longform(int *restrict x, int ncat, size_t *restrict ix_arr, size_t st, size_t end,
2674
+ CategSplit cat_split_type, MissingAction missing_action,
2675
+ int &restrict chosen_cat, signed char *restrict split_categ, int *restrict saved_cat_mode,
2676
+ size_t *restrict buffer_cnt, int_t *restrict buffer_indices)
2677
+ {
2678
+ if (st >= end || ncat <= 1) return -HUGE_VAL;
2679
+ size_t n_nas = 0;
2680
+ int xval;
2681
+
2682
+ /* count categories */
2683
+ memset(buffer_cnt, 0, sizeof(size_t) * ncat);
2684
+ if (missing_action == Fail)
2685
+ {
2686
+ for (size_t row = st; row <= end; row++)
2687
+ if (likely(x[ix_arr[row]] >= 0))
2688
+ buffer_cnt[x[ix_arr[row]]]++;
2689
+ }
2690
+
2691
+ else if (missing_action == Impute)
2692
+ {
2693
+ for (size_t row = st; row <= end; row++)
2694
+ {
2695
+ xval = x[ix_arr[row]];
2696
+ if (unlikely(xval < 0))
2697
+ n_nas++;
2698
+ else
2699
+ buffer_cnt[xval]++;
2700
+ }
2701
+
2702
+ if (unlikely(n_nas >= end-st)) return -HUGE_VAL;
2703
+
2704
+ if (n_nas)
2705
+ {
2706
+ auto idxmax = std::max_element(buffer_cnt, buffer_cnt + ncat);
2707
+ *idxmax += n_nas;
2708
+ *saved_cat_mode = (int)std::distance(buffer_cnt, idxmax);
2709
+ }
2710
+ }
2711
+
2712
+ else
2713
+ {
2714
+ for (size_t row = st; row <= end; row++)
2715
+ {
2716
+ xval = x[ix_arr[row]];
2717
+ if (likely(xval >= 0)) buffer_cnt[xval]++;
2718
+ }
2719
+ }
2720
+
2721
+ std::iota(buffer_indices, buffer_indices + ncat, (int_t)0);
2722
+ std::sort(buffer_indices, buffer_indices + ncat,
2723
+ [&buffer_cnt](const int_t a, const int_t b)
2724
+ {return buffer_cnt[a] < buffer_cnt[b];});
2725
+
2726
+ int curr = 0;
2727
+ if (split_categ != NULL)
2728
+ {
2729
+ while (buffer_cnt[buffer_indices[curr]] == 0)
2730
+ {
2731
+ split_categ[buffer_indices[curr]] = -1;
2732
+ curr++;
2733
+ }
2734
+ }
2735
+
2736
+ else
2737
+ {
2738
+ while (buffer_cnt[buffer_indices[curr]] == 0) curr++;
2739
+ }
2740
+
2741
+ int ncat_present = ncat - curr;
2742
+ if (ncat_present <= 1) return -HUGE_VAL;
2743
+ if (ncat_present == 2)
2744
+ {
2745
+ switch (cat_split_type)
2746
+ {
2747
+ case SingleCateg:
2748
+ {
2749
+ chosen_cat = buffer_indices[curr];
2750
+ break;
2751
+ }
2752
+
2753
+ case SubSet:
2754
+ {
2755
+ split_categ[buffer_indices[curr]] = 1;
2756
+ split_categ[buffer_indices[curr+1]] = 0;
2757
+ break;
2758
+ }
2759
+ }
2760
+
2761
+ ldouble_safe pct_left
2762
+ =
2763
+ (ldouble_safe)buffer_cnt[buffer_indices[curr]]
2764
+ /
2765
+ (ldouble_safe)(
2766
+ buffer_cnt[buffer_indices[curr]]
2767
+ +
2768
+ buffer_cnt[buffer_indices[curr+1]]
2769
+ );
2770
+
2771
+ return ((ldouble_safe)buffer_cnt[buffer_indices[curr]] * (2. * pct_left)
2772
+ +
2773
+ (ldouble_safe)buffer_cnt[buffer_indices[curr+1]] * (2. - 2.*pct_left))
2774
+ /
2775
+ (ldouble_safe)(buffer_cnt[buffer_indices[curr]] + buffer_cnt[buffer_indices[curr+1]]);
2776
+ }
2777
+
2778
+ size_t ntot;
2779
+ if (missing_action == Impute)
2780
+ ntot = end - st + 1;
2781
+ else
2782
+ ntot = std::accumulate(buffer_cnt, buffer_cnt + ncat, (size_t)0);
2783
+ if (unlikely(ntot <= 1)) unexpected_error();
2784
+ ldouble_safe ntot_ = (ldouble_safe)ntot;
2785
+
2786
+ switch (cat_split_type)
2787
+ {
2788
+ case SingleCateg:
2789
+ {
2790
+ double pct_one_cat = 1. / (double)ncat_present;
2791
+ double pct_left_smallest = (ldouble_safe)buffer_cnt[buffer_indices[curr]] / ntot_;
2792
+ double gain_smallest
2793
+ =
2794
+ (ldouble_safe)buffer_cnt[buffer_indices[curr]] * (pct_left_smallest / pct_one_cat)
2795
+ +
2796
+ (ldouble_safe)(ntot - buffer_cnt[buffer_indices[curr]]) * ((1. - pct_left_smallest) / (1. - pct_one_cat))
2797
+ ;
2798
+
2799
+ double pct_left_biggest = (ldouble_safe)buffer_cnt[buffer_indices[ncat-1]] / ntot_;
2800
+ double gain_biggest
2801
+ =
2802
+ (ldouble_safe)buffer_cnt[buffer_indices[ncat-1]] * (pct_left_biggest / pct_one_cat)
2803
+ +
2804
+ (ldouble_safe)(ntot - buffer_cnt[buffer_indices[ncat-1]]) * ((1. - pct_left_biggest) / (1. - pct_one_cat))
2805
+ ;
2806
+
2807
+ if (gain_smallest >= gain_biggest)
2808
+ {
2809
+ chosen_cat = buffer_indices[curr];
2810
+ return gain_smallest / ntot_;
2811
+ }
2812
+
2813
+ else
2814
+ {
2815
+ chosen_cat = buffer_indices[ncat-1];
2816
+ return gain_biggest / ntot_;
2817
+ }
2818
+ break;
2819
+ }
2820
+
2821
+ case SubSet:
2822
+ {
2823
+ size_t cnt_left = 0;
2824
+ size_t cnt_right;
2825
+ int st_cat = curr - 1;
2826
+ double this_gain;
2827
+ double best_gain = -HUGE_VAL;
2828
+ int best_cat = 0;
2829
+ ldouble_safe pct_left;
2830
+ double pct_cat_left;
2831
+ double ncat_present_ = (double)ncat_present;
2832
+ for (; curr < ncat; curr++)
2833
+ {
2834
+ cnt_left += buffer_cnt[buffer_indices[curr]];
2835
+ cnt_right = ntot - cnt_left;
2836
+ pct_left = (ldouble_safe)cnt_left / ntot_;
2837
+ pct_cat_left = (double)(curr - st_cat) / ncat_present_;
2838
+ this_gain
2839
+ =
2840
+ (ldouble_safe)cnt_left * (pct_left / pct_cat_left)
2841
+ +
2842
+ (ldouble_safe)cnt_right * (((ldouble_safe)1 - pct_left) / (1. - pct_cat_left))
2843
+ ;
2844
+ if (this_gain > best_gain)
2845
+ {
2846
+ best_gain = this_gain;
2847
+ best_cat = curr;
2848
+ }
2849
+ }
2850
+
2851
+ if (best_gain <= -HUGE_VAL) return best_gain;
2852
+ st_cat++;
2853
+ for (; st_cat <= best_cat; st_cat++)
2854
+ split_categ[buffer_indices[st_cat]] = 1;
2855
+ for (; st_cat < ncat; st_cat++)
2856
+ split_categ[buffer_indices[st_cat]] = 0;
2857
+ return best_gain / ntot_;
2858
+ break;
2859
+ }
2860
+ }
2861
+
2862
+ /* This will not be reached, but CRAN might complain otherwise */
2863
+ unreachable();
2864
+ return -HUGE_VAL;
2865
+ }
2866
+
2867
+ template <class mapping, class int_t, class ldouble_safe>
2868
+ double find_split_dens_longform_weighted(int *restrict x, int ncat, size_t *restrict ix_arr, size_t st, size_t end,
2869
+ CategSplit cat_split_type, MissingAction missing_action,
2870
+ int &restrict chosen_cat, signed char *restrict split_categ, int *restrict saved_cat_mode,
2871
+ int_t *restrict buffer_indices, mapping &restrict w)
2872
+ {
2873
+ if (st >= end || ncat <= 1) return -HUGE_VAL;
2874
+ ldouble_safe w_missing = 0;
2875
+ int xval;
2876
+ size_t ix_;
2877
+
2878
+ /* count categories */
2879
+ /* TODO: allocate this buffer externally */
2880
+ std::vector<ldouble_safe> buffer_cnt(ncat, (ldouble_safe)0);
2881
+ if (missing_action == Fail)
2882
+ {
2883
+ for (size_t row = st; row <= end; row++)
2884
+ {
2885
+ ix_ = ix_arr[row];
2886
+ if (unlikely(x[ix_]) < 0) continue;
2887
+ buffer_cnt[x[ix_]] += w[ix_];
2888
+ }
2889
+ }
2890
+
2891
+ else if (missing_action == Impute)
2892
+ {
2893
+ for (size_t row = st; row <= end; row++)
2894
+ {
2895
+ ix_ = ix_arr[row];
2896
+ xval = x[ix_];
2897
+ if (unlikely(xval < 0))
2898
+ w_missing += w[ix_];
2899
+ else
2900
+ buffer_cnt[xval] += w[ix_];
2901
+ }
2902
+
2903
+ if (w_missing)
2904
+ {
2905
+ auto idxmax = std::max_element(buffer_cnt.begin(), buffer_cnt.end());
2906
+ *idxmax += w_missing;
2907
+ *saved_cat_mode = (int)std::distance(buffer_cnt.begin(), idxmax);
2908
+ }
2909
+ }
2910
+
2911
+ else
2912
+ {
2913
+ for (size_t row = st; row <= end; row++)
2914
+ {
2915
+ ix_ = ix_arr[row];
2916
+ xval = x[ix_];
2917
+ if (likely(xval >= 0)) buffer_cnt[xval] += w[ix_];
2918
+ }
2919
+ }
2920
+
2921
+ std::iota(buffer_indices, buffer_indices + ncat, (int_t)0);
2922
+ std::sort(buffer_indices, buffer_indices + ncat,
2923
+ [&buffer_cnt](const int_t a, const int_t b)
2924
+ {return buffer_cnt[a] < buffer_cnt[b];});
2925
+
2926
+ int curr = 0;
2927
+ if (split_categ != NULL)
2928
+ {
2929
+ while (buffer_cnt[buffer_indices[curr]] == 0)
2930
+ {
2931
+ split_categ[buffer_indices[curr]] = -1;
2932
+ curr++;
2933
+ }
2934
+ }
2935
+
2936
+ else
2937
+ {
2938
+ while (buffer_cnt[buffer_indices[curr]] == 0) curr++;
2939
+ }
2940
+
2941
+ int ncat_present = ncat - curr;
2942
+ if (ncat_present <= 1) return -HUGE_VAL;
2943
+ if (ncat_present == 2)
2944
+ {
2945
+ switch (cat_split_type)
2946
+ {
2947
+ case SingleCateg:
2948
+ {
2949
+ chosen_cat = buffer_indices[curr];
2950
+ break;
2951
+ }
2952
+
2953
+ case SubSet:
2954
+ {
2955
+ split_categ[buffer_indices[curr]] = 1;
2956
+ split_categ[buffer_indices[curr+1]] = 0;
2957
+ break;
2958
+ }
2959
+ }
2960
+
2961
+ ldouble_safe pct_left
2962
+ =
2963
+ buffer_cnt[buffer_indices[curr]]
2964
+ /
2965
+ (
2966
+ buffer_cnt[buffer_indices[curr]]
2967
+ +
2968
+ buffer_cnt[buffer_indices[curr+1]]
2969
+ );
2970
+
2971
+ return (buffer_cnt[buffer_indices[curr]] * (pct_left * 2.)
2972
+ +
2973
+ buffer_cnt[buffer_indices[curr+1]] * (2. - 2.*pct_left))
2974
+ /
2975
+ (buffer_cnt[buffer_indices[curr]] + buffer_cnt[buffer_indices[curr+1]]);
2976
+ }
2977
+
2978
+ ldouble_safe ntot = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
2979
+ if (unlikely(ntot <= 0)) unexpected_error();
2980
+
2981
+ switch (cat_split_type)
2982
+ {
2983
+ case SingleCateg:
2984
+ {
2985
+ double pct_one_cat = 1. / (double)ncat_present;
2986
+ double pct_left_smallest = buffer_cnt[buffer_indices[curr]] / ntot;
2987
+ double gain_smallest
2988
+ =
2989
+ buffer_cnt[buffer_indices[curr]] * (pct_left_smallest / pct_one_cat)
2990
+ +
2991
+ (ntot - buffer_cnt[buffer_indices[curr]]) * ((1. - pct_left_smallest) / (1. - pct_one_cat))
2992
+ ;
2993
+
2994
+ double pct_left_biggest = buffer_cnt[buffer_indices[ncat-1]] / ntot;
2995
+ double gain_biggest
2996
+ =
2997
+ buffer_cnt[buffer_indices[ncat-1]] * (pct_left_biggest / pct_one_cat)
2998
+ +
2999
+ (ntot - buffer_cnt[buffer_indices[ncat-1]]) * ((1. - pct_left_biggest) / (1. - pct_one_cat))
3000
+ ;
3001
+
3002
+ if (gain_smallest >= gain_biggest)
3003
+ {
3004
+ chosen_cat = buffer_indices[curr];
3005
+ return gain_smallest / ntot;
3006
+ }
3007
+
3008
+ else
3009
+ {
3010
+ chosen_cat = buffer_indices[ncat-1];
3011
+ return gain_biggest / ntot;
3012
+ }
3013
+ break;
3014
+ }
3015
+
3016
+ case SubSet:
3017
+ {
3018
+ ldouble_safe cnt_left = 0;
3019
+ ldouble_safe cnt_right;
3020
+ int st_cat = curr - 1;
3021
+ double this_gain;
3022
+ double best_gain = -HUGE_VAL;
3023
+ int best_cat = 0;
3024
+ ldouble_safe pct_left;
3025
+ double pct_cat_left;
3026
+ double ncat_present_ = (double)ncat_present;
3027
+ for (; curr < ncat; curr++)
3028
+ {
3029
+ cnt_left += buffer_cnt[buffer_indices[curr]];
3030
+ cnt_right = ntot - cnt_left;
3031
+ pct_left = cnt_left / ntot;
3032
+ pct_cat_left = (double)(curr - st_cat) / ncat_present_;
3033
+ this_gain
3034
+ =
3035
+ (ldouble_safe)cnt_left * (pct_left / pct_cat_left)
3036
+ +
3037
+ (ldouble_safe)cnt_right * (((ldouble_safe)1 - pct_left) / (1. - pct_cat_left))
3038
+ ;
3039
+ if (this_gain > best_gain)
3040
+ {
3041
+ best_gain = this_gain;
3042
+ best_cat = curr;
3043
+ }
3044
+ }
3045
+
3046
+ if (best_gain <= -HUGE_VAL) return best_gain;
3047
+ st_cat++;
3048
+ for (; st_cat <= best_cat; st_cat++)
3049
+ split_categ[buffer_indices[st_cat]] = 1;
3050
+ for (; st_cat < ncat; st_cat++)
3051
+ split_categ[buffer_indices[st_cat]] = 0;
3052
+ return best_gain / ntot;
3053
+ break;
3054
+ }
3055
+ }
3056
+
3057
+ /* This will not be reached, but CRAN might complain otherwise */
3058
+ unreachable();
3059
+ return -HUGE_VAL;
3060
+ }
3061
+
3062
+ /* for split-criterion in hyperplanes (see below for version aimed at single-variable splits) */
3063
+ template <class ldouble_safe>
3064
+ double eval_guided_crit(double *restrict x, size_t n, GainCriterion criterion,
3065
+ double min_gain, bool as_relative_gain, double *restrict buffer_sd,
3066
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3067
+ size_t *restrict ix_arr_plus_st,
3068
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3069
+ double *restrict X_row_major, size_t ncols,
3070
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3071
+ {
3072
+ /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
3073
+ all numbers are assumed to be small and in the same scale */
3074
+ double gain = 0;
3075
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3076
+
3077
+ /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
3078
+ if (unlikely(n == 2))
3079
+ {
3080
+ if (x[0] == x[1]) return -HUGE_VAL;
3081
+ split_point = midpoint_with_reorder(x[0], x[1]);
3082
+ gain = 1.;
3083
+ if (gain > min_gain)
3084
+ return gain;
3085
+ else
3086
+ return 0.;
3087
+ }
3088
+
3089
+ if (criterion == FullGain)
3090
+ {
3091
+ /* TODO: these buffers should be allocated externally */
3092
+ std::vector<size_t> argsorted(n);
3093
+ std::iota(argsorted.begin(), argsorted.end(), (size_t)0);
3094
+ std::sort(argsorted.begin(), argsorted.end(),
3095
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3096
+ if (x[argsorted[0]] == x[argsorted[n-1]]) return -HUGE_VAL;
3097
+ std::vector<double> temp_buffer(n + mult2(ncols));
3098
+ for (size_t ix = 0; ix < n; ix++) temp_buffer[ix] = x[argsorted[ix]];
3099
+ for (size_t ix = 0; ix < n; ix++)
3100
+ argsorted[ix] = ix_arr_plus_st[argsorted[ix]];
3101
+ size_t ignored;
3102
+ return find_split_full_gain<double, ldouble_safe>(
3103
+ temp_buffer.data(), (size_t)0, n-1, argsorted.data(),
3104
+ cols_use, ncols_use, force_cols_use,
3105
+ X_row_major, ncols,
3106
+ Xr, Xr_ind, Xr_indptr,
3107
+ temp_buffer.data() + n, temp_buffer.data() + n + ncols,
3108
+ ignored, split_point,
3109
+ false);
3110
+ }
3111
+
3112
+ /* sort in ascending order */
3113
+ std::sort(x, x + n);
3114
+ xmin = x[0]; xmax = x[n-1];
3115
+ if (x[0] == x[n-1]) return -HUGE_VAL;
3116
+
3117
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3118
+ gain = find_split_rel_gain<double, ldouble_safe>(x, n, split_point);
3119
+ else if (criterion == Pooled || criterion == Averaged)
3120
+ gain = find_split_std_gain<double, ldouble_safe>(x, n, buffer_sd, criterion, min_gain, split_point);
3121
+ else if (criterion == DensityCrit)
3122
+ gain = find_split_dens_shortform<double, ldouble_safe>(x, n, split_point);
3123
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3124
+ return std::fmax(0., gain);
3125
+ }
3126
+
3127
+ template <class ldouble_safe>
3128
+ double eval_guided_crit_weighted(double *restrict x, size_t n, GainCriterion criterion,
3129
+ double min_gain, bool as_relative_gain, double *restrict buffer_sd,
3130
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3131
+ double *restrict w, size_t *restrict buffer_indices,
3132
+ size_t *restrict ix_arr_plus_st,
3133
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3134
+ double *restrict X_row_major, size_t ncols,
3135
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3136
+ {
3137
+ /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
3138
+ all numbers are assumed to be small and in the same scale */
3139
+ double gain = 0;
3140
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3141
+
3142
+ /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
3143
+ if (unlikely(n == 2))
3144
+ {
3145
+ if (x[0] == x[1]) return -HUGE_VAL;
3146
+ split_point = midpoint_with_reorder(x[0], x[1]);
3147
+ gain = 1.;
3148
+ if (gain > min_gain)
3149
+ return gain;
3150
+ else
3151
+ return 0.;
3152
+ }
3153
+
3154
+ /* sort in ascending order */
3155
+ std::iota(buffer_indices, buffer_indices + n, (size_t)0);
3156
+ std::sort(buffer_indices, buffer_indices + n,
3157
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3158
+ xmin = x[buffer_indices[0]]; xmax = x[buffer_indices[n-1]];
3159
+ if (xmin == xmax) return -HUGE_VAL;
3160
+
3161
+ if (criterion == Pooled || criterion == Averaged)
3162
+ gain = find_split_std_gain_weighted<double, ldouble_safe>(x, n, buffer_sd, criterion, min_gain, split_point, w, buffer_indices);
3163
+ else if (criterion == DensityCrit)
3164
+ gain = find_split_dens_shortform_weighted<double, double *restrict, ldouble_safe>(x, n, split_point, w, buffer_indices);
3165
+ else if (criterion == FullGain)
3166
+ {
3167
+ std::vector<size_t> argsorted(n);
3168
+ std::iota(argsorted.begin(), argsorted.end(), (size_t)0);
3169
+ std::sort(argsorted.begin(), argsorted.end(),
3170
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3171
+ if (x[argsorted[0]] == x[argsorted[n-1]]) return -HUGE_VAL;
3172
+ std::vector<double> temp_buffer(n + mult2(ncols));
3173
+ for (size_t ix = 0; ix < n; ix++) temp_buffer[ix] = x[argsorted[ix]];
3174
+ for (size_t ix = 0; ix < n; ix++)
3175
+ argsorted[ix] = ix_arr_plus_st[argsorted[ix]];
3176
+ size_t ignored;
3177
+ gain = find_split_full_gain_weighted<double, double *restrict, ldouble_safe>(
3178
+ temp_buffer.data(), (size_t)0, n-1, argsorted.data(),
3179
+ cols_use, ncols_use, force_cols_use,
3180
+ X_row_major, ncols,
3181
+ Xr, Xr_ind, Xr_indptr,
3182
+ temp_buffer.data() + n, temp_buffer.data() + n + ncols,
3183
+ ignored, split_point,
3184
+ false,
3185
+ w);
3186
+ }
3187
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3188
+ return std::fmax(0., gain);
3189
+ }
3190
+
3191
+ /* for split-criterion in single-variable splits */
3192
+ template <class real_t_, class ldouble_safe>
3193
+ double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, real_t_ *restrict x,
3194
+ double *restrict buffer_sd, bool as_relative_gain,
3195
+ double *restrict buffer_imputed_x, double *restrict saved_xmedian,
3196
+ size_t &split_ix, double &restrict split_point, double &restrict xmin, double &restrict xmax,
3197
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3198
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3199
+ double *restrict X_row_major, size_t ncols,
3200
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3201
+ {
3202
+ size_t st_orig = st;
3203
+ double gain = 0;
3204
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3205
+
3206
+ /* move NAs to the front if there's any, exclude them from calculations */
3207
+ if (missing_action != Fail)
3208
+ st = move_NAs_to_front(ix_arr, st, end, x);
3209
+
3210
+ if (unlikely(st >= end)) return -HUGE_VAL;
3211
+ else if (unlikely(st == (end-1)))
3212
+ {
3213
+ if (x[ix_arr[st]] == x[ix_arr[end]])
3214
+ return -HUGE_VAL;
3215
+ split_point = midpoint_with_reorder(x[ix_arr[st]], x[ix_arr[end]]);
3216
+ split_ix = st;
3217
+ gain = 1.;
3218
+ if (gain > min_gain)
3219
+ return gain;
3220
+ else
3221
+ return 0.;
3222
+ }
3223
+
3224
+ /* sort in ascending order */
3225
+ std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
3226
+ if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
3227
+ xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
3228
+
3229
+ /* unlike the previous case for the extended model, the data here has not been centered,
3230
+ which could make the standard deviations have poor precision. It's nevertheless not
3231
+ necessary for this mean to have good precision, since it's only meant for centering,
3232
+ so it can be calculated inexactly with simd instructions. */
3233
+ real_t_ xmean = 0;
3234
+ if (criterion == Pooled || criterion == Averaged)
3235
+ {
3236
+ for (size_t ix = st; ix <= end; ix++)
3237
+ xmean += x[ix_arr[ix]];
3238
+ xmean /= (real_t_)(end - st + 1);
3239
+ }
3240
+
3241
+ if (missing_action == Impute && st > st_orig)
3242
+ {
3243
+ missing_action = Fail;
3244
+ fill_NAs_with_median(ix_arr, st_orig, st, end, x, buffer_imputed_x, saved_xmedian);
3245
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3246
+ gain = find_split_rel_gain<double, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, split_point, split_ix);
3247
+ else if (criterion == Pooled || criterion == Averaged)
3248
+ gain = find_split_std_gain<double, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, buffer_sd, criterion, min_gain, split_point, split_ix);
3249
+ else if (criterion == DensityCrit)
3250
+ gain = find_split_dens<double, ldouble_safe>(buffer_imputed_x, ix_arr, st_orig, end, split_point, split_ix);
3251
+ else if (criterion == FullGain)
3252
+ {
3253
+ /* TODO: this buffer should be allocated from outside */
3254
+ std::vector<double> temp_buffer(mult2(ncols));
3255
+ gain = find_split_full_gain<double, ldouble_safe>(
3256
+ buffer_imputed_x, st_orig, end, ix_arr,
3257
+ cols_use, ncols_use, force_cols_use,
3258
+ X_row_major, ncols,
3259
+ Xr, Xr_ind, Xr_indptr,
3260
+ temp_buffer.data(), temp_buffer.data() + ncols,
3261
+ split_ix, split_point, true);
3262
+ }
3263
+
3264
+ /* Note: in theory, it should be possible to use a faster version assuming a contiguous array for 'x',
3265
+ but such an approach might give inexact split points. Better to avoid such inexactness at the
3266
+ expense of more computations. */
3267
+ }
3268
+
3269
+ else
3270
+ {
3271
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3272
+ gain = find_split_rel_gain<real_t_, ldouble_safe>(x, xmean, ix_arr, st, end, split_point, split_ix);
3273
+ else if (criterion == Pooled || criterion == Averaged)
3274
+ gain = find_split_std_gain<real_t_, ldouble_safe>(x, xmean, ix_arr, st, end, buffer_sd, criterion, min_gain, split_point, split_ix);
3275
+ else if (criterion == DensityCrit)
3276
+ gain = find_split_dens<real_t_, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix);
3277
+ else if (criterion == FullGain)
3278
+ {
3279
+ /* TODO: this buffer should be allocated from outside */
3280
+ std::vector<double> temp_buffer(mult2(ncols));
3281
+ gain = find_split_full_gain<real_t_, ldouble_safe>(
3282
+ x, st, end, ix_arr,
3283
+ cols_use, ncols_use, force_cols_use,
3284
+ X_row_major, ncols,
3285
+ Xr, Xr_ind, Xr_indptr,
3286
+ temp_buffer.data(), temp_buffer.data() + ncols,
3287
+ split_ix, split_point, true);
3288
+ }
3289
+ }
3290
+
3291
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3292
+ return std::fmax(0., gain);
3293
+ }
3294
+
3295
+ template <class real_t_, class mapping, class ldouble_safe>
3296
+ double eval_guided_crit_weighted(size_t *restrict ix_arr, size_t st, size_t end, real_t_ *restrict x,
3297
+ double *restrict buffer_sd, bool as_relative_gain,
3298
+ double *restrict buffer_imputed_x, double *restrict saved_xmedian,
3299
+ size_t &split_ix, double &restrict split_point, double &restrict xmin, double &restrict xmax,
3300
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3301
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3302
+ double *restrict X_row_major, size_t ncols,
3303
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
3304
+ mapping &restrict w)
3305
+ {
3306
+ size_t st_orig = st;
3307
+ double gain = 0;
3308
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3309
+
3310
+ /* move NAs to the front if there's any, exclude them from calculations */
3311
+ if (missing_action != Fail)
3312
+ st = move_NAs_to_front(ix_arr, st, end, x);
3313
+
3314
+ if (unlikely(st >= end)) return -HUGE_VAL;
3315
+ else if (unlikely(st == (end-1)))
3316
+ {
3317
+ if (x[ix_arr[st]] == x[ix_arr[end]])
3318
+ return -HUGE_VAL;
3319
+ split_point = midpoint_with_reorder(x[ix_arr[st]], x[ix_arr[end]]);
3320
+ split_ix = st;
3321
+ gain = 1.;
3322
+ if (gain > min_gain)
3323
+ return gain;
3324
+ else
3325
+ return 0.;
3326
+ }
3327
+
3328
+ /* sort in ascending order */
3329
+ std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
3330
+ if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
3331
+ xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
3332
+
3333
+ /* unlike the previous case for the extended model, the data here has not been centered,
3334
+ which could make the standard deviations have poor precision. It's nevertheless not
3335
+ necessary for this mean to have good precision, since it's only meant for centering,
3336
+ so it can be calculated inexactly with simd instructions. */
3337
+ real_t_ xmean = 0;
3338
+ real_t_ cnt = 0;
3339
+ if (criterion == Pooled || criterion == Averaged)
3340
+ {
3341
+ for (size_t ix = st; ix <= end; ix++)
3342
+ {
3343
+ xmean += x[ix_arr[ix]];
3344
+ cnt += w[ix_arr[ix]];
3345
+ }
3346
+ xmean /= cnt;
3347
+ }
3348
+
3349
+ if (missing_action == Impute && st > st_orig)
3350
+ {
3351
+ missing_action = Fail;
3352
+ fill_NAs_with_median(ix_arr, st_orig, st, end, x, buffer_imputed_x, saved_xmedian);
3353
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3354
+ gain = find_split_rel_gain_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, split_point, split_ix, w);
3355
+ else if (criterion == Pooled || criterion == Averaged)
3356
+ gain = find_split_std_gain_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, buffer_sd, criterion, min_gain, split_point, split_ix, w);
3357
+ else if (criterion == DensityCrit)
3358
+ gain = find_split_dens_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, ix_arr, st_orig, end, split_point, split_ix, w);
3359
+ else if (criterion == FullGain)
3360
+ {
3361
+ std::vector<double> temp_buffer(mult2(ncols));
3362
+ gain = find_split_full_gain_weighted<double, mapping, ldouble_safe>(
3363
+ buffer_imputed_x, st_orig, end, ix_arr,
3364
+ cols_use, ncols_use, force_cols_use,
3365
+ X_row_major, ncols,
3366
+ Xr, Xr_ind, Xr_indptr,
3367
+ temp_buffer.data(), temp_buffer.data() + ncols,
3368
+ split_ix, split_point, true,
3369
+ w);
3370
+ }
3371
+ }
3372
+
3373
+ else
3374
+ {
3375
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3376
+ gain = find_split_rel_gain_weighted<real_t_, mapping, ldouble_safe>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
3377
+ else if (criterion == Pooled || criterion == Averaged)
3378
+ gain = find_split_std_gain_weighted<real_t_, mapping, ldouble_safe>(x, xmean, ix_arr, st, end, buffer_sd, criterion, min_gain, split_point, split_ix, w);
3379
+ else if (criterion == DensityCrit)
3380
+ gain = find_split_dens_weighted<real_t_, mapping, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix, w);
3381
+ else if (criterion == FullGain)
3382
+ {
3383
+ std::vector<double> temp_buffer(mult2(ncols));
3384
+ gain = find_split_full_gain_weighted<real_t_, mapping, ldouble_safe>(
3385
+ x, st, end, ix_arr,
3386
+ cols_use, ncols_use, force_cols_use,
3387
+ X_row_major, ncols,
3388
+ Xr, Xr_ind, Xr_indptr,
3389
+ temp_buffer.data(), temp_buffer.data() + ncols,
3390
+ split_ix, split_point, true,
3391
+ w);
3392
+ }
3393
+ }
3394
+
3395
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3396
+ return std::fmax(0., gain);
3397
+ }
3398
+
3399
+ /* TODO: here it should only need to look at the non-zero entries. It can then use the
3400
+ same algorithm as above, but putting an extra check to see where do the zeros fit in
3401
+ the sorted order of the non-zero entries while calculating gains and SDs, and then
3402
+ call the 'divide_subset' function after-the-fact to reach the same end result.
3403
+ It should be much faster than this if the non-zero entries are few. */
3404
+ template <class real_t_, class sparse_ix, class ldouble_safe>
3405
+ double eval_guided_crit(size_t ix_arr[], size_t st, size_t end,
3406
+ size_t col_num, real_t_ Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
3407
+ double buffer_arr[], size_t buffer_pos[], bool as_relative_gain,
3408
+ double *restrict saved_xmedian,
3409
+ double &split_point, double &xmin, double &xmax,
3410
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3411
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3412
+ double *restrict X_row_major, size_t ncols,
3413
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3414
+ {
3415
+ size_t ignored;
3416
+
3417
+
3418
+ todense(ix_arr, st, end,
3419
+ col_num, Xc, Xc_ind, Xc_indptr,
3420
+ buffer_arr);
3421
+ size_t tot = end - st + 1;
3422
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3423
+
3424
+ if (missing_action == Impute)
3425
+ {
3426
+ missing_action = Fail;
3427
+ for (size_t ix = 0; ix < tot; ix++)
3428
+ {
3429
+ if (unlikely(is_na_or_inf(buffer_arr[ix])))
3430
+ {
3431
+ goto fill_missing;
3432
+ }
3433
+ }
3434
+ goto no_nas;
3435
+
3436
+ fill_missing:
3437
+ {
3438
+ size_t idx_half = div2(tot);
3439
+ std::nth_element(buffer_pos, buffer_pos + idx_half, buffer_pos + tot,
3440
+ [&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
3441
+ *saved_xmedian = buffer_arr[buffer_pos[idx_half]];
3442
+
3443
+ if ((tot % 2) == 0)
3444
+ {
3445
+ double xlow = *std::max_element(buffer_pos, buffer_pos + idx_half);
3446
+ *saved_xmedian = xlow + ((*saved_xmedian)-xlow)/2.;
3447
+ }
3448
+
3449
+ for (size_t ix = 0; ix < tot; ix++)
3450
+ buffer_arr[ix] = is_na_or_inf(buffer_arr[ix])? (*saved_xmedian) : buffer_arr[ix];
3451
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3452
+ }
3453
+ }
3454
+
3455
+ no_nas:
3456
+ return eval_guided_crit<double, ldouble_safe>(
3457
+ buffer_pos, 0, end - st, buffer_arr, buffer_arr + tot,
3458
+ as_relative_gain, saved_xmedian, (double*)NULL, ignored, split_point,
3459
+ xmin, xmax, criterion, min_gain, missing_action,
3460
+ cols_use, ncols_use, force_cols_use,
3461
+ X_row_major, ncols,
3462
+ Xr, Xr_ind, Xr_indptr);
3463
+ }
3464
+
3465
+ template <class real_t_, class sparse_ix, class mapping, class ldouble_safe>
3466
+ double eval_guided_crit_weighted(size_t ix_arr[], size_t st, size_t end,
3467
+ size_t col_num, real_t_ Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
3468
+ double buffer_arr[], size_t buffer_pos[], bool as_relative_gain,
3469
+ double *restrict saved_xmedian,
3470
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3471
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3472
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3473
+ double *restrict X_row_major, size_t ncols,
3474
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
3475
+ mapping &restrict w)
3476
+ {
3477
+ size_t ignored;
3478
+
3479
+
3480
+ todense(ix_arr, st, end,
3481
+ col_num, Xc, Xc_ind, Xc_indptr,
3482
+ buffer_arr);
3483
+ size_t tot = end - st + 1;
3484
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3485
+
3486
+
3487
+ if (missing_action == Impute)
3488
+ {
3489
+ missing_action = Fail;
3490
+ for (size_t ix = 0; ix < tot; ix++)
3491
+ {
3492
+ if (unlikely(is_na_or_inf(buffer_arr[ix])))
3493
+ {
3494
+ goto fill_missing;
3495
+ }
3496
+ }
3497
+ goto no_nas;
3498
+
3499
+ fill_missing:
3500
+ {
3501
+ size_t idx_half = div2(tot);
3502
+ std::nth_element(buffer_pos, buffer_pos + idx_half, buffer_pos + tot,
3503
+ [&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
3504
+ *saved_xmedian = buffer_arr[buffer_pos[idx_half]];
3505
+
3506
+ if ((tot % 2) == 0)
3507
+ {
3508
+ double xlow = *std::max_element(buffer_pos, buffer_pos + idx_half);
3509
+ *saved_xmedian = xlow + ((*saved_xmedian)-xlow)/2.;
3510
+ }
3511
+
3512
+ for (size_t ix = 0; ix < tot; ix++)
3513
+ buffer_arr[ix] = is_na_or_inf(buffer_arr[ix])? (*saved_xmedian) : buffer_arr[ix];
3514
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3515
+ }
3516
+ }
3517
+
3518
+
3519
+ no_nas:
3520
+ /* TODO: allocate this buffer externally */
3521
+ std::vector<double> buffer_w(tot);
3522
+ for (size_t row = st; row <= end; row++)
3523
+ buffer_w[row-st] = w[ix_arr[row]];
3524
+ /* TODO: in this case, as the weights match with the order of the indices, could use a faster version
3525
+ with a weighted rel_gain function instead (not yet implemented). */
3526
+ return eval_guided_crit_weighted<double, std::vector<double>, ldouble_safe>(
3527
+ buffer_pos, 0, end - st, buffer_arr, buffer_arr + tot,
3528
+ as_relative_gain, saved_xmedian, (double*)NULL, ignored, split_point,
3529
+ xmin, xmax, criterion, min_gain, missing_action,
3530
+ cols_use, ncols_use, force_cols_use,
3531
+ X_row_major, ncols,
3532
+ Xr, Xr_ind, Xr_indptr,
3533
+ buffer_w);
3534
+ }
3535
+
3536
+ /* How this works:
3537
+ - For Averaged criterion, will take the expected standard deviation that would be gotten with the category counts
3538
+ if each category got assigned a real number at random ~ Unif(0,1) and the data were thus converted to
3539
+ numerical. In such case, the best split (highest sd gain) is always putting the second-highest count in one
3540
+ branch, so there is no point in doing a full search over other permutations. In order to get more reasonable
3541
+ splits, when using the option to split by subsets of categories, it will sort the counts and evaluate only
3542
+ splits in which the categories are grouped in sorted order - in such cases it tends to pick either the
3543
+ smallest or the largest category to assign to one branch, but sometimes picks groups too.
3544
+ - For Pooled criterion, will take shannon entropy, which tends to make a more even split. In the case of splitting
3545
+ by a single category, it always puts the largest category in a separate branch. In the case of subsets,
3546
+ it can either evaluate possible splits over all permutations (not feasible if there are too many categories),
3547
+ or look up for splits in sorted order just like for Averaged criterion.
3548
+ Splitting by averaged Gini gain (like with Averaged) also selects always the second-largest category to put in one branch,
3549
+ while splitting by weighted Gini (like with Pooled) usually selects the largest category to put in one branch. The
3550
+ Gini gain is not easily comparable to that of numerical columns, so it's not offered as an option here.
3551
+ */
3552
+ /* https://math.stackexchange.com/questions/3343384/expected-variance-and-kurtosis-from-pmf-in-which-possible-discrete-values-are-dr */
3553
+ /* TODO: 'buffer_pos' doesn't need to be 'size_t', 'int' would suffice */
3554
+ template <class ldouble_safe>
3555
+ double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
3556
+ int *restrict saved_cat_mode,
3557
+ size_t *restrict buffer_cnt, size_t *restrict buffer_pos, double *restrict buffer_prob,
3558
+ int &restrict chosen_cat, signed char *restrict split_categ, signed char *restrict buffer_split,
3559
+ GainCriterion criterion, double min_gain, bool all_perm,
3560
+ MissingAction missing_action, CategSplit cat_split_type)
3561
+ {
3562
+ if (criterion == DensityCrit)
3563
+ return find_split_dens_longform<size_t, ldouble_safe>(
3564
+ x, ncat, ix_arr, st, end,
3565
+ cat_split_type, missing_action,
3566
+ chosen_cat, split_categ, saved_cat_mode,
3567
+ buffer_cnt, buffer_pos);
3568
+ if (st >= end) return -HUGE_VAL;
3569
+ size_t n_nas = 0;
3570
+ int xval;
3571
+
3572
+ /* count categories */
3573
+ memset(buffer_cnt, 0, sizeof(size_t) * ncat);
3574
+ if (missing_action == Fail)
3575
+ {
3576
+ for (size_t row = st; row <= end; row++)
3577
+ if (likely(x[ix_arr[row]] >= 0))
3578
+ buffer_cnt[x[ix_arr[row]]]++;
3579
+ }
3580
+
3581
+ else if (missing_action == Impute)
3582
+ {
3583
+ for (size_t row = st; row <= end; row++)
3584
+ {
3585
+ xval = x[ix_arr[row]];
3586
+ if (unlikely(xval < 0))
3587
+ n_nas++;
3588
+ else
3589
+ buffer_cnt[xval]++;
3590
+ }
3591
+
3592
+ if (unlikely(n_nas >= end-st)) return -HUGE_VAL;
3593
+
3594
+ if (n_nas)
3595
+ {
3596
+ auto idxmax = std::max_element(buffer_cnt, buffer_cnt + ncat);
3597
+ *idxmax += n_nas;
3598
+ *saved_cat_mode = (int)std::distance(buffer_cnt, idxmax);
3599
+ }
3600
+ }
3601
+
3602
+ else
3603
+ {
3604
+ for (size_t row = st; row <= end; row++)
3605
+ {
3606
+ xval = x[ix_arr[row]];
3607
+ if (likely(xval >= 0)) buffer_cnt[xval]++;
3608
+ }
3609
+ }
3610
+
3611
+ double this_gain = -HUGE_VAL;
3612
+ double best_gain = -HUGE_VAL;
3613
+ std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
3614
+ size_t st_pos = 0;
3615
+
3616
+ switch (cat_split_type)
3617
+ {
3618
+ case SingleCateg:
3619
+ {
3620
+ size_t cnt = end - st + 1;
3621
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
3622
+ size_t ncat_present = 0;
3623
+
3624
+ switch(criterion)
3625
+ {
3626
+ case Averaged:
3627
+ {
3628
+ /* move zero-counts to the beginning */
3629
+ size_t temp;
3630
+ for (int cat = 0; cat < ncat; cat++)
3631
+ {
3632
+ if (buffer_cnt[cat])
3633
+ {
3634
+ ncat_present++;
3635
+ buffer_prob[cat] = (ldouble_safe) buffer_cnt[cat] / cnt_l;
3636
+ }
3637
+
3638
+ else
3639
+ {
3640
+ temp = buffer_pos[st_pos];
3641
+ buffer_pos[st_pos] = buffer_pos[cat];
3642
+ buffer_pos[cat] = temp;
3643
+ st_pos++;
3644
+ }
3645
+ }
3646
+
3647
+ if (ncat_present <= 1) return -HUGE_VAL;
3648
+
3649
+ double sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3650
+
3651
+ /* try isolating each category one at a time */
3652
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3653
+ {
3654
+ this_gain = sd_gain(sd_full,
3655
+ 0.0,
3656
+ (expected_sd_cat_single<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt)));
3657
+ if (this_gain > min_gain && this_gain > best_gain)
3658
+ {
3659
+ best_gain = this_gain;
3660
+ chosen_cat = buffer_pos[pos];
3661
+ }
3662
+ }
3663
+ break;
3664
+ }
3665
+
3666
+ case Pooled:
3667
+ {
3668
+ /* here it will always pick the largest one */
3669
+ size_t ncat_present = 0;
3670
+ size_t cnt_max = 0;
3671
+ for (int cat = 0; cat < ncat; cat++)
3672
+ {
3673
+ if (buffer_cnt[cat])
3674
+ {
3675
+ ncat_present++;
3676
+ if (cnt_max < buffer_cnt[cat])
3677
+ {
3678
+ cnt_max = buffer_cnt[cat];
3679
+ chosen_cat = cat;
3680
+ }
3681
+ }
3682
+ }
3683
+
3684
+ if (ncat_present <= 1) return -HUGE_VAL;
3685
+
3686
+ ldouble_safe cnt_left = (ldouble_safe)((end - st + 1) - cnt_max);
3687
+ this_gain = (
3688
+ (ldouble_safe)cnt * std::log((ldouble_safe)cnt)
3689
+ - cnt_left * std::log(cnt_left)
3690
+ - (ldouble_safe)cnt_max * std::log((ldouble_safe)cnt_max)
3691
+ ) / cnt;
3692
+ best_gain = (this_gain > min_gain)? this_gain : best_gain;
3693
+ break;
3694
+ }
3695
+
3696
+ default:
3697
+ {
3698
+ unexpected_error();
3699
+ break;
3700
+ }
3701
+ }
3702
+ break;
3703
+ }
3704
+
3705
+ case SubSet:
3706
+ {
3707
+ /* sort by counts */
3708
+ std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
3709
+
3710
+ /* set split as: (1):left (0):right (-1):not_present */
3711
+ memset(buffer_split, 0, ncat * sizeof(signed char));
3712
+
3713
+ ldouble_safe cnt = (ldouble_safe)(end - st + 1);
3714
+
3715
+ switch(criterion)
3716
+ {
3717
+ case Averaged:
3718
+ {
3719
+ /* determine first non-zero and convert to probabilities */
3720
+ double sd_full;
3721
+ for (int cat = 0; cat < ncat; cat++)
3722
+ {
3723
+ if (buffer_cnt[buffer_pos[cat]])
3724
+ {
3725
+ buffer_prob[buffer_pos[cat]] = (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt;
3726
+ }
3727
+
3728
+ else
3729
+ {
3730
+ buffer_split[buffer_pos[cat]] = -1;
3731
+ st_pos++;
3732
+ }
3733
+ }
3734
+
3735
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
3736
+
3737
+ /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
3738
+ size_t ncat_present = (size_t)ncat - st_pos;
3739
+ sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3740
+ if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
3741
+
3742
+ /* move categories one at a time */
3743
+ for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
3744
+ {
3745
+ buffer_split[buffer_pos[pos]] = 1;
3746
+ this_gain = sd_gain(sd_full,
3747
+ (expected_sd_cat<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, pos - st_pos + 1, buffer_pos + st_pos)),
3748
+ (expected_sd_cat<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1))
3749
+ );
3750
+ if (this_gain > min_gain && this_gain > best_gain)
3751
+ {
3752
+ best_gain = this_gain;
3753
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
3754
+ }
3755
+ }
3756
+
3757
+ break;
3758
+ }
3759
+
3760
+ case Pooled:
3761
+ {
3762
+ ldouble_safe s = 0;
3763
+
3764
+ /* determine first non-zero and get base info */
3765
+ for (int cat = 0; cat < ncat; cat++)
3766
+ {
3767
+ if (buffer_cnt[buffer_pos[cat]])
3768
+ {
3769
+ s += (buffer_cnt[buffer_pos[cat]] <= 1)?
3770
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[cat]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[cat]]));
3771
+ }
3772
+
3773
+ else
3774
+ {
3775
+ buffer_split[buffer_pos[cat]] = -1;
3776
+ st_pos++;
3777
+ }
3778
+ }
3779
+
3780
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
3781
+
3782
+ /* calculate base info */
3783
+ ldouble_safe base_info = cnt * std::log(cnt) - s;
3784
+
3785
+ if (all_perm)
3786
+ {
3787
+ size_t cnt_left, cnt_right;
3788
+ double s_left, s_right;
3789
+ size_t ncat_present = (size_t)ncat - st_pos;
3790
+ size_t ncomb = pow2(ncat_present) - 1;
3791
+ size_t best_combin;
3792
+
3793
+ for (size_t combin = 1; combin < ncomb; combin++)
3794
+ {
3795
+ cnt_left = 0; cnt_right = 0;
3796
+ s_left = 0; s_right = 0;
3797
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3798
+ {
3799
+ if (extract_bit(combin, pos))
3800
+ {
3801
+ cnt_left += buffer_cnt[buffer_pos[pos]];
3802
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
3803
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
3804
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
3805
+ }
3806
+
3807
+ else
3808
+ {
3809
+ cnt_right += buffer_cnt[buffer_pos[pos]];
3810
+ s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
3811
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
3812
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
3813
+ }
3814
+ }
3815
+
3816
+ this_gain = categ_gain<size_t, ldouble_safe>(
3817
+ cnt_left, cnt_right,
3818
+ s_left, s_right,
3819
+ base_info, cnt);
3820
+
3821
+ if (this_gain > min_gain && this_gain > best_gain)
3822
+ {
3823
+ best_gain = this_gain;
3824
+ best_combin = combin;
3825
+ }
3826
+
3827
+ }
3828
+
3829
+ if (best_gain > min_gain)
3830
+ for (size_t pos = 0; pos < ncat_present; pos++)
3831
+ split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
3832
+
3833
+ }
3834
+
3835
+ else
3836
+ {
3837
+ /* try moving the categories one at a time */
3838
+ size_t cnt_left = 0;
3839
+ size_t cnt_right = end - st + 1;
3840
+ double s_left = 0;
3841
+ double s_right = s;
3842
+
3843
+ for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
3844
+ {
3845
+ buffer_split[buffer_pos[pos]] = 1;
3846
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
3847
+ 0 : ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[pos]]));
3848
+ s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
3849
+ 0 : ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[pos]]));
3850
+ cnt_left += buffer_cnt[buffer_pos[pos]];
3851
+ cnt_right -= buffer_cnt[buffer_pos[pos]];
3852
+
3853
+ this_gain = categ_gain<size_t, ldouble_safe>(
3854
+ cnt_left, cnt_right,
3855
+ s_left, s_right,
3856
+ base_info, cnt);
3857
+
3858
+ if (this_gain > min_gain && this_gain > best_gain)
3859
+ {
3860
+ best_gain = this_gain;
3861
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
3862
+ }
3863
+ }
3864
+ }
3865
+
3866
+ break;
3867
+ }
3868
+
3869
+ default:
3870
+ {
3871
+ unexpected_error();
3872
+ break;
3873
+ }
3874
+ }
3875
+ }
3876
+ }
3877
+
3878
+ if (st == (end-1)) return 0;
3879
+
3880
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
3881
+ return 0;
3882
+ else
3883
+ return best_gain;
3884
+ }
3885
+
3886
+
3887
+ template <class mapping, class ldouble_safe>
3888
+ double eval_guided_crit_weighted(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
3889
+ int *restrict saved_cat_mode,
3890
+ size_t *restrict buffer_pos, double *restrict buffer_prob,
3891
+ int &restrict chosen_cat, signed char *restrict split_categ, signed char *restrict buffer_split,
3892
+ GainCriterion criterion, double min_gain, bool all_perm,
3893
+ MissingAction missing_action, CategSplit cat_split_type,
3894
+ mapping &restrict w)
3895
+ {
3896
+ if (criterion == DensityCrit)
3897
+ return find_split_dens_longform_weighted<mapping, size_t, ldouble_safe>(
3898
+ x, ncat, ix_arr, st, end,
3899
+ cat_split_type, missing_action,
3900
+ chosen_cat, split_categ, saved_cat_mode,
3901
+ buffer_pos, w);
3902
+ if (st >= end) return -HUGE_VAL;
3903
+ ldouble_safe w_missing = 0;
3904
+ int xval;
3905
+ size_t ix_;
3906
+
3907
+ /* count categories */
3908
+ /* TODO: allocate this buffer externally */
3909
+ std::vector<ldouble_safe> buffer_cnt(ncat, (ldouble_safe)0);
3910
+ if (missing_action == Fail)
3911
+ {
3912
+ for (size_t row = st; row <= end; row++)
3913
+ {
3914
+ ix_ = ix_arr[row];
3915
+ if (unlikely(x[ix_]) < 0) continue;
3916
+ buffer_cnt[x[ix_]] += w[ix_];
3917
+ }
3918
+ }
3919
+
3920
+ else if (missing_action == Impute)
3921
+ {
3922
+ for (size_t row = st; row <= end; row++)
3923
+ {
3924
+ ix_ = ix_arr[row];
3925
+ xval = x[ix_];
3926
+ if (unlikely(xval < 0))
3927
+ w_missing += w[ix_];
3928
+ else
3929
+ buffer_cnt[xval] += w[ix_];
3930
+ }
3931
+
3932
+ if (w_missing)
3933
+ {
3934
+ auto idxmax = std::max_element(buffer_cnt.begin(), buffer_cnt.end());
3935
+ *idxmax += w_missing;
3936
+ *saved_cat_mode = (int)std::distance(buffer_cnt.begin(), idxmax);
3937
+ }
3938
+ }
3939
+
3940
+ else
3941
+ {
3942
+ for (size_t row = st; row <= end; row++)
3943
+ {
3944
+ ix_ = ix_arr[row];
3945
+ xval = x[ix_];
3946
+ if (likely(xval >= 0)) buffer_cnt[xval] += w[ix_];
3947
+ }
3948
+ }
3949
+
3950
+ ldouble_safe cnt = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
3951
+
3952
+ double this_gain = -HUGE_VAL;
3953
+ double best_gain = -HUGE_VAL;
3954
+ std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
3955
+ size_t st_pos = 0;
3956
+
3957
+ switch(cat_split_type)
3958
+ {
3959
+ case SingleCateg:
3960
+ {
3961
+ size_t ncat_present = 0;
3962
+
3963
+ switch(criterion)
3964
+ {
3965
+ case Averaged:
3966
+ {
3967
+ /* move zero-counts to the beginning */
3968
+ size_t temp;
3969
+ for (int cat = 0; cat < ncat; cat++)
3970
+ {
3971
+ if (buffer_cnt[cat])
3972
+ {
3973
+ ncat_present++;
3974
+ buffer_prob[cat] = buffer_cnt[cat] / cnt;
3975
+ }
3976
+
3977
+ else
3978
+ {
3979
+ temp = buffer_pos[st_pos];
3980
+ buffer_pos[st_pos] = buffer_pos[cat];
3981
+ buffer_pos[cat] = temp;
3982
+ st_pos++;
3983
+ }
3984
+ }
3985
+
3986
+ if (ncat_present <= 1) return -HUGE_VAL;
3987
+
3988
+ double sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3989
+
3990
+ /* try isolating each category one at a time */
3991
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3992
+ {
3993
+ this_gain = sd_gain(sd_full,
3994
+ 0.0,
3995
+ (expected_sd_cat_single<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt))
3996
+ );
3997
+ if (this_gain > min_gain && this_gain > best_gain)
3998
+ {
3999
+ best_gain = this_gain;
4000
+ chosen_cat = buffer_pos[pos];
4001
+ }
4002
+ }
4003
+ break;
4004
+ }
4005
+
4006
+ case Pooled:
4007
+ {
4008
+ /* here it will always pick the largest one */
4009
+ size_t ncat_present = 0;
4010
+ ldouble_safe cnt_max = 0;
4011
+ for (int cat = 0; cat < ncat; cat++)
4012
+ {
4013
+ if (buffer_cnt[cat])
4014
+ {
4015
+ ncat_present++;
4016
+ if (cnt_max < buffer_cnt[cat])
4017
+ {
4018
+ cnt_max = buffer_cnt[cat];
4019
+ chosen_cat = cat;
4020
+ }
4021
+ }
4022
+ }
4023
+
4024
+ if (ncat_present <= 1) return -HUGE_VAL;
4025
+
4026
+ ldouble_safe cnt_left = (ldouble_safe)(cnt - cnt_max);
4027
+
4028
+ /* TODO: think of a better way of dealing with numbers between zero and one */
4029
+ this_gain = (
4030
+ std::fmax((ldouble_safe)1, cnt) * std::log(std::fmax((ldouble_safe)1, cnt))
4031
+ - std::fmax((ldouble_safe)1, cnt_left) * std::log(std::fmax((ldouble_safe)1, cnt_left))
4032
+ - std::fmax((ldouble_safe)1, cnt_max) * std::log(std::fmax((ldouble_safe)1, cnt_max))
4033
+ ) / std::fmax((ldouble_safe)1, cnt);
4034
+ best_gain = (this_gain > min_gain)? this_gain : best_gain;
4035
+ break;
4036
+ }
4037
+
4038
+ default:
4039
+ {
4040
+ unexpected_error();
4041
+ break;
4042
+ }
4043
+ }
4044
+ break;
4045
+ }
4046
+
4047
+ case SubSet:
4048
+ {
4049
+ /* sort by counts */
4050
+ std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
4051
+
4052
+ /* set split as: (1):left (0):right (-1):not_present */
4053
+ memset(buffer_split, 0, ncat * sizeof(signed char));
4054
+
4055
+
4056
+ switch(criterion)
4057
+ {
4058
+ case Averaged:
4059
+ {
4060
+ /* determine first non-zero and convert to probabilities */
4061
+ double sd_full;
4062
+ for (int cat = 0; cat < ncat; cat++)
4063
+ {
4064
+ if (buffer_cnt[buffer_pos[cat]])
4065
+ {
4066
+ buffer_prob[buffer_pos[cat]] = (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt;
4067
+ }
4068
+
4069
+ else
4070
+ {
4071
+ buffer_split[buffer_pos[cat]] = -1;
4072
+ st_pos++;
4073
+ }
4074
+ }
4075
+
4076
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
4077
+
4078
+ /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
4079
+ size_t ncat_present = (size_t)ncat - st_pos;
4080
+ sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
4081
+ if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
4082
+
4083
+ /* move categories one at a time */
4084
+ for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
4085
+ {
4086
+ buffer_split[buffer_pos[pos]] = 1;
4087
+ /* TODO: is this correct? */
4088
+ this_gain = sd_gain(sd_full,
4089
+ (expected_sd_cat<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, pos - st_pos + 1, buffer_pos + st_pos)),
4090
+ (expected_sd_cat<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1))
4091
+ );
4092
+ if (this_gain > min_gain && this_gain > best_gain)
4093
+ {
4094
+ best_gain = this_gain;
4095
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
4096
+ }
4097
+ }
4098
+
4099
+ break;
4100
+ }
4101
+
4102
+ case Pooled:
4103
+ {
4104
+ ldouble_safe s = 0;
4105
+
4106
+ /* determine first non-zero and get base info */
4107
+ for (int cat = 0; cat < ncat; cat++)
4108
+ {
4109
+ if (buffer_cnt[buffer_pos[cat]])
4110
+ {
4111
+ s += (buffer_cnt[buffer_pos[cat]] <= 1)?
4112
+ (ldouble_safe)0
4113
+ :
4114
+ ((ldouble_safe) buffer_cnt[buffer_pos[cat]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[cat]]));
4115
+ }
4116
+
4117
+ else
4118
+ {
4119
+ buffer_split[buffer_pos[cat]] = -1;
4120
+ st_pos++;
4121
+ }
4122
+ }
4123
+
4124
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
4125
+
4126
+ /* calculate base info */
4127
+ ldouble_safe base_info = std::fmax((ldouble_safe)1, cnt) * std::log(std::fmax((ldouble_safe)1, cnt)) - s;
4128
+
4129
+ if (all_perm)
4130
+ {
4131
+ size_t cnt_left, cnt_right;
4132
+ double s_left, s_right;
4133
+ size_t ncat_present = (size_t)ncat - st_pos;
4134
+ size_t ncomb = pow2(ncat_present) - 1;
4135
+ size_t best_combin;
4136
+
4137
+ for (size_t combin = 1; combin < ncomb; combin++)
4138
+ {
4139
+ cnt_left = 0; cnt_right = 0;
4140
+ s_left = 0; s_right = 0;
4141
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
4142
+ {
4143
+ if (extract_bit(combin, pos))
4144
+ {
4145
+ cnt_left += buffer_cnt[buffer_pos[pos]];
4146
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
4147
+ (ldouble_safe)0
4148
+ :
4149
+ ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
4150
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4151
+ }
4152
+
4153
+ else
4154
+ {
4155
+ cnt_right += buffer_cnt[buffer_pos[pos]];
4156
+ s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
4157
+ (ldouble_safe)0
4158
+ :
4159
+ ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
4160
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4161
+ }
4162
+ }
4163
+
4164
+ this_gain = categ_gain<size_t, ldouble_safe>(
4165
+ cnt_left, cnt_right,
4166
+ s_left, s_right,
4167
+ base_info, cnt);
4168
+
4169
+ if (this_gain > min_gain && this_gain > best_gain)
4170
+ {
4171
+ best_gain = this_gain;
4172
+ best_combin = combin;
4173
+ }
4174
+
4175
+ }
4176
+
4177
+ if (best_gain > min_gain)
4178
+ for (size_t pos = 0; pos < ncat_present; pos++)
4179
+ split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
4180
+
4181
+ }
4182
+
4183
+ else
4184
+ {
4185
+ /* try moving the categories one at a time */
4186
+ size_t cnt_left = 0;
4187
+ size_t cnt_right = end - st + 1;
4188
+ double s_left = 0;
4189
+ double s_right = s;
4190
+
4191
+ for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
4192
+ {
4193
+ buffer_split[buffer_pos[pos]] = 1;
4194
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
4195
+ (ldouble_safe)0
4196
+ :
4197
+ ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4198
+ s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
4199
+ (ldouble_safe)0
4200
+ :
4201
+ ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4202
+ cnt_left += buffer_cnt[buffer_pos[pos]];
4203
+ cnt_right -= buffer_cnt[buffer_pos[pos]];
4204
+
4205
+ this_gain = categ_gain<size_t, ldouble_safe>(
4206
+ cnt_left, cnt_right,
4207
+ s_left, s_right,
4208
+ base_info, cnt);
4209
+
4210
+ if (this_gain > min_gain && this_gain > best_gain)
4211
+ {
4212
+ best_gain = this_gain;
4213
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
4214
+ }
4215
+ }
4216
+ }
4217
+
4218
+ break;
4219
+ }
4220
+
4221
+ default:
4222
+ {
4223
+ unexpected_error();
4224
+ break;
4225
+ }
4226
+ }
4227
+ }
4228
+ }
4229
+
4230
+ if (st == (end-1)) return 0;
4231
+
4232
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
4233
+ return 0;
4234
+ else
4235
+ return best_gain;
4236
+ }