isotree 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -1
  3. data/LICENSE.txt +2 -2
  4. data/README.md +32 -14
  5. data/ext/isotree/ext.cpp +144 -31
  6. data/ext/isotree/extconf.rb +7 -7
  7. data/lib/isotree/isolation_forest.rb +110 -30
  8. data/lib/isotree/version.rb +1 -1
  9. data/vendor/isotree/LICENSE +1 -1
  10. data/vendor/isotree/README.md +165 -27
  11. data/vendor/isotree/include/isotree.hpp +2111 -0
  12. data/vendor/isotree/include/isotree_oop.hpp +394 -0
  13. data/vendor/isotree/inst/COPYRIGHTS +62 -0
  14. data/vendor/isotree/src/RcppExports.cpp +525 -52
  15. data/vendor/isotree/src/Rwrapper.cpp +1931 -268
  16. data/vendor/isotree/src/c_interface.cpp +953 -0
  17. data/vendor/isotree/src/crit.hpp +4232 -0
  18. data/vendor/isotree/src/dist.hpp +1886 -0
  19. data/vendor/isotree/src/exp_depth_table.hpp +134 -0
  20. data/vendor/isotree/src/extended.hpp +1444 -0
  21. data/vendor/isotree/src/external_facing_generic.hpp +399 -0
  22. data/vendor/isotree/src/fit_model.hpp +2401 -0
  23. data/vendor/isotree/src/{dealloc.cpp → headers_joined.hpp} +38 -22
  24. data/vendor/isotree/src/helpers_iforest.hpp +813 -0
  25. data/vendor/isotree/src/{impute.cpp → impute.hpp} +353 -122
  26. data/vendor/isotree/src/indexer.cpp +515 -0
  27. data/vendor/isotree/src/instantiate_template_headers.cpp +118 -0
  28. data/vendor/isotree/src/instantiate_template_headers.hpp +240 -0
  29. data/vendor/isotree/src/isoforest.hpp +1659 -0
  30. data/vendor/isotree/src/isotree.hpp +1804 -392
  31. data/vendor/isotree/src/isotree_exportable.hpp +99 -0
  32. data/vendor/isotree/src/merge_models.cpp +159 -16
  33. data/vendor/isotree/src/mult.hpp +1321 -0
  34. data/vendor/isotree/src/oop_interface.cpp +842 -0
  35. data/vendor/isotree/src/oop_interface.hpp +278 -0
  36. data/vendor/isotree/src/other_helpers.hpp +219 -0
  37. data/vendor/isotree/src/predict.hpp +1932 -0
  38. data/vendor/isotree/src/python_helpers.hpp +134 -0
  39. data/vendor/isotree/src/ref_indexer.hpp +154 -0
  40. data/vendor/isotree/src/robinmap/LICENSE +21 -0
  41. data/vendor/isotree/src/robinmap/README.md +483 -0
  42. data/vendor/isotree/src/robinmap/include/tsl/robin_growth_policy.h +406 -0
  43. data/vendor/isotree/src/robinmap/include/tsl/robin_hash.h +1620 -0
  44. data/vendor/isotree/src/robinmap/include/tsl/robin_map.h +807 -0
  45. data/vendor/isotree/src/robinmap/include/tsl/robin_set.h +660 -0
  46. data/vendor/isotree/src/serialize.cpp +4300 -139
  47. data/vendor/isotree/src/sql.cpp +141 -59
  48. data/vendor/isotree/src/subset_models.cpp +174 -0
  49. data/vendor/isotree/src/utils.hpp +3808 -0
  50. data/vendor/isotree/src/xoshiro.hpp +467 -0
  51. data/vendor/isotree/src/ziggurat.hpp +405 -0
  52. metadata +38 -104
  53. data/vendor/cereal/LICENSE +0 -24
  54. data/vendor/cereal/README.md +0 -85
  55. data/vendor/cereal/include/cereal/access.hpp +0 -351
  56. data/vendor/cereal/include/cereal/archives/adapters.hpp +0 -163
  57. data/vendor/cereal/include/cereal/archives/binary.hpp +0 -169
  58. data/vendor/cereal/include/cereal/archives/json.hpp +0 -1019
  59. data/vendor/cereal/include/cereal/archives/portable_binary.hpp +0 -334
  60. data/vendor/cereal/include/cereal/archives/xml.hpp +0 -956
  61. data/vendor/cereal/include/cereal/cereal.hpp +0 -1089
  62. data/vendor/cereal/include/cereal/details/helpers.hpp +0 -422
  63. data/vendor/cereal/include/cereal/details/polymorphic_impl.hpp +0 -796
  64. data/vendor/cereal/include/cereal/details/polymorphic_impl_fwd.hpp +0 -65
  65. data/vendor/cereal/include/cereal/details/static_object.hpp +0 -127
  66. data/vendor/cereal/include/cereal/details/traits.hpp +0 -1411
  67. data/vendor/cereal/include/cereal/details/util.hpp +0 -84
  68. data/vendor/cereal/include/cereal/external/base64.hpp +0 -134
  69. data/vendor/cereal/include/cereal/external/rapidjson/allocators.h +0 -284
  70. data/vendor/cereal/include/cereal/external/rapidjson/cursorstreamwrapper.h +0 -78
  71. data/vendor/cereal/include/cereal/external/rapidjson/document.h +0 -2652
  72. data/vendor/cereal/include/cereal/external/rapidjson/encodedstream.h +0 -299
  73. data/vendor/cereal/include/cereal/external/rapidjson/encodings.h +0 -716
  74. data/vendor/cereal/include/cereal/external/rapidjson/error/en.h +0 -74
  75. data/vendor/cereal/include/cereal/external/rapidjson/error/error.h +0 -161
  76. data/vendor/cereal/include/cereal/external/rapidjson/filereadstream.h +0 -99
  77. data/vendor/cereal/include/cereal/external/rapidjson/filewritestream.h +0 -104
  78. data/vendor/cereal/include/cereal/external/rapidjson/fwd.h +0 -151
  79. data/vendor/cereal/include/cereal/external/rapidjson/internal/biginteger.h +0 -290
  80. data/vendor/cereal/include/cereal/external/rapidjson/internal/diyfp.h +0 -271
  81. data/vendor/cereal/include/cereal/external/rapidjson/internal/dtoa.h +0 -245
  82. data/vendor/cereal/include/cereal/external/rapidjson/internal/ieee754.h +0 -78
  83. data/vendor/cereal/include/cereal/external/rapidjson/internal/itoa.h +0 -308
  84. data/vendor/cereal/include/cereal/external/rapidjson/internal/meta.h +0 -186
  85. data/vendor/cereal/include/cereal/external/rapidjson/internal/pow10.h +0 -55
  86. data/vendor/cereal/include/cereal/external/rapidjson/internal/regex.h +0 -740
  87. data/vendor/cereal/include/cereal/external/rapidjson/internal/stack.h +0 -232
  88. data/vendor/cereal/include/cereal/external/rapidjson/internal/strfunc.h +0 -69
  89. data/vendor/cereal/include/cereal/external/rapidjson/internal/strtod.h +0 -290
  90. data/vendor/cereal/include/cereal/external/rapidjson/internal/swap.h +0 -46
  91. data/vendor/cereal/include/cereal/external/rapidjson/istreamwrapper.h +0 -128
  92. data/vendor/cereal/include/cereal/external/rapidjson/memorybuffer.h +0 -70
  93. data/vendor/cereal/include/cereal/external/rapidjson/memorystream.h +0 -71
  94. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/inttypes.h +0 -316
  95. data/vendor/cereal/include/cereal/external/rapidjson/msinttypes/stdint.h +0 -300
  96. data/vendor/cereal/include/cereal/external/rapidjson/ostreamwrapper.h +0 -81
  97. data/vendor/cereal/include/cereal/external/rapidjson/pointer.h +0 -1414
  98. data/vendor/cereal/include/cereal/external/rapidjson/prettywriter.h +0 -277
  99. data/vendor/cereal/include/cereal/external/rapidjson/rapidjson.h +0 -656
  100. data/vendor/cereal/include/cereal/external/rapidjson/reader.h +0 -2230
  101. data/vendor/cereal/include/cereal/external/rapidjson/schema.h +0 -2497
  102. data/vendor/cereal/include/cereal/external/rapidjson/stream.h +0 -223
  103. data/vendor/cereal/include/cereal/external/rapidjson/stringbuffer.h +0 -121
  104. data/vendor/cereal/include/cereal/external/rapidjson/writer.h +0 -709
  105. data/vendor/cereal/include/cereal/external/rapidxml/license.txt +0 -52
  106. data/vendor/cereal/include/cereal/external/rapidxml/manual.html +0 -406
  107. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml.hpp +0 -2624
  108. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_iterators.hpp +0 -175
  109. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_print.hpp +0 -428
  110. data/vendor/cereal/include/cereal/external/rapidxml/rapidxml_utils.hpp +0 -123
  111. data/vendor/cereal/include/cereal/macros.hpp +0 -154
  112. data/vendor/cereal/include/cereal/specialize.hpp +0 -139
  113. data/vendor/cereal/include/cereal/types/array.hpp +0 -79
  114. data/vendor/cereal/include/cereal/types/atomic.hpp +0 -55
  115. data/vendor/cereal/include/cereal/types/base_class.hpp +0 -203
  116. data/vendor/cereal/include/cereal/types/bitset.hpp +0 -176
  117. data/vendor/cereal/include/cereal/types/boost_variant.hpp +0 -164
  118. data/vendor/cereal/include/cereal/types/chrono.hpp +0 -72
  119. data/vendor/cereal/include/cereal/types/common.hpp +0 -129
  120. data/vendor/cereal/include/cereal/types/complex.hpp +0 -56
  121. data/vendor/cereal/include/cereal/types/concepts/pair_associative_container.hpp +0 -73
  122. data/vendor/cereal/include/cereal/types/deque.hpp +0 -62
  123. data/vendor/cereal/include/cereal/types/forward_list.hpp +0 -68
  124. data/vendor/cereal/include/cereal/types/functional.hpp +0 -43
  125. data/vendor/cereal/include/cereal/types/list.hpp +0 -62
  126. data/vendor/cereal/include/cereal/types/map.hpp +0 -36
  127. data/vendor/cereal/include/cereal/types/memory.hpp +0 -425
  128. data/vendor/cereal/include/cereal/types/optional.hpp +0 -66
  129. data/vendor/cereal/include/cereal/types/polymorphic.hpp +0 -483
  130. data/vendor/cereal/include/cereal/types/queue.hpp +0 -132
  131. data/vendor/cereal/include/cereal/types/set.hpp +0 -103
  132. data/vendor/cereal/include/cereal/types/stack.hpp +0 -76
  133. data/vendor/cereal/include/cereal/types/string.hpp +0 -61
  134. data/vendor/cereal/include/cereal/types/tuple.hpp +0 -123
  135. data/vendor/cereal/include/cereal/types/unordered_map.hpp +0 -36
  136. data/vendor/cereal/include/cereal/types/unordered_set.hpp +0 -99
  137. data/vendor/cereal/include/cereal/types/utility.hpp +0 -47
  138. data/vendor/cereal/include/cereal/types/valarray.hpp +0 -89
  139. data/vendor/cereal/include/cereal/types/variant.hpp +0 -109
  140. data/vendor/cereal/include/cereal/types/vector.hpp +0 -112
  141. data/vendor/cereal/include/cereal/version.hpp +0 -52
  142. data/vendor/isotree/src/Makevars +0 -4
  143. data/vendor/isotree/src/crit.cpp +0 -912
  144. data/vendor/isotree/src/dist.cpp +0 -749
  145. data/vendor/isotree/src/extended.cpp +0 -790
  146. data/vendor/isotree/src/fit_model.cpp +0 -1090
  147. data/vendor/isotree/src/helpers_iforest.cpp +0 -324
  148. data/vendor/isotree/src/isoforest.cpp +0 -771
  149. data/vendor/isotree/src/mult.cpp +0 -607
  150. data/vendor/isotree/src/predict.cpp +0 -853
  151. data/vendor/isotree/src/utils.cpp +0 -1566
@@ -0,0 +1,4232 @@
1
+ /* Isolation forests and variations thereof, with adjustments for incorporation
2
+ * of categorical variables and missing values.
3
+ * Writen for C++11 standard and aimed at being used in R and Python.
4
+ *
5
+ * This library is based on the following works:
6
+ * [1] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
7
+ * "Isolation forest."
8
+ * 2008 Eighth IEEE International Conference on Data Mining. IEEE, 2008.
9
+ * [2] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
10
+ * "Isolation-based anomaly detection."
11
+ * ACM Transactions on Knowledge Discovery from Data (TKDD) 6.1 (2012): 3.
12
+ * [3] Hariri, Sahand, Matias Carrasco Kind, and Robert J. Brunner.
13
+ * "Extended Isolation Forest."
14
+ * arXiv preprint arXiv:1811.02141 (2018).
15
+ * [4] Liu, Fei Tony, Kai Ming Ting, and Zhi-Hua Zhou.
16
+ * "On detecting clustered anomalies using SCiForest."
17
+ * Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Berlin, Heidelberg, 2010.
18
+ * [5] https://sourceforge.net/projects/iforest/
19
+ * [6] https://math.stackexchange.com/questions/3388518/expected-number-of-paths-required-to-separate-elements-in-a-binary-tree
20
+ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
21
+ * [8] Cortes, David.
22
+ * "Distance approximation using Isolation Forests."
23
+ * arXiv preprint arXiv:1910.12362 (2019).
24
+ * [9] Cortes, David.
25
+ * "Imputing missing values with unsupervised random trees."
26
+ * arXiv preprint arXiv:1911.06646 (2019).
27
+ * [10] https://math.stackexchange.com/questions/3333220/expected-average-depth-in-random-binary-tree-constructed-top-to-bottom
28
+ * [11] Cortes, David.
29
+ * "Revisiting randomized choices in isolation forests."
30
+ * arXiv preprint arXiv:2110.13402 (2021).
31
+ * [12] Guha, Sudipto, et al.
32
+ * "Robust random cut forest based anomaly detection on streams."
33
+ * International conference on machine learning. PMLR, 2016.
34
+ * [13] Cortes, David.
35
+ * "Isolation forests: looking beyond tree depth."
36
+ * arXiv preprint arXiv:2111.11639 (2021).
37
+ * [14] Ting, Kai Ming, Yue Zhu, and Zhi-Hua Zhou.
38
+ * "Isolation kernel and its effect on SVM"
39
+ * Proceedings of the 24th ACM SIGKDD
40
+ * International Conference on Knowledge Discovery & Data Mining. 2018.
41
+ *
42
+ * BSD 2-Clause License
43
+ * Copyright (c) 2019-2022, David Cortes
44
+ * All rights reserved.
45
+ * Redistribution and use in source and binary forms, with or without
46
+ * modification, are permitted provided that the following conditions are met:
47
+ * * Redistributions of source code must retain the above copyright notice, this
48
+ * list of conditions and the following disclaimer.
49
+ * * Redistributions in binary form must reproduce the above copyright notice,
50
+ * this list of conditions and the following disclaimer in the documentation
51
+ * and/or other materials provided with the distribution.
52
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
55
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
56
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
57
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
58
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
59
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
60
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
61
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
62
+ */
63
+ #include "isotree.hpp"
64
+
65
+ /* TODO: should the kurtosis calculation impute values when using ndim=1 + missing_action=Impute?
66
+ It should be the theoretically correct approach, but will cause the kurtosis to increase
67
+ significantly if there is a large number of missing values, which would lead to prefer
68
+ splitting on columns with mostly missing values. */
69
+
70
+ /* TODO: this kurtosis caps the minimum values to zero, but the minimum achievable value is 1,
71
+ see how are imprecise results used outside of the function in the different kind of calculations
72
+ that use kurtosis and maybe change the logic. */
73
+
74
+ #define pw1(x) ((x))
75
+ #define pw2(x) ((x) * (x))
76
+ #define pw3(x) ((x) * (x) * (x))
77
+ #define pw4(x) ((x) * (x) * (x) * (x))
78
+
79
+ /* https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics */
80
+ template <class real_t, class ldouble_safe>
81
+ double calc_kurtosis(size_t ix_arr[], size_t st, size_t end, real_t x[], MissingAction missing_action)
82
+ {
83
+ ldouble_safe m = 0;
84
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
85
+ ldouble_safe delta, delta_s, delta_div;
86
+ ldouble_safe diff, n;
87
+ ldouble_safe out;
88
+
89
+ if (missing_action == Fail)
90
+ {
91
+ for (size_t row = st; row <= end; row++)
92
+ {
93
+ n = (ldouble_safe)(row - st + 1);
94
+
95
+ delta = x[ix_arr[row]] - m;
96
+ delta_div = delta / n;
97
+ delta_s = delta_div * delta_div;
98
+ diff = delta * (delta_div * (ldouble_safe)(row - st));
99
+
100
+ m += delta_div;
101
+ M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
102
+ M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
103
+ M2 += diff;
104
+ }
105
+
106
+ if (unlikely(!is_na_or_inf(M2) && M2 <= 0))
107
+ {
108
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
109
+ return -HUGE_VAL;
110
+ }
111
+
112
+ out = ( M4 / M2 ) * ( (ldouble_safe)(end - st + 1) / M2 );
113
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
114
+ }
115
+
116
+ else
117
+ {
118
+ size_t cnt = 0;
119
+ for (size_t row = st; row <= end; row++)
120
+ {
121
+ if (likely(!is_na_or_inf(x[ix_arr[row]])))
122
+ {
123
+ cnt++;
124
+ n = (ldouble_safe) cnt;
125
+
126
+ delta = x[ix_arr[row]] - m;
127
+ delta_div = delta / n;
128
+ delta_s = delta_div * delta_div;
129
+ diff = delta * (delta_div * (ldouble_safe)(cnt - 1));
130
+
131
+ m += delta_div;
132
+ M4 += diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
133
+ M3 += diff * delta_div * (n - 2) - 3 * delta_div * M2;
134
+ M2 += diff;
135
+ }
136
+ }
137
+
138
+ if (unlikely(cnt == 0)) return -HUGE_VAL;
139
+ if (unlikely(!is_na_or_inf(M2) && M2 <= 0))
140
+ {
141
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
142
+ return -HUGE_VAL;
143
+ }
144
+
145
+ out = ( M4 / M2 ) * ( (ldouble_safe)cnt / M2 );
146
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
147
+ }
148
+ }
149
+
150
+ template <class real_t, class ldouble_safe>
151
+ double calc_kurtosis(real_t x[], size_t n, MissingAction missing_action)
152
+ {
153
+ ldouble_safe m = 0;
154
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
155
+ ldouble_safe delta, delta_s, delta_div;
156
+ ldouble_safe diff, n_;
157
+ ldouble_safe out;
158
+
159
+ if (missing_action == Fail)
160
+ {
161
+ for (size_t row = 0; row < n; row++)
162
+ {
163
+ n_ = (ldouble_safe)(row + 1);
164
+
165
+ delta = x[row] - m;
166
+ delta_div = delta / n_;
167
+ delta_s = delta_div * delta_div;
168
+ diff = delta * (delta_div * (ldouble_safe)row);
169
+
170
+ m += delta_div;
171
+ M4 += diff * delta_s * (n_ * n_ - 3 * n_ + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
172
+ M3 += diff * delta_div * (n_ - 2) - 3 * delta_div * M2;
173
+ M2 += diff;
174
+ }
175
+
176
+ out = ( M4 / M2 ) * ( (ldouble_safe)n / M2 );
177
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
178
+ }
179
+
180
+ else
181
+ {
182
+ size_t cnt = 0;
183
+ for (size_t row = 0; row < n; row++)
184
+ {
185
+ if (likely(!is_na_or_inf(x[row])))
186
+ {
187
+ cnt++;
188
+ n_ = (ldouble_safe) cnt;
189
+
190
+ delta = x[row] - m;
191
+ delta_div = delta / n_;
192
+ delta_s = delta_div * delta_div;
193
+ diff = delta * (delta_div * (ldouble_safe)(cnt - 1));
194
+
195
+ m += delta_div;
196
+ M4 += diff * delta_s * (n_ * n_ - 3 * n_ + 3) + 6 * delta_s * M2 - 4 * delta_div * M3;
197
+ M3 += diff * delta_div * (n_ - 2) - 3 * delta_div * M2;
198
+ M2 += diff;
199
+ }
200
+ }
201
+
202
+ if (unlikely(cnt == 0)) return -HUGE_VAL;
203
+
204
+ out = ( M4 / M2 ) * ( (ldouble_safe)cnt / M2 );
205
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
206
+ }
207
+ }
208
+
209
+ /* TODO: is this algorithm correct? */
210
+ template <class real_t, class mapping, class ldouble_safe>
211
+ double calc_kurtosis_weighted(size_t ix_arr[], size_t st, size_t end, real_t x[],
212
+ MissingAction missing_action, mapping &restrict w)
213
+ {
214
+ ldouble_safe m = 0;
215
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
216
+ ldouble_safe delta, delta_s, delta_div;
217
+ ldouble_safe diff;
218
+ ldouble_safe n = 0;
219
+ ldouble_safe out;
220
+ ldouble_safe n_prev = 0.;
221
+ ldouble_safe w_this;
222
+
223
+ for (size_t row = st; row <= end; row++)
224
+ {
225
+ if (likely(!is_na_or_inf(x[ix_arr[row]])))
226
+ {
227
+ w_this = w[ix_arr[row]];
228
+ n += w_this;
229
+
230
+ delta = x[ix_arr[row]] - m;
231
+ delta_div = delta / n;
232
+ delta_s = delta_div * delta_div;
233
+ diff = delta * (delta_div * n_prev);
234
+ n_prev = n;
235
+
236
+ m += w_this * (delta_div);
237
+ M4 += w_this * (diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3);
238
+ M3 += w_this * (diff * delta_div * (n - 2) - 3 * delta_div * M2);
239
+ M2 += w_this * (diff);
240
+ }
241
+ }
242
+
243
+ if (unlikely(n <= 0)) return -HUGE_VAL;
244
+ if (unlikely(!is_na_or_inf(M2) && M2 <= std::numeric_limits<double>::epsilon()))
245
+ {
246
+ if (!check_more_than_two_unique_values(ix_arr, st, end, x, missing_action))
247
+ return -HUGE_VAL;
248
+ }
249
+
250
+ out = ( M4 / M2 ) * ( n / M2 );
251
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
252
+ }
253
+
254
+ template <class real_t, class ldouble_safe>
255
+ double calc_kurtosis_weighted(real_t *restrict x, size_t n_, MissingAction missing_action, real_t *restrict w)
256
+ {
257
+ ldouble_safe m = 0;
258
+ ldouble_safe M2 = 0, M3 = 0, M4 = 0;
259
+ ldouble_safe delta, delta_s, delta_div;
260
+ ldouble_safe diff;
261
+ ldouble_safe n = 0;
262
+ ldouble_safe out;
263
+ ldouble_safe n_prev = 0.;
264
+ ldouble_safe w_this;
265
+
266
+ for (size_t row = 0; row < n_; row++)
267
+ {
268
+ if (likely(!is_na_or_inf(x[row])))
269
+ {
270
+ w_this = w[row];
271
+ n += w_this;
272
+
273
+ delta = x[row] - m;
274
+ delta_div = delta / n;
275
+ delta_s = delta_div * delta_div;
276
+ diff = delta * (delta_div * n_prev);
277
+ n_prev = n;
278
+
279
+ m += w_this * (delta_div);
280
+ M4 += w_this * (diff * delta_s * (n * n - 3 * n + 3) + 6 * delta_s * M2 - 4 * delta_div * M3);
281
+ M3 += w_this * (diff * delta_div * (n - 2) - 3 * delta_div * M2);
282
+ M2 += w_this * (diff);
283
+ }
284
+ }
285
+
286
+ if (unlikely(n <= 0)) return -HUGE_VAL;
287
+
288
+ out = ( M4 / M2 ) * ( n / M2 );
289
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
290
+ }
291
+
292
+
293
+ /* TODO: make these compensated sums */
294
+ /* TODO: can this use the same algorithm as above but with a correction at the end,
295
+ like it was done for the variance? */
296
+ template <class real_t, class sparse_ix, class ldouble_safe>
297
+ double calc_kurtosis(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
298
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
299
+ MissingAction missing_action)
300
+ {
301
+ /* ix_arr must be already sorted beforehand */
302
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
303
+ return -HUGE_VAL;
304
+
305
+ ldouble_safe s1 = 0;
306
+ ldouble_safe s2 = 0;
307
+ ldouble_safe s3 = 0;
308
+ ldouble_safe s4 = 0;
309
+ ldouble_safe x_sq;
310
+ size_t cnt = end - st + 1;
311
+
312
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
313
+
314
+ size_t st_col = Xc_indptr[col_num];
315
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
316
+ size_t curr_pos = st_col;
317
+ size_t ind_end_col = Xc_ind[end_col];
318
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
319
+
320
+ ldouble_safe xval;
321
+
322
+ if (missing_action != Fail)
323
+ {
324
+ for (size_t *row = ptr_st;
325
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
326
+ )
327
+ {
328
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
329
+ {
330
+ xval = Xc[curr_pos];
331
+ if (unlikely(is_na_or_inf(xval)))
332
+ {
333
+ cnt--;
334
+ }
335
+
336
+ else
337
+ {
338
+ /* TODO: is it safe to use FMA here? some calculations rely on assuming that
339
+ some of these 's' are larger than the others. Would this procedure be guaranteed
340
+ to preserve such differences if done with a mixture of sums and FMAs? */
341
+ x_sq = square(xval);
342
+ s1 += xval;
343
+ s2 = std::fma(xval, xval, s2);
344
+ s3 = std::fma(x_sq, xval, s3);
345
+ s4 = std::fma(x_sq, x_sq, s4);
346
+ // s1 += pw1(xval);
347
+ // s2 += pw2(xval);
348
+ // s3 += pw3(xval);
349
+ // s4 += pw4(xval);
350
+ }
351
+
352
+ if (row == ix_arr + end || curr_pos == end_col) break;
353
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
354
+ }
355
+
356
+ else
357
+ {
358
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
359
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
360
+ else
361
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
362
+ }
363
+ }
364
+
365
+ if (unlikely(cnt <= (end - st + 1) - (Xc_indptr[col_num+1] - Xc_indptr[col_num]))) return -HUGE_VAL;
366
+ }
367
+
368
+ else
369
+ {
370
+ for (size_t *row = ptr_st;
371
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
372
+ )
373
+ {
374
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
375
+ {
376
+ xval = Xc[curr_pos];
377
+ x_sq = square(xval);
378
+ s1 += xval;
379
+ s2 = std::fma(xval, xval, s2);
380
+ s3 = std::fma(x_sq, xval, s3);
381
+ s4 = std::fma(x_sq, x_sq, s4);
382
+ // s1 += pw1(xval);
383
+ // s2 += pw2(xval);
384
+ // s3 += pw3(xval);
385
+ // s4 += pw4(xval);
386
+
387
+ if (row == ix_arr + end || curr_pos == end_col) break;
388
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
389
+ }
390
+
391
+ else
392
+ {
393
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
394
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
395
+ else
396
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
397
+ }
398
+ }
399
+ }
400
+
401
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
402
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
403
+ ldouble_safe sn = s1 / cnt_l;
404
+ ldouble_safe v = s2 / cnt_l - pw2(sn);
405
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
406
+ if (
407
+ v <= std::numeric_limits<double>::epsilon() &&
408
+ !check_more_than_two_unique_values(ix_arr, st, end, col_num,
409
+ Xc_indptr, Xc_ind, Xc,
410
+ missing_action)
411
+ )
412
+ return -HUGE_VAL;
413
+ if (unlikely(v <= 0)) return 0.;
414
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
415
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
416
+ }
417
+
418
+ template <class real_t, class sparse_ix, class ldouble_safe>
419
+ double calc_kurtosis(size_t col_num, size_t nrows,
420
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
421
+ MissingAction missing_action)
422
+ {
423
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
424
+ return -HUGE_VAL;
425
+
426
+ ldouble_safe s1 = 0;
427
+ ldouble_safe s2 = 0;
428
+ ldouble_safe s3 = 0;
429
+ ldouble_safe s4 = 0;
430
+ ldouble_safe x_sq;
431
+ size_t cnt = nrows;
432
+
433
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
434
+
435
+ ldouble_safe xval;
436
+
437
+ if (missing_action != Fail)
438
+ {
439
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
440
+ {
441
+ xval = Xc[ix];
442
+ if (unlikely(is_na_or_inf(xval)))
443
+ {
444
+ cnt--;
445
+ }
446
+
447
+ else
448
+ {
449
+ x_sq = square(xval);
450
+ s1 += xval;
451
+ s2 = std::fma(xval, xval, s2);
452
+ s3 = std::fma(x_sq, xval, s3);
453
+ s4 = std::fma(x_sq, x_sq, s4);
454
+ // s1 += pw1(xval);
455
+ // s2 += pw2(xval);
456
+ // s3 += pw3(xval);
457
+ // s4 += pw4(xval);
458
+ }
459
+ }
460
+
461
+ if (cnt <= (nrows) - (Xc_indptr[col_num+1] - Xc_indptr[col_num])) return -HUGE_VAL;
462
+ }
463
+
464
+ else
465
+ {
466
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num+1]; ix++)
467
+ {
468
+ xval = Xc[ix];
469
+ x_sq = square(xval);
470
+ s1 += xval;
471
+ s2 = std::fma(xval, xval, s2);
472
+ s3 = std::fma(x_sq, xval, s3);
473
+ s4 = std::fma(x_sq, x_sq, s4);
474
+ // s1 += pw1(xval);
475
+ // s2 += pw2(xval);
476
+ // s3 += pw3(xval);
477
+ // s4 += pw4(xval);
478
+ }
479
+ }
480
+
481
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
482
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
483
+ ldouble_safe sn = s1 / cnt_l;
484
+ ldouble_safe v = s2 / cnt_l - pw2(sn);
485
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
486
+ if (
487
+ v <= std::numeric_limits<double>::epsilon() &&
488
+ !check_more_than_two_unique_values(nrows, col_num,
489
+ Xc_indptr, Xc_ind, Xc,
490
+ missing_action)
491
+ )
492
+ return -HUGE_VAL;
493
+ if (unlikely(v <= 0)) return 0.;
494
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt_l * pw4(sn)) / (cnt_l * pw2(v));
495
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
496
+ }
497
+
498
+
499
+ template <class real_t, class sparse_ix, class mapping, class ldouble_safe>
500
+ double calc_kurtosis_weighted(size_t *restrict ix_arr, size_t st, size_t end, size_t col_num,
501
+ real_t Xc[], sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
502
+ MissingAction missing_action, mapping &restrict w)
503
+ {
504
+ /* ix_arr must be already sorted beforehand */
505
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
506
+ return -HUGE_VAL;
507
+
508
+ ldouble_safe s1 = 0;
509
+ ldouble_safe s2 = 0;
510
+ ldouble_safe s3 = 0;
511
+ ldouble_safe s4 = 0;
512
+ ldouble_safe x_sq;
513
+ ldouble_safe w_this;
514
+ ldouble_safe cnt = 0;
515
+ for (size_t row = st; row <= end; row++)
516
+ cnt += w[ix_arr[row]];
517
+
518
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
519
+
520
+ size_t st_col = Xc_indptr[col_num];
521
+ size_t end_col = Xc_indptr[col_num + 1] - 1;
522
+ size_t curr_pos = st_col;
523
+ size_t ind_end_col = Xc_ind[end_col];
524
+ size_t *ptr_st = std::lower_bound(ix_arr + st, ix_arr + end + 1, Xc_ind[st_col]);
525
+
526
+ ldouble_safe xval;
527
+
528
+ if (missing_action != Fail)
529
+ {
530
+ for (size_t *row = ptr_st;
531
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
532
+ )
533
+ {
534
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
535
+ {
536
+ w_this = w[*row];
537
+ xval = Xc[curr_pos];
538
+
539
+ if (unlikely(is_na_or_inf(xval)))
540
+ {
541
+ cnt -= w_this;
542
+ }
543
+
544
+ else
545
+ {
546
+ x_sq = xval * xval;
547
+ s1 = std::fma(w_this, xval, s1);
548
+ s2 = std::fma(w_this, x_sq, s2);
549
+ s3 = std::fma(w_this, x_sq*xval, s3);
550
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
551
+ // s1 += w_this * pw1(xval);
552
+ // s2 += w_this * pw2(xval);
553
+ // s3 += w_this * pw3(xval);
554
+ // s4 += w_this * pw4(xval);
555
+ }
556
+
557
+ if (row == ix_arr + end || curr_pos == end_col) break;
558
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
559
+ }
560
+
561
+ else
562
+ {
563
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
564
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
565
+ else
566
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
567
+ }
568
+ }
569
+
570
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
571
+ }
572
+
573
+ else
574
+ {
575
+ for (size_t *row = ptr_st;
576
+ row != ix_arr + end + 1 && curr_pos != end_col + 1 && ind_end_col >= *row;
577
+ )
578
+ {
579
+ if (Xc_ind[curr_pos] == (sparse_ix)(*row))
580
+ {
581
+ w_this = w[*row];
582
+ xval = Xc[curr_pos];
583
+
584
+ x_sq = xval * xval;
585
+ s1 = std::fma(w_this, xval, s1);
586
+ s2 = std::fma(w_this, x_sq, s2);
587
+ s3 = std::fma(w_this, x_sq*xval, s3);
588
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
589
+ // s1 += w_this * pw1(xval);
590
+ // s2 += w_this * pw2(xval);
591
+ // s3 += w_this * pw3(xval);
592
+ // s4 += w_this * pw4(xval);
593
+
594
+ if (row == ix_arr + end || curr_pos == end_col) break;
595
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *(++row)) - Xc_ind;
596
+ }
597
+
598
+ else
599
+ {
600
+ if (Xc_ind[curr_pos] > (sparse_ix)(*row))
601
+ row = std::lower_bound(row + 1, ix_arr + end + 1, Xc_ind[curr_pos]);
602
+ else
603
+ curr_pos = std::lower_bound(Xc_ind + curr_pos + 1, Xc_ind + end_col + 1, *row) - Xc_ind;
604
+ }
605
+ }
606
+ }
607
+
608
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
609
+ ldouble_safe sn = s1 / cnt;
610
+ ldouble_safe v = s2 / cnt - pw2(sn);
611
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
612
+ if (
613
+ v <= std::numeric_limits<double>::epsilon() &&
614
+ !check_more_than_two_unique_values(ix_arr, st, end, col_num,
615
+ Xc_indptr, Xc_ind, Xc,
616
+ missing_action)
617
+ )
618
+ return -HUGE_VAL;
619
+ if (v <= 0) return 0.;
620
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt * pw4(sn)) / (cnt * pw2(v));
621
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
622
+ }
623
+
624
+ template <class real_t, class sparse_ix, class ldouble_safe>
625
+ double calc_kurtosis_weighted(size_t col_num, size_t nrows,
626
+ real_t *restrict Xc, sparse_ix *restrict Xc_ind, sparse_ix *restrict Xc_indptr,
627
+ MissingAction missing_action, real_t *restrict w)
628
+ {
629
+ if (Xc_indptr[col_num] == Xc_indptr[col_num + 1])
630
+ return -HUGE_VAL;
631
+
632
+ ldouble_safe s1 = 0;
633
+ ldouble_safe s2 = 0;
634
+ ldouble_safe s3 = 0;
635
+ ldouble_safe s4 = 0;
636
+ ldouble_safe x_sq;
637
+ ldouble_safe w_this;
638
+ ldouble_safe cnt = nrows - (Xc_indptr[col_num + 1] - Xc_indptr[col_num]);
639
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
640
+ cnt += w[Xc_ind[ix]];
641
+
642
+ if (unlikely(cnt <= 0)) return -HUGE_VAL;
643
+
644
+ ldouble_safe xval;
645
+
646
+ if (missing_action != Fail)
647
+ {
648
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
649
+ {
650
+ w_this = w[Xc_ind[ix]];
651
+ xval = Xc[ix];
652
+
653
+ if (unlikely(is_na_or_inf(xval)))
654
+ {
655
+ cnt -= w_this;
656
+ }
657
+
658
+ else
659
+ {
660
+ x_sq = xval * xval;
661
+ s1 = std::fma(w_this, xval, s1);
662
+ s2 = std::fma(w_this, x_sq, s2);
663
+ s3 = std::fma(w_this, x_sq*xval, s3);
664
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
665
+ // s1 += w_this * pw1(xval);
666
+ // s2 += w_this * pw2(xval);
667
+ // s3 += w_this * pw3(xval);
668
+ // s4 += w_this * pw4(xval);
669
+ }
670
+ }
671
+
672
+ if (cnt <= 0) return -HUGE_VAL;
673
+ }
674
+
675
+ else
676
+ {
677
+ for (auto ix = Xc_indptr[col_num]; ix < Xc_indptr[col_num + 1]; ix++)
678
+ {
679
+ w_this = w[Xc_ind[ix]];
680
+ xval = Xc[ix];
681
+
682
+ x_sq = xval * xval;
683
+ s1 = std::fma(w_this, xval, s1);
684
+ s2 = std::fma(w_this, x_sq, s2);
685
+ s3 = std::fma(w_this, x_sq*xval, s3);
686
+ s4 = std::fma(w_this, x_sq*x_sq, s4);
687
+ // s1 += w_this * pw1(xval);
688
+ // s2 += w_this * pw2(xval);
689
+ // s3 += w_this * pw3(xval);
690
+ // s4 += w_this * pw4(xval);
691
+ }
692
+ }
693
+
694
+ if (unlikely(cnt <= 1 || s2 == 0 || s2 == pw2(s1))) return -HUGE_VAL;
695
+ ldouble_safe sn = s1 / cnt;
696
+ ldouble_safe v = s2 / cnt - pw2(sn);
697
+ if (unlikely(std::isnan(v))) return -HUGE_VAL;
698
+ if (
699
+ v <= std::numeric_limits<double>::epsilon() &&
700
+ !check_more_than_two_unique_values(nrows, col_num,
701
+ Xc_indptr, Xc_ind, Xc,
702
+ missing_action)
703
+ )
704
+ return -HUGE_VAL;
705
+ if (unlikely(v <= 0)) return -HUGE_VAL;
706
+ ldouble_safe out = (s4 - 4 * s3 * sn + 6 * s2 * pw2(sn) - 4 * s1 * pw3(sn) + cnt * pw4(sn)) / (cnt * pw2(v));
707
+ return (!is_na_or_inf(out))? std::fmax((double)out, 0.) : (-HUGE_VAL);
708
+ }
709
+
710
+
711
+ template <class ldouble_safe>
712
+ double calc_kurtosis_internal(size_t cnt, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
713
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
714
+ {
715
+ /* This calculation proceeds as follows:
716
+ - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
717
+ each category, and approximate kurtosis by sampling from such distribution
718
+ with the same probabilities as given by the current counts.
719
+ - If splitting by isolating one category, will binarize at each categorical level,
720
+ assume the values are zero or one, and output the average assuming each categorical
721
+ level has equal probability of being picked.
722
+ (Note that both are misleading heuristics, but might be better than random)
723
+ */
724
+ double sum_kurt = 0;
725
+
726
+ cnt -= buffer_cnt[ncat];
727
+ if (cnt <= 1) return -HUGE_VAL;
728
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
729
+ for (int cat = 0; cat < ncat; cat++)
730
+ buffer_prob[cat] = buffer_cnt[cat] / cnt_l;
731
+
732
+ switch (cat_split_type)
733
+ {
734
+ case SubSet:
735
+ {
736
+ ldouble_safe temp_v;
737
+ ldouble_safe s1, s2, s3, s4;
738
+ ldouble_safe coef;
739
+ ldouble_safe coef2;
740
+ ldouble_safe w_this;
741
+ UniformUnitInterval runif(0, 1);
742
+ size_t ntry = 50;
743
+ for (size_t iternum = 0; iternum < 50; iternum++)
744
+ {
745
+ s1 = 0; s2 = 0; s3 = 0; s4 = 0;
746
+ for (int cat = 0; cat < ncat; cat++)
747
+ {
748
+ coef = runif(rnd_generator);
749
+ coef2 = coef * coef;
750
+ w_this = buffer_prob[cat];
751
+ s1 = std::fma(w_this, coef, s1);
752
+ s2 = std::fma(w_this, coef2, s2);
753
+ s3 = std::fma(w_this, coef2*coef, s3);
754
+ s4 = std::fma(w_this, coef2*coef2, s4);
755
+ // s1 += buffer_prob[cat] * pw1(coef);
756
+ // s2 += buffer_prob[cat] * pw2(coef);
757
+ // s3 += buffer_prob[cat] * pw3(coef);
758
+ // s4 += buffer_prob[cat] * pw4(coef);
759
+ }
760
+ temp_v = s2 - pw2(s1);
761
+ if (temp_v <= 0)
762
+ ntry--;
763
+ else
764
+ sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
765
+ }
766
+ if (unlikely(!ntry))
767
+ return -HUGE_VAL;
768
+ else if (unlikely(is_na_or_inf(sum_kurt)))
769
+ return -HUGE_VAL;
770
+ else
771
+ return std::fmax(sum_kurt, 0.) / (double)ntry;
772
+ }
773
+
774
+ case SingleCateg:
775
+ {
776
+ double p;
777
+ int ncat_present = ncat;
778
+ for (int cat = 0; cat < ncat; cat++)
779
+ {
780
+ p = buffer_prob[cat];
781
+ if (p == 0)
782
+ ncat_present--;
783
+ else
784
+ sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
785
+ }
786
+ if (ncat_present <= 1)
787
+ return -HUGE_VAL;
788
+ else if (unlikely(is_na_or_inf(sum_kurt)))
789
+ return -HUGE_VAL;
790
+ else
791
+ return std::fmax(sum_kurt, 0.) / (double)ncat_present;
792
+ }
793
+ }
794
+
795
+ return -1; /* this will never be reached, but CRAN complains otherwise */
796
+ }
797
+
798
+ template <class ldouble_safe>
799
+ double calc_kurtosis(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat, size_t *restrict buffer_cnt, double buffer_prob[],
800
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
801
+ {
802
+ /* This calculation proceeds as follows:
803
+ - If splitting by subsets, it will assign a random weight ~Unif(0,1) to
804
+ each category, and approximate kurtosis by sampling from such distribution
805
+ with the same probabilities as given by the current counts.
806
+ - If splitting by isolating one category, will binarize at each categorical level,
807
+ assume the values are zero or one, and output the average assuming each categorical
808
+ level has equal probability of being picked.
809
+ (Note that both are misleading heuristics, but might be better than random)
810
+ */
811
+ size_t cnt = end - st + 1;
812
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
813
+
814
+ if (missing_action == Fail)
815
+ {
816
+ for (size_t row = st; row <= end; row++)
817
+ buffer_cnt[x[ix_arr[row]]]++;
818
+ }
819
+
820
+ else
821
+ {
822
+ for (size_t row = st; row <= end; row++)
823
+ {
824
+ if (likely(x[ix_arr[row]] >= 0))
825
+ buffer_cnt[x[ix_arr[row]]]++;
826
+ else
827
+ buffer_cnt[ncat]++;
828
+ }
829
+ }
830
+
831
+ return calc_kurtosis_internal<ldouble_safe>(
832
+ cnt, x, ncat, buffer_cnt, buffer_prob,
833
+ missing_action, cat_split_type, rnd_generator);
834
+ }
835
+
836
+ template <class ldouble_safe>
837
+ double calc_kurtosis(size_t nrows, int x[], int ncat, size_t buffer_cnt[], double buffer_prob[],
838
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator)
839
+ {
840
+ size_t cnt = nrows;
841
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
842
+
843
+ if (missing_action == Fail)
844
+ {
845
+ for (size_t row = 0; row < nrows; row++)
846
+ buffer_cnt[x[row]]++;
847
+ }
848
+
849
+ else
850
+ {
851
+ for (size_t row = 0; row < nrows; row++)
852
+ {
853
+ if (likely(x[row] >= 0))
854
+ buffer_cnt[x[row]]++;
855
+ else
856
+ buffer_cnt[ncat]++;
857
+ }
858
+ }
859
+
860
+ return calc_kurtosis_internal<ldouble_safe>(
861
+ cnt, x, ncat, buffer_cnt, buffer_prob,
862
+ missing_action, cat_split_type, rnd_generator);
863
+ }
864
+
865
+
866
+ /* TODO: this one should get a buffer preallocated from outside */
867
+ template <class mapping, class ldouble_safe>
868
+ double calc_kurtosis_weighted_internal(std::vector<ldouble_safe> &buffer_cnt, int x[], int ncat,
869
+ double buffer_prob[], MissingAction missing_action, CategSplit cat_split_type,
870
+ RNG_engine &rnd_generator, mapping &restrict w)
871
+ {
872
+ double sum_kurt = 0;
873
+
874
+ ldouble_safe cnt = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
875
+
876
+ cnt -= buffer_cnt[ncat];
877
+ if (unlikely(cnt <= 1)) return -HUGE_VAL;
878
+ for (int cat = 0; cat < ncat; cat++)
879
+ buffer_prob[cat] = buffer_cnt[cat] / cnt;
880
+
881
+ switch (cat_split_type)
882
+ {
883
+ case SubSet:
884
+ {
885
+ ldouble_safe temp_v;
886
+ ldouble_safe s1, s2, s3, s4;
887
+ ldouble_safe coef, coef2;
888
+ ldouble_safe w_this;
889
+ UniformUnitInterval runif(0, 1);
890
+ size_t ntry = 50;
891
+ for (size_t iternum = 0; iternum < 50; iternum++)
892
+ {
893
+ s1 = 0; s2 = 0; s3 = 0; s4 = 0;
894
+ for (int cat = 0; cat < ncat; cat++)
895
+ {
896
+ coef = runif(rnd_generator);
897
+ coef2 = coef * coef;
898
+ w_this = buffer_prob[cat];
899
+ s1 = std::fma(w_this, coef, s1);
900
+ s2 = std::fma(w_this, coef2, s2);
901
+ s3 = std::fma(w_this, coef2*coef, s3);
902
+ s4 = std::fma(w_this, coef2*coef2, s4);
903
+ // s1 += buffer_prob[cat] * pw1(coef);
904
+ // s2 += buffer_prob[cat] * pw2(coef);
905
+ // s3 += buffer_prob[cat] * pw3(coef);
906
+ // s4 += buffer_prob[cat] * pw4(coef);
907
+ }
908
+ temp_v = s2 - pw2(s1);
909
+ if (unlikely(temp_v <= 0))
910
+ ntry--;
911
+ else
912
+ sum_kurt += (s4 - 4 * s3 * pw1(s1) + 6 * s2 * pw2(s1) - 4 * s1 * pw3(s1) + pw4(s1)) / pw2(temp_v);
913
+ }
914
+ if (unlikely(!ntry))
915
+ return -HUGE_VAL;
916
+ else if (unlikely(is_na_or_inf(sum_kurt)))
917
+ return -HUGE_VAL;
918
+ else
919
+ return std::fmax(sum_kurt, 0.) / (double)ntry;
920
+ }
921
+
922
+ case SingleCateg:
923
+ {
924
+ double p;
925
+ int ncat_present = ncat;
926
+ for (int cat = 0; cat < ncat; cat++)
927
+ {
928
+ p = buffer_prob[cat];
929
+ if (p == 0)
930
+ ncat_present--;
931
+ else
932
+ sum_kurt += (p - 4 * p * pw1(p) + 6 * p * pw2(p) - 4 * p * pw3(p) + pw4(p)) / pw2(p - pw2(p));
933
+ }
934
+ if (ncat_present <= 1)
935
+ return -HUGE_VAL;
936
+ else if (unlikely(is_na_or_inf(sum_kurt)))
937
+ return -HUGE_VAL;
938
+ else
939
+ return std::fmax(sum_kurt, 0.) / (double)ncat_present;
940
+ }
941
+ }
942
+
943
+ return -1; /* this will never be reached, but CRAN complains otherwise */
944
+ }
945
+
946
+ template <class mapping, class ldouble_safe>
947
+ double calc_kurtosis_weighted(size_t ix_arr[], size_t st, size_t end, int x[], int ncat, double buffer_prob[],
948
+ MissingAction missing_action, CategSplit cat_split_type, RNG_engine &rnd_generator,
949
+ mapping &restrict w)
950
+ {
951
+ std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
952
+ ldouble_safe w_this;
953
+
954
+ for (size_t row = st; row <= end; row++)
955
+ {
956
+ w_this = w[ix_arr[row]];
957
+ if (likely(x[ix_arr[row]] >= 0))
958
+ buffer_cnt[x[ix_arr[row]]] += w_this;
959
+ else
960
+ buffer_cnt[ncat] += w_this;
961
+ }
962
+
963
+ return calc_kurtosis_weighted_internal<mapping, ldouble_safe>(
964
+ buffer_cnt, x, ncat,
965
+ buffer_prob, missing_action, cat_split_type,
966
+ rnd_generator, w);
967
+ }
968
+
969
+ template <class real_t, class ldouble_safe>
970
+ double calc_kurtosis_weighted(size_t nrows, int x[], int ncat, double *restrict buffer_prob,
971
+ MissingAction missing_action, CategSplit cat_split_type,
972
+ RNG_engine &rnd_generator, real_t *restrict w)
973
+ {
974
+ std::vector<ldouble_safe> buffer_cnt(ncat+1, 0.);
975
+ ldouble_safe w_this;
976
+
977
+ for (size_t row = 0; row < nrows; row++)
978
+ {
979
+ w_this = w[row];
980
+ if (likely(x[row] >= 0))
981
+ buffer_cnt[x[row]] += w_this;
982
+ else
983
+ buffer_cnt[ncat] += w_this;
984
+ }
985
+
986
+ return calc_kurtosis_weighted_internal<real_t *restrict, ldouble_safe>(
987
+ buffer_cnt, x, ncat,
988
+ buffer_prob, missing_action, cat_split_type,
989
+ rnd_generator, w);
990
+ }
991
+
992
+ template <class int_t, class ldouble_safe>
993
+ double expected_sd_cat(double p[], size_t n, int_t pos[])
994
+ {
995
+ if (n <= 1) return 0;
996
+
997
+ ldouble_safe cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
998
+ for (size_t cat1 = 2; cat1 < n; cat1++)
999
+ {
1000
+ cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
1001
+ for (size_t cat2 = 0; cat2 < cat1; cat2++)
1002
+ cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
1003
+ }
1004
+ return std::sqrt(std::fmax(cum_var, (ldouble_safe)0));
1005
+ }
1006
+
1007
+ template <class number, class int_t, class ldouble_safe>
1008
+ double expected_sd_cat(number *restrict counts, double *restrict p, size_t n, int_t *restrict pos)
1009
+ {
1010
+ if (n <= 1) return 0;
1011
+
1012
+ number tot = std::accumulate(pos, pos + n, (number)0, [&counts](number tot, const size_t ix){return tot + counts[ix];});
1013
+ ldouble_safe cnt_div = (ldouble_safe) tot;
1014
+ for (size_t cat = 0; cat < n; cat++)
1015
+ p[pos[cat]] = (ldouble_safe)counts[pos[cat]] / cnt_div;
1016
+
1017
+ return expected_sd_cat<int_t, ldouble_safe>(p, n, pos);
1018
+ }
1019
+
1020
+ template <class number, class int_t, class ldouble_safe>
1021
+ double expected_sd_cat_single(number *restrict counts, double *restrict p, size_t n, int_t *restrict pos, size_t cat_exclude, number cnt)
1022
+ {
1023
+ if (cat_exclude == 0)
1024
+ return expected_sd_cat<number, int_t, ldouble_safe>(counts, p, n-1, pos + 1);
1025
+
1026
+ else if (cat_exclude == (n-1))
1027
+ return expected_sd_cat<number, int_t, ldouble_safe>(counts, p, n-1, pos);
1028
+
1029
+ size_t ix_exclude = pos[cat_exclude];
1030
+
1031
+ ldouble_safe cnt_div = (ldouble_safe) (cnt - counts[ix_exclude]);
1032
+ for (size_t cat = 0; cat < n; cat++)
1033
+ p[pos[cat]] = (ldouble_safe)counts[pos[cat]] / cnt_div;
1034
+
1035
+ ldouble_safe cum_var;
1036
+ if (cat_exclude != 1)
1037
+ cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[1]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[1]]) / 3.0 + p[pos[1]] / 3.0;
1038
+ else
1039
+ cum_var = -square(p[pos[0]]) / 3.0 - p[pos[0]] * p[pos[2]] / 2.0 + p[pos[0]] / 3.0 - square(p[pos[2]]) / 3.0 + p[pos[2]] / 3.0;
1040
+ for (size_t cat1 = (cat_exclude == 1)? 3 : 2; cat1 < n; cat1++)
1041
+ {
1042
+ if (pos[cat1] == ix_exclude) continue;
1043
+ cum_var += p[pos[cat1]] / 3.0 - square(p[pos[cat1]]) / 3.0;
1044
+ for (size_t cat2 = 0; cat2 < cat1; cat2++)
1045
+ {
1046
+ if (pos[cat2] == ix_exclude) continue;
1047
+ cum_var -= p[pos[cat1]] * p[pos[cat2]] / 2.0;
1048
+ }
1049
+
1050
+ }
1051
+ return std::sqrt(std::fmax(cum_var, (ldouble_safe)0));
1052
+ }
1053
+
1054
+ template <class number, class int_t, class ldouble_safe>
1055
+ double expected_sd_cat_internal(int ncat, number *restrict buffer_cnt, ldouble_safe cnt_l,
1056
+ int_t *restrict buffer_pos, double *restrict buffer_prob)
1057
+ {
1058
+ /* move zero-valued to the beginning */
1059
+ std::iota(buffer_pos, buffer_pos + ncat, (int_t)0);
1060
+ int_t st_pos = 0;
1061
+ int ncat_present = 0;
1062
+ int_t temp;
1063
+ for (int cat = 0; cat < ncat; cat++)
1064
+ {
1065
+ if (buffer_cnt[cat])
1066
+ {
1067
+ ncat_present++;
1068
+ buffer_prob[cat] = (ldouble_safe) buffer_cnt[cat] / cnt_l;
1069
+ }
1070
+
1071
+ else
1072
+ {
1073
+ temp = buffer_pos[st_pos];
1074
+ buffer_pos[st_pos] = buffer_pos[cat];
1075
+ buffer_pos[cat] = temp;
1076
+ st_pos++;
1077
+ }
1078
+ }
1079
+
1080
+ if (ncat_present <= 1) return 0;
1081
+ return expected_sd_cat<int_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
1082
+ }
1083
+
1084
+
1085
+ template <class int_t, class ldouble_safe>
1086
+ double expected_sd_cat(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat,
1087
+ MissingAction missing_action,
1088
+ size_t *restrict buffer_cnt, int_t *restrict buffer_pos, double buffer_prob[])
1089
+ {
1090
+ /* generate counts */
1091
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, (size_t)0);
1092
+ size_t cnt = end - st + 1;
1093
+
1094
+ if (missing_action != Fail)
1095
+ {
1096
+ int xval;
1097
+ for (size_t row = st; row <= end; row++)
1098
+ {
1099
+ xval = x[ix_arr[row]];
1100
+ if (unlikely(xval < 0))
1101
+ buffer_cnt[ncat]++;
1102
+ else
1103
+ buffer_cnt[xval]++;
1104
+ }
1105
+ cnt -= buffer_cnt[ncat];
1106
+ if (cnt == 0) return 0;
1107
+ }
1108
+
1109
+ else
1110
+ {
1111
+ for (size_t row = st; row <= end; row++)
1112
+ {
1113
+ if (likely(x[ix_arr[row]] >= 0)) buffer_cnt[x[ix_arr[row]]]++;
1114
+ }
1115
+ }
1116
+
1117
+ return expected_sd_cat_internal<size_t, int_t, ldouble_safe>(ncat, buffer_cnt, cnt, buffer_pos, buffer_prob);
1118
+ }
1119
+
1120
+ template <class mapping, class int_t, class ldouble_safe>
1121
+ double expected_sd_cat_weighted(size_t *restrict ix_arr, size_t st, size_t end, int x[], int ncat,
1122
+ MissingAction missing_action, mapping &restrict w,
1123
+ double *restrict buffer_cnt, int_t *restrict buffer_pos, double *restrict buffer_prob)
1124
+ {
1125
+ /* generate counts */
1126
+ std::fill(buffer_cnt, buffer_cnt + ncat + 1, 0.);
1127
+ ldouble_safe cnt = 0;
1128
+
1129
+ if (missing_action != Fail)
1130
+ {
1131
+ int xval;
1132
+ double w_this;
1133
+ for (size_t row = st; row <= end; row++)
1134
+ {
1135
+ xval = x[ix_arr[row]];
1136
+ w_this = w[ix_arr[row]];
1137
+
1138
+ if (unlikely(xval < 0)) {
1139
+ buffer_cnt[ncat] += w_this;
1140
+ }
1141
+ else {
1142
+ buffer_cnt[xval] += w_this;
1143
+ cnt += w_this;
1144
+ }
1145
+ }
1146
+ if (cnt == 0) return 0;
1147
+ }
1148
+
1149
+ else
1150
+ {
1151
+ for (size_t row = st; row <= end; row++)
1152
+ {
1153
+ if (likely(x[ix_arr[row]] >= 0))
1154
+ {
1155
+ buffer_cnt[x[ix_arr[row]]] += w[ix_arr[row]];
1156
+ }
1157
+ }
1158
+ for (int cat = 0; cat < ncat; cat++)
1159
+ cnt += buffer_cnt[cat];
1160
+ if (unlikely(cnt == 0)) return 0;
1161
+ }
1162
+
1163
+ return expected_sd_cat_internal<double, int_t, ldouble_safe>(ncat, buffer_cnt, cnt, buffer_pos, buffer_prob);
1164
+ }
1165
+
1166
+ /* Note: this isn't exactly comparable to the pooled gain from numeric variables,
1167
+ but among all the possible options, this is what happens to end up in the most
1168
+ similar scale when considering standardized gain. */
1169
+ template <class number, class ldouble_safe>
1170
+ double categ_gain(number cnt_left, number cnt_right,
1171
+ ldouble_safe s_left, ldouble_safe s_right,
1172
+ ldouble_safe base_info, ldouble_safe cnt)
1173
+ {
1174
+ return (
1175
+ base_info -
1176
+ (((cnt_left <= 1)? 0 : ((ldouble_safe)cnt_left * std::log((ldouble_safe)cnt_left))) - s_left) -
1177
+ (((cnt_right <= 1)? 0 : ((ldouble_safe)cnt_right * std::log((ldouble_safe)cnt_right))) - s_right)
1178
+ ) / cnt;
1179
+ }
1180
+
1181
+
1182
+ /* A couple notes about gain calculation:
1183
+
1184
+ Here one wants to find the best split point, maximizing either:
1185
+ (1/sigma) * (sigma - (1/n)*(n_left*sigma_left + n_right*sigma_right))
1186
+ or:
1187
+ (1/sigma) * (sigma - (1/2)*(sigma_left + sigma_right))
1188
+
1189
+ All the algorithms here use the sorted-indices approach, which is
1190
+ an exact method (note that there's still room for optimization by adding the
1191
+ unsorted approach for small sample sizes and for sparse matrices).
1192
+
1193
+ A naive approach would move observations one at a time from right
1194
+ to left using this formula:
1195
+ sigma = (ssq - s^2/n) / n
1196
+ ssq = sum(x^2)
1197
+ s = sum(x)
1198
+ But such approach has poor numerical precision, and this library is
1199
+ aimed precisely at cases in which there are outliers in the data.
1200
+ It's possible to improve the numerical precision by standardizing the
1201
+ data beforehand, but this library uses instead a more exact two-pass
1202
+ sigma calculation observation-by-observation (from left to right and
1203
+ from right to left, keeping the calculations of the first pass in an
1204
+ array and calculating gain in the second pass), but there's
1205
+ other methods too.
1206
+
1207
+ If one is aiming at maximizing the pooled gain, it's possible to
1208
+ simplify either the gain or the increase in gain without involving
1209
+ 'ssq'. Assuming one already has 'ssq' and 's' calculated for the left and
1210
+ right partitions, and one wants to move one ovservation from right to left,
1211
+ the following hold:
1212
+ s_right = s - s_left
1213
+ ssq_right = ssq - ssq_left
1214
+ n_right = n - n_left
1215
+ If one then moves observation x, these are updated as follows:
1216
+ s_left_new = s_left + x
1217
+ s_right_new = s - s_left - x
1218
+ ssq_left_new = ssq_left + x^2
1219
+ ssq_right_new = ssq - ssq_left - x^2
1220
+ n_left_new = n_left + 1
1221
+ n_right_new = n - n_left - 1
1222
+ Gain is then:
1223
+ (1/sigma) * (sigma - (1/n)*({ssq_left_new - s_left_new^2/n_left_new} + {ssq_right_new - s_right_new^2/n_right_new}))
1224
+ Which simplifies to:
1225
+ 1 - (1/(sigma*n))(ssq - ( (s_left + x)^2/(n_left+1) + (s - (s_left + x))^2/(n - (n_left+1)) ))
1226
+ Since 'sigma', n', and 'ssq' are constant, they can be ignored when determining the
1227
+ maximum gain - that is, one is interest in finding the point that maximizes:
1228
+ (s_left+x)^2/(n_left+1) + (s-(s_left+x))^2/(n-(n_left+1))
1229
+ And this calculation will be robust-enough when dealing with numbers that were
1230
+ already standardized beforehand, as the extended model does at each step.
1231
+ Note however that, when fitting this model, one is usually interested in evaluating
1232
+ the actual gain, standardized by the standard deviation, as it will try different
1233
+ linear combinations which will give different standard deviations, so this simpler
1234
+ formula cannot be applied unless only one linear combination is probed.
1235
+
1236
+ One can also look at:
1237
+ diff_gain = (1/sigma) * (gain_new - gain)
1238
+ Which can be simplified to something that doesn't include sums of squares:
1239
+ (1/(sigma*n))*( -s_left^2/n_left - (s-s_left)^2/(n-n_left) + (s_left+x)^2/(n_left+1) + (s-(s_left+x))^2/(n-(n_left+1)) )
1240
+ And this calculation would in theory allow getting the actual standardized gain.
1241
+ In practice however, this calculation can have poor numerical precision when the
1242
+ sample size is large, so the functions here do not even attempt at calculating it,
1243
+ and this is the reason why the two-pass approach is preferred.
1244
+
1245
+ The averaged SD formula unfortunately doesn't reduce to something that would involve
1246
+ only sums.
1247
+ */
1248
+
1249
+ /* TODO: maybe it's not a good idea to use the two-pass approach with un-standardized
1250
+ variables at large sample sizes (ndim=1), considering that they come in sorted order.
1251
+ Maybe it should instead use sums of centered squares: sigma = sqrt((x-mean(x))^2/n)
1252
+ The sums of centered squares method is also likely to be more precise. */
1253
+
1254
+ template <class real_t>
1255
+ double midpoint(real_t x, real_t y)
1256
+ {
1257
+ real_t m = x + (y-x)/(real_t)2;
1258
+ if (likely((double)m < (double)y))
1259
+ return m;
1260
+ else {
1261
+ m = std::nextafter(m, y);
1262
+ if (m > x && m < y)
1263
+ return m;
1264
+ else
1265
+ return x;
1266
+ }
1267
+ }
1268
+
1269
+ template <class real_t>
1270
+ double midpoint_with_reorder(real_t x, real_t y)
1271
+ {
1272
+ if (x < y)
1273
+ return midpoint(x, y);
1274
+ else
1275
+ return midpoint(y, x);
1276
+ }
1277
+
1278
+ #define sd_gain(sd, sd_left, sd_right) (1. - ((sd_left) + (sd_right)) / (2. * (sd)))
1279
+ #define pooled_gain(sd, cnt, sd_left, sd_right, cnt_left, cnt_right) \
1280
+ (1. - (1./(sd))*( ( ((real_t)(cnt_left))/(cnt) )*(sd_left) + ( ((real_t)(cnt_right)/(cnt)) )*(sd_right) ))
1281
+
1282
+
1283
+ /* TODO: make this a compensated sum */
1284
+ template <class real_t, class real_t_>
1285
+ double find_split_rel_gain_t(real_t_ *restrict x, size_t n, double &restrict split_point)
1286
+ {
1287
+ real_t this_gain;
1288
+ real_t best_gain = -HUGE_VAL;
1289
+ real_t x1 = 0, x2 = 0;
1290
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0;
1291
+ for (size_t row = 0; row < n; row++)
1292
+ sum_tot += x[row];
1293
+ for (size_t row = 0; row < n-1; row++)
1294
+ {
1295
+ sum_left += x[row];
1296
+ if (x[row] == x[row+1])
1297
+ continue;
1298
+
1299
+ sum_right = sum_tot - sum_left;
1300
+ this_gain = sum_left * (sum_left / (real_t)(row+1))
1301
+ + sum_right * (sum_right / (real_t)(n-row-1));
1302
+ if (this_gain > best_gain)
1303
+ {
1304
+ best_gain = this_gain;
1305
+ x1 = x[row]; x2 = x[row+1];
1306
+ }
1307
+ }
1308
+
1309
+ if (best_gain <= -HUGE_VAL)
1310
+ return best_gain;
1311
+ split_point = midpoint(x1, x2);
1312
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1313
+ }
1314
+
1315
+ template <class real_t_, class ldouble_safe>
1316
+ double find_split_rel_gain(real_t_ *restrict x, size_t n, double &restrict split_point)
1317
+ {
1318
+ if (n < THRESHOLD_LONG_DOUBLE)
1319
+ return find_split_rel_gain_t<double, real_t_>(x, n, split_point);
1320
+ else
1321
+ return find_split_rel_gain_t<ldouble_safe, real_t_>(x, n, split_point);
1322
+ }
1323
+
1324
+ /* Note: there is no 'weighted' version of 'find_split_rel_gain' with unindexed 'x', because calling it would
1325
+ imply having to argsort the 'x' values in order to sort the weights, which is less efficient. */
1326
+
1327
+ template <class real_t, class real_t_>
1328
+ double find_split_rel_gain_t(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &split_point, size_t &restrict split_ix)
1329
+ {
1330
+ real_t this_gain;
1331
+ real_t best_gain = -HUGE_VAL;
1332
+ split_ix = 0; /* <- avoid out-of-bounds at the end */
1333
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0;
1334
+ for (size_t row = st; row <= end; row++)
1335
+ sum_tot += x[ix_arr[row]] - xmean;
1336
+ for (size_t row = st; row < end; row++)
1337
+ {
1338
+ sum_left += x[ix_arr[row]] - xmean;
1339
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1340
+ continue;
1341
+
1342
+ sum_right = sum_tot - sum_left;
1343
+ this_gain = sum_left * (sum_left / (real_t)(row - st + 1))
1344
+ + sum_right * (sum_right / (real_t)(end - row));
1345
+ if (this_gain > best_gain)
1346
+ {
1347
+ best_gain = this_gain;
1348
+ split_ix = row;
1349
+ }
1350
+ }
1351
+
1352
+ if (best_gain <= -HUGE_VAL)
1353
+ return best_gain;
1354
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1355
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1356
+ }
1357
+
1358
+ template <class real_t_, class ldouble_safe>
1359
+ double find_split_rel_gain(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &restrict split_point, size_t &restrict split_ix)
1360
+ {
1361
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1362
+ return find_split_rel_gain_t<double, real_t_>(x, xmean, ix_arr, st, end, split_point, split_ix);
1363
+ else
1364
+ return find_split_rel_gain_t<ldouble_safe, real_t_>(x, xmean, ix_arr, st, end, split_point, split_ix);
1365
+ }
1366
+
1367
+ template <class real_t, class real_t_, class mapping>
1368
+ double find_split_rel_gain_weighted_t(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &split_point, size_t &restrict split_ix, mapping &restrict w)
1369
+ {
1370
+ real_t this_gain;
1371
+ real_t best_gain = -HUGE_VAL;
1372
+ split_ix = 0; /* <- avoid out-of-bounds at the end */
1373
+ real_t sum_left = 0, sum_right = 0, sum_tot = 0, sumw = 0, sumw_tot = 0;
1374
+
1375
+ for (size_t row = st; row <= end; row++)
1376
+ sumw_tot += w[ix_arr[row]];
1377
+ for (size_t row = st; row <= end; row++)
1378
+ sum_tot += x[ix_arr[row]] - xmean;
1379
+ for (size_t row = st; row < end; row++)
1380
+ {
1381
+ sumw += w[ix_arr[row]];
1382
+ sum_left += x[ix_arr[row]] - xmean;
1383
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1384
+ continue;
1385
+
1386
+ sum_right = sum_tot - sum_left;
1387
+ this_gain = sum_left * (sum_left / sumw)
1388
+ + sum_right * (sum_right / (sumw_tot - sumw));
1389
+ if (this_gain > best_gain)
1390
+ {
1391
+ best_gain = this_gain;
1392
+ split_ix = row;
1393
+ }
1394
+ }
1395
+
1396
+ if (best_gain <= -HUGE_VAL)
1397
+ return best_gain;
1398
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1399
+ return std::fmax((double)best_gain, std::numeric_limits<double>::epsilon());
1400
+ }
1401
+
1402
+ template <class real_t_, class mapping, class ldouble_safe>
1403
+ double find_split_rel_gain_weighted(real_t_ *restrict x, real_t_ xmean, size_t *restrict ix_arr, size_t st, size_t end, double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
1404
+ {
1405
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1406
+ return find_split_rel_gain_weighted_t<double, real_t_, mapping>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
1407
+ else
1408
+ return find_split_rel_gain_weighted_t<ldouble_safe, real_t_, mapping>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
1409
+ }
1410
+
1411
+
1412
+ template <class real_t, class real_t_>
1413
+ real_t calc_sd_right_to_left(real_t_ *restrict x, size_t n, double *restrict sd_arr)
1414
+ {
1415
+ real_t running_mean = 0;
1416
+ real_t running_ssq = 0;
1417
+ real_t mean_prev = x[n-1];
1418
+ for (size_t row = 0; row < n-1; row++)
1419
+ {
1420
+ running_mean += (x[n-row-1] - running_mean) / (real_t)(row+1);
1421
+ running_ssq += (x[n-row-1] - running_mean) * (x[n-row-1] - mean_prev);
1422
+ mean_prev = running_mean;
1423
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1424
+ }
1425
+ running_mean += (x[0] - running_mean) / (real_t)n;
1426
+ running_ssq += (x[0] - running_mean) * (x[0] - mean_prev);
1427
+ return std::sqrt(running_ssq / (real_t)n);
1428
+ }
1429
+
1430
+ template <class real_t_, class ldouble_safe>
1431
+ ldouble_safe calc_sd_right_to_left_weighted(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1432
+ double *restrict w, ldouble_safe &cumw, size_t *restrict sorted_ix)
1433
+ {
1434
+ ldouble_safe running_mean = 0;
1435
+ ldouble_safe running_ssq = 0;
1436
+ ldouble_safe mean_prev = x[sorted_ix[n-1]];
1437
+ ldouble_safe cnt = 0;
1438
+ double w_this;
1439
+ for (size_t row = 0; row < n-1; row++)
1440
+ {
1441
+ w_this = w[sorted_ix[n-row-1]];
1442
+ cnt += w_this;
1443
+ running_mean += w_this * (x[sorted_ix[n-row-1]] - running_mean) / cnt;
1444
+ running_ssq += w_this * ((x[sorted_ix[n-row-1]] - running_mean) * (x[sorted_ix[n-row-1]] - mean_prev));
1445
+ mean_prev = running_mean;
1446
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / cnt);
1447
+ }
1448
+ w_this = w[sorted_ix[0]];
1449
+ cnt += w_this;
1450
+ running_mean += (x[sorted_ix[0]] - running_mean) / cnt;
1451
+ running_ssq += w_this * ((x[sorted_ix[0]] - running_mean) * (x[sorted_ix[0]] - mean_prev));
1452
+ cumw = cnt;
1453
+ return std::sqrt(running_ssq / cnt);
1454
+ }
1455
+
1456
+ template <class real_t, class real_t_>
1457
+ real_t calc_sd_right_to_left(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr)
1458
+ {
1459
+ real_t running_mean = 0;
1460
+ real_t running_ssq = 0;
1461
+ real_t mean_prev = x[ix_arr[end]] - xmean;
1462
+ size_t n = end - st + 1;
1463
+ for (size_t row = 0; row < n-1; row++)
1464
+ {
1465
+ running_mean += ((x[ix_arr[end-row]] - xmean) - running_mean) / (real_t)(row+1);
1466
+ running_ssq += ((x[ix_arr[end-row]] - xmean) - running_mean) * ((x[ix_arr[end-row]] - xmean) - mean_prev);
1467
+ mean_prev = running_mean;
1468
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1469
+ }
1470
+ running_mean += ((x[ix_arr[st]] - xmean) - running_mean) / (real_t)n;
1471
+ running_ssq += ((x[ix_arr[st]] - xmean) - running_mean) * ((x[ix_arr[st]] - xmean) - mean_prev);
1472
+ return std::sqrt(running_ssq / (real_t)n);
1473
+ }
1474
+
1475
+ template <class real_t_, class mapping, class ldouble_safe>
1476
+ ldouble_safe calc_sd_right_to_left_weighted(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end,
1477
+ double *restrict sd_arr, mapping &restrict w, ldouble_safe &cumw)
1478
+ {
1479
+ ldouble_safe running_mean = 0;
1480
+ ldouble_safe running_ssq = 0;
1481
+ real_t_ mean_prev = x[ix_arr[end]] - xmean;
1482
+ size_t n = end - st + 1;
1483
+ ldouble_safe cnt = 0;
1484
+ double w_this;
1485
+ for (size_t row = 0; row < n-1; row++)
1486
+ {
1487
+ w_this = w[ix_arr[end-row]];
1488
+ cnt += w_this;
1489
+ running_mean += w_this * ((x[ix_arr[end-row]] - xmean) - running_mean) / cnt;
1490
+ running_ssq += w_this * (((x[ix_arr[end-row]] - xmean) - running_mean) * ((x[ix_arr[end-row]] - xmean) - mean_prev));
1491
+ mean_prev = running_mean;
1492
+ sd_arr[n-row-1] = (row == 0)? 0. : std::sqrt(running_ssq / cnt);
1493
+ }
1494
+ w_this = w[ix_arr[st]];
1495
+ cnt += w_this;
1496
+ running_mean += ((x[ix_arr[st]] - xmean) - running_mean) / cnt;
1497
+ running_ssq += w_this * (((x[ix_arr[st]] - xmean) - running_mean) * ((x[ix_arr[st]] - xmean) - mean_prev));
1498
+ cumw = cnt;
1499
+ return std::sqrt(running_ssq / cnt);
1500
+ }
1501
+
1502
+ template <class real_t, class real_t_>
1503
+ double find_split_std_gain_t(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1504
+ GainCriterion criterion, double min_gain, double &restrict split_point)
1505
+ {
1506
+ real_t full_sd = calc_sd_right_to_left<real_t>(x, n, sd_arr);
1507
+ real_t running_mean = 0;
1508
+ real_t running_ssq = 0;
1509
+ real_t mean_prev = x[0];
1510
+ real_t best_gain = -HUGE_VAL;
1511
+ real_t this_sd, this_gain;
1512
+ real_t n_ = (real_t)n;
1513
+ size_t best_ix = 0;
1514
+ for (size_t row = 0; row < n-1; row++)
1515
+ {
1516
+ running_mean += (x[row] - running_mean) / (real_t)(row+1);
1517
+ running_ssq += (x[row] - running_mean) * (x[row] - mean_prev);
1518
+ mean_prev = running_mean;
1519
+ if (x[row] == x[row+1])
1520
+ continue;
1521
+
1522
+ this_sd = (row == 0)? 0. : std::sqrt(running_ssq / (real_t)(row+1));
1523
+ this_gain = (criterion == Pooled)?
1524
+ pooled_gain(full_sd, n_, this_sd, sd_arr[row+1], row+1, n-row-1)
1525
+ :
1526
+ sd_gain(full_sd, this_sd, sd_arr[row+1]);
1527
+ if (this_gain > best_gain && this_gain > min_gain)
1528
+ {
1529
+ best_gain = this_gain;
1530
+ best_ix = row;
1531
+ }
1532
+ }
1533
+
1534
+ if (best_gain > -HUGE_VAL)
1535
+ split_point = midpoint(x[best_ix], x[best_ix+1]);
1536
+
1537
+ return best_gain;
1538
+ }
1539
+
1540
+ template <class real_t_, class ldouble_safe>
1541
+ double find_split_std_gain(real_t_ *restrict x, size_t n, double *restrict sd_arr,
1542
+ GainCriterion criterion, double min_gain, double &restrict split_point)
1543
+ {
1544
+ if (n < THRESHOLD_LONG_DOUBLE)
1545
+ return find_split_std_gain_t<double, real_t_>(x, n, sd_arr, criterion, min_gain, split_point);
1546
+ else
1547
+ return find_split_std_gain_t<ldouble_safe, real_t_>(x, n, sd_arr, criterion, min_gain, split_point);
1548
+ }
1549
+
1550
+ template <class real_t, class ldouble_safe>
1551
+ double find_split_std_gain_weighted(real_t *restrict x, size_t n, double *restrict sd_arr,
1552
+ GainCriterion criterion, double min_gain, double &restrict split_point,
1553
+ double *restrict w, size_t *restrict sorted_ix)
1554
+ {
1555
+ ldouble_safe cumw;
1556
+ double full_sd = calc_sd_right_to_left_weighted(x, n, sd_arr, w, cumw, sorted_ix);
1557
+ ldouble_safe running_mean = 0;
1558
+ ldouble_safe running_ssq = 0;
1559
+ ldouble_safe mean_prev = x[sorted_ix[0]];
1560
+ double best_gain = -HUGE_VAL;
1561
+ double this_sd, this_gain;
1562
+ double w_this;
1563
+ ldouble_safe currw = 0;
1564
+ size_t best_ix = 0;
1565
+
1566
+ for (size_t row = 0; row < n-1; row++)
1567
+ {
1568
+ w_this = w[sorted_ix[row]];
1569
+ currw += w_this;
1570
+ running_mean += w_this * (x[sorted_ix[row]] - running_mean) / currw;
1571
+ running_ssq += w_this * ((x[sorted_ix[row]] - running_mean) * (x[sorted_ix[row]] - mean_prev));
1572
+ mean_prev = running_mean;
1573
+ if (x[sorted_ix[row]] == x[sorted_ix[row+1]])
1574
+ continue;
1575
+
1576
+ this_sd = (row == 0)? 0. : std::sqrt(running_ssq / currw);
1577
+ this_gain = (criterion == Pooled)?
1578
+ pooled_gain(full_sd, cumw, this_sd, sd_arr[row+1], currw, cumw-currw)
1579
+ :
1580
+ sd_gain(full_sd, this_sd, sd_arr[row+1]);
1581
+ if (this_gain > best_gain && this_gain > min_gain)
1582
+ {
1583
+ best_gain = this_gain;
1584
+ best_ix = row;
1585
+ }
1586
+ }
1587
+
1588
+ if (best_gain > -HUGE_VAL)
1589
+ split_point = midpoint(x[sorted_ix[best_ix]], x[sorted_ix[best_ix+1]]);
1590
+
1591
+ return best_gain;
1592
+ }
1593
+
1594
+ template <class real_t, class real_t_>
1595
+ double find_split_std_gain_t(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1596
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix)
1597
+ {
1598
+ real_t full_sd = calc_sd_right_to_left<real_t>(x, xmean, ix_arr, st, end, sd_arr);
1599
+ real_t running_mean = 0;
1600
+ real_t running_ssq = 0;
1601
+ real_t mean_prev = x[ix_arr[st]] - xmean;
1602
+ real_t best_gain = -HUGE_VAL;
1603
+ real_t n = (real_t)(end - st + 1);
1604
+ real_t this_sd, this_gain;
1605
+ split_ix = st;
1606
+ for (size_t row = st; row < end; row++)
1607
+ {
1608
+ running_mean += ((x[ix_arr[row]] - xmean) - running_mean) / (real_t)(row-st+1);
1609
+ running_ssq += ((x[ix_arr[row]] - xmean) - running_mean) * ((x[ix_arr[row]] - xmean) - mean_prev);
1610
+ mean_prev = running_mean;
1611
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1612
+ continue;
1613
+
1614
+ this_sd = (row == st)? 0. : std::sqrt(running_ssq / (real_t)(row-st+1));
1615
+ this_gain = (criterion == Pooled)?
1616
+ pooled_gain(full_sd, n, this_sd, sd_arr[row-st+1], row-st+1, end-row)
1617
+ :
1618
+ sd_gain(full_sd, this_sd, sd_arr[row-st+1]);
1619
+ if (this_gain > best_gain && this_gain > min_gain)
1620
+ {
1621
+ best_gain = this_gain;
1622
+ split_ix = row;
1623
+ }
1624
+ }
1625
+
1626
+ if (best_gain > -HUGE_VAL)
1627
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1628
+
1629
+ return best_gain;
1630
+ }
1631
+
1632
+ template <class real_t_, class ldouble_safe>
1633
+ double find_split_std_gain(real_t_ *restrict x, real_t_ xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1634
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix)
1635
+ {
1636
+ if ((end-st+1) < THRESHOLD_LONG_DOUBLE)
1637
+ return find_split_std_gain_t<double, real_t_>(x, xmean, ix_arr, st, end, sd_arr, criterion, min_gain, split_point, split_ix);
1638
+ else
1639
+ return find_split_std_gain_t<ldouble_safe, real_t_>(x, xmean, ix_arr, st, end, sd_arr, criterion, min_gain, split_point, split_ix);
1640
+ }
1641
+
1642
+ template <class real_t, class mapping, class ldouble_safe>
1643
+ double find_split_std_gain_weighted(real_t *restrict x, real_t xmean, size_t ix_arr[], size_t st, size_t end, double *restrict sd_arr,
1644
+ GainCriterion criterion, double min_gain, double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
1645
+ {
1646
+ ldouble_safe cumw;
1647
+ double full_sd = calc_sd_right_to_left_weighted(x, xmean, ix_arr, st, end, sd_arr, w, cumw);
1648
+ ldouble_safe running_mean = 0;
1649
+ ldouble_safe running_ssq = 0;
1650
+ ldouble_safe mean_prev = x[ix_arr[st]] - xmean;
1651
+ double best_gain = -HUGE_VAL;
1652
+ ldouble_safe currw = 0;
1653
+ double this_sd, this_gain;
1654
+ double w_this;
1655
+ split_ix = st;
1656
+
1657
+ for (size_t row = st; row < end; row++)
1658
+ {
1659
+ w_this = w[ix_arr[row]];
1660
+ currw += w_this;
1661
+ running_mean += w_this * ((x[ix_arr[row]] - xmean) - running_mean) / currw;
1662
+ running_ssq += w_this * (((x[ix_arr[row]] - xmean) - running_mean) * ((x[ix_arr[row]] - xmean) - mean_prev));
1663
+ mean_prev = running_mean;
1664
+ if (x[ix_arr[row]] == x[ix_arr[row+1]])
1665
+ continue;
1666
+
1667
+ this_sd = (row == st)? 0. : std::sqrt(running_ssq / currw);
1668
+ this_gain = (criterion == Pooled)?
1669
+ pooled_gain(full_sd, cumw, this_sd, sd_arr[row-st+1], currw, cumw-currw)
1670
+ :
1671
+ sd_gain(full_sd, this_sd, sd_arr[row-st+1]);
1672
+ if (this_gain > best_gain && this_gain > min_gain)
1673
+ {
1674
+ best_gain = this_gain;
1675
+ split_ix = row;
1676
+ }
1677
+ }
1678
+
1679
+ if (best_gain > -HUGE_VAL)
1680
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1681
+
1682
+ return best_gain;
1683
+ }
1684
+
1685
+ #ifndef _FOR_R
1686
+ #if defined(__clang__)
1687
+ #pragma clang diagnostic push
1688
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
1689
+ #endif
1690
+ #endif
1691
+
1692
+ #ifndef _FOR_R
1693
+ [[gnu::optimize("Ofast")]]
1694
+ #endif
1695
+ static inline void xpy1(double *restrict x, double *restrict y, size_t n)
1696
+ {
1697
+ for (size_t ix = 0; ix < n; ix++) y[ix] += x[ix];
1698
+ }
1699
+
1700
+ #ifndef _FOR_R
1701
+ [[gnu::optimize("Ofast")]]
1702
+ #endif
1703
+ static inline void axpy1(const double a, double *restrict x, double *restrict y, size_t n)
1704
+ {
1705
+ for (size_t ix = 0; ix < n; ix++) y[ix] = std::fma(a, x[ix], y[ix]);
1706
+ }
1707
+
1708
+ #ifndef _FOR_R
1709
+ [[gnu::optimize("Ofast")]]
1710
+ #endif
1711
+ static inline void xpy1(double *restrict xval, size_t ind[], size_t nnz, double *restrict y)
1712
+ {
1713
+ for (size_t ix = 0; ix < nnz; ix++) y[ind[ix]] += xval[ix];
1714
+ }
1715
+
1716
+ #ifndef _FOR_R
1717
+ [[gnu::optimize("Ofast")]]
1718
+ #endif
1719
+ static inline void axpy1(const double a, double *restrict xval, size_t ind[], size_t nnz, double *restrict y)
1720
+ {
1721
+ for (size_t ix = 0; ix < nnz; ix++) y[ind[ix]] = std::fma(a, xval[ix], y[ind[ix]]);
1722
+ }
1723
+
1724
+ #ifndef _FOR_R
1725
+ #if defined(__clang__)
1726
+ #pragma clang diagnostic pop
1727
+ #endif
1728
+ #endif
1729
+
1730
+ template <class real_t, class ldouble_safe>
1731
+ double find_split_full_gain(real_t *restrict x, size_t st, size_t end, size_t *restrict ix_arr,
1732
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
1733
+ double *restrict X_row_major, size_t ncols,
1734
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
1735
+ double *restrict buffer_sum_left, double *restrict buffer_sum_tot,
1736
+ size_t &restrict split_ix, double &restrict split_point,
1737
+ bool x_uses_ix_arr)
1738
+ {
1739
+ if (end <= st) return -HUGE_VAL;
1740
+ if (cols_use != NULL && ncols_use && (double)ncols_use / (double)ncols < 0.1)
1741
+ force_cols_use = true;
1742
+
1743
+ memset(buffer_sum_tot, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
1744
+ if (Xr_indptr == NULL)
1745
+ {
1746
+ if (force_cols_use)
1747
+ {
1748
+ double *restrict ptr_row;
1749
+ for (size_t row = st; row <= end; row++)
1750
+ {
1751
+ ptr_row = X_row_major + ix_arr[row]*ncols;
1752
+ for (size_t col = 0; col < ncols_use; col++)
1753
+ buffer_sum_tot[col] += ptr_row[cols_use[col]];
1754
+ }
1755
+ }
1756
+
1757
+ else
1758
+ {
1759
+ for (size_t row = st; row <= end; row++)
1760
+ xpy1(X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
1761
+ }
1762
+ }
1763
+
1764
+ else
1765
+ {
1766
+ if (force_cols_use)
1767
+ {
1768
+ size_t *curr_begin;
1769
+ size_t *row_end;
1770
+ size_t *curr_col;
1771
+ double *restrict Xr_this;
1772
+ size_t *cols_end = cols_use + ncols_use;
1773
+ for (size_t row = st; row <= end; row++)
1774
+ {
1775
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
1776
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
1777
+ if (curr_begin == row_end) continue;
1778
+ curr_col = cols_use;
1779
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
1780
+
1781
+ while (curr_col < cols_end && curr_begin < row_end)
1782
+ {
1783
+ if (*curr_begin == *curr_col)
1784
+ {
1785
+ buffer_sum_tot[std::distance(cols_use, curr_col)] += Xr_this[std::distance(curr_begin, row_end)];
1786
+ curr_col++;
1787
+ curr_begin++;
1788
+ }
1789
+
1790
+ else
1791
+ {
1792
+ if (*curr_begin > *curr_col)
1793
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
1794
+ else
1795
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
1796
+ }
1797
+ }
1798
+ }
1799
+ }
1800
+
1801
+ else
1802
+ {
1803
+ size_t ptr_this;
1804
+ for (size_t row = st; row <= end; row++)
1805
+ {
1806
+ ptr_this = Xr_indptr[ix_arr[row]];
1807
+ xpy1(Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
1808
+ }
1809
+ }
1810
+ }
1811
+
1812
+ double best_gain = -HUGE_VAL;
1813
+ double this_gain;
1814
+ double sl, sr;
1815
+ double dl, dr;
1816
+ double vleft, vright;
1817
+ memset(buffer_sum_left, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
1818
+ if (Xr_indptr == NULL)
1819
+ {
1820
+ if (!force_cols_use)
1821
+ {
1822
+ for (size_t row = st; row < end; row++)
1823
+ {
1824
+ xpy1(X_row_major + ix_arr[row]*ncols, buffer_sum_left, ncols);
1825
+ if (x_uses_ix_arr) {
1826
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1827
+ }
1828
+ else {
1829
+ if (unlikely(x[row] == x[row+1])) continue;
1830
+ }
1831
+
1832
+ vleft = 0;
1833
+ vright = 0;
1834
+ dl = (double)(row-st+1);
1835
+ dr = (double)(end-row);
1836
+ for (size_t col = 0; col < ncols; col++)
1837
+ {
1838
+ sl = buffer_sum_left[col];
1839
+ vleft += sl * (sl / dl);
1840
+ sr = buffer_sum_tot[col] - sl;
1841
+ vright += sr * (sr / dr);
1842
+ }
1843
+
1844
+ this_gain = vleft + vright;
1845
+ if (this_gain > best_gain)
1846
+ {
1847
+ best_gain = this_gain;
1848
+ split_ix = row;
1849
+ }
1850
+ }
1851
+ }
1852
+
1853
+ else
1854
+ {
1855
+ double *restrict ptr_row;
1856
+ for (size_t row = st; row < end; row++)
1857
+ {
1858
+ ptr_row = X_row_major + ix_arr[row]*ncols;
1859
+ for (size_t col = 0; col < ncols_use; col++)
1860
+ buffer_sum_left[col] += ptr_row[cols_use[col]];
1861
+ if (x_uses_ix_arr) {
1862
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1863
+ }
1864
+ else {
1865
+ if (unlikely(x[row] == x[row+1])) continue;
1866
+ }
1867
+
1868
+ vleft = 0;
1869
+ vright = 0;
1870
+ dl = (double)(row-st+1);
1871
+ dr = (double)(end-row);
1872
+ for (size_t col = 0; col < ncols_use; col++)
1873
+ {
1874
+ sl = buffer_sum_left[col];
1875
+ vleft += sl * (sl / dl);
1876
+ sr = buffer_sum_tot[col] - sl;
1877
+ vright += sr * (sr / dr);
1878
+ }
1879
+
1880
+ this_gain = vleft + vright;
1881
+ if (this_gain > best_gain)
1882
+ {
1883
+ best_gain = this_gain;
1884
+ split_ix = row;
1885
+ }
1886
+ }
1887
+ }
1888
+ }
1889
+
1890
+ else
1891
+ {
1892
+ if (!force_cols_use)
1893
+ {
1894
+ size_t ptr_this;
1895
+ for (size_t row = st; row < end; row++)
1896
+ {
1897
+ ptr_this = Xr_indptr[ix_arr[row]];
1898
+ xpy1(Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_left);
1899
+ if (x_uses_ix_arr) {
1900
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1901
+ }
1902
+ else {
1903
+ if (unlikely(x[row] == x[row+1])) continue;
1904
+ }
1905
+
1906
+ vleft = 0;
1907
+ vright = 0;
1908
+ dl = (double)(row-st+1);
1909
+ dr = (double)(end-row);
1910
+ for (size_t col = 0; col < ncols; col++)
1911
+ {
1912
+ sl = buffer_sum_left[col];
1913
+ vleft += sl * (sl / dl);
1914
+ sr = buffer_sum_tot[col] - sl;
1915
+ vright += sr * (sr / dr);
1916
+ }
1917
+
1918
+ this_gain = vleft + vright;
1919
+ if (this_gain > best_gain)
1920
+ {
1921
+ best_gain = this_gain;
1922
+ split_ix = row;
1923
+ }
1924
+ }
1925
+ }
1926
+
1927
+ else
1928
+ {
1929
+ size_t *curr_begin;
1930
+ size_t *row_end;
1931
+ size_t *curr_col;
1932
+ double *restrict Xr_this;
1933
+ size_t *cols_end = cols_use + ncols_use;
1934
+ for (size_t row = st; row < end; row++)
1935
+ {
1936
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
1937
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
1938
+ if (curr_begin == row_end) goto skip_sum;
1939
+ curr_col = cols_use;
1940
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
1941
+ while (curr_col < cols_end && curr_begin < row_end)
1942
+ {
1943
+ if (*curr_begin == *curr_col)
1944
+ {
1945
+ buffer_sum_left[std::distance(cols_use, curr_col)] += Xr_this[std::distance(curr_begin, row_end)];
1946
+ curr_col++;
1947
+ curr_begin++;
1948
+ }
1949
+
1950
+ else
1951
+ {
1952
+ if (*curr_begin > *curr_col)
1953
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
1954
+ else
1955
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
1956
+ }
1957
+ }
1958
+
1959
+ skip_sum:
1960
+ if (x_uses_ix_arr) {
1961
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
1962
+ }
1963
+ else {
1964
+ if (unlikely(x[row] == x[row+1])) continue;
1965
+ }
1966
+
1967
+ vleft = 0;
1968
+ vright = 0;
1969
+ dl = (double)(row-st+1);
1970
+ dr = (double)(end-row);
1971
+ for (size_t col = 0; col < ncols_use; col++)
1972
+ {
1973
+ sl = buffer_sum_left[col];
1974
+ vleft += sl * (sl / dl);
1975
+ sr = buffer_sum_tot[col] - sl;
1976
+ vright += sr * (sr / dr);
1977
+ }
1978
+
1979
+ this_gain = vleft + vright;
1980
+ if (this_gain > best_gain)
1981
+ {
1982
+ best_gain = this_gain;
1983
+ split_ix = row;
1984
+ }
1985
+ }
1986
+ }
1987
+ }
1988
+
1989
+ if (best_gain <= -HUGE_VAL) return best_gain;
1990
+
1991
+ if (x_uses_ix_arr)
1992
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
1993
+ else
1994
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
1995
+ return best_gain / (ldouble_safe)(end - st + 1);
1996
+ }
1997
+
1998
+ template <class real_t, class mapping, class ldouble_safe>
1999
+ double find_split_full_gain_weighted(real_t *restrict x, size_t st, size_t end, size_t *restrict ix_arr,
2000
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
2001
+ double *restrict X_row_major, size_t ncols,
2002
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
2003
+ double *restrict buffer_sum_left, double *restrict buffer_sum_tot,
2004
+ size_t &restrict split_ix, double &restrict split_point,
2005
+ bool x_uses_ix_arr,
2006
+ mapping &restrict w)
2007
+ {
2008
+ if (end <= st) return -HUGE_VAL;
2009
+ if (cols_use != NULL && ncols_use && (double)ncols_use / (double)ncols < 0.1)
2010
+ force_cols_use = true;
2011
+
2012
+ double wtot = 0;
2013
+ if (x_uses_ix_arr)
2014
+ {
2015
+ for (size_t row = st; row <= end; row++)
2016
+ wtot += w[ix_arr[row]];
2017
+ }
2018
+
2019
+ else
2020
+ {
2021
+ for (size_t row = st; row <= end; row++)
2022
+ wtot += w[row];
2023
+ }
2024
+
2025
+ memset(buffer_sum_tot, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
2026
+ if (Xr_indptr == NULL)
2027
+ {
2028
+ if (!force_cols_use)
2029
+ {
2030
+ if (x_uses_ix_arr)
2031
+ {
2032
+ for (size_t row = st; row <= end; row++)
2033
+ axpy1(w[ix_arr[row]], X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
2034
+ }
2035
+
2036
+ else
2037
+ {
2038
+ for (size_t row = st; row <= end; row++)
2039
+ axpy1(w[row], X_row_major + ix_arr[row]*ncols, buffer_sum_tot, ncols);
2040
+ }
2041
+ }
2042
+
2043
+ else
2044
+ {
2045
+ double *restrict ptr_row;
2046
+ double w_row;
2047
+
2048
+ if (x_uses_ix_arr)
2049
+ {
2050
+ for (size_t row = st; row <= end; row++)
2051
+ {
2052
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2053
+ w_row = w[ix_arr[row]];
2054
+ for (size_t col = 0; col < ncols_use; col++)
2055
+ buffer_sum_tot[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_tot[col]);
2056
+ }
2057
+ }
2058
+
2059
+ else
2060
+ {
2061
+ for (size_t row = st; row <= end; row++)
2062
+ {
2063
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2064
+ w_row = w[row];
2065
+ for (size_t col = 0; col < ncols_use; col++)
2066
+ buffer_sum_tot[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_tot[col]);
2067
+ }
2068
+ }
2069
+ }
2070
+ }
2071
+
2072
+ else
2073
+ {
2074
+ if (!force_cols_use)
2075
+ {
2076
+ size_t ptr_this;
2077
+ if (x_uses_ix_arr)
2078
+ {
2079
+ for (size_t row = st; row <= end; row++)
2080
+ {
2081
+ ptr_this = Xr_indptr[ix_arr[row]];
2082
+ axpy1(w[ix_arr[row]], Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
2083
+ }
2084
+ }
2085
+
2086
+ else
2087
+ {
2088
+ for (size_t row = st; row <= end; row++)
2089
+ {
2090
+ ptr_this = Xr_indptr[ix_arr[row]];
2091
+ axpy1(w[row], Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_tot);
2092
+ }
2093
+ }
2094
+ }
2095
+
2096
+ else
2097
+ {
2098
+ size_t *curr_begin;
2099
+ size_t *row_end;
2100
+ size_t *curr_col;
2101
+ double *restrict Xr_this;
2102
+ size_t *cols_end = cols_use + ncols_use;
2103
+ double w_row;
2104
+ for (size_t row = st; row <= end; row++)
2105
+ {
2106
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
2107
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
2108
+ if (curr_begin == row_end) continue;
2109
+ curr_col = cols_use;
2110
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
2111
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2112
+ size_t dtemp;
2113
+
2114
+ while (curr_col < cols_end && curr_begin < row_end)
2115
+ {
2116
+ if (*curr_begin == *curr_col)
2117
+ {
2118
+ dtemp = std::distance(cols_use, curr_col);
2119
+ buffer_sum_tot[dtemp]
2120
+ =
2121
+ std::fma(w_row, Xr_this[std::distance(curr_begin, row_end)], buffer_sum_tot[dtemp]);
2122
+ curr_col++;
2123
+ curr_begin++;
2124
+ }
2125
+
2126
+ else
2127
+ {
2128
+ if (*curr_begin > *curr_col)
2129
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
2130
+ else
2131
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
2132
+ }
2133
+ }
2134
+ }
2135
+ }
2136
+ }
2137
+
2138
+ double best_gain = -HUGE_VAL;
2139
+ double this_gain;
2140
+ double sl, sr;
2141
+ double vleft, vright;
2142
+ double wleft = 0;
2143
+ double w_row;
2144
+ double wright;
2145
+ memset(buffer_sum_left, 0, (force_cols_use? ncols_use : ncols)*sizeof(double));
2146
+ if (Xr_indptr == NULL)
2147
+ {
2148
+ if (!force_cols_use)
2149
+ {
2150
+ for (size_t row = st; row < end; row++)
2151
+ {
2152
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2153
+ wleft += w_row;
2154
+ axpy1(w_row, X_row_major + ix_arr[row]*ncols, buffer_sum_left, ncols);
2155
+ if (x_uses_ix_arr) {
2156
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2157
+ }
2158
+ else {
2159
+ if (unlikely(x[row] == x[row+1])) continue;
2160
+ }
2161
+
2162
+ vleft = 0;
2163
+ vright = 0;
2164
+ wright = wtot - wleft;
2165
+ for (size_t col = 0; col < ncols; col++)
2166
+ {
2167
+ sl = buffer_sum_left[col];
2168
+ vleft += sl * (sl / wleft);
2169
+ sr = buffer_sum_tot[col] - sl;
2170
+ vright += sr * (sr / wright);
2171
+ }
2172
+
2173
+ this_gain = vleft + vright;
2174
+ if (this_gain > best_gain)
2175
+ {
2176
+ best_gain = this_gain;
2177
+ split_ix = row;
2178
+ }
2179
+ }
2180
+ }
2181
+
2182
+ else
2183
+ {
2184
+ double *restrict ptr_row;
2185
+ double w_row;
2186
+ for (size_t row = st; row < end; row++)
2187
+ {
2188
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2189
+ wleft += w_row;
2190
+
2191
+ ptr_row = X_row_major + ix_arr[row]*ncols;
2192
+ for (size_t col = 0; col < ncols_use; col++)
2193
+ buffer_sum_left[col] = std::fma(w_row, ptr_row[cols_use[col]], buffer_sum_left[col]);
2194
+ if (x_uses_ix_arr) {
2195
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2196
+ }
2197
+ else {
2198
+ if (unlikely(x[row] == x[row+1])) continue;
2199
+ }
2200
+
2201
+ vleft = 0;
2202
+ vright = 0;
2203
+ wright = wtot - wleft;
2204
+ for (size_t col = 0; col < ncols_use; col++)
2205
+ {
2206
+ sl = buffer_sum_left[col];
2207
+ vleft += sl * (sl / wleft);
2208
+ sr = buffer_sum_tot[col] - sl;
2209
+ vright += sr * (sr / wright);
2210
+ }
2211
+
2212
+ this_gain = vleft + vright;
2213
+ if (this_gain > best_gain)
2214
+ {
2215
+ best_gain = this_gain;
2216
+ split_ix = row;
2217
+ }
2218
+ }
2219
+ }
2220
+ }
2221
+
2222
+ else
2223
+ {
2224
+ if (!force_cols_use)
2225
+ {
2226
+ size_t ptr_this;
2227
+ double w_row;
2228
+ for (size_t row = st; row < end; row++)
2229
+ {
2230
+ w_row= w[x_uses_ix_arr? ix_arr[row] : row];
2231
+ wleft += w_row;
2232
+ ptr_this = Xr_indptr[ix_arr[row]];
2233
+ axpy1(w_row, Xr + ptr_this, Xr_ind + ptr_this, Xr_indptr[ix_arr[row]+1] - ptr_this, buffer_sum_left);
2234
+ if (x_uses_ix_arr) {
2235
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2236
+ }
2237
+ else {
2238
+ if (unlikely(x[row] == x[row+1])) continue;
2239
+ }
2240
+
2241
+ vleft = 0;
2242
+ vright = 0;
2243
+ wright = wtot - wleft;
2244
+ for (size_t col = 0; col < ncols; col++)
2245
+ {
2246
+ sl = buffer_sum_left[col];
2247
+ vleft += sl * (sl / wleft);
2248
+ sr = buffer_sum_tot[col] - sl;
2249
+ vright += sr * (sr / wright);
2250
+ }
2251
+
2252
+ this_gain = vleft + vright;
2253
+ if (this_gain > best_gain)
2254
+ {
2255
+ best_gain = this_gain;
2256
+ split_ix = row;
2257
+ }
2258
+ }
2259
+ }
2260
+
2261
+ else
2262
+ {
2263
+ size_t *curr_begin;
2264
+ size_t *row_end;
2265
+ size_t *curr_col;
2266
+ double *restrict Xr_this;
2267
+ size_t *cols_end = cols_use + ncols_use;
2268
+ double w_row;
2269
+ size_t dtemp;
2270
+ for (size_t row = st; row < end; row++)
2271
+ {
2272
+ w_row = w[x_uses_ix_arr? ix_arr[row] : row];
2273
+ wleft += w_row;
2274
+
2275
+ curr_begin = Xr_ind + Xr_indptr[ix_arr[row]];
2276
+ row_end = Xr_ind + Xr_indptr[ix_arr[row] + 1];
2277
+ if (curr_begin == row_end) goto skip_sum;
2278
+ curr_col = cols_use;
2279
+ Xr_this = Xr + Xr_indptr[ix_arr[row]];
2280
+ while (curr_col < cols_end && curr_begin < row_end)
2281
+ {
2282
+ if (*curr_begin == *curr_col)
2283
+ {
2284
+ dtemp = std::distance(cols_use, curr_col);
2285
+ buffer_sum_left[dtemp]
2286
+ =
2287
+ std::fma(w_row, Xr_this[std::distance(curr_begin, row_end)], buffer_sum_left[dtemp]);
2288
+ curr_col++;
2289
+ curr_begin++;
2290
+ }
2291
+
2292
+ else
2293
+ {
2294
+ if (*curr_begin > *curr_col)
2295
+ curr_col = std::lower_bound(curr_col, cols_end, *curr_begin);
2296
+ else
2297
+ curr_begin = std::lower_bound(curr_begin, row_end, *curr_col);
2298
+ }
2299
+ }
2300
+
2301
+ skip_sum:
2302
+ if (x_uses_ix_arr) {
2303
+ if (unlikely(x[ix_arr[row]] == x[ix_arr[row+1]])) continue;
2304
+ }
2305
+ else {
2306
+ if (unlikely(x[row] == x[row+1])) continue;
2307
+ }
2308
+
2309
+ vleft = 0;
2310
+ vright = 0;
2311
+ wright = wtot - wleft;
2312
+ for (size_t col = 0; col < ncols_use; col++)
2313
+ {
2314
+ sl = buffer_sum_left[col];
2315
+ vleft += sl * (sl / wleft);
2316
+ sr = buffer_sum_tot[col] - sl;
2317
+ vright += sr * (sr / wright);
2318
+ }
2319
+
2320
+ this_gain = vleft + vright;
2321
+ if (this_gain > best_gain)
2322
+ {
2323
+ best_gain = this_gain;
2324
+ split_ix = row;
2325
+ }
2326
+ }
2327
+ }
2328
+ }
2329
+
2330
+ if (best_gain <= -HUGE_VAL) return best_gain;
2331
+
2332
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
2333
+ return best_gain / wtot;
2334
+ }
2335
+
2336
+ template <class real_t_, class real_t>
2337
+ double find_split_dens_shortform_t(real_t *restrict x, size_t n, double &restrict split_point)
2338
+ {
2339
+ double best_gain = -HUGE_VAL;
2340
+ size_t n_minus_one = n - 1;
2341
+ real_t_ xmin = x[0];
2342
+ real_t_ xmax = x[n-1];
2343
+ real_t_ xleft, xright;
2344
+ real_t_ xmid;
2345
+ double this_gain;
2346
+ size_t split_ix = 0;
2347
+
2348
+ for (size_t ix = 0; ix < n_minus_one; ix++)
2349
+ {
2350
+ if (x[ix] == x[ix+1]) continue;
2351
+ xmid = (real_t_)x[ix] + ((real_t_)x[ix+1] - (real_t_)x[ix]) / (real_t_)2;
2352
+ xleft = xmid - xmin;
2353
+ xright = xmax - xmid;
2354
+ if (unlikely(!xleft || !xright)) continue;
2355
+ this_gain = (real_t_)square(ix+1) / xleft + (real_t_)square(n_minus_one - ix) / xright;
2356
+ if (this_gain > best_gain)
2357
+ {
2358
+ best_gain = this_gain;
2359
+ split_ix = ix;
2360
+ }
2361
+ }
2362
+
2363
+ if (best_gain <= -HUGE_VAL) return best_gain;
2364
+
2365
+ real_t_ xtot = (real_t_)xmax - (real_t_)xmin;
2366
+ real_t_ nleft = (real_t_)(split_ix+1);
2367
+ real_t_ nright = (real_t_)(n_minus_one - split_ix);
2368
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
2369
+ real_t_ rpct_left = split_point / xtot;
2370
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2371
+ real_t_ rpct_right = (real_t_)1 - rpct_left;
2372
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2373
+
2374
+ real_t_ nl_sq = nleft / (real_t_)n; nl_sq = square(nl_sq);
2375
+ real_t_ nr_sq = nright / (real_t_)n; nl_sq = square(nr_sq);
2376
+
2377
+ return nl_sq / rpct_left + nr_sq / rpct_right;
2378
+ }
2379
+
2380
+ template <class real_t, class ldouble_safe>
2381
+ double find_split_dens_shortform(real_t *restrict x, size_t n, double &restrict split_point)
2382
+ {
2383
+ if (n < INT32_MAX)
2384
+ return find_split_dens_shortform_t<double, real_t>(x, n, split_point);
2385
+ else
2386
+ return find_split_dens_shortform_t<ldouble_safe, real_t>(x, n, split_point);
2387
+ }
2388
+
2389
+ template <class real_t_, class real_t, class mapping>
2390
+ double find_split_dens_shortform_weighted_t(real_t *restrict x, size_t n, double &restrict split_point, mapping &restrict w, size_t *restrict buffer_indices)
2391
+ {
2392
+ double best_gain = -HUGE_VAL;
2393
+ size_t n_minus_one = n - 1;
2394
+ real_t_ xmin = x[buffer_indices[0]];
2395
+ real_t_ xmax = x[buffer_indices[n-1]];
2396
+ real_t_ xleft, xright;
2397
+ real_t_ xmid;
2398
+ double this_gain;
2399
+
2400
+ real_t_ wtot = 0;
2401
+ for (size_t ix = 0; ix < n; ix++)
2402
+ wtot += w[buffer_indices[ix]];
2403
+ real_t_ w_left = 0;
2404
+ real_t_ w_right;
2405
+ real_t_ best_w = 0;
2406
+ size_t split_ix = 0;
2407
+
2408
+ for (size_t ix = 0; ix < n_minus_one; ix++)
2409
+ {
2410
+ w_left += w[buffer_indices[ix]];
2411
+ if (x[buffer_indices[ix]] == x[buffer_indices[ix+1]]) continue;
2412
+ xmid = (real_t_)x[buffer_indices[ix]] + ((real_t_)x[buffer_indices[ix+1]] - (real_t_)x[buffer_indices[ix]]) / (real_t_)2;
2413
+ xleft = xmid - xmin;
2414
+ xright = xmax - xmid;
2415
+ if (unlikely(!xleft || !xright)) continue;
2416
+
2417
+ w_right = wtot - w_left;
2418
+ this_gain = square(w_left) / xleft + square(w_right) / xright;
2419
+ if (this_gain > best_gain)
2420
+ {
2421
+ best_gain = this_gain;
2422
+ best_w = w_left;
2423
+ split_ix = xmid;
2424
+ }
2425
+ }
2426
+
2427
+ if (best_gain <= -HUGE_VAL) return best_gain;
2428
+
2429
+ real_t_ xtot = xmax - xmin;
2430
+ w_left = best_w;
2431
+ w_right = wtot - w_left;
2432
+ w_left = std::fmax(w_left, std::numeric_limits<double>::min());
2433
+ w_right = std::fmax(w_right, std::numeric_limits<double>::min());
2434
+ split_point = midpoint(x[buffer_indices[split_ix]], x[buffer_indices[split_ix+1]]);
2435
+ real_t_ rpct_left = split_point / xtot;
2436
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2437
+ real_t_ rpct_right = (real_t_)1 - rpct_left;
2438
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2439
+
2440
+ real_t_ wl_sq = w_left / wtot; wl_sq = square(wl_sq);
2441
+ real_t_ wr_sq = w_right / wtot; wl_sq = square(wr_sq);
2442
+
2443
+ return wl_sq / rpct_left + wr_sq / rpct_right;
2444
+ }
2445
+
2446
+ template <class real_t, class mapping, class ldouble_safe>
2447
+ double find_split_dens_shortform_weighted(real_t *restrict x, size_t n, double &restrict split_point, mapping &restrict w, size_t *restrict buffer_indices)
2448
+ {
2449
+ if (n < INT32_MAX)
2450
+ return find_split_dens_shortform_weighted_t<double, real_t, mapping>(x, n, split_point, w, buffer_indices);
2451
+ else
2452
+ return find_split_dens_shortform_weighted_t<ldouble_safe, real_t, mapping>(x, n, split_point, w, buffer_indices);
2453
+ }
2454
+
2455
+ template <class real_t>
2456
+ double find_split_dens_shortform(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2457
+ double &restrict split_point, size_t &restrict split_ix)
2458
+ {
2459
+ double best_gain = -HUGE_VAL;
2460
+ real_t xmin = x[ix_arr[st]];
2461
+ real_t xmax = x[ix_arr[end]];
2462
+ real_t xleft, xright;
2463
+ real_t xmid;
2464
+ double this_gain;
2465
+
2466
+ for (size_t row = st; row < end; row++)
2467
+ {
2468
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2469
+ xmid = x[ix_arr[row]] + (x[ix_arr[row+1]] - x[ix_arr[row]]) / (real_t)2;
2470
+ xleft = xmid - xmin;
2471
+ xright = xmax - xmid;
2472
+ if (unlikely(!xleft || !xright)) continue;
2473
+ this_gain = square(row-st+1) / xleft + square(end-row) / xright;
2474
+ if (this_gain > best_gain)
2475
+ {
2476
+ best_gain = this_gain;
2477
+ split_ix = row;
2478
+ }
2479
+ }
2480
+
2481
+ if (best_gain <= -HUGE_VAL) return best_gain;
2482
+
2483
+ double xtot = (double)xmax - (double)xmin;
2484
+ double nleft = (double)(split_ix-st+1);
2485
+ double nright = (double)(end - split_ix);
2486
+ split_point = midpoint(x[ix_arr[split_ix]], x[ix_arr[split_ix+1]]);
2487
+ double rpct_left = split_point / xtot;
2488
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2489
+ double rpct_right = 1. - rpct_left;
2490
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2491
+ double ntot = (double)(end - st + 1);
2492
+
2493
+ double nl_sq = nleft / ntot; nl_sq = square(nl_sq);
2494
+ double nr_sq = nright / ntot; nl_sq = square(nr_sq);
2495
+
2496
+ return nl_sq / rpct_left + nr_sq / rpct_right;
2497
+ }
2498
+
2499
+ template <class real_t, class mapping>
2500
+ double find_split_dens_shortform_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2501
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2502
+ {
2503
+ double best_gain = -HUGE_VAL;
2504
+ real_t xmin = x[ix_arr[st]];
2505
+ real_t xmax = x[ix_arr[end]];
2506
+ real_t xleft, xright;
2507
+ real_t xmid;
2508
+ double this_gain;
2509
+
2510
+ double wtot = 0;
2511
+ for (size_t row = st; row <= end; row++)
2512
+ wtot += w[ix_arr[row]];
2513
+ double w_left = 0;
2514
+ double w_right;
2515
+ double best_w = 0;
2516
+
2517
+ for (size_t row = st; row < end; row++)
2518
+ {
2519
+ w_left += w[ix_arr[row]];
2520
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2521
+ xmid = x[ix_arr[row]] + (x[ix_arr[row+1]] - x[ix_arr[row]]) / (real_t)2;
2522
+ xleft = xmid - xmin;
2523
+ xright = xmax - xmid;
2524
+ if (unlikely(!xleft || !xright)) continue;
2525
+
2526
+ w_right = wtot - w_left;
2527
+ this_gain = square(w_left) / xleft + square(w_right) / xright;
2528
+ if (this_gain > best_gain)
2529
+ {
2530
+ best_gain = this_gain;
2531
+ best_w = w_left;
2532
+ split_ix = row;
2533
+ }
2534
+ }
2535
+
2536
+ if (best_gain <= -HUGE_VAL) return best_gain;
2537
+
2538
+ double xtot = (double)xmax - (double)xmin;
2539
+ w_left = best_w;
2540
+ w_right = wtot - w_left;
2541
+ w_left = std::fmax(w_left, std::numeric_limits<double>::min());
2542
+ w_right = std::fmax(w_right, std::numeric_limits<double>::min());
2543
+ split_point = midpoint(x[split_ix], x[split_ix+1]);
2544
+ double rpct_left = split_point / xtot;
2545
+ rpct_left = std::fmax(rpct_left, std::numeric_limits<double>::min());
2546
+ double rpct_right = 1. - rpct_left;
2547
+ rpct_right = std::fmax(rpct_right, std::numeric_limits<double>::min());
2548
+
2549
+ double wl_sq = w_left / wtot; wl_sq = square(wl_sq);
2550
+ double wr_sq = w_right / wtot; wl_sq = square(wr_sq);
2551
+
2552
+ return wl_sq / rpct_left + wr_sq / rpct_right;
2553
+ }
2554
+
2555
+ /* This is a slower but more numerically-robust form */
2556
+ template <class real_t, class ldouble_safe>
2557
+ double find_split_dens_longform(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2558
+ double &restrict split_point, size_t &restrict split_ix)
2559
+ {
2560
+ double best_gain = -HUGE_VAL;
2561
+ real_t xmin = x[ix_arr[st]];
2562
+ real_t xmax = x[ix_arr[end]];
2563
+ real_t xleft, xright;
2564
+ real_t xmid;
2565
+ ldouble_safe pct_left, pct_right;
2566
+ ldouble_safe rpct_left, rpct_right;
2567
+ ldouble_safe n_tot = end - st + 1;
2568
+ ldouble_safe xtot = (ldouble_safe)xmax - (ldouble_safe)xmin;
2569
+ ldouble_safe cnt_left;
2570
+ double this_gain;
2571
+
2572
+ for (size_t row = st; row < end; row++)
2573
+ {
2574
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2575
+ xmid = midpoint(x[ix_arr[row]], x[ix_arr[row+1]]);
2576
+ xleft = xmid - xmin;
2577
+ xright = xmax - xmid;
2578
+ if (unlikely(!xleft || !xright)) continue;
2579
+
2580
+ cnt_left = (ldouble_safe)(row-st+1);
2581
+
2582
+ xleft = std::fmax(xleft, (real_t)std::numeric_limits<real_t>::min());
2583
+ xright = std::fmax(xright, (real_t)std::numeric_limits<real_t>::min());
2584
+ pct_left = cnt_left / n_tot;
2585
+ pct_right = (ldouble_safe)1 - pct_left;
2586
+ rpct_left = (ldouble_safe)xleft / xtot;
2587
+ rpct_right = (ldouble_safe)xright / xtot;
2588
+
2589
+ this_gain = square(pct_left) / rpct_left + square(pct_right) / rpct_right;
2590
+ if (unlikely(is_na_or_inf(this_gain))) continue;
2591
+ if (this_gain > best_gain)
2592
+ {
2593
+ best_gain = this_gain;
2594
+ split_point = xmid;
2595
+ split_ix = row;
2596
+ }
2597
+ }
2598
+
2599
+ return best_gain;
2600
+ }
2601
+
2602
+ template <class real_t, class mapping, class ldouble_safe>
2603
+ double find_split_dens_longform_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2604
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2605
+ {
2606
+ double best_gain = -HUGE_VAL;
2607
+ real_t xmin = x[ix_arr[st]];
2608
+ real_t xmax = x[ix_arr[end]];
2609
+ real_t xleft, xright;
2610
+ real_t xmid;
2611
+ ldouble_safe pct_left, pct_right;
2612
+ ldouble_safe rpct_left, rpct_right;
2613
+ ldouble_safe xtot = (ldouble_safe)xmax - (ldouble_safe)xmin;
2614
+ double this_gain;
2615
+
2616
+ ldouble_safe wtot = 0;
2617
+ for (size_t row = st; row <= end; row++)
2618
+ wtot += w[ix_arr[row]];
2619
+ ldouble_safe w_left = 0;
2620
+
2621
+ for (size_t row = st; row < end; row++)
2622
+ {
2623
+ w_left += w[ix_arr[row]];
2624
+ if (x[ix_arr[row]] == x[ix_arr[row+1]]) continue;
2625
+ xmid = midpoint(x[ix_arr[row]], x[ix_arr[row+1]]);
2626
+ xleft = xmid - xmin;
2627
+ xright = xmax - xmid;
2628
+ if (unlikely(!xleft || !xright)) continue;
2629
+
2630
+ xleft = std::fmax(xleft, (real_t)std::numeric_limits<real_t>::min());
2631
+ xright = std::fmax(xright, (real_t)std::numeric_limits<real_t>::min());
2632
+ pct_left = w_left / wtot;
2633
+ pct_right = (ldouble_safe)1 - pct_left;
2634
+ rpct_left = (ldouble_safe)xleft / xtot;
2635
+ rpct_right = (ldouble_safe)xright / xtot;
2636
+
2637
+ this_gain = square(pct_left) / rpct_left + square(pct_right) / rpct_right;
2638
+ if (unlikely(is_na_or_inf(this_gain))) continue;
2639
+ if (this_gain > best_gain)
2640
+ {
2641
+ best_gain = this_gain;
2642
+ split_point = xmid;
2643
+ split_ix = row;
2644
+ }
2645
+ }
2646
+
2647
+ return best_gain;
2648
+ }
2649
+
2650
+ template <class real_t, class ldouble_safe>
2651
+ double find_split_dens(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2652
+ double &restrict split_point, size_t &restrict split_ix)
2653
+ {
2654
+ if (end - st + 1 < INT32_MAX && x[ix_arr[end]] - x[ix_arr[st]] >= 1)
2655
+ return find_split_dens_shortform<real_t>(x, ix_arr, st, end, split_point, split_ix);
2656
+ else
2657
+ return find_split_dens_longform<real_t, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix);
2658
+ }
2659
+
2660
+ template <class real_t, class mapping, class ldouble_safe>
2661
+ double find_split_dens_weighted(real_t *restrict x, size_t *restrict ix_arr, size_t st, size_t end,
2662
+ double &restrict split_point, size_t &restrict split_ix, mapping &restrict w)
2663
+ {
2664
+ if (end - st + 1 < INT32_MAX && x[ix_arr[end]] - x[ix_arr[st]] >= 1)
2665
+ return find_split_dens_shortform_weighted<real_t, mapping>(x, ix_arr, st, end, split_point, split_ix, w);
2666
+ else
2667
+ return find_split_dens_longform_weighted<real_t, mapping, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix, w);
2668
+ }
2669
+
2670
+ template <class int_t, class ldouble_safe>
2671
+ double find_split_dens_longform(int *restrict x, int ncat, size_t *restrict ix_arr, size_t st, size_t end,
2672
+ CategSplit cat_split_type, MissingAction missing_action,
2673
+ int &restrict chosen_cat, signed char *restrict split_categ, int *restrict saved_cat_mode,
2674
+ size_t *restrict buffer_cnt, int_t *restrict buffer_indices)
2675
+ {
2676
+ if (st >= end || ncat <= 1) return -HUGE_VAL;
2677
+ size_t n_nas = 0;
2678
+ int xval;
2679
+
2680
+ /* count categories */
2681
+ memset(buffer_cnt, 0, sizeof(size_t) * ncat);
2682
+ if (missing_action == Fail)
2683
+ {
2684
+ for (size_t row = st; row <= end; row++)
2685
+ if (likely(x[ix_arr[row]] >= 0))
2686
+ buffer_cnt[x[ix_arr[row]]]++;
2687
+ }
2688
+
2689
+ else if (missing_action == Impute)
2690
+ {
2691
+ for (size_t row = st; row <= end; row++)
2692
+ {
2693
+ xval = x[ix_arr[row]];
2694
+ if (unlikely(xval < 0))
2695
+ n_nas++;
2696
+ else
2697
+ buffer_cnt[xval]++;
2698
+ }
2699
+
2700
+ if (unlikely(n_nas >= end-st)) return -HUGE_VAL;
2701
+
2702
+ if (n_nas)
2703
+ {
2704
+ auto idxmax = std::max_element(buffer_cnt, buffer_cnt + ncat);
2705
+ *idxmax += n_nas;
2706
+ *saved_cat_mode = (int)std::distance(buffer_cnt, idxmax);
2707
+ }
2708
+ }
2709
+
2710
+ else
2711
+ {
2712
+ for (size_t row = st; row <= end; row++)
2713
+ {
2714
+ xval = x[ix_arr[row]];
2715
+ if (likely(xval >= 0)) buffer_cnt[xval]++;
2716
+ }
2717
+ }
2718
+
2719
+ std::iota(buffer_indices, buffer_indices + ncat, (int_t)0);
2720
+ std::sort(buffer_indices, buffer_indices + ncat,
2721
+ [&buffer_cnt](const int_t a, const int_t b)
2722
+ {return buffer_cnt[a] < buffer_cnt[b];});
2723
+
2724
+ int curr = 0;
2725
+ if (split_categ != NULL)
2726
+ {
2727
+ while (buffer_cnt[buffer_indices[curr]] == 0)
2728
+ {
2729
+ split_categ[buffer_indices[curr]] = -1;
2730
+ curr++;
2731
+ }
2732
+ }
2733
+
2734
+ else
2735
+ {
2736
+ while (buffer_cnt[buffer_indices[curr]] == 0) curr++;
2737
+ }
2738
+
2739
+ int ncat_present = ncat - curr;
2740
+ if (ncat_present <= 1) return -HUGE_VAL;
2741
+ if (ncat_present == 2)
2742
+ {
2743
+ switch (cat_split_type)
2744
+ {
2745
+ case SingleCateg:
2746
+ {
2747
+ chosen_cat = buffer_indices[curr];
2748
+ break;
2749
+ }
2750
+
2751
+ case SubSet:
2752
+ {
2753
+ split_categ[buffer_indices[curr]] = 1;
2754
+ split_categ[buffer_indices[curr+1]] = 0;
2755
+ break;
2756
+ }
2757
+ }
2758
+
2759
+ ldouble_safe pct_left
2760
+ =
2761
+ (ldouble_safe)buffer_cnt[buffer_indices[curr]]
2762
+ /
2763
+ (ldouble_safe)(
2764
+ buffer_cnt[buffer_indices[curr]]
2765
+ +
2766
+ buffer_cnt[buffer_indices[curr+1]]
2767
+ );
2768
+
2769
+ return ((ldouble_safe)buffer_cnt[buffer_indices[curr]] * (2. * pct_left)
2770
+ +
2771
+ (ldouble_safe)buffer_cnt[buffer_indices[curr+1]] * (2. - 2.*pct_left))
2772
+ /
2773
+ (ldouble_safe)(buffer_cnt[buffer_indices[curr]] + buffer_cnt[buffer_indices[curr+1]]);
2774
+ }
2775
+
2776
+ size_t ntot;
2777
+ if (missing_action == Impute)
2778
+ ntot = end - st + 1;
2779
+ else
2780
+ ntot = std::accumulate(buffer_cnt, buffer_cnt + ncat, (size_t)0);
2781
+ if (unlikely(ntot <= 1)) unexpected_error();
2782
+ ldouble_safe ntot_ = (ldouble_safe)ntot;
2783
+
2784
+ switch (cat_split_type)
2785
+ {
2786
+ case SingleCateg:
2787
+ {
2788
+ double pct_one_cat = 1. / (double)ncat_present;
2789
+ double pct_left_smallest = (ldouble_safe)buffer_cnt[buffer_indices[curr]] / ntot_;
2790
+ double gain_smallest
2791
+ =
2792
+ (ldouble_safe)buffer_cnt[buffer_indices[curr]] * (pct_left_smallest / pct_one_cat)
2793
+ +
2794
+ (ldouble_safe)(ntot - buffer_cnt[buffer_indices[curr]]) * ((1. - pct_left_smallest) / (1. - pct_one_cat))
2795
+ ;
2796
+
2797
+ double pct_left_biggest = (ldouble_safe)buffer_cnt[buffer_indices[ncat-1]] / ntot_;
2798
+ double gain_biggest
2799
+ =
2800
+ (ldouble_safe)buffer_cnt[buffer_indices[ncat-1]] * (pct_left_biggest / pct_one_cat)
2801
+ +
2802
+ (ldouble_safe)(ntot - buffer_cnt[buffer_indices[ncat-1]]) * ((1. - pct_left_biggest) / (1. - pct_one_cat))
2803
+ ;
2804
+
2805
+ if (gain_smallest >= gain_biggest)
2806
+ {
2807
+ chosen_cat = buffer_indices[curr];
2808
+ return gain_smallest / ntot_;
2809
+ }
2810
+
2811
+ else
2812
+ {
2813
+ chosen_cat = buffer_indices[ncat-1];
2814
+ return gain_biggest / ntot_;
2815
+ }
2816
+ break;
2817
+ }
2818
+
2819
+ case SubSet:
2820
+ {
2821
+ size_t cnt_left = 0;
2822
+ size_t cnt_right;
2823
+ int st_cat = curr - 1;
2824
+ double this_gain;
2825
+ double best_gain = -HUGE_VAL;
2826
+ int best_cat = 0;
2827
+ ldouble_safe pct_left;
2828
+ double pct_cat_left;
2829
+ double ncat_present_ = (double)ncat_present;
2830
+ for (; curr < ncat; curr++)
2831
+ {
2832
+ cnt_left += buffer_cnt[buffer_indices[curr]];
2833
+ cnt_right = ntot - cnt_left;
2834
+ pct_left = (ldouble_safe)cnt_left / ntot_;
2835
+ pct_cat_left = (double)(curr - st_cat) / ncat_present_;
2836
+ this_gain
2837
+ =
2838
+ (ldouble_safe)cnt_left * (pct_left / pct_cat_left)
2839
+ +
2840
+ (ldouble_safe)cnt_right * (((ldouble_safe)1 - pct_left) / (1. - pct_cat_left))
2841
+ ;
2842
+ if (this_gain > best_gain)
2843
+ {
2844
+ best_gain = this_gain;
2845
+ best_cat = curr;
2846
+ }
2847
+ }
2848
+
2849
+ if (best_gain <= -HUGE_VAL) return best_gain;
2850
+ st_cat++;
2851
+ for (; st_cat <= best_cat; st_cat++)
2852
+ split_categ[buffer_indices[st_cat]] = 1;
2853
+ for (; st_cat < ncat; st_cat++)
2854
+ split_categ[buffer_indices[st_cat]] = 0;
2855
+ return best_gain / ntot_;
2856
+ break;
2857
+ }
2858
+ }
2859
+
2860
+ /* This will not be reached, but CRAN might complain otherwise */
2861
+ return -HUGE_VAL;
2862
+ }
2863
+
2864
+ template <class mapping, class int_t, class ldouble_safe>
2865
+ double find_split_dens_longform_weighted(int *restrict x, int ncat, size_t *restrict ix_arr, size_t st, size_t end,
2866
+ CategSplit cat_split_type, MissingAction missing_action,
2867
+ int &restrict chosen_cat, signed char *restrict split_categ, int *restrict saved_cat_mode,
2868
+ int_t *restrict buffer_indices, mapping &restrict w)
2869
+ {
2870
+ if (st >= end || ncat <= 1) return -HUGE_VAL;
2871
+ ldouble_safe w_missing = 0;
2872
+ int xval;
2873
+ size_t ix_;
2874
+
2875
+ /* count categories */
2876
+ /* TODO: allocate this buffer externally */
2877
+ std::vector<ldouble_safe> buffer_cnt(ncat, (ldouble_safe)0);
2878
+ if (missing_action == Fail)
2879
+ {
2880
+ for (size_t row = st; row <= end; row++)
2881
+ {
2882
+ ix_ = ix_arr[row];
2883
+ if (unlikely(x[ix_]) < 0) continue;
2884
+ buffer_cnt[x[ix_]] += w[ix_];
2885
+ }
2886
+ }
2887
+
2888
+ else if (missing_action == Impute)
2889
+ {
2890
+ for (size_t row = st; row <= end; row++)
2891
+ {
2892
+ ix_ = ix_arr[row];
2893
+ xval = x[ix_];
2894
+ if (unlikely(xval < 0))
2895
+ w_missing += w[ix_];
2896
+ else
2897
+ buffer_cnt[xval] += w[ix_];
2898
+ }
2899
+
2900
+ if (w_missing)
2901
+ {
2902
+ auto idxmax = std::max_element(buffer_cnt.begin(), buffer_cnt.end());
2903
+ *idxmax += w_missing;
2904
+ *saved_cat_mode = (int)std::distance(buffer_cnt.begin(), idxmax);
2905
+ }
2906
+ }
2907
+
2908
+ else
2909
+ {
2910
+ for (size_t row = st; row <= end; row++)
2911
+ {
2912
+ ix_ = ix_arr[row];
2913
+ xval = x[ix_];
2914
+ if (likely(xval >= 0)) buffer_cnt[xval] += w[ix_];
2915
+ }
2916
+ }
2917
+
2918
+ std::iota(buffer_indices, buffer_indices + ncat, (int_t)0);
2919
+ std::sort(buffer_indices, buffer_indices + ncat,
2920
+ [&buffer_cnt](const int_t a, const int_t b)
2921
+ {return buffer_cnt[a] < buffer_cnt[b];});
2922
+
2923
+ int curr = 0;
2924
+ if (split_categ != NULL)
2925
+ {
2926
+ while (buffer_cnt[buffer_indices[curr]] == 0)
2927
+ {
2928
+ split_categ[buffer_indices[curr]] = -1;
2929
+ curr++;
2930
+ }
2931
+ }
2932
+
2933
+ else
2934
+ {
2935
+ while (buffer_cnt[buffer_indices[curr]] == 0) curr++;
2936
+ }
2937
+
2938
+ int ncat_present = ncat - curr;
2939
+ if (ncat_present <= 1) return -HUGE_VAL;
2940
+ if (ncat_present == 2)
2941
+ {
2942
+ switch (cat_split_type)
2943
+ {
2944
+ case SingleCateg:
2945
+ {
2946
+ chosen_cat = buffer_indices[curr];
2947
+ break;
2948
+ }
2949
+
2950
+ case SubSet:
2951
+ {
2952
+ split_categ[buffer_indices[curr]] = 1;
2953
+ split_categ[buffer_indices[curr+1]] = 0;
2954
+ break;
2955
+ }
2956
+ }
2957
+
2958
+ ldouble_safe pct_left
2959
+ =
2960
+ buffer_cnt[buffer_indices[curr]]
2961
+ /
2962
+ (
2963
+ buffer_cnt[buffer_indices[curr]]
2964
+ +
2965
+ buffer_cnt[buffer_indices[curr+1]]
2966
+ );
2967
+
2968
+ return (buffer_cnt[buffer_indices[curr]] * (pct_left * 2.)
2969
+ +
2970
+ buffer_cnt[buffer_indices[curr+1]] * (2. - 2.*pct_left))
2971
+ /
2972
+ (buffer_cnt[buffer_indices[curr]] + buffer_cnt[buffer_indices[curr+1]]);
2973
+ }
2974
+
2975
+ ldouble_safe ntot = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
2976
+ if (unlikely(ntot <= 0)) unexpected_error();
2977
+
2978
+ switch (cat_split_type)
2979
+ {
2980
+ case SingleCateg:
2981
+ {
2982
+ double pct_one_cat = 1. / (double)ncat_present;
2983
+ double pct_left_smallest = buffer_cnt[buffer_indices[curr]] / ntot;
2984
+ double gain_smallest
2985
+ =
2986
+ buffer_cnt[buffer_indices[curr]] * (pct_left_smallest / pct_one_cat)
2987
+ +
2988
+ (ntot - buffer_cnt[buffer_indices[curr]]) * ((1. - pct_left_smallest) / (1. - pct_one_cat))
2989
+ ;
2990
+
2991
+ double pct_left_biggest = buffer_cnt[buffer_indices[ncat-1]] / ntot;
2992
+ double gain_biggest
2993
+ =
2994
+ buffer_cnt[buffer_indices[ncat-1]] * (pct_left_biggest / pct_one_cat)
2995
+ +
2996
+ (ntot - buffer_cnt[buffer_indices[ncat-1]]) * ((1. - pct_left_biggest) / (1. - pct_one_cat))
2997
+ ;
2998
+
2999
+ if (gain_smallest >= gain_biggest)
3000
+ {
3001
+ chosen_cat = buffer_indices[curr];
3002
+ return gain_smallest / ntot;
3003
+ }
3004
+
3005
+ else
3006
+ {
3007
+ chosen_cat = buffer_indices[ncat-1];
3008
+ return gain_biggest / ntot;
3009
+ }
3010
+ break;
3011
+ }
3012
+
3013
+ case SubSet:
3014
+ {
3015
+ ldouble_safe cnt_left = 0;
3016
+ ldouble_safe cnt_right;
3017
+ int st_cat = curr - 1;
3018
+ double this_gain;
3019
+ double best_gain = -HUGE_VAL;
3020
+ int best_cat = 0;
3021
+ ldouble_safe pct_left;
3022
+ double pct_cat_left;
3023
+ double ncat_present_ = (double)ncat_present;
3024
+ for (; curr < ncat; curr++)
3025
+ {
3026
+ cnt_left += buffer_cnt[buffer_indices[curr]];
3027
+ cnt_right = ntot - cnt_left;
3028
+ pct_left = cnt_left / ntot;
3029
+ pct_cat_left = (double)(curr - st_cat) / ncat_present_;
3030
+ this_gain
3031
+ =
3032
+ (ldouble_safe)cnt_left * (pct_left / pct_cat_left)
3033
+ +
3034
+ (ldouble_safe)cnt_right * (((ldouble_safe)1 - pct_left) / (1. - pct_cat_left))
3035
+ ;
3036
+ if (this_gain > best_gain)
3037
+ {
3038
+ best_gain = this_gain;
3039
+ best_cat = curr;
3040
+ }
3041
+ }
3042
+
3043
+ if (best_gain <= -HUGE_VAL) return best_gain;
3044
+ st_cat++;
3045
+ for (; st_cat <= best_cat; st_cat++)
3046
+ split_categ[buffer_indices[st_cat]] = 1;
3047
+ for (; st_cat < ncat; st_cat++)
3048
+ split_categ[buffer_indices[st_cat]] = 0;
3049
+ return best_gain / ntot;
3050
+ break;
3051
+ }
3052
+ }
3053
+
3054
+ /* This will not be reached, but CRAN might complain otherwise */
3055
+ return -HUGE_VAL;
3056
+ }
3057
+
3058
+ /* for split-criterion in hyperplanes (see below for version aimed at single-variable splits) */
3059
+ template <class ldouble_safe>
3060
+ double eval_guided_crit(double *restrict x, size_t n, GainCriterion criterion,
3061
+ double min_gain, bool as_relative_gain, double *restrict buffer_sd,
3062
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3063
+ size_t *restrict ix_arr_plus_st,
3064
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3065
+ double *restrict X_row_major, size_t ncols,
3066
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3067
+ {
3068
+ /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
3069
+ all numbers are assumed to be small and in the same scale */
3070
+ double gain = 0;
3071
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3072
+
3073
+ /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
3074
+ if (unlikely(n == 2))
3075
+ {
3076
+ if (x[0] == x[1]) return -HUGE_VAL;
3077
+ split_point = midpoint_with_reorder(x[0], x[1]);
3078
+ gain = 1.;
3079
+ if (gain > min_gain)
3080
+ return gain;
3081
+ else
3082
+ return 0.;
3083
+ }
3084
+
3085
+ if (criterion == FullGain)
3086
+ {
3087
+ /* TODO: these buffers should be allocated externally */
3088
+ std::vector<size_t> argsorted(n);
3089
+ std::iota(argsorted.begin(), argsorted.end(), (size_t)0);
3090
+ std::sort(argsorted.begin(), argsorted.end(),
3091
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3092
+ if (x[argsorted[0]] == x[argsorted[n-1]]) return -HUGE_VAL;
3093
+ std::vector<double> temp_buffer(n + mult2(ncols));
3094
+ for (size_t ix = 0; ix < n; ix++) temp_buffer[ix] = x[argsorted[ix]];
3095
+ for (size_t ix = 0; ix < n; ix++)
3096
+ argsorted[ix] = ix_arr_plus_st[argsorted[ix]];
3097
+ size_t ignored;
3098
+ return find_split_full_gain<double, ldouble_safe>(
3099
+ temp_buffer.data(), (size_t)0, n-1, argsorted.data(),
3100
+ cols_use, ncols_use, force_cols_use,
3101
+ X_row_major, ncols,
3102
+ Xr, Xr_ind, Xr_indptr,
3103
+ temp_buffer.data() + n, temp_buffer.data() + n + ncols,
3104
+ ignored, split_point,
3105
+ false);
3106
+ }
3107
+
3108
+ /* sort in ascending order */
3109
+ std::sort(x, x + n);
3110
+ xmin = x[0]; xmax = x[n-1];
3111
+ if (x[0] == x[n-1]) return -HUGE_VAL;
3112
+
3113
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3114
+ gain = find_split_rel_gain<double, ldouble_safe>(x, n, split_point);
3115
+ else if (criterion == Pooled || criterion == Averaged)
3116
+ gain = find_split_std_gain<double, ldouble_safe>(x, n, buffer_sd, criterion, min_gain, split_point);
3117
+ else if (criterion == DensityCrit)
3118
+ gain = find_split_dens_shortform<double, ldouble_safe>(x, n, split_point);
3119
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3120
+ return std::fmax(0., gain);
3121
+ }
3122
+
3123
+ template <class ldouble_safe>
3124
+ double eval_guided_crit_weighted(double *restrict x, size_t n, GainCriterion criterion,
3125
+ double min_gain, bool as_relative_gain, double *restrict buffer_sd,
3126
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3127
+ double *restrict w, size_t *restrict buffer_indices,
3128
+ size_t *restrict ix_arr_plus_st,
3129
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3130
+ double *restrict X_row_major, size_t ncols,
3131
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3132
+ {
3133
+ /* Note: the input 'x' is supposed to be a linear combination of standardized variables, so
3134
+ all numbers are assumed to be small and in the same scale */
3135
+ double gain = 0;
3136
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3137
+
3138
+ /* here it's assumed the 'x' vector matches exactly with 'ix_arr' + 'st' */
3139
+ if (unlikely(n == 2))
3140
+ {
3141
+ if (x[0] == x[1]) return -HUGE_VAL;
3142
+ split_point = midpoint_with_reorder(x[0], x[1]);
3143
+ gain = 1.;
3144
+ if (gain > min_gain)
3145
+ return gain;
3146
+ else
3147
+ return 0.;
3148
+ }
3149
+
3150
+ /* sort in ascending order */
3151
+ std::iota(buffer_indices, buffer_indices + n, (size_t)0);
3152
+ std::sort(buffer_indices, buffer_indices + n,
3153
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3154
+ xmin = x[buffer_indices[0]]; xmax = x[buffer_indices[n-1]];
3155
+ if (xmin == xmax) return -HUGE_VAL;
3156
+
3157
+ if (criterion == Pooled || criterion == Averaged)
3158
+ gain = find_split_std_gain_weighted<double, ldouble_safe>(x, n, buffer_sd, criterion, min_gain, split_point, w, buffer_indices);
3159
+ else if (criterion == DensityCrit)
3160
+ gain = find_split_dens_shortform_weighted<double, double *restrict, ldouble_safe>(x, n, split_point, w, buffer_indices);
3161
+ else if (criterion == FullGain)
3162
+ {
3163
+ std::vector<size_t> argsorted(n);
3164
+ std::iota(argsorted.begin(), argsorted.end(), (size_t)0);
3165
+ std::sort(argsorted.begin(), argsorted.end(),
3166
+ [&x](const size_t a, const size_t b){return x[a] < x[b];});
3167
+ if (x[argsorted[0]] == x[argsorted[n-1]]) return -HUGE_VAL;
3168
+ std::vector<double> temp_buffer(n + mult2(ncols));
3169
+ for (size_t ix = 0; ix < n; ix++) temp_buffer[ix] = x[argsorted[ix]];
3170
+ for (size_t ix = 0; ix < n; ix++)
3171
+ argsorted[ix] = ix_arr_plus_st[argsorted[ix]];
3172
+ size_t ignored;
3173
+ gain = find_split_full_gain_weighted<double, double *restrict, ldouble_safe>(
3174
+ temp_buffer.data(), (size_t)0, n-1, argsorted.data(),
3175
+ cols_use, ncols_use, force_cols_use,
3176
+ X_row_major, ncols,
3177
+ Xr, Xr_ind, Xr_indptr,
3178
+ temp_buffer.data() + n, temp_buffer.data() + n + ncols,
3179
+ ignored, split_point,
3180
+ false,
3181
+ w);
3182
+ }
3183
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3184
+ return std::fmax(0., gain);
3185
+ }
3186
+
3187
+ /* for split-criterion in single-variable splits */
3188
+ template <class real_t_, class ldouble_safe>
3189
+ double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, real_t_ *restrict x,
3190
+ double *restrict buffer_sd, bool as_relative_gain,
3191
+ double *restrict buffer_imputed_x, double *restrict saved_xmedian,
3192
+ size_t &split_ix, double &restrict split_point, double &restrict xmin, double &restrict xmax,
3193
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3194
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3195
+ double *restrict X_row_major, size_t ncols,
3196
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3197
+ {
3198
+ size_t st_orig = st;
3199
+ double gain = 0;
3200
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3201
+
3202
+ /* move NAs to the front if there's any, exclude them from calculations */
3203
+ if (missing_action != Fail)
3204
+ st = move_NAs_to_front(ix_arr, st, end, x);
3205
+
3206
+ if (unlikely(st >= end)) return -HUGE_VAL;
3207
+ else if (unlikely(st == (end-1)))
3208
+ {
3209
+ if (x[ix_arr[st]] == x[ix_arr[end]])
3210
+ return -HUGE_VAL;
3211
+ split_point = midpoint_with_reorder(x[ix_arr[st]], x[ix_arr[end]]);
3212
+ split_ix = st;
3213
+ gain = 1.;
3214
+ if (gain > min_gain)
3215
+ return gain;
3216
+ else
3217
+ return 0.;
3218
+ }
3219
+
3220
+ /* sort in ascending order */
3221
+ std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
3222
+ if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
3223
+ xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
3224
+
3225
+ /* unlike the previous case for the extended model, the data here has not been centered,
3226
+ which could make the standard deviations have poor precision. It's nevertheless not
3227
+ necessary for this mean to have good precision, since it's only meant for centering,
3228
+ so it can be calculated inexactly with simd instructions. */
3229
+ real_t_ xmean = 0;
3230
+ if (criterion == Pooled || criterion == Averaged)
3231
+ {
3232
+ for (size_t ix = st; ix <= end; ix++)
3233
+ xmean += x[ix_arr[ix]];
3234
+ xmean /= (real_t_)(end - st + 1);
3235
+ }
3236
+
3237
+ if (missing_action == Impute && st > st_orig)
3238
+ {
3239
+ missing_action = Fail;
3240
+ fill_NAs_with_median(ix_arr, st_orig, st, end, x, buffer_imputed_x, saved_xmedian);
3241
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3242
+ gain = find_split_rel_gain<double, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, split_point, split_ix);
3243
+ else if (criterion == Pooled || criterion == Averaged)
3244
+ gain = find_split_std_gain<double, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, buffer_sd, criterion, min_gain, split_point, split_ix);
3245
+ else if (criterion == DensityCrit)
3246
+ gain = find_split_dens<double, ldouble_safe>(buffer_imputed_x, ix_arr, st_orig, end, split_point, split_ix);
3247
+ else if (criterion == FullGain)
3248
+ {
3249
+ /* TODO: this buffer should be allocated from outside */
3250
+ std::vector<double> temp_buffer(mult2(ncols));
3251
+ gain = find_split_full_gain<double, ldouble_safe>(
3252
+ buffer_imputed_x, st_orig, end, ix_arr,
3253
+ cols_use, ncols_use, force_cols_use,
3254
+ X_row_major, ncols,
3255
+ Xr, Xr_ind, Xr_indptr,
3256
+ temp_buffer.data(), temp_buffer.data() + ncols,
3257
+ split_ix, split_point, true);
3258
+ }
3259
+
3260
+ /* Note: in theory, it should be possible to use a faster version assuming a contiguous array for 'x',
3261
+ but such an approach might give inexact split points. Better to avoid such inexactness at the
3262
+ expense of more computations. */
3263
+ }
3264
+
3265
+ else
3266
+ {
3267
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3268
+ gain = find_split_rel_gain<real_t_, ldouble_safe>(x, xmean, ix_arr, st, end, split_point, split_ix);
3269
+ else if (criterion == Pooled || criterion == Averaged)
3270
+ gain = find_split_std_gain<real_t_, ldouble_safe>(x, xmean, ix_arr, st, end, buffer_sd, criterion, min_gain, split_point, split_ix);
3271
+ else if (criterion == DensityCrit)
3272
+ gain = find_split_dens<real_t_, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix);
3273
+ else if (criterion == FullGain)
3274
+ {
3275
+ /* TODO: this buffer should be allocated from outside */
3276
+ std::vector<double> temp_buffer(mult2(ncols));
3277
+ gain = find_split_full_gain<real_t_, ldouble_safe>(
3278
+ x, st, end, ix_arr,
3279
+ cols_use, ncols_use, force_cols_use,
3280
+ X_row_major, ncols,
3281
+ Xr, Xr_ind, Xr_indptr,
3282
+ temp_buffer.data(), temp_buffer.data() + ncols,
3283
+ split_ix, split_point, true);
3284
+ }
3285
+ }
3286
+
3287
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3288
+ return std::fmax(0., gain);
3289
+ }
3290
+
3291
+ template <class real_t_, class mapping, class ldouble_safe>
3292
+ double eval_guided_crit_weighted(size_t *restrict ix_arr, size_t st, size_t end, real_t_ *restrict x,
3293
+ double *restrict buffer_sd, bool as_relative_gain,
3294
+ double *restrict buffer_imputed_x, double *restrict saved_xmedian,
3295
+ size_t &split_ix, double &restrict split_point, double &restrict xmin, double &restrict xmax,
3296
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3297
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3298
+ double *restrict X_row_major, size_t ncols,
3299
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
3300
+ mapping &restrict w)
3301
+ {
3302
+ size_t st_orig = st;
3303
+ double gain = 0;
3304
+ if (criterion == DensityCrit || criterion == FullGain) min_gain = 0;
3305
+
3306
+ /* move NAs to the front if there's any, exclude them from calculations */
3307
+ if (missing_action != Fail)
3308
+ st = move_NAs_to_front(ix_arr, st, end, x);
3309
+
3310
+ if (unlikely(st >= end)) return -HUGE_VAL;
3311
+ else if (unlikely(st == (end-1)))
3312
+ {
3313
+ if (x[ix_arr[st]] == x[ix_arr[end]])
3314
+ return -HUGE_VAL;
3315
+ split_point = midpoint_with_reorder(x[ix_arr[st]], x[ix_arr[end]]);
3316
+ split_ix = st;
3317
+ gain = 1.;
3318
+ if (gain > min_gain)
3319
+ return gain;
3320
+ else
3321
+ return 0.;
3322
+ }
3323
+
3324
+ /* sort in ascending order */
3325
+ std::sort(ix_arr + st, ix_arr + end + 1, [&x](const size_t a, const size_t b){return x[a] < x[b];});
3326
+ if (x[ix_arr[st]] == x[ix_arr[end]]) return -HUGE_VAL;
3327
+ xmin = x[ix_arr[st]]; xmax = x[ix_arr[end]];
3328
+
3329
+ /* unlike the previous case for the extended model, the data here has not been centered,
3330
+ which could make the standard deviations have poor precision. It's nevertheless not
3331
+ necessary for this mean to have good precision, since it's only meant for centering,
3332
+ so it can be calculated inexactly with simd instructions. */
3333
+ real_t_ xmean = 0;
3334
+ real_t_ cnt = 0;
3335
+ if (criterion == Pooled || criterion == Averaged)
3336
+ {
3337
+ for (size_t ix = st; ix <= end; ix++)
3338
+ {
3339
+ xmean += x[ix_arr[ix]];
3340
+ cnt += w[ix_arr[ix]];
3341
+ }
3342
+ xmean /= cnt;
3343
+ }
3344
+
3345
+ if (missing_action == Impute && st > st_orig)
3346
+ {
3347
+ missing_action = Fail;
3348
+ fill_NAs_with_median(ix_arr, st_orig, st, end, x, buffer_imputed_x, saved_xmedian);
3349
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3350
+ gain = find_split_rel_gain_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, split_point, split_ix, w);
3351
+ else if (criterion == Pooled || criterion == Averaged)
3352
+ gain = find_split_std_gain_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, (double)xmean, ix_arr, st_orig, end, buffer_sd, criterion, min_gain, split_point, split_ix, w);
3353
+ else if (criterion == DensityCrit)
3354
+ gain = find_split_dens_weighted<double, mapping, ldouble_safe>(buffer_imputed_x, ix_arr, st_orig, end, split_point, split_ix, w);
3355
+ else if (criterion == FullGain)
3356
+ {
3357
+ std::vector<double> temp_buffer(mult2(ncols));
3358
+ gain = find_split_full_gain_weighted<double, mapping, ldouble_safe>(
3359
+ buffer_imputed_x, st_orig, end, ix_arr,
3360
+ cols_use, ncols_use, force_cols_use,
3361
+ X_row_major, ncols,
3362
+ Xr, Xr_ind, Xr_indptr,
3363
+ temp_buffer.data(), temp_buffer.data() + ncols,
3364
+ split_ix, split_point, true,
3365
+ w);
3366
+ }
3367
+ }
3368
+
3369
+ else
3370
+ {
3371
+ if (criterion == Pooled && as_relative_gain && min_gain <= 0)
3372
+ gain = find_split_rel_gain_weighted<real_t_, mapping, ldouble_safe>(x, xmean, ix_arr, st, end, split_point, split_ix, w);
3373
+ else if (criterion == Pooled || criterion == Averaged)
3374
+ gain = find_split_std_gain_weighted<real_t_, mapping, ldouble_safe>(x, xmean, ix_arr, st, end, buffer_sd, criterion, min_gain, split_point, split_ix, w);
3375
+ else if (criterion == DensityCrit)
3376
+ gain = find_split_dens_weighted<real_t_, mapping, ldouble_safe>(x, ix_arr, st, end, split_point, split_ix, w);
3377
+ else if (criterion == FullGain)
3378
+ {
3379
+ std::vector<double> temp_buffer(mult2(ncols));
3380
+ gain = find_split_full_gain_weighted<real_t_, mapping, ldouble_safe>(
3381
+ x, st, end, ix_arr,
3382
+ cols_use, ncols_use, force_cols_use,
3383
+ X_row_major, ncols,
3384
+ Xr, Xr_ind, Xr_indptr,
3385
+ temp_buffer.data(), temp_buffer.data() + ncols,
3386
+ split_ix, split_point, true,
3387
+ w);
3388
+ }
3389
+ }
3390
+
3391
+ /* Note: a gain of -Inf signals that the data is unsplittable. Zero signals it's below the minimum. */
3392
+ return std::fmax(0., gain);
3393
+ }
3394
+
3395
+ /* TODO: here it should only need to look at the non-zero entries. It can then use the
3396
+ same algorithm as above, but putting an extra check to see where do the zeros fit in
3397
+ the sorted order of the non-zero entries while calculating gains and SDs, and then
3398
+ call the 'divide_subset' function after-the-fact to reach the same end result.
3399
+ It should be much faster than this if the non-zero entries are few. */
3400
+ template <class real_t_, class sparse_ix, class ldouble_safe>
3401
+ double eval_guided_crit(size_t ix_arr[], size_t st, size_t end,
3402
+ size_t col_num, real_t_ Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
3403
+ double buffer_arr[], size_t buffer_pos[], bool as_relative_gain,
3404
+ double *restrict saved_xmedian,
3405
+ double &split_point, double &xmin, double &xmax,
3406
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3407
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3408
+ double *restrict X_row_major, size_t ncols,
3409
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr)
3410
+ {
3411
+ size_t ignored;
3412
+
3413
+
3414
+ todense(ix_arr, st, end,
3415
+ col_num, Xc, Xc_ind, Xc_indptr,
3416
+ buffer_arr);
3417
+ size_t tot = end - st + 1;
3418
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3419
+
3420
+ if (missing_action == Impute)
3421
+ {
3422
+ missing_action = Fail;
3423
+ for (size_t ix = 0; ix < tot; ix++)
3424
+ {
3425
+ if (unlikely(is_na_or_inf(buffer_arr[ix])))
3426
+ {
3427
+ goto fill_missing;
3428
+ }
3429
+ }
3430
+ goto no_nas;
3431
+
3432
+ fill_missing:
3433
+ {
3434
+ size_t idx_half = div2(tot);
3435
+ std::nth_element(buffer_pos, buffer_pos + idx_half, buffer_pos + tot,
3436
+ [&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
3437
+ *saved_xmedian = buffer_arr[buffer_pos[idx_half]];
3438
+
3439
+ if ((tot % 2) == 0)
3440
+ {
3441
+ double xlow = *std::max_element(buffer_pos, buffer_pos + idx_half);
3442
+ *saved_xmedian = xlow + ((*saved_xmedian)-xlow)/2.;
3443
+ }
3444
+
3445
+ for (size_t ix = 0; ix < tot; ix++)
3446
+ buffer_arr[ix] = is_na_or_inf(buffer_arr[ix])? (*saved_xmedian) : buffer_arr[ix];
3447
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3448
+ }
3449
+ }
3450
+
3451
+ no_nas:
3452
+ return eval_guided_crit<double, ldouble_safe>(
3453
+ buffer_pos, 0, end - st, buffer_arr, buffer_arr + tot,
3454
+ as_relative_gain, saved_xmedian, (double*)NULL, ignored, split_point,
3455
+ xmin, xmax, criterion, min_gain, missing_action,
3456
+ cols_use, ncols_use, force_cols_use,
3457
+ X_row_major, ncols,
3458
+ Xr, Xr_ind, Xr_indptr);
3459
+ }
3460
+
3461
+ template <class real_t_, class sparse_ix, class mapping, class ldouble_safe>
3462
+ double eval_guided_crit_weighted(size_t ix_arr[], size_t st, size_t end,
3463
+ size_t col_num, real_t_ Xc[], sparse_ix Xc_ind[], sparse_ix Xc_indptr[],
3464
+ double buffer_arr[], size_t buffer_pos[], bool as_relative_gain,
3465
+ double *restrict saved_xmedian,
3466
+ double &restrict split_point, double &restrict xmin, double &restrict xmax,
3467
+ GainCriterion criterion, double min_gain, MissingAction missing_action,
3468
+ size_t *restrict cols_use, size_t ncols_use, bool force_cols_use,
3469
+ double *restrict X_row_major, size_t ncols,
3470
+ double *restrict Xr, size_t *restrict Xr_ind, size_t *restrict Xr_indptr,
3471
+ mapping &restrict w)
3472
+ {
3473
+ size_t ignored;
3474
+
3475
+
3476
+ todense(ix_arr, st, end,
3477
+ col_num, Xc, Xc_ind, Xc_indptr,
3478
+ buffer_arr);
3479
+ size_t tot = end - st + 1;
3480
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3481
+
3482
+
3483
+ if (missing_action == Impute)
3484
+ {
3485
+ missing_action = Fail;
3486
+ for (size_t ix = 0; ix < tot; ix++)
3487
+ {
3488
+ if (unlikely(is_na_or_inf(buffer_arr[ix])))
3489
+ {
3490
+ goto fill_missing;
3491
+ }
3492
+ }
3493
+ goto no_nas;
3494
+
3495
+ fill_missing:
3496
+ {
3497
+ size_t idx_half = div2(tot);
3498
+ std::nth_element(buffer_pos, buffer_pos + idx_half, buffer_pos + tot,
3499
+ [&buffer_arr](const size_t a, const size_t b){return buffer_arr[a] < buffer_arr[b];});
3500
+ *saved_xmedian = buffer_arr[buffer_pos[idx_half]];
3501
+
3502
+ if ((tot % 2) == 0)
3503
+ {
3504
+ double xlow = *std::max_element(buffer_pos, buffer_pos + idx_half);
3505
+ *saved_xmedian = xlow + ((*saved_xmedian)-xlow)/2.;
3506
+ }
3507
+
3508
+ for (size_t ix = 0; ix < tot; ix++)
3509
+ buffer_arr[ix] = is_na_or_inf(buffer_arr[ix])? (*saved_xmedian) : buffer_arr[ix];
3510
+ std::iota(buffer_pos, buffer_pos + tot, (size_t)0);
3511
+ }
3512
+ }
3513
+
3514
+
3515
+ no_nas:
3516
+ /* TODO: allocate this buffer externally */
3517
+ std::vector<double> buffer_w(tot);
3518
+ for (size_t row = st; row <= end; row++)
3519
+ buffer_w[row-st] = w[ix_arr[row]];
3520
+ /* TODO: in this case, as the weights match with the order of the indices, could use a faster version
3521
+ with a weighted rel_gain function instead (not yet implemented). */
3522
+ return eval_guided_crit_weighted<double, std::vector<double>, ldouble_safe>(
3523
+ buffer_pos, 0, end - st, buffer_arr, buffer_arr + tot,
3524
+ as_relative_gain, saved_xmedian, (double*)NULL, ignored, split_point,
3525
+ xmin, xmax, criterion, min_gain, missing_action,
3526
+ cols_use, ncols_use, force_cols_use,
3527
+ X_row_major, ncols,
3528
+ Xr, Xr_ind, Xr_indptr,
3529
+ buffer_w);
3530
+ }
3531
+
3532
+ /* How this works:
3533
+ - For Averaged criterion, will take the expected standard deviation that would be gotten with the category counts
3534
+ if each category got assigned a real number at random ~ Unif(0,1) and the data were thus converted to
3535
+ numerical. In such case, the best split (highest sd gain) is always putting the second-highest count in one
3536
+ branch, so there is no point in doing a full search over other permutations. In order to get more reasonable
3537
+ splits, when using the option to split by subsets of categories, it will sort the counts and evaluate only
3538
+ splits in which the categories are grouped in sorted order - in such cases it tends to pick either the
3539
+ smallest or the largest category to assign to one branch, but sometimes picks groups too.
3540
+ - For Pooled criterion, will take shannon entropy, which tends to make a more even split. In the case of splitting
3541
+ by a single category, it always puts the largest category in a separate branch. In the case of subsets,
3542
+ it can either evaluate possible splits over all permutations (not feasible if there are too many categories),
3543
+ or look up for splits in sorted order just like for Averaged criterion.
3544
+ Splitting by averaged Gini gain (like with Averaged) also selects always the second-largest category to put in one branch,
3545
+ while splitting by weighted Gini (like with Pooled) usually selects the largest category to put in one branch. The
3546
+ Gini gain is not easily comparable to that of numerical columns, so it's not offered as an option here.
3547
+ */
3548
+ /* https://math.stackexchange.com/questions/3343384/expected-variance-and-kurtosis-from-pmf-in-which-possible-discrete-values-are-dr */
3549
+ /* TODO: 'buffer_pos' doesn't need to be 'size_t', 'int' would suffice */
3550
+ template <class ldouble_safe>
3551
+ double eval_guided_crit(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
3552
+ int *restrict saved_cat_mode,
3553
+ size_t *restrict buffer_cnt, size_t *restrict buffer_pos, double *restrict buffer_prob,
3554
+ int &restrict chosen_cat, signed char *restrict split_categ, signed char *restrict buffer_split,
3555
+ GainCriterion criterion, double min_gain, bool all_perm,
3556
+ MissingAction missing_action, CategSplit cat_split_type)
3557
+ {
3558
+ if (criterion == DensityCrit)
3559
+ return find_split_dens_longform<size_t, ldouble_safe>(
3560
+ x, ncat, ix_arr, st, end,
3561
+ cat_split_type, missing_action,
3562
+ chosen_cat, split_categ, saved_cat_mode,
3563
+ buffer_cnt, buffer_pos);
3564
+ if (st >= end) return -HUGE_VAL;
3565
+ size_t n_nas = 0;
3566
+ int xval;
3567
+
3568
+ /* count categories */
3569
+ memset(buffer_cnt, 0, sizeof(size_t) * ncat);
3570
+ if (missing_action == Fail)
3571
+ {
3572
+ for (size_t row = st; row <= end; row++)
3573
+ if (likely(x[ix_arr[row]] >= 0))
3574
+ buffer_cnt[x[ix_arr[row]]]++;
3575
+ }
3576
+
3577
+ else if (missing_action == Impute)
3578
+ {
3579
+ for (size_t row = st; row <= end; row++)
3580
+ {
3581
+ xval = x[ix_arr[row]];
3582
+ if (unlikely(xval < 0))
3583
+ n_nas++;
3584
+ else
3585
+ buffer_cnt[xval]++;
3586
+ }
3587
+
3588
+ if (unlikely(n_nas >= end-st)) return -HUGE_VAL;
3589
+
3590
+ if (n_nas)
3591
+ {
3592
+ auto idxmax = std::max_element(buffer_cnt, buffer_cnt + ncat);
3593
+ *idxmax += n_nas;
3594
+ *saved_cat_mode = (int)std::distance(buffer_cnt, idxmax);
3595
+ }
3596
+ }
3597
+
3598
+ else
3599
+ {
3600
+ for (size_t row = st; row <= end; row++)
3601
+ {
3602
+ xval = x[ix_arr[row]];
3603
+ if (likely(xval >= 0)) buffer_cnt[xval]++;
3604
+ }
3605
+ }
3606
+
3607
+ double this_gain = -HUGE_VAL;
3608
+ double best_gain = -HUGE_VAL;
3609
+ std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
3610
+ size_t st_pos = 0;
3611
+
3612
+ switch (cat_split_type)
3613
+ {
3614
+ case SingleCateg:
3615
+ {
3616
+ size_t cnt = end - st + 1;
3617
+ ldouble_safe cnt_l = (ldouble_safe) cnt;
3618
+ size_t ncat_present = 0;
3619
+
3620
+ switch(criterion)
3621
+ {
3622
+ case Averaged:
3623
+ {
3624
+ /* move zero-counts to the beginning */
3625
+ size_t temp;
3626
+ for (int cat = 0; cat < ncat; cat++)
3627
+ {
3628
+ if (buffer_cnt[cat])
3629
+ {
3630
+ ncat_present++;
3631
+ buffer_prob[cat] = (ldouble_safe) buffer_cnt[cat] / cnt_l;
3632
+ }
3633
+
3634
+ else
3635
+ {
3636
+ temp = buffer_pos[st_pos];
3637
+ buffer_pos[st_pos] = buffer_pos[cat];
3638
+ buffer_pos[cat] = temp;
3639
+ st_pos++;
3640
+ }
3641
+ }
3642
+
3643
+ if (ncat_present <= 1) return -HUGE_VAL;
3644
+
3645
+ double sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3646
+
3647
+ /* try isolating each category one at a time */
3648
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3649
+ {
3650
+ this_gain = sd_gain(sd_full,
3651
+ 0.0,
3652
+ (expected_sd_cat_single<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt)));
3653
+ if (this_gain > min_gain && this_gain > best_gain)
3654
+ {
3655
+ best_gain = this_gain;
3656
+ chosen_cat = buffer_pos[pos];
3657
+ }
3658
+ }
3659
+ break;
3660
+ }
3661
+
3662
+ case Pooled:
3663
+ {
3664
+ /* here it will always pick the largest one */
3665
+ size_t ncat_present = 0;
3666
+ size_t cnt_max = 0;
3667
+ for (int cat = 0; cat < ncat; cat++)
3668
+ {
3669
+ if (buffer_cnt[cat])
3670
+ {
3671
+ ncat_present++;
3672
+ if (cnt_max < buffer_cnt[cat])
3673
+ {
3674
+ cnt_max = buffer_cnt[cat];
3675
+ chosen_cat = cat;
3676
+ }
3677
+ }
3678
+ }
3679
+
3680
+ if (ncat_present <= 1) return -HUGE_VAL;
3681
+
3682
+ ldouble_safe cnt_left = (ldouble_safe)((end - st + 1) - cnt_max);
3683
+ this_gain = (
3684
+ (ldouble_safe)cnt * std::log((ldouble_safe)cnt)
3685
+ - cnt_left * std::log(cnt_left)
3686
+ - (ldouble_safe)cnt_max * std::log((ldouble_safe)cnt_max)
3687
+ ) / cnt;
3688
+ best_gain = (this_gain > min_gain)? this_gain : best_gain;
3689
+ break;
3690
+ }
3691
+
3692
+ default:
3693
+ {
3694
+ unexpected_error();
3695
+ break;
3696
+ }
3697
+ }
3698
+ break;
3699
+ }
3700
+
3701
+ case SubSet:
3702
+ {
3703
+ /* sort by counts */
3704
+ std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
3705
+
3706
+ /* set split as: (1):left (0):right (-1):not_present */
3707
+ memset(buffer_split, 0, ncat * sizeof(signed char));
3708
+
3709
+ ldouble_safe cnt = (ldouble_safe)(end - st + 1);
3710
+
3711
+ switch(criterion)
3712
+ {
3713
+ case Averaged:
3714
+ {
3715
+ /* determine first non-zero and convert to probabilities */
3716
+ double sd_full;
3717
+ for (int cat = 0; cat < ncat; cat++)
3718
+ {
3719
+ if (buffer_cnt[buffer_pos[cat]])
3720
+ {
3721
+ buffer_prob[buffer_pos[cat]] = (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt;
3722
+ }
3723
+
3724
+ else
3725
+ {
3726
+ buffer_split[buffer_pos[cat]] = -1;
3727
+ st_pos++;
3728
+ }
3729
+ }
3730
+
3731
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
3732
+
3733
+ /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
3734
+ size_t ncat_present = (size_t)ncat - st_pos;
3735
+ sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3736
+ if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
3737
+
3738
+ /* move categories one at a time */
3739
+ for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
3740
+ {
3741
+ buffer_split[buffer_pos[pos]] = 1;
3742
+ this_gain = sd_gain(sd_full,
3743
+ (expected_sd_cat<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, pos - st_pos + 1, buffer_pos + st_pos)),
3744
+ (expected_sd_cat<size_t, size_t, ldouble_safe>(buffer_cnt, buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1))
3745
+ );
3746
+ if (this_gain > min_gain && this_gain > best_gain)
3747
+ {
3748
+ best_gain = this_gain;
3749
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
3750
+ }
3751
+ }
3752
+
3753
+ break;
3754
+ }
3755
+
3756
+ case Pooled:
3757
+ {
3758
+ ldouble_safe s = 0;
3759
+
3760
+ /* determine first non-zero and get base info */
3761
+ for (int cat = 0; cat < ncat; cat++)
3762
+ {
3763
+ if (buffer_cnt[buffer_pos[cat]])
3764
+ {
3765
+ s += (buffer_cnt[buffer_pos[cat]] <= 1)?
3766
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[cat]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[cat]]));
3767
+ }
3768
+
3769
+ else
3770
+ {
3771
+ buffer_split[buffer_pos[cat]] = -1;
3772
+ st_pos++;
3773
+ }
3774
+ }
3775
+
3776
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
3777
+
3778
+ /* calculate base info */
3779
+ ldouble_safe base_info = cnt * std::log(cnt) - s;
3780
+
3781
+ if (all_perm)
3782
+ {
3783
+ size_t cnt_left, cnt_right;
3784
+ double s_left, s_right;
3785
+ size_t ncat_present = (size_t)ncat - st_pos;
3786
+ size_t ncomb = pow2(ncat_present) - 1;
3787
+ size_t best_combin;
3788
+
3789
+ for (size_t combin = 1; combin < ncomb; combin++)
3790
+ {
3791
+ cnt_left = 0; cnt_right = 0;
3792
+ s_left = 0; s_right = 0;
3793
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3794
+ {
3795
+ if (extract_bit(combin, pos))
3796
+ {
3797
+ cnt_left += buffer_cnt[buffer_pos[pos]];
3798
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
3799
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
3800
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
3801
+ }
3802
+
3803
+ else
3804
+ {
3805
+ cnt_right += buffer_cnt[buffer_pos[pos]];
3806
+ s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
3807
+ 0 : ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
3808
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
3809
+ }
3810
+ }
3811
+
3812
+ this_gain = categ_gain<size_t, ldouble_safe>(
3813
+ cnt_left, cnt_right,
3814
+ s_left, s_right,
3815
+ base_info, cnt);
3816
+
3817
+ if (this_gain > min_gain && this_gain > best_gain)
3818
+ {
3819
+ best_gain = this_gain;
3820
+ best_combin = combin;
3821
+ }
3822
+
3823
+ }
3824
+
3825
+ if (best_gain > min_gain)
3826
+ for (size_t pos = 0; pos < ncat_present; pos++)
3827
+ split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
3828
+
3829
+ }
3830
+
3831
+ else
3832
+ {
3833
+ /* try moving the categories one at a time */
3834
+ size_t cnt_left = 0;
3835
+ size_t cnt_right = end - st + 1;
3836
+ double s_left = 0;
3837
+ double s_right = s;
3838
+
3839
+ for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
3840
+ {
3841
+ buffer_split[buffer_pos[pos]] = 1;
3842
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
3843
+ 0 : ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[pos]]));
3844
+ s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
3845
+ 0 : ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[pos]]));
3846
+ cnt_left += buffer_cnt[buffer_pos[pos]];
3847
+ cnt_right -= buffer_cnt[buffer_pos[pos]];
3848
+
3849
+ this_gain = categ_gain<size_t, ldouble_safe>(
3850
+ cnt_left, cnt_right,
3851
+ s_left, s_right,
3852
+ base_info, cnt);
3853
+
3854
+ if (this_gain > min_gain && this_gain > best_gain)
3855
+ {
3856
+ best_gain = this_gain;
3857
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
3858
+ }
3859
+ }
3860
+ }
3861
+
3862
+ break;
3863
+ }
3864
+
3865
+ default:
3866
+ {
3867
+ unexpected_error();
3868
+ break;
3869
+ }
3870
+ }
3871
+ }
3872
+ }
3873
+
3874
+ if (st == (end-1)) return 0;
3875
+
3876
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
3877
+ return 0;
3878
+ else
3879
+ return best_gain;
3880
+ }
3881
+
3882
+
3883
+ template <class mapping, class ldouble_safe>
3884
+ double eval_guided_crit_weighted(size_t *restrict ix_arr, size_t st, size_t end, int *restrict x, int ncat,
3885
+ int *restrict saved_cat_mode,
3886
+ size_t *restrict buffer_pos, double *restrict buffer_prob,
3887
+ int &restrict chosen_cat, signed char *restrict split_categ, signed char *restrict buffer_split,
3888
+ GainCriterion criterion, double min_gain, bool all_perm,
3889
+ MissingAction missing_action, CategSplit cat_split_type,
3890
+ mapping &restrict w)
3891
+ {
3892
+ if (criterion == DensityCrit)
3893
+ return find_split_dens_longform_weighted<mapping, size_t, ldouble_safe>(
3894
+ x, ncat, ix_arr, st, end,
3895
+ cat_split_type, missing_action,
3896
+ chosen_cat, split_categ, saved_cat_mode,
3897
+ buffer_pos, w);
3898
+ if (st >= end) return -HUGE_VAL;
3899
+ ldouble_safe w_missing = 0;
3900
+ int xval;
3901
+ size_t ix_;
3902
+
3903
+ /* count categories */
3904
+ /* TODO: allocate this buffer externally */
3905
+ std::vector<ldouble_safe> buffer_cnt(ncat, (ldouble_safe)0);
3906
+ if (missing_action == Fail)
3907
+ {
3908
+ for (size_t row = st; row <= end; row++)
3909
+ {
3910
+ ix_ = ix_arr[row];
3911
+ if (unlikely(x[ix_]) < 0) continue;
3912
+ buffer_cnt[x[ix_]] += w[ix_];
3913
+ }
3914
+ }
3915
+
3916
+ else if (missing_action == Impute)
3917
+ {
3918
+ for (size_t row = st; row <= end; row++)
3919
+ {
3920
+ ix_ = ix_arr[row];
3921
+ xval = x[ix_];
3922
+ if (unlikely(xval < 0))
3923
+ w_missing += w[ix_];
3924
+ else
3925
+ buffer_cnt[xval] += w[ix_];
3926
+ }
3927
+
3928
+ if (w_missing)
3929
+ {
3930
+ auto idxmax = std::max_element(buffer_cnt.begin(), buffer_cnt.end());
3931
+ *idxmax += w_missing;
3932
+ *saved_cat_mode = (int)std::distance(buffer_cnt.begin(), idxmax);
3933
+ }
3934
+ }
3935
+
3936
+ else
3937
+ {
3938
+ for (size_t row = st; row <= end; row++)
3939
+ {
3940
+ ix_ = ix_arr[row];
3941
+ xval = x[ix_];
3942
+ if (likely(xval >= 0)) buffer_cnt[xval] += w[ix_];
3943
+ }
3944
+ }
3945
+
3946
+ ldouble_safe cnt = std::accumulate(buffer_cnt.begin(), buffer_cnt.end(), (ldouble_safe)0);
3947
+
3948
+ double this_gain = -HUGE_VAL;
3949
+ double best_gain = -HUGE_VAL;
3950
+ std::iota(buffer_pos, buffer_pos + ncat, (size_t)0);
3951
+ size_t st_pos = 0;
3952
+
3953
+ switch(cat_split_type)
3954
+ {
3955
+ case SingleCateg:
3956
+ {
3957
+ size_t ncat_present = 0;
3958
+
3959
+ switch(criterion)
3960
+ {
3961
+ case Averaged:
3962
+ {
3963
+ /* move zero-counts to the beginning */
3964
+ size_t temp;
3965
+ for (int cat = 0; cat < ncat; cat++)
3966
+ {
3967
+ if (buffer_cnt[cat])
3968
+ {
3969
+ ncat_present++;
3970
+ buffer_prob[cat] = buffer_cnt[cat] / cnt;
3971
+ }
3972
+
3973
+ else
3974
+ {
3975
+ temp = buffer_pos[st_pos];
3976
+ buffer_pos[st_pos] = buffer_pos[cat];
3977
+ buffer_pos[cat] = temp;
3978
+ st_pos++;
3979
+ }
3980
+ }
3981
+
3982
+ if (ncat_present <= 1) return -HUGE_VAL;
3983
+
3984
+ double sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
3985
+
3986
+ /* try isolating each category one at a time */
3987
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
3988
+ {
3989
+ this_gain = sd_gain(sd_full,
3990
+ 0.0,
3991
+ (expected_sd_cat_single<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, ncat_present, buffer_pos + st_pos, pos - st_pos, cnt))
3992
+ );
3993
+ if (this_gain > min_gain && this_gain > best_gain)
3994
+ {
3995
+ best_gain = this_gain;
3996
+ chosen_cat = buffer_pos[pos];
3997
+ }
3998
+ }
3999
+ break;
4000
+ }
4001
+
4002
+ case Pooled:
4003
+ {
4004
+ /* here it will always pick the largest one */
4005
+ size_t ncat_present = 0;
4006
+ ldouble_safe cnt_max = 0;
4007
+ for (int cat = 0; cat < ncat; cat++)
4008
+ {
4009
+ if (buffer_cnt[cat])
4010
+ {
4011
+ ncat_present++;
4012
+ if (cnt_max < buffer_cnt[cat])
4013
+ {
4014
+ cnt_max = buffer_cnt[cat];
4015
+ chosen_cat = cat;
4016
+ }
4017
+ }
4018
+ }
4019
+
4020
+ if (ncat_present <= 1) return -HUGE_VAL;
4021
+
4022
+ ldouble_safe cnt_left = (ldouble_safe)(cnt - cnt_max);
4023
+
4024
+ /* TODO: think of a better way of dealing with numbers between zero and one */
4025
+ this_gain = (
4026
+ std::fmax((ldouble_safe)1, cnt) * std::log(std::fmax((ldouble_safe)1, cnt))
4027
+ - std::fmax((ldouble_safe)1, cnt_left) * std::log(std::fmax((ldouble_safe)1, cnt_left))
4028
+ - std::fmax((ldouble_safe)1, cnt_max) * std::log(std::fmax((ldouble_safe)1, cnt_max))
4029
+ ) / std::fmax((ldouble_safe)1, cnt);
4030
+ best_gain = (this_gain > min_gain)? this_gain : best_gain;
4031
+ break;
4032
+ }
4033
+
4034
+ default:
4035
+ {
4036
+ unexpected_error();
4037
+ break;
4038
+ }
4039
+ }
4040
+ break;
4041
+ }
4042
+
4043
+ case SubSet:
4044
+ {
4045
+ /* sort by counts */
4046
+ std::sort(buffer_pos, buffer_pos + ncat, [&buffer_cnt](const size_t a, const size_t b){return buffer_cnt[a] < buffer_cnt[b];});
4047
+
4048
+ /* set split as: (1):left (0):right (-1):not_present */
4049
+ memset(buffer_split, 0, ncat * sizeof(signed char));
4050
+
4051
+
4052
+ switch(criterion)
4053
+ {
4054
+ case Averaged:
4055
+ {
4056
+ /* determine first non-zero and convert to probabilities */
4057
+ double sd_full;
4058
+ for (int cat = 0; cat < ncat; cat++)
4059
+ {
4060
+ if (buffer_cnt[buffer_pos[cat]])
4061
+ {
4062
+ buffer_prob[buffer_pos[cat]] = (ldouble_safe)buffer_cnt[buffer_pos[cat]] / cnt;
4063
+ }
4064
+
4065
+ else
4066
+ {
4067
+ buffer_split[buffer_pos[cat]] = -1;
4068
+ st_pos++;
4069
+ }
4070
+ }
4071
+
4072
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
4073
+
4074
+ /* calculate full SD assuming they take values randomly ~Unif(0, 1) */
4075
+ size_t ncat_present = (size_t)ncat - st_pos;
4076
+ sd_full = expected_sd_cat<size_t, ldouble_safe>(buffer_prob, ncat_present, buffer_pos + st_pos);
4077
+ if (ncat_present >= log2ceil(SIZE_MAX)) all_perm = false;
4078
+
4079
+ /* move categories one at a time */
4080
+ for (size_t pos = st_pos; pos < ((size_t)ncat - st_pos - 1); pos++)
4081
+ {
4082
+ buffer_split[buffer_pos[pos]] = 1;
4083
+ /* TODO: is this correct? */
4084
+ this_gain = sd_gain(sd_full,
4085
+ (expected_sd_cat<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, pos - st_pos + 1, buffer_pos + st_pos)),
4086
+ (expected_sd_cat<ldouble_safe, size_t, ldouble_safe>(buffer_cnt.data(), buffer_prob, (size_t)ncat - pos - 1, buffer_pos + pos + 1))
4087
+ );
4088
+ if (this_gain > min_gain && this_gain > best_gain)
4089
+ {
4090
+ best_gain = this_gain;
4091
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
4092
+ }
4093
+ }
4094
+
4095
+ break;
4096
+ }
4097
+
4098
+ case Pooled:
4099
+ {
4100
+ ldouble_safe s = 0;
4101
+
4102
+ /* determine first non-zero and get base info */
4103
+ for (int cat = 0; cat < ncat; cat++)
4104
+ {
4105
+ if (buffer_cnt[buffer_pos[cat]])
4106
+ {
4107
+ s += (buffer_cnt[buffer_pos[cat]] <= 1)?
4108
+ (ldouble_safe)0
4109
+ :
4110
+ ((ldouble_safe) buffer_cnt[buffer_pos[cat]] * std::log((ldouble_safe)buffer_cnt[buffer_pos[cat]]));
4111
+ }
4112
+
4113
+ else
4114
+ {
4115
+ buffer_split[buffer_pos[cat]] = -1;
4116
+ st_pos++;
4117
+ }
4118
+ }
4119
+
4120
+ if ((int)st_pos >= (ncat-1)) return -HUGE_VAL;
4121
+
4122
+ /* calculate base info */
4123
+ ldouble_safe base_info = std::fmax((ldouble_safe)1, cnt) * std::log(std::fmax((ldouble_safe)1, cnt)) - s;
4124
+
4125
+ if (all_perm)
4126
+ {
4127
+ size_t cnt_left, cnt_right;
4128
+ double s_left, s_right;
4129
+ size_t ncat_present = (size_t)ncat - st_pos;
4130
+ size_t ncomb = pow2(ncat_present) - 1;
4131
+ size_t best_combin;
4132
+
4133
+ for (size_t combin = 1; combin < ncomb; combin++)
4134
+ {
4135
+ cnt_left = 0; cnt_right = 0;
4136
+ s_left = 0; s_right = 0;
4137
+ for (size_t pos = st_pos; (int)pos < ncat; pos++)
4138
+ {
4139
+ if (extract_bit(combin, pos))
4140
+ {
4141
+ cnt_left += buffer_cnt[buffer_pos[pos]];
4142
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
4143
+ (ldouble_safe)0
4144
+ :
4145
+ ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
4146
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4147
+ }
4148
+
4149
+ else
4150
+ {
4151
+ cnt_right += buffer_cnt[buffer_pos[pos]];
4152
+ s_right += (buffer_cnt[buffer_pos[pos]] <= 1)?
4153
+ (ldouble_safe)0
4154
+ :
4155
+ ((ldouble_safe) buffer_cnt[buffer_pos[pos]]
4156
+ * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4157
+ }
4158
+ }
4159
+
4160
+ this_gain = categ_gain<size_t, ldouble_safe>(
4161
+ cnt_left, cnt_right,
4162
+ s_left, s_right,
4163
+ base_info, cnt);
4164
+
4165
+ if (this_gain > min_gain && this_gain > best_gain)
4166
+ {
4167
+ best_gain = this_gain;
4168
+ best_combin = combin;
4169
+ }
4170
+
4171
+ }
4172
+
4173
+ if (best_gain > min_gain)
4174
+ for (size_t pos = 0; pos < ncat_present; pos++)
4175
+ split_categ[buffer_pos[st_pos + pos]] = extract_bit(best_combin, pos);
4176
+
4177
+ }
4178
+
4179
+ else
4180
+ {
4181
+ /* try moving the categories one at a time */
4182
+ size_t cnt_left = 0;
4183
+ size_t cnt_right = end - st + 1;
4184
+ double s_left = 0;
4185
+ double s_right = s;
4186
+
4187
+ for (size_t pos = st_pos; pos < (ncat - st_pos - 1); pos++)
4188
+ {
4189
+ buffer_split[buffer_pos[pos]] = 1;
4190
+ s_left += (buffer_cnt[buffer_pos[pos]] <= 1)?
4191
+ (ldouble_safe)0
4192
+ :
4193
+ ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4194
+ s_right -= (buffer_cnt[buffer_pos[pos]] <= 1)?
4195
+ (ldouble_safe)0
4196
+ :
4197
+ ((ldouble_safe)buffer_cnt[buffer_pos[pos]] * std::log((ldouble_safe) buffer_cnt[buffer_pos[pos]]));
4198
+ cnt_left += buffer_cnt[buffer_pos[pos]];
4199
+ cnt_right -= buffer_cnt[buffer_pos[pos]];
4200
+
4201
+ this_gain = categ_gain<size_t, ldouble_safe>(
4202
+ cnt_left, cnt_right,
4203
+ s_left, s_right,
4204
+ base_info, cnt);
4205
+
4206
+ if (this_gain > min_gain && this_gain > best_gain)
4207
+ {
4208
+ best_gain = this_gain;
4209
+ memcpy(split_categ, buffer_split, ncat * sizeof(signed char));
4210
+ }
4211
+ }
4212
+ }
4213
+
4214
+ break;
4215
+ }
4216
+
4217
+ default:
4218
+ {
4219
+ unexpected_error();
4220
+ break;
4221
+ }
4222
+ }
4223
+ }
4224
+ }
4225
+
4226
+ if (st == (end-1)) return 0;
4227
+
4228
+ if (best_gain <= -HUGE_VAL && this_gain <= min_gain && this_gain > -HUGE_VAL)
4229
+ return 0;
4230
+ else
4231
+ return best_gain;
4232
+ }